diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 6450b2ad878..c190b883a42 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -39,7 +39,7 @@ config_setting( config_setting( name = "android_armeabi", values = { - "cc_target_os": "android", + "crosstool_top": "//external:android/crosstool", "cpu": "armeabi", }, visibility = ["//visibility:public"], @@ -218,7 +218,9 @@ filegroup( "//tensorflow/compiler/jit/ops:all_files", "//tensorflow/compiler/tests:all_files", "//tensorflow/compiler/tf2xla:all_files", + "//tensorflow/compiler/tf2xla/cc:all_files", "//tensorflow/compiler/tf2xla/kernels:all_files", + "//tensorflow/compiler/tf2xla/ops:all_files", "//tensorflow/compiler/xla:all_files", "//tensorflow/compiler/xla/client:all_files", "//tensorflow/compiler/xla/client/lib:all_files", @@ -253,7 +255,7 @@ filegroup( "//tensorflow/contrib/data/python/kernel_tests:all_files", "//tensorflow/contrib/data/python/ops:all_files", "//tensorflow/contrib/data/python/util:all_files", - "//tensorflow/contrib/decision_trees:all_files", + "//tensorflow/contrib/decision_trees/proto:all_files", "//tensorflow/contrib/distributions:all_files", "//tensorflow/contrib/factorization:all_files", "//tensorflow/contrib/factorization/kernels:all_files", @@ -284,6 +286,8 @@ filegroup( "//tensorflow/contrib/ndlstm:all_files", "//tensorflow/contrib/nn:all_files", "//tensorflow/contrib/opt:all_files", + "//tensorflow/contrib/predictor:all_files", + "//tensorflow/contrib/remote_fused_graph/pylib:all_files", "//tensorflow/contrib/rnn:all_files", "//tensorflow/contrib/saved_model:all_files", "//tensorflow/contrib/saved_model/cc/saved_model:all_files", @@ -302,10 +306,13 @@ filegroup( "//tensorflow/contrib/stateless:all_files", "//tensorflow/contrib/tensor_forest:all_files", "//tensorflow/contrib/tensor_forest/hybrid:all_files", + "//tensorflow/contrib/tensor_forest/kernels/v4:all_files", + "//tensorflow/contrib/tensor_forest/proto:all_files", "//tensorflow/contrib/tensorboard:all_files", "//tensorflow/contrib/testing:all_files", "//tensorflow/contrib/text:all_files", "//tensorflow/contrib/tfprof/python/tools/tfprof:all_files", + "//tensorflow/contrib/tpu:all_files", "//tensorflow/contrib/training:all_files", "//tensorflow/contrib/util:all_files", "//tensorflow/contrib/verbs:all_files", @@ -353,70 +360,6 @@ filegroup( "//tensorflow/python/ops/distributions:all_files", "//tensorflow/python/saved_model:all_files", "//tensorflow/python/tools:all_files", - "//tensorflow/tensorboard:all_files", - "//tensorflow/tensorboard/backend:all_files", - "//tensorflow/tensorboard/backend/event_processing:all_files", - "//tensorflow/tensorboard/components:all_files", - "//tensorflow/tensorboard/components/tf_audio_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_audio_dashboard/test:all_files", - "//tensorflow/tensorboard/components/tf_backend:all_files", - "//tensorflow/tensorboard/components/tf_backend/test:all_files", - "//tensorflow/tensorboard/components/tf_color_scale:all_files", - "//tensorflow/tensorboard/components/tf_color_scale/test:all_files", - "//tensorflow/tensorboard/components/tf_dashboard_common:all_files", - "//tensorflow/tensorboard/components/tf_dashboard_common/test:all_files", - "//tensorflow/tensorboard/components/tf_distribution_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_globals:all_files", - "//tensorflow/tensorboard/components/tf_graph:all_files", - "//tensorflow/tensorboard/components/tf_graph/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_app:all_files", - "//tensorflow/tensorboard/components/tf_graph_app/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_board:all_files", - "//tensorflow/tensorboard/components/tf_graph_board/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_common:all_files", - "//tensorflow/tensorboard/components/tf_graph_controls:all_files", - "//tensorflow/tensorboard/components/tf_graph_controls/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_graph_dashboard/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_debugger_data_card:all_files", - "//tensorflow/tensorboard/components/tf_graph_debugger_data_card/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_info:all_files", - "//tensorflow/tensorboard/components/tf_graph_info/demo:all_files", - "//tensorflow/tensorboard/components/tf_graph_loader:all_files", - "//tensorflow/tensorboard/components/tf_graph_loader/demo:all_files", - "//tensorflow/tensorboard/components/tf_histogram_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_image_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_imports:all_files", - "//tensorflow/tensorboard/components/tf_option_selector:all_files", - "//tensorflow/tensorboard/components/tf_profile_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_profile_dashboard/demo:all_files", - "//tensorflow/tensorboard/components/tf_runs_selector:all_files", - "//tensorflow/tensorboard/components/tf_scalar_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_scalar_dashboard/demo:all_files", - "//tensorflow/tensorboard/components/tf_storage:all_files", - "//tensorflow/tensorboard/components/tf_storage/test:all_files", - "//tensorflow/tensorboard/components/tf_tensorboard:all_files", - "//tensorflow/tensorboard/components/tf_text_dashboard:all_files", - "//tensorflow/tensorboard/components/tf_trace_viewer:all_files", - "//tensorflow/tensorboard/components/vz_distribution_chart:all_files", - "//tensorflow/tensorboard/components/vz_histogram_timeseries:all_files", - "//tensorflow/tensorboard/components/vz_line_chart:all_files", - "//tensorflow/tensorboard/components/vz_projector:all_files", - "//tensorflow/tensorboard/components/vz_projector/test:all_files", - "//tensorflow/tensorboard/components/vz_sorting:all_files", - "//tensorflow/tensorboard/components/vz_sorting/test:all_files", - "//tensorflow/tensorboard/demo:all_files", - "//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:all_files", - "//tensorflow/tensorboard/plugins:all_files", - "//tensorflow/tensorboard/plugins/audio:all_files", - "//tensorflow/tensorboard/plugins/distributions:all_files", - "//tensorflow/tensorboard/plugins/graphs:all_files", - "//tensorflow/tensorboard/plugins/histograms:all_files", - "//tensorflow/tensorboard/plugins/images:all_files", - "//tensorflow/tensorboard/plugins/projector:all_files", - "//tensorflow/tensorboard/plugins/scalars:all_files", - "//tensorflow/tensorboard/plugins/text:all_files", - "//tensorflow/tensorboard/scripts:all_files", "//tensorflow/tools/api/golden:all_files", "//tensorflow/tools/api/lib:all_files", "//tensorflow/tools/api/tests:all_files", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 77faa475ed4..b1568c5f634 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -628,7 +628,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s, // Target nodes const char** c_target_oper_names, int ntargets, const char** handle, TF_Status* status) { - status->status = Status::OK(); + *handle = nullptr; std::vector input_names(ninputs); std::vector output_names(noutputs); @@ -643,16 +643,12 @@ void TF_PRunSetup(TF_DeprecatedSession* s, target_oper_names[i] = c_target_oper_names[i]; } tensorflow::string new_handle; - Status result; - result = s->session->PRunSetup(input_names, output_names, target_oper_names, - &new_handle); - if (result.ok()) { + status->status = s->session->PRunSetup(input_names, output_names, + target_oper_names, &new_handle); + if (status->status.ok()) { char* buf = new char[new_handle.size() + 1]; memcpy(buf, new_handle.c_str(), new_handle.size() + 1); *handle = buf; - } else { - *handle = nullptr; - status->status = result; } } @@ -2326,6 +2322,8 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, int ninputs, const TF_Output* outputs, int noutputs, const TF_Operation* const* target_opers, int ntargets, const char** handle, TF_Status* status) { + *handle = nullptr; + if (!ExtendSessionGraphHelper(session, status)) { return; } diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 15139a47acf..3aeafb46855 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1101,8 +1101,7 @@ TF_CAPI_EXPORT extern void TF_SessionRun( // needed. // // On failure, out_status contains a tensorflow::Status with an error -// message. -// NOTE: This is EXPERIMENTAL and subject to change. +// message. *handle is set to nullptr. TF_CAPI_EXPORT extern void TF_SessionPRunSetup( TF_Session*, // Input names @@ -1118,7 +1117,6 @@ TF_CAPI_EXPORT extern void TF_SessionPRunSetup( // Continue to run the graph with additional feeds and fetches. The // execution state is uniquely identified by the handle. -// NOTE: This is EXPERIMENTAL and subject to change. TF_CAPI_EXPORT extern void TF_SessionPRun( TF_Session*, const char* handle, // Input tensors diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index f89cc6384b3..9801add1dac 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -61,7 +61,6 @@ cc_library( ":gradients", ":ops", ":scope", - "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -274,10 +273,6 @@ cc_library( deps = [ ":cc_ops", ":grad_op_registry", - ":ops", - ":scope", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", ], ) @@ -305,10 +300,6 @@ cc_library( ":cc_ops", ":cc_ops_internal", ":grad_op_registry", - ":ops", - ":scope", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", ], ) @@ -527,7 +518,6 @@ cc_library( deps = [ ":coordinator", "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", @@ -560,8 +550,6 @@ cc_library( srcs = ["training/coordinator.cc"], hdrs = ["training/coordinator.h"], deps = [ - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/cc/saved_model/constants.h b/tensorflow/cc/saved_model/constants.h index 94a3b3cf465..c940df8a876 100644 --- a/tensorflow/cc/saved_model/constants.h +++ b/tensorflow/cc/saved_model/constants.h @@ -21,6 +21,9 @@ namespace tensorflow { /// SavedModel assets directory. constexpr char kSavedModelAssetsDirectory[] = "assets"; +/// SavedModel assets.extra directory. +constexpr char kSavedModelAssetsExtraDirectory[] = "assets.extra"; + /// SavedModel assets key for graph collection-def. constexpr char kSavedModelAssetsKey[] = "saved_model_assets"; diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index 807f5904afc..f98abc8a817 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/monitoring/counter.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/protobuf_internal.h" #include "tensorflow/core/protobuf/saved_model.pb.h" @@ -76,8 +77,16 @@ Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto, return Status::OK(); } } + string tags_as_string = "{ "; + for (const string& tag : tags) { + tags_as_string = strings::StrCat(tags_as_string, tag, " "); + } + tags_as_string = strings::StrCat(tags_as_string, "}"); return Status(error::Code::NOT_FOUND, - "Could not find meta graph def matching supplied tags."); + "Could not find meta graph def matching supplied tags: " + + tags_as_string + + ". To inspect available tag-sets in the SavedModel, please " + "use the SavedModel CLI: `saved_model_cli`"); } Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index cef29e7b071..0ad6b33bba5 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -133,9 +133,9 @@ TEST_F(LoaderTest, NoTagMatch) { Status st = LoadSavedModel(session_options, run_options, export_dir, {"missing-tag"}, &bundle); EXPECT_FALSE(st.ok()); - EXPECT_TRUE( - StringPiece(st.error_message()) - .contains("Could not find meta graph def matching supplied tags.")) + EXPECT_TRUE(StringPiece(st.error_message()) + .contains("Could not find meta graph def matching supplied " + "tags: { missing-tag }")) << st.error_message(); } @@ -151,7 +151,7 @@ TEST_F(LoaderTest, NoTagMatchMultiple) { EXPECT_FALSE(st.ok()); EXPECT_TRUE( StringPiece(st.error_message()) - .contains("Could not find meta graph def matching supplied tags.")) + .contains("Could not find meta graph def matching supplied tags: ")) << st.error_message(); } diff --git a/tensorflow/cc/saved_model/tag_constants.h b/tensorflow/cc/saved_model/tag_constants.h index 48ab1158e46..2b0b2d5c7fb 100644 --- a/tensorflow/cc/saved_model/tag_constants.h +++ b/tensorflow/cc/saved_model/tag_constants.h @@ -18,10 +18,13 @@ limitations under the License. namespace tensorflow { +/// Tag for the `gpu` graph. +constexpr char kSavedModelTagGpu[] = "gpu"; + /// Tag for the `serving` graph. constexpr char kSavedModelTagServe[] = "serve"; -/// Tag for the `training` graph.` +/// Tag for the `training` graph. constexpr char kSavedModelTagTrain[] = "train"; } // namespace tensorflow diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 1f6fe28188c..31637358c30 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -126,14 +126,11 @@ cc_library( deps = [ ":tfcompile_lib", ":tfcompile_proto", - "//tensorflow/compiler/xla/legacy_flags:alias_analysis_flags", "//tensorflow/compiler/xla/legacy_flags:buffer_assignment_flags", "//tensorflow/compiler/xla/legacy_flags:compiler_functor_flags", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/legacy_flags:hlo_graph_dumper_flags", - "//tensorflow/compiler/xla/legacy_flags:llvm_util_flags", "//tensorflow/compiler/xla/legacy_flags:service_flags", "//tensorflow/compiler/xla/legacy_flags:util_flags", "//tensorflow/compiler/xla/service:compiler", diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 6fed46b4329..e03d28cd96f 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -23,14 +23,11 @@ limitations under the License. #include "tensorflow/compiler/aot/flags.h" #include "tensorflow/compiler/aot/tfcompile.pb.h" #include "tensorflow/compiler/aot/tfcompile_util.h" -#include "tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h" #include "tensorflow/compiler/xla/legacy_flags/buffer_assignment_flags.h" #include "tensorflow/compiler/xla/legacy_flags/compiler_functor_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/legacy_flags/hlo_graph_dumper_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h" #include "tensorflow/compiler/xla/legacy_flags/service_flags.h" #include "tensorflow/compiler/xla/legacy_flags/util_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -136,14 +133,11 @@ int main(int argc, char** argv) { std::vector flag_list; AppendMainFlags(&flag_list, &flags); - xla::legacy_flags::AppendAliasAnalysisFlags(&flag_list); xla::legacy_flags::AppendBufferAssignmentFlags(&flag_list); xla::legacy_flags::AppendCompilerFunctorFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list); xla::legacy_flags::AppendHloGraphDumperFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendLlvmUtilFlags(&flag_list); xla::legacy_flags::AppendServiceFlags(&flag_list); xla::legacy_flags::AppendUtilFlags(&flag_list); diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 5f857191da7..306e704415b 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -22,20 +22,6 @@ load("//tensorflow:tensorflow.bzl", "cc_header_only_library") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -# This target can be used by XLA device plugins to prevent circular -# dependencies, and provides access to all of the required headers -# for building a device library. -cc_header_only_library( - name = "xla_jit_headers_lib", - visibility = ["//visibility:public"], - deps = [ - ":xla_cpu_device", - ":xla_cpu_jit", - ":xla_gpu_device", - ":xla_gpu_jit", - ], -) - # Target that bundles up the XLA CPU and GPU JIT devices. cc_library( name = "jit", @@ -283,3 +269,15 @@ filegroup( ), visibility = ["//tensorflow:__subpackages__"], ) + +# This target can be used by XLA device plugins to prevent circular dependencies, and provides access to all of the required headers for building a device library. +cc_header_only_library( + name = "xla_jit_headers_lib", + visibility = ["//visibility:public"], + deps = [ + ":xla_cpu_device", + ":xla_cpu_jit", + ":xla_gpu_device", + ":xla_gpu_jit", + ], +) diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index c4116cb8b52..ed204b81821 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -38,6 +38,7 @@ cc_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc index 29c5ff72429..bd051d06ae9 100644 --- a/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_device_launch_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/jit/xla_device_context.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/framework/allocator.h" @@ -149,6 +150,8 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { xla::ExecutionOptions execution_options; *execution_options.mutable_shape_with_output_layout() = kernel->xla_output_shape; + *execution_options.mutable_debug_options() = + xla::legacy_flags::GetDebugOptionsFromFlags(); Env* env = Env::Default(); auto start_time = env->NowMicros(); VLOG(1) << "Executing XLA Computation..."; @@ -202,8 +205,8 @@ void XlaDeviceLaunchOp::Compute(OpKernelContext* ctx) { // Apply variable updates, if any. VLOG(2) << "Applying variable updates"; - for (int i = 0; i < kernel->variable_updates.size(); ++i) { - const XlaCompiler::VariableUpdate& write = kernel->variable_updates[i]; + for (int i = 0; i < kernel->resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& write = kernel->resource_updates[i]; OP_REQUIRES(ctx, write.input_index >= 0 && write.input_index < ctx->num_inputs(), errors::Internal("Invalid input index for variable write.")); diff --git a/tensorflow/compiler/jit/ops/BUILD b/tensorflow/compiler/jit/ops/BUILD index 8d1fa03cc0d..e5787ca4c8c 100644 --- a/tensorflow/compiler/jit/ops/BUILD +++ b/tensorflow/compiler/jit/ops/BUILD @@ -1,32 +1,20 @@ licenses(["notice"]) # Apache 2.0 package( - default_visibility = [ - "//tensorflow/compiler/tf2xla:internal", - ], + default_visibility = ["//tensorflow/compiler/tf2xla:internal"], ) cc_library( name = "xla_ops", - srcs = [ - "xla_ops.cc", - ], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], + srcs = ["xla_ops.cc"], + deps = ["//tensorflow/core:framework"], alwayslink = 1, ) cc_library( name = "parallel_check_op", srcs = ["parallel_check_op.cc"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ], + deps = ["//tensorflow/core:framework"], alwayslink = 1, ) diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index 63ca77f9a91..2325217b973 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -182,17 +182,18 @@ Status BuildArguments(int num_constant_args, XlaCompiler::Argument& arg = (*args)[input_num]; arg.name = variable_args[variable_id].name; + arg.kind = XlaCompiler::Argument::kVariable; if (variable_args[variable_id].present) { const Tensor& value = variable_args[variable_id].value; - arg.kind = XlaCompiler::Argument::kVariable; arg.type = value.dtype(); arg.shape = value.shape(); + arg.initialized = true; } else { // The values of uninitialized variables are not passed as inputs, since // they are meaningless. However, it is legal to assign to a resource // variable for the first time inside the XLA computation, so we do permit // uninitialized variables. - arg.kind = XlaCompiler::Argument::kUninitializedVariable; + arg.initialized = false; arg.type = DT_INVALID; arg.shape = TensorShape(); } diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index f329e83e14d..0ab81ebd5ff 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -137,7 +137,7 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, done(result.status()); return; } - const void* src_ptr = xla::LiteralUtil::InternalData(*result.ValueOrDie()); + const void* src_ptr = result.ValueOrDie()->InternalData(); void* dst_ptr = DMAHelper::base(cpu_tensor); size_t total_bytes = cpu_tensor->TotalBytes(); memcpy(dst_ptr, src_ptr, total_bytes); diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 4bbb2767ac0..c42d7d754c4 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -40,6 +40,7 @@ py_library( "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform", + "//tensorflow/python:random_seed", "//tensorflow/python:variables", ], ) @@ -323,7 +324,7 @@ tf_xla_py_test( tf_xla_py_test( name = "reverse_ops_test", - size = "small", + size = "medium", srcs = ["reverse_ops_test.py"], deps = [ ":xla_test", diff --git a/tensorflow/compiler/tests/spacetobatch_op_test.py b/tensorflow/compiler/tests/spacetobatch_op_test.py index 9c3b86c84b2..c013f4b50a4 100644 --- a/tensorflow/compiler/tests/spacetobatch_op_test.py +++ b/tensorflow/compiler/tests/spacetobatch_op_test.py @@ -228,34 +228,40 @@ class SpaceToBatchNDTest(XLATestCase): outputs=[[[0, 0], [2, 21]], [[0, 0], [5, 51]], [[1, 11], [3, 31]], [[4, 41], [6, 61]]]) - def testDirect(self): + def testDirect0(self): # Test with zero-size remaining dimension. self._testDirect( input_shape=[3, 1, 2, 0], block_shape=[3], paddings=[[0, 2]]) + def testDirect1(self): # Test with zero-size blocked dimension. self._testDirect( input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[0, 0]]) + def testDirect2(self): # Test with padding up from zero size. self._testDirect( input_shape=[3, 0, 2, 5], block_shape=[3], paddings=[[1, 2]]) + def testDirect3(self): self._testDirect( input_shape=[3, 3, 4, 5, 2], block_shape=[3, 4, 2], paddings=[[1, 2], [0, 0], [3, 0]]) + def testDirect4(self): self._testDirect( input_shape=[3, 3, 4, 5, 2], block_shape=[3, 4, 2, 2], paddings=[[1, 2], [0, 0], [3, 0], [0, 0]]) + def testDirect5(self): self._testDirect( input_shape=[3, 2, 2, 3, 4, 5, 2, 5], block_shape=[1, 1, 3, 4, 2, 2], paddings=[[0, 0], [0, 0], [1, 2], [0, 0], [3, 0], [0, 0]]) + def testDirect6(self): self._testDirect( input_shape=[3, 2, 2, 3, 4, 5, 2, 5], block_shape=[1, 1, 3, 4, 2, 2, 1], diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index 27a29773053..b3067be51dd 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -335,7 +335,7 @@ class TensorArrayTest(xla_test.XLATestCase): r0_bad = gen_data_flow_ops._tensor_array_read_v3( handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow) with self.assertRaisesOpError( - "TensorArray dtype is float but Op requested dtype double."): + "TensorArray dtype is float but op has dtype double."): r0_bad.eval() # Test reading from a different index than the one we wrote to @@ -573,13 +573,12 @@ class TensorArrayTest(xla_test.XLATestCase): [2000.0, -2000.0]], grad_vals[0]) - # TODO(phawkins): implement TensorArrayClose - # def testCloseTensorArray(self): - # with self.test_session() as session, self.test_scope(): - # ta = tensor_array_ops.TensorArray( - # dtype=dtypes.float32, tensor_array_name="foo", size=3) - # c1 = ta.close() - # session.run(c1) + def testCloseTensorArray(self): + with self.test_session() as session, self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, tensor_array_name="foo", size=3) + c1 = ta.close() + session.run(c1) def testSizeTensorArray(self): with self.test_session(), self.test_scope(): @@ -588,17 +587,16 @@ class TensorArrayTest(xla_test.XLATestCase): s = ta.size() self.assertAllEqual(3, s.eval()) - # TODO(phawkins): implement TensorArrayClose - # def testWriteCloseTensorArray(self): - # with self.test_session(), self.test_scope(): - # ta = tensor_array_ops.TensorArray( - # dtype=dtypes.float32, - # tensor_array_name="foo", - # size=3, - # infer_shape=False) - # w0 = ta.write(0, [[4.0, 5.0]]) - # w1 = w0.write(1, [3.0]) - # w1.close().run() # Expected to run without problems + def testWriteCloseTensorArray(self): + with self.test_session(), self.test_scope(): + ta = tensor_array_ops.TensorArray( + dtype=dtypes.float32, + tensor_array_name="foo", + size=3, + infer_shape=False) + w0 = ta.write(0, [[4.0, 5.0]]) + w1 = w0.write(1, [3.0]) + w1.close().run() # Expected to run without problems # TODO(phawkins): implement while loops. # def _testWhileLoopWritePackGradients(self, dynamic_size, dtype): diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 93c484ca7a0..f4d3bc96354 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -42,6 +42,7 @@ cc_library( deps = [ ":common", ":dump_graph", + ":functionalize_control_flow", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -152,7 +153,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", ], ) @@ -165,13 +165,10 @@ cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", - "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", "//tensorflow/core:ops", "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core:testlib", ], ) @@ -203,6 +200,58 @@ cc_library( ], ) +cc_library( + name = "functionalize_control_flow", + srcs = ["functionalize_control_flow.cc"], + hdrs = ["functionalize_control_flow.h"], + deps = [ + "//tensorflow/compiler/jit:graph_to_functiondef", + "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/tf2xla/ops:functional_ops", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + ], +) + +cc_test( + name = "functionalize_control_flow_test", + srcs = ["functionalize_control_flow_test.cc"], + deps = [ + ":functionalize_control_flow", + ":test_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", + "//tensorflow/compiler/tf2xla/cc:functional_ops", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:ops", + "//tensorflow/core:resource_variable_ops_op_lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD new file mode 100644 index 00000000000..599265ba449 --- /dev/null +++ b/tensorflow/compiler/tf2xla/cc/BUILD @@ -0,0 +1,44 @@ +package( + default_visibility = ["//tensorflow/compiler/tf2xla:internal"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_cc") + +tf_gen_op_wrapper_cc( + name = "functional_ops_gen", + include_internal_ops = 1, + out_ops_file = "ops/functional_ops", + deps = ["//tensorflow/compiler/tf2xla/ops:functional_ops"], +) + +cc_library( + name = "functional_ops", + srcs = ["ops/functional_ops.cc"], + hdrs = ["ops/functional_ops.h"], + deps = [ + "//tensorflow/cc:const_op", + "//tensorflow/cc:ops", + "//tensorflow/cc:scope", + "//tensorflow/compiler/tf2xla/ops:functional_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc new file mode 100644 index 00000000000..623c52120fd --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -0,0 +1,566 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/jit/graph_to_functiondef.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/graph/control_flow.h" + +namespace tensorflow { + +namespace { + +const char* const kArgOp = "_Arg"; +const char* const kRetValOp = "_Retval"; + +// Information about a loop argument. +struct Arg { + // Every loop argument has an Enter node. + Node* enter; + + // Is the loop argument a loop-invariant value? Taken from the `is_constant` + // attribute on the Enter node. + bool is_loop_invariant; + + // If 'is_loop_invariant' is true, the following are all nullptr. Non-constant + // arguments must have all of the following nodes: + Node* merge = nullptr; + Node* switch_node = nullptr; + Node* next_iteration = nullptr; + Node* exit = nullptr; +}; + +// Information about a loop frame. +struct Frame { + string name; + + // Pointer to the parent frame. The root frame has a pointer to itself. + Frame* parent = nullptr; + int num_children = 0; + + // Arguments to this loop. + std::vector args; + + // The loop condition of the loop. There should be exactly one loop condition + // in every loop. + Node* loop_cond = nullptr; + + // Set of nodes that belong to the loop frame. + std::unordered_set nodes; +}; + +// Copies a subgraph from `graph` to `output` by performing a reverse DFS +// starting at nodes in vector `stack`. +// `node_map` is a vector indexed by source node ID to dest nodes. +// Does not traverse into nodes in `node_map`, so by adding nodes to `node_map` +// before the traversal clients can cut the graph. Returns an error if the +// traversal leaves 'frame'; the client must add enough nodes to `node_map` to +// cut the graph and prevent the traversal from escaping. +// +// `squash_src_outputs` contains a bool for each source node ID. If true, then +// the source output on that node will be replaced by zero when copied. This is +// used when replacing a Switch node with an _Arg node. The output we are +// taking from the Switch node was not necessarily the first output, but _Arg +// nodes only have one output. By adding the Switch node to `squash_src_outputs` +// we rewrite the src_output of the corresponding edge to be 0. +Status CopySubgraph(const Graph& graph, const Frame& frame, + std::vector stack, + const std::vector& squash_src_outputs, + std::vector* node_map, Graph* output) { + std::vector visited(graph.num_node_ids(), false); + while (!stack.empty()) { + Node* n = stack.back(); + stack.pop_back(); + + VLOG(3) << "Copying node " << n->name(); + + if (visited[n->id()]) continue; + visited[n->id()] = true; + + for (const Edge* e : n->in_edges()) { + Node* src = e->src(); + if (frame.nodes.find(src) == frame.nodes.end()) { + // We traversed out of the loop frame, without encountering a cut node. + return errors::Internal("Graph traversal of loop frame ", frame.name, + " escaped frame at ", src->name(), + " without encountering an argument node."); + } + if ((*node_map)[src->id()] == nullptr) { + (*node_map)[src->id()] = output->CopyNode(src); + stack.push_back(src); + } + Node* src_copy = (*node_map)[e->src()->id()]; + int src_output = squash_src_outputs[e->src()->id()] ? 0 : e->src_output(); + Node* dst_copy = (*node_map)[e->dst()->id()]; + output->AddEdge(src_copy, src_output, dst_copy, e->dst_input()); + } + } + return Status::OK(); +} + +Status BuildArgNode(Graph* graph, DataType type, int index, Node** arg_node) { + NodeDef arg_def; + NodeDefBuilder builder(strings::StrCat("_Arg", index), kArgOp); + builder.Attr("T", type); + builder.Attr("index", index); + TF_RETURN_IF_ERROR(builder.Finalize(&arg_def)); + Status status; + *arg_node = graph->AddNode(arg_def, &status); + return status; +} + +Status BuildRetvalNode(Graph* graph, DataType type, int index, + Node** retval_node) { + NodeDef ret_def; + ret_def.set_op(kRetValOp); + ret_def.set_name(strings::StrCat("_Retval", index)); + AddNodeAttr("T", type, &ret_def); + AddNodeAttr("index", index, &ret_def); + Status status; + *retval_node = graph->AddNode(ret_def, &status); + return status; +} + +// Builds a graph for the loop condition. +Status BuildLoopCondition(const Graph& graph, Frame* frame, + std::unique_ptr* cond_output) { + VLOG(2) << "Building loop condition for " << frame->name; + *cond_output = xla::MakeUnique(graph.op_registry()); + Graph* output = cond_output->get(); + + // Map from nodes in the original graph to the condition graph. + std::vector node_map(graph.num_node_ids(), nullptr); + std::vector squash_src_outputs(graph.num_node_ids(), false); + + // Build one _Arg node for each Enter node. + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + + Node* arg_node; + TF_RETURN_IF_ERROR( + BuildArgNode(output, arg.enter->input_type(0), i, &arg_node)); + if (arg.is_loop_invariant) { + node_map[arg.enter->id()] = arg_node; + } else { + node_map[arg.merge->id()] = arg_node; + } + } + + // Build a Retval node for the loop condition. The LoopCond nodes are always + // boolean because of the type constraints on the LoopCond op. + TF_RETURN_IF_ERROR( + BuildRetvalNode(output, DT_BOOL, 0, &node_map[frame->loop_cond->id()])); + + // Performs a reverse DFS, copying nodes and edges to the output graph. + // The _Arg and _Retval nodes were added unconditionally above, so we are + // guaranteed to get the correct function signature. + TF_RETURN_IF_ERROR(CopySubgraph(graph, *frame, {frame->loop_cond}, + squash_src_outputs, &node_map, output)); + + return Status::OK(); +} + +// Builds a graph for the loop body. +Status BuildLoopBody(const Graph& graph, Frame* frame, + DataTypeVector* arg_types, + std::unique_ptr* body_output) { + VLOG(2) << "Building loop body for " << frame->name; + *body_output = xla::MakeUnique(graph.op_registry()); + Graph* output = body_output->get(); + + // Map from nodes in the original graph to the condition graph. + std::vector node_map(graph.num_node_ids(), nullptr); + std::vector squash_src_outputs(graph.num_node_ids(), false); + + // Build one _Arg node for each Enter node. + std::vector next_iterations; + next_iterations.reserve(frame->args.size()); + arg_types->reserve(frame->args.size()); + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + + DataType dtype = arg.enter->input_type(0); + arg_types->push_back(dtype); + Node* arg_node; + TF_RETURN_IF_ERROR(BuildArgNode(output, dtype, i, &arg_node)); + + if (dtype == DT_RESOURCE) { + // The convention of the XLA bridge is that resource variable arguments + // are only inputs to the loop body and have no corresponding output. + // TODO(b/37741920): change the convention so that DT_RESOURCE variables + // are both inputs and outputs, and then remove this case. + TF_RET_CHECK(arg.is_loop_invariant); + node_map[arg.enter->id()] = arg_node; + } else { + Node* retval_node; + TF_RETURN_IF_ERROR(BuildRetvalNode(output, dtype, i, &retval_node)); + + if (arg.is_loop_invariant) { + // Argument is loop-invariant. Forward it from the Arg to the Retval. + node_map[arg.enter->id()] = arg_node; + output->AddEdge(arg_node, 0, retval_node, 0); + } else { + // Argument is loop-varying. + node_map[arg.switch_node->id()] = arg_node; + // The Switch node has two outputs, but _Arg only has one. This tells + // the CopySubgraph function to rewrite the output number of edges from + // the _Arg node to be 0 rather than copying the output number from the + // Switch node. + squash_src_outputs[arg.switch_node->id()] = true; + node_map[arg.next_iteration->id()] = retval_node; + next_iterations.push_back(arg.next_iteration); + } + } + } + + // Performs a reverse DFS, copying nodes and edges to the output graph. + // The _Arg and _Retval nodes were added unconditionally above, so we are + // guaranteed to get the correct function signature. + TF_RETURN_IF_ERROR(CopySubgraph(graph, *frame, std::move(next_iterations), + squash_src_outputs, &node_map, output)); + + return Status::OK(); +} + +Status FunctionalizeLoop(Graph* graph, Frame* frame, + FunctionLibraryDefinition* library) { + VLOG(2) << "Frame " << frame->name << " before: " + << dump_graph::DumpGraphToFile("functionalize_before", *graph); + + // Split loop-varying Enter nodes with multiple successors. If the same + // Tensor is fed as input to multiple loop arguments, we may end up with a + // shared Enter node. We clone Enter nodes with multiple successors to + // maintain the invariant of a unique Enter node per argument of the final + // loop. + std::vector args; + for (const Arg& arg : frame->args) { + if (arg.is_loop_invariant) { + args.push_back(arg); + } else { + std::vector edges(arg.enter->out_edges().begin(), + arg.enter->out_edges().end()); + for (int i = 0; i < edges.size(); ++i) { + TF_RET_CHECK(!edges[i]->IsControlEdge()); + Arg new_arg; + new_arg.is_loop_invariant = false; + if (i == 0) { + new_arg.enter = arg.enter; + } else { + new_arg.enter = graph->CopyNode(arg.enter); + frame->nodes.insert(new_arg.enter); + for (Edge const* e : arg.enter->in_edges()) { + graph->AddEdge(e->src(), e->src_output(), new_arg.enter, + e->IsControlEdge() ? Graph::kControlSlot : 0); + } + Node* dst = edges[i]->dst(); + int dst_input = edges[i]->dst_input(); + graph->RemoveEdge(edges[i]); + graph->AddEdge(new_arg.enter, 0, dst, dst_input); + } + args.push_back(new_arg); + } + } + } + frame->args = std::move(args); + + // Order the arguments so that: + // a) resource variables are last, and + // b) sort lexicographically by name (for deterministic output). + std::sort(frame->args.begin(), frame->args.end(), + [](const Arg& a, const Arg& b) { + bool a_is_resource = (a.enter->input_type(0) == DT_RESOURCE); + bool b_is_resource = (b.enter->input_type(0) == DT_RESOURCE); + return std::tie(a_is_resource, a.enter->name()) < + std::tie(b_is_resource, b.enter->name()); + }); + + if (frame->loop_cond == nullptr) { + return errors::InvalidArgument("Loop ", frame->name, + " has no LoopCond node"); + } + + // Find the set of Switch nodes that are successors of the LoopCond. + std::unordered_set switches; + for (const Edge* edge : frame->loop_cond->out_edges()) { + if (!edge->IsControlEdge() && IsSwitch(edge->dst()) && + edge->dst_input() == 1) { + switches.insert(edge->dst()); + } + } + + // For each non-constant argument, looks for the following pattern of nodes: + // Enter ----> Merge --------> Switch --> Exit + // ^ ^ + // | | + // NextIteration LoopCond + // ^ ^ + // | | + // ... ... + for (Arg& arg : frame->args) { + if (!arg.is_loop_invariant) { + // Follow the edge from the Enter to Merge. + if (arg.enter->out_edges().size() != 1) { + return errors::Internal("Enter node for loop-varying argument ", + arg.enter->name(), + " does not have exactly one successor"); + } + const Edge* enter_merge = *arg.enter->out_edges().begin(); + arg.merge = enter_merge->dst(); + if (!IsMerge(arg.merge)) { + return errors::InvalidArgument( + "Successor of Enter node for loop-varying argument ", + arg.merge->name(), + " is not a Merge node; got: ", arg.merge->type_string()); + } + + // Find the NextIteration from the merge. There should be two inputs to + // the Merge and the NextIteration should be the other input. + if (arg.merge->input_types().size() != 2) { + return errors::InvalidArgument( + "Unexpected number of inputs to Merge node for loop-varying " + "argument ", + arg.merge->name(), "; expected 2, got ", + arg.merge->input_types().size()); + } + TF_RETURN_IF_ERROR(arg.merge->input_node(1 - enter_merge->dst_input(), + &arg.next_iteration)); + if (!IsNextIteration(arg.next_iteration)) { + return errors::InvalidArgument( + "Expected NextIteration node as input to Merge node; got node ", + arg.next_iteration->name(), " with kind ", + arg.next_iteration->type_string()); + } + + // Find the Switch successor of the Merge. There should be exactly one + // Switch node that is a successor of both the Merge and the LoopCond. + for (const Edge* edge : arg.merge->out_edges()) { + if (edge->dst_input() == 0 && IsSwitch(edge->dst()) && + switches.find(edge->dst()) != switches.end()) { + if (arg.switch_node != nullptr) { + return errors::InvalidArgument("Duplicate Switch successors to ", + arg.merge->name()); + } + arg.switch_node = edge->dst(); + } + } + if (arg.switch_node == nullptr) { + return errors::InvalidArgument("Missing Switch successor to ", + arg.merge->name()); + } + + // Find the Exit successor of the Switch. + for (const Edge* edge : arg.switch_node->out_edges()) { + if (edge->src_output() == 0 && IsExit(edge->dst())) { + if (arg.exit != nullptr) { + return errors::InvalidArgument("Duplicate Exit successors to ", + arg.switch_node->name()); + } + arg.exit = edge->dst(); + } + } + if (arg.exit == nullptr) { + return errors::InvalidArgument("Mising Exit successor to ", + arg.switch_node->name()); + } + } + } + + // Builds the condition and body functions. + std::unique_ptr cond_graph; + TF_RETURN_IF_ERROR(BuildLoopCondition(*graph, frame, &cond_graph)); + DataTypeVector arg_types; + std::unique_ptr body_graph; + TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); + + VLOG(2) << "Frame " << frame->name << " condition: " + << dump_graph::DumpGraphToFile("loop_condition", *cond_graph) + << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph); + + static std::atomic sequence_num(0LL); + int64 id = ++sequence_num; + NameAttrList cond_name; + cond_name.set_name(strings::StrCat("_functionalize_cond_", id)); + NameAttrList body_name; + body_name.set_name(strings::StrCat("_functionalize_body_", id)); + FunctionDef cond_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*cond_graph, cond_name.name(), &cond_fdef)); + FunctionDef body_fdef; + TF_RETURN_IF_ERROR( + GraphToFunctionDef(*body_graph, body_name.name(), &body_fdef)); + + TF_RETURN_IF_ERROR(library->AddFunctionDef(cond_fdef)); + TF_RETURN_IF_ERROR(library->AddFunctionDef(body_fdef)); + + // Builds a While operator. + NodeDef while_def; + NodeDefBuilder builder(frame->loop_cond->name(), "XlaWhile"); + builder.Attr("T", arg_types); + builder.Attr("cond", cond_name); + builder.Attr("body", body_name); + std::vector inputs; + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + builder.ControlInput(in_edge->src()->name()); + } else { + inputs.push_back(NodeDefBuilder::NodeOut( + in_edge->src()->name(), in_edge->src_output(), arg_types[i])); + } + } + builder.Input(inputs); + TF_RETURN_IF_ERROR(builder.Finalize(&while_def)); + + Status status; + Node* while_node = graph->AddNode(while_def, &status); + if (!status.ok()) { + return status; + } + + // Copies edges to the Enter nodes and from the Exit nodes onto the While. + for (int i = 0; i < frame->args.size(); ++i) { + const Arg& arg = frame->args[i]; + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg.enter->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + graph->AddControlEdge(in_edge->src(), while_node); + } else { + graph->AddEdge(in_edge->src(), in_edge->src_output(), while_node, i); + } + + if (!arg.is_loop_invariant) { + std::vector edges(arg.exit->out_edges().begin(), + arg.exit->out_edges().end()); + for (const Edge* edge : edges) { + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + graph->RemoveEdge(edge); + + int src_output = + dst_input == Graph::kControlSlot ? Graph::kControlSlot : i; + graph->AddEdge(while_node, src_output, dst, dst_input); + } + } + } + + // Remove the old nodes from the graph, and add the while node to the parent + // frame. + for (Node* node : frame->nodes) { + graph->RemoveNode(node); + } + frame->parent->nodes.insert(while_node); + + VLOG(2) << "Frame " << frame->name << " after: " + << dump_graph::DumpGraphToFile("functionalize_after", *graph); + + return Status::OK(); +} + +} // namespace + +// Transformation that converts Tensorflow's graph control flow constructs into +// functional equivalents. +Status FunctionalizeControlFlow(Graph* graph, + FunctionLibraryDefinition* library) { + VLOG(2) << "FunctionalizeControlFlow: " + << dump_graph::DumpGraphToFile("functionalize_initial", *graph); + // Note: BuildControlFlowInfo() requires that the graph's source node is + // connected to all source nodes in the graph. Many graphs violate this + // invariant. + std::vector cf_info; + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph, &cf_info)); + + // Builds Frames, indexed by name. + std::unordered_map frames; + for (Node* node : graph->op_nodes()) { + const ControlFlowInfo& cf = cf_info[node->id()]; + + VLOG(2) << "node: " << node->name() << " frame_name: " << cf.frame_name + << " frame: " << (cf.frame ? cf.frame->name() : "---") + << " parent_frame: " + << (cf.parent_frame ? cf.parent_frame->name() : "---"); + TF_RET_CHECK(cf.frame != nullptr && cf.parent_frame != nullptr); + + Frame& frame = frames[cf.frame_name]; + Frame* parent = &frames[cf_info[cf.parent_frame->id()].frame_name]; + if (frame.parent == nullptr) { + frame.parent = parent; + frame.name = cf.frame_name; + ++parent->num_children; + } else if (frame.parent != parent) { + return errors::InvalidArgument("Mismatched parent frames for ", + cf.frame->id(), ": ", parent->name, " vs ", + frame.parent->name); + } + + if (IsEnter(node)) { + Arg arg; + arg.enter = node; + TF_RETURN_IF_ERROR(GetNodeAttr(arg.enter->attrs(), "is_constant", + &arg.is_loop_invariant)); + frame.args.push_back(arg); + } else if (IsLoopCond(node)) { + if (frame.loop_cond) { + return errors::InvalidArgument( + "Loop ", cf.frame_name, + " has more than one LoopCond node: ", node->name(), " and ", + frame.loop_cond->name()); + } + frame.loop_cond = node; + } + frame.nodes.insert(node); + } + + // Adds frames with no children (i.e., the innermost frames) to a worklist. + std::deque worklist; + for (auto& frame : frames) { + if (frame.second.num_children == 0) { + worklist.push_back(&frame.second); + } + } + + // Eliminate loops from innermost to outermost. + while (!worklist.empty()) { + Frame* frame = worklist.front(); + worklist.pop_front(); + if (frame->parent == frame) { + // Skip the root frame. + continue; + } + + TF_RETURN_IF_ERROR(FunctionalizeLoop(graph, frame, library)); + + // If the parent has no remaining children, add it to the worklist. + --frame->parent->num_children; + if (frame->parent->num_children == 0) { + worklist.push_back(frame->parent); + } + } + + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h new file mode 100644 index 00000000000..1535dc80b0c --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ +#define TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { + +// Transformation that converts tf.while_loop() loops into functional While +// operators, suitable for XLA compilation. +// TODO(b/36470387): add support for conditionals. +Status FunctionalizeControlFlow(Graph* graph, + FunctionLibraryDefinition* library); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_FUNCTIONALIZE_CONTROL_FLOW_H_ diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc new file mode 100644 index 00000000000..7f6717ffafd --- /dev/null +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -0,0 +1,647 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/cc/ops/functional_ops.h" +#include "tensorflow/compiler/tf2xla/test_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/equal_graph_def.h" + +namespace tensorflow { +namespace { + +// Returns the names of the "cond" and "body" functions for the While node +// in a graph. +Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, + NameAttrList* body) { + for (const NodeDef& node : graph.node()) { + if (node.op() == "XlaWhile") { + const NameAttrList* result; + TF_RETURN_IF_ERROR(GetNodeAttr(node, "cond", &result)); + *cond = *result; + TF_RETURN_IF_ERROR(GetNodeAttr(node, "body", &result)); + *body = *result; + return Status::OK(); + } + } + return errors::NotFound("No XlaWhile node found in graph"); +} + +// Graph: +// x = array_ops.placeholder(dtypes.int32) +// y = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x]) +TEST(FunctionalizeControlFlow, OneLoopVar) { + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); + + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto enter = + ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop"); + auto merge = ops::Merge(scope.WithOpName("while/Merge"), + std::initializer_list{enter, dummy}); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), + 10); + auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); + auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); + auto switch_ = + ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); + auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"), + switch_.output_false); + auto identity = + ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity, one); + auto next_iteration = + ops::NextIteration(scope.WithOpName("while/NextIteration"), add); + + auto sink = ops::Identity(scope.WithOpName("sink"), exit); + + // Remove the dummy node and add the loop backedge. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1); + + TF_EXPECT_OK(scope.ToGraph(&graph)); + } + + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::XlaWhile(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // Condition graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); + auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity, one); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + +// Graph: +// x = array_ops.placeholder(dtypes.int32) +// y = array_ops.placeholder(dtypes.int32) +// cond = lambda (i, j): i + 3 < 10 +// body = lambda (i, j): (i < 10, j * 2) +// z = control_flow_ops.while_loop(cond, body, [x, y]) +TEST(FunctionalizeControlFlow, TwoLoopVars) { + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); + + auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); + auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); + auto enter_x = + ops::internal::Enter(scope.WithOpName("while/Enter/x"), x, "aloop"); + auto enter_y = + ops::internal::Enter(scope.WithOpName("while/Enter/y"), y, "aloop"); + auto merge_x = ops::Merge(scope.WithOpName("while/Merge/x"), + std::initializer_list{enter_x, dummy}); + auto merge_y = ops::Merge(scope.WithOpName("while/Merge/y"), + std::initializer_list{enter_y, dummy}); + + // Loop condition + auto three = ops::Const(scope.WithOpName("while/cond/three") + .WithControlDependencies(merge_x.output), + 3); + auto cond_add = + ops::Add(scope.WithOpName("while/cond/Add"), merge_x.output, three); + auto ten = ops::Const(scope.WithOpName("while/cond/ten") + .WithControlDependencies(merge_x.output), + 10); + auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); + auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); + + auto switch_x = ops::Switch(scope.WithOpName("while/Switch/x"), + merge_x.output, loop_cond); + auto switch_y = ops::Switch(scope.WithOpName("while/Switch/y"), + merge_y.output, loop_cond); + + auto exit_x = ops::internal::Exit(scope.WithOpName("while/Exit/x"), + switch_x.output_false); + auto exit_y = ops::internal::Exit(scope.WithOpName("while/Exit/y"), + switch_y.output_false); + + auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"), + switch_x.output_true); + auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), + switch_y.output_true); + + auto one = ops::Const( + scope.WithOpName("while/add/one").WithControlDependencies(identity_x), + 1); + auto two = ops::Const( + scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), + 2); + + auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one); + auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two); + auto next_iteration_x = + ops::NextIteration(scope.WithOpName("while/NextIteration/x"), add); + auto next_iteration_y = + ops::NextIteration(scope.WithOpName("while/NextIteration/y"), mul); + + auto sink_x = ops::Identity(scope.WithOpName("sink_x"), exit_x); + auto sink_y = ops::Identity(scope.WithOpName("sink_y"), exit_y); + + // Remove the dummy node and add the loop backedges. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(next_iteration_x.node(), 0, merge_x.output.node(), + 1); + scope.graph()->AddEdge(next_iteration_y.node(), 0, merge_y.output.node(), + 1); + + TF_EXPECT_OK(scope.ToGraph(&graph)); + } + + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto x = ops::Placeholder(scope.WithOpName("Placeholder/x"), DT_INT32); + auto y = ops::Placeholder(scope.WithOpName("Placeholder/y"), DT_INT32); + auto while_op = + ops::XlaWhile(scope.WithOpName("while/LoopCond"), + std::initializer_list{x, y}, cond_fn, body_fn); + auto sink_x = ops::Identity(scope.WithOpName("sink_x"), while_op[0]); + auto sink_y = ops::Identity(scope.WithOpName("sink_y"), while_op[1]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // Condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto three = ops::Const(scope.WithOpName("while/cond/three") + .WithControlDependencies(arg0.output), + 3); + auto cond_add = + ops::Add(scope.WithOpName("while/cond/Add"), arg0.output, three); + auto ten = ops::Const( + scope.WithOpName("while/cond/ten").WithControlDependencies(arg0.output), + 10); + auto less = ops::Less(scope.WithOpName("while/cond/Less"), cond_add, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + + auto identity_x = ops::Identity(scope.WithOpName("while/Identity/x"), arg0); + auto identity_y = ops::Identity(scope.WithOpName("while/Identity/y"), arg1); + + auto one = ops::Const( + scope.WithOpName("while/add/one").WithControlDependencies(identity_x), + 1); + auto two = ops::Const( + scope.WithOpName("while/mul/two").WithControlDependencies(identity_x), + 2); + + auto add = ops::Add(scope.WithOpName("while/add"), identity_x, one); + auto mul = ops::Add(scope.WithOpName("while/mul"), identity_y, two); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), mul, 1); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32}), result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + +// Example with nesting, loop-invariant arguments, and resource variables. +// +// accum = resource_variable_ops.ResourceVariable(1) +// x = array_ops.placeholder(2, dtype=dtypes.int32) +// y = 3 + x +// +// def inner_body(j, k): +// add = state_ops.assign_add(accum, k * j + x) +// with ops.control_dependencies([add]): +// return [j + 1, k] +// +// def body(i): +// m = control_flow_ops.while_loop(lambda j, k: j < 5, inner_body, +// [1, y], name="inner") +// with ops.control_dependencies(m): +// return [i + 1] +// +// z = control_flow_ops.while_loop(lambda i: i < 10, body, [0], name="outer") +TEST(FunctionalizeControlFlow, Complex) { + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); + + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto three = ops::Const(scope.WithOpName("three"), 3); + auto y = ops::Add(scope.WithOpName("y"), x, three); + + auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, + TensorShape({})); + + // Outer loop + auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); + auto enter_i = + ops::internal::Enter(scope.WithOpName("outer/Enter_i"), zero, "outer"); + auto merge_i = ops::Merge(scope.WithOpName("outer/Merge_i"), + std::initializer_list{enter_i, dummy}); + auto ten = ops::Const(scope.WithOpName("outer/Less/y") + .WithControlDependencies(merge_i.output), + 10); + auto less_i = + ops::Less(scope.WithOpName("outer/Less_i"), merge_i.output, ten); + auto outer_loop_cond = + ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_i); + auto switch_i = ops::Switch(scope.WithOpName("outer/Switch"), + merge_i.output, outer_loop_cond); + auto exit_i = ops::internal::Exit(scope.WithOpName("outer/Exit"), + switch_i.output_false); + auto identity_i = + ops::Identity(scope.WithOpName("outer/Identity"), switch_i.output_true); + + auto enter_x_outer = + ops::internal::Enter(scope.WithOpName("outer/Enter_x"), x, "outer", + ops::internal::Enter::Attrs().IsConstant(true)); + auto enter_k_outer = + ops::internal::Enter(scope.WithOpName("outer/Enter_k"), y, "outer", + ops::internal::Enter::Attrs().IsConstant(true)); + auto enter_var_outer = + ops::internal::Enter(scope.WithOpName("outer/Enter_var"), var, "outer", + ops::internal::Enter::Attrs().IsConstant(true)); + + // Inner loop + auto one_j = ops::Const( + scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); + auto enter_j = ops::internal::Enter(scope.WithOpName("outer/inner/Enter_j"), + one_j, "inner"); + auto enter_k = + ops::internal::Enter(scope.WithOpName("outer/inner/Enter_k") + .WithControlDependencies(identity_i), + enter_k_outer, "inner"); + auto enter_x = ops::internal::Enter( + scope.WithOpName("outer/inner/Enter_x"), enter_x_outer, "inner", + ops::internal::Enter::Attrs().IsConstant(true)); + auto enter_var = ops::internal::Enter( + scope.WithOpName("outer/inner/Enter_var"), enter_var_outer, "inner", + ops::internal::Enter::Attrs().IsConstant(true)); + + auto merge_j = ops::Merge(scope.WithOpName("outer/inner/Merge_j"), + std::initializer_list{enter_j, dummy}); + auto merge_k = ops::Merge(scope.WithOpName("outer/inner/Merge_k"), + std::initializer_list{enter_k, dummy}); + + auto five = ops::Const(scope.WithOpName("outer/inner/Five") + .WithControlDependencies(merge_j.output), + 5); + auto less_j = + ops::Less(scope.WithOpName("outer/inner/Less_j"), merge_j.output, five); + auto loop_cond = ops::LoopCond(scope.WithOpName("outer/LoopCond"), less_j); + + auto switch_j = ops::Switch(scope.WithOpName("outer/inner/Switch_j"), + merge_j.output, loop_cond); + auto switch_k = ops::Switch(scope.WithOpName("outer/inner/Switch_k"), + merge_k.output, loop_cond); + auto exit_j = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_j"), + switch_j.output_false); + auto exit_k = ops::internal::Exit(scope.WithOpName("outer/inner/Exit_k"), + switch_k.output_false); + auto identity_j = ops::Identity(scope.WithOpName("outer/inner/Identity_j"), + switch_j.output_true); + auto identity_k = ops::Identity(scope.WithOpName("outer/inner/Identity_k"), + switch_k.output_true); + + // Variable update + auto mul_jk = + ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k); + auto add_jkx = + ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, enter_x); + auto assign = ops::AssignAddVariableOp( + scope.WithOpName("outer/inner/assign_add"), enter_var, add_jkx); + + auto one = + ops::Const(scope.WithOpName("outer/inner/One") + .WithControlDependencies( + gtl::ArraySlice{assign.operation}), + 1); + auto add_j = + ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); + + auto next_iteration_j = ops::NextIteration( + scope.WithOpName("outer/inner/NextIteration_j"), add_j); + auto next_iteration_k = ops::NextIteration( + scope.WithOpName("outer/inner/NextIteration_k"), identity_k); + + // Body and backedge for outer loop. + auto one_outer = ops::Const( + scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); + auto add_i = + ops::Add(scope.WithOpName("outer/add") + .WithControlDependencies(gtl::ArraySlice{ + exit_j.output.op(), exit_k.output.op()}), + identity_i, one_outer); + auto next_iteration_i = + ops::NextIteration(scope.WithOpName("outer/NextIteration"), add_i); + + auto sink = ops::Identity(scope.WithOpName("sink"), exit_i); + + // Remove the dummy node and add the loop backedge. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(next_iteration_i.node(), 0, merge_i.output.node(), + 1); + scope.graph()->AddEdge(next_iteration_j.node(), 0, merge_j.output.node(), + 1); + scope.graph()->AddEdge(next_iteration_k.node(), 0, merge_k.output.node(), + 1); + + TF_EXPECT_OK(scope.ToGraph(&graph)); + } + + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + + NameAttrList outer_cond_fn, outer_body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &outer_cond_fn, &outer_body_fn)); + + // Outer graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto three = ops::Const(scope.WithOpName("three"), 3); + auto y = ops::Add(scope.WithOpName("y"), x, three); + + auto var = ops::VarHandleOp(scope.WithOpName("Variable"), DT_INT32, + TensorShape({})); + + auto zero = ops::Const(scope.WithOpName("outer/Const"), 0); + + auto while_op = ops::XlaWhile(scope.WithOpName("outer/LoopCond"), + std::initializer_list{zero, y, x, var}, + outer_cond_fn, outer_body_fn); + auto sink = ops::Identity(scope.WithOpName("sink"), while_op[0]); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // Outer condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto ten = ops::Const( + scope.WithOpName("outer/Less/y").WithControlDependencies(arg0.output), + 10); + auto less = ops::Less(scope.WithOpName("outer/Less_i"), arg0, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(outer_cond_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // Outer body graph. + NameAttrList inner_cond_fn, inner_body_fn; + { + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(outer_body_fn.name(), library, &result)); + + // Find the inner condition and body names. + TF_EXPECT_OK( + FindWhileCondAndBody(result.gdef, &inner_cond_fn, &inner_body_fn)); + + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto identity_i = ops::Identity(scope.WithOpName("outer/Identity"), arg0); + auto one_j = ops::Const( + scope.WithOpName("outer/j").WithControlDependencies(identity_i), 1); + auto while_op = + ops::XlaWhile(scope.WithOpName("outer/LoopCond_1"), + std::initializer_list{one_j, arg1, arg2, arg3}, + inner_cond_fn, inner_body_fn); + + auto one_outer = ops::Const( + scope.WithOpName("outer/add/y").WithControlDependencies(identity_i), 1); + auto add_i = + ops::Add(scope.WithOpName("outer/add") + .WithControlDependencies(gtl::ArraySlice{ + while_op[0].op(), while_op[1].op()}), + identity_i, one_outer); + + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_i, 0); + auto retval1 = ops::_Retval(scope.WithOpName("_retval1_RetVal"), arg1, 1); + auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // Inner condition graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto five = ops::Const( + scope.WithOpName("outer/inner/Five").WithControlDependencies(arg0), 5); + auto less_j = ops::Less(scope.WithOpName("outer/inner/Less_j"), arg0, five); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less_j, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(inner_cond_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // Inner body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg0 = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto arg1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto arg3 = ops::_Arg(scope.WithOpName("_arg3"), DT_RESOURCE, 3); + + auto identity_j = + ops::Identity(scope.WithOpName("outer/inner/Identity_j"), arg0); + auto identity_k = + ops::Identity(scope.WithOpName("outer/inner/Identity_k"), arg1); + + auto mul_jk = + ops::Mul(scope.WithOpName("outer/inner/mul"), identity_j, identity_k); + auto add_jkx = ops::Add(scope.WithOpName("outer/inner/add"), mul_jk, arg2); + auto assign = ops::AssignAddVariableOp( + scope.WithOpName("outer/inner/assign_add"), arg3, add_jkx); + + auto one = + ops::Const(scope.WithOpName("outer/inner/One") + .WithControlDependencies( + gtl::ArraySlice{assign.operation}), + 1); + auto add_j = + ops::Add(scope.WithOpName("outer/inner/add_j"), identity_j, one); + + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add_j, 0); + auto retval1 = + ops::_Retval(scope.WithOpName("_retval1_RetVal"), identity_k, 1); + auto retval2 = ops::_Retval(scope.WithOpName("_retval2_RetVal"), arg2, 2); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK( + InstantiateFunctionForTest(inner_body_fn.name(), library, &result)); + + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32, DT_RESOURCE}), + result.arg_types); + EXPECT_EQ((DataTypeVector{DT_INT32, DT_INT32, DT_INT32}), result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index a434c746809..96b4fdfec6d 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -68,6 +68,7 @@ tf_kernel_library( "reduction_ops.h", ], deps = [ + ":while_op", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/xla:literal_util", @@ -91,6 +92,21 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "while_op", + srcs = ["while_op.cc"], + hdrs = ["while_op.h"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/ops:functional_ops", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:tensorflow_opensource", + ], +) + # Kernels that only work on CPU, because they use XLA custom calls. # Only link this when using the CPU backend for XLA. # diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 620fc844378..6ad72c6219e 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -51,13 +51,26 @@ class ArgOp : public XlaOpKernel { XlaContext& xc = XlaContext::Get(ctx); const XlaContext::Argument& arg = xc.args()[index_]; - if (arg.is_variable) { + if (arg.is_resource) { + XlaResource::Kind kind; + switch (arg.kind) { + case XlaCompiler::Argument::kVariable: + kind = XlaResource::kVariable; + break; + case XlaCompiler::Argument::kTensorArray: + kind = XlaResource::kTensorArray; + break; + default: + CHECK(false); + } + // TODO(phawkins): this code assumes that variables do not alias. - XlaVariable* var; - OP_REQUIRES_OK(ctx, xc.CreateVariable(index_, arg.name, arg.value.type, - arg.value.handle, &var)); - var->tensor_array_size = arg.tensor_array_size; - ctx->SetVariableOutput(0, var); + XlaResource* resource; + OP_REQUIRES_OK(ctx, + xc.CreateResource(kind, index_, arg.name, arg.value.type, + arg.value.handle, &resource)); + resource->tensor_array_size = arg.tensor_array_size; + ctx->SetResourceOutput(0, resource); } else if (arg.value.is_constant) { ctx->SetConstantOutput(0, arg.value.constant_value); } else { diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 8642cbf2a92..21d3e64872e 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -127,8 +127,8 @@ void BatchToSpace(XlaOpKernelContext* ctx, std::vector end_indices = reshaped_permuted_shape; std::vector strides(input_rank, 1); for (int i = 0; i < block_rank; ++i) { - int64 crop_start = xla::LiteralUtil::Get(crops, {i, 0}); - int64 crop_end = xla::LiteralUtil::Get(crops, {i, 1}); + int64 crop_start = crops.Get({i, 0}); + int64 crop_end = crops.Get({i, 1}); OP_REQUIRES(ctx, crop_start >= 0 && crop_end >= 0, errors::InvalidArgument("Crops must be non-negative")); start_indices[1 + i] = crop_start; diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index b0fee5e4bca..bc2cd31230d 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -55,7 +55,7 @@ class BCastGradArgsOp : public XlaOpKernel { BCast::Vec vec; for (int64 i = 0; i < in_shape.num_elements(); ++i) { - vec.push_back(xla::LiteralUtil::Get(literal, {i})); + vec.push_back(literal.Get({i})); } shapes.push_back(vec); } diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index e2eacb3839d..73a4740e29a 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -52,7 +52,7 @@ class ConcatBaseOp : public XlaOpKernel { xla::Literal literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(axis_index_, &literal)); // TODO(annarev): add a helper to support int64 input. - const int32 concat_dim = xla::LiteralUtil::Get(literal, {}); + const int32 concat_dim = literal.Get({}); std::vector values; std::vector shapes; @@ -163,7 +163,7 @@ class ConcatOffsetOp : public XlaOpKernel { xla::Literal concat_dim_literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &concat_dim_literal)); - const int64 cdim = xla::LiteralUtil::Get(concat_dim_literal, {}); + const int64 cdim = concat_dim_literal.Get({}); VLOG(1) << "ConcatOffset " << cdim << "," << dims; int32 axis = cdim < 0 ? cdim + dims : cdim; @@ -185,12 +185,10 @@ class ConcatOffsetOp : public XlaOpKernel { for (int64 j = 0; j < dims; ++j) { if (j == axis) { out_vec(j) = offset; - offset += xla::LiteralUtil::Get(inp_literal, {j}); + offset += inp_literal.Get({j}); } else { - const int32 inp0_element = - xla::LiteralUtil::Get(inp0_literal, {j}); - const int32 inp_element = - xla::LiteralUtil::Get(inp_literal, {j}); + const int32 inp0_element = inp0_literal.Get({j}); + const int32 inp_element = inp_literal.Get({j}); OP_REQUIRES( ctx, (inp0_element == inp_element), errors::InvalidArgument("input[", i, ",", j, "] mismatch: ", diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 107c673f4a7..0330e34c98d 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -103,8 +103,7 @@ class DynamicStitchOp : public XlaOpKernel { int max_index = -1; for (int input_num = 0; input_num < indices.size(); input_num++) { for (int i = 0; i < indices[input_num].shape().dimensions(0); ++i) { - max_index = std::max( - max_index, xla::LiteralUtil::Get(indices[input_num], {i})); + max_index = std::max(max_index, indices[input_num].Get({i})); } } int number_of_indices = max_index + 1; @@ -118,7 +117,7 @@ class DynamicStitchOp : public XlaOpKernel { int index_used_count = 0; for (int input_num = 0; input_num < indices.size(); input_num++) { for (int i = 0; i < indices[input_num].shape().dimensions(0); ++i) { - int index = xla::LiteralUtil::Get(indices[input_num], {i}); + int index = indices[input_num].Get({i}); src_input_vector[index] = input_num; src_slice_vector[index] = i; if (!src_index_used[index]) { diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index 1e1d2a1b4b3..9e090fe01cb 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -52,7 +52,7 @@ class FillOp : public XlaOpKernel { std::vector broadcast; broadcast.reserve(dims_literal.shape().dimensions(0)); for (int i = 0; i < dims_literal.shape().dimensions(0); ++i) { - broadcast.push_back(xla::LiteralUtil::Get(dims_literal, {i})); + broadcast.push_back(dims_literal.Get({i})); } // Look up the value input, reshaping to a scalar if it was a // 'legacy' scalar (secretly a vector). diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 49eadaf9d1f..3c1cdef5f80 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -66,10 +66,10 @@ class GatherOp : public XlaOpKernel { std::vector args; args.push_back(tc.GetOrCreateRuntimeContextParameter()); args.push_back(b.ConstantLiteral( - *xla::LiteralUtil::CreateR0(indices_shape.num_elements()))); + *xla::Literal::CreateR0(indices_shape.num_elements()))); args.push_back(b.ConstantLiteral( - *xla::LiteralUtil::CreateR0(params_shape.dim_size(0)))); - args.push_back(b.ConstantLiteral(*xla::LiteralUtil::CreateR0( + *xla::Literal::CreateR0(params_shape.dim_size(0)))); + args.push_back(b.ConstantLiteral(*xla::Literal::CreateR0( params_shape.num_elements() / params_shape.dim_size(0)))); args.push_back(ctx->Input(0)); args.push_back(ctx->Input(1)); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index df002dddd04..6be66cf66ec 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -69,7 +69,7 @@ class ArgMaxOp : public XlaOpKernel { // XLA op would have the same requirement. xla::Literal literal; OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal)); - const int32 dim = xla::LiteralUtil::Get(literal, {}); + const int32 dim = literal.Get({}); OP_REQUIRES(ctx, dim >= 0, errors::InvalidArgument("dim must be >= 0")); OP_REQUIRES( ctx, dim < input_shape.dims(), @@ -97,14 +97,13 @@ class ArgMaxOp : public XlaOpKernel { std::vector args; args.push_back(ctx->Input(0)); args.push_back(b.ConstantLiteral( - *xla::LiteralUtil::CreateR1(input_shape.dim_sizes()))); + *xla::Literal::CreateR1(input_shape.dim_sizes()))); if (input_shape.dims() > 1) { // Don't bother passing the output shape and dim for the 1d case, since // the shape is always a scalar and the dim is always 0. args.push_back(b.ConstantLiteral( - *xla::LiteralUtil::CreateR1(output_shape.dim_sizes()))); - args.push_back( - b.ConstantLiteral(*xla::LiteralUtil::CreateR0(dim))); + *xla::Literal::CreateR1(output_shape.dim_sizes()))); + args.push_back(b.ConstantLiteral(*xla::Literal::CreateR0(dim))); } xla::Shape xla_shape = diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index 22476f4a0c5..cc13ab02034 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -60,8 +60,8 @@ class PadOp : public XlaOpKernel { xla::PaddingConfig config; for (int i = 0; i < fixed_dims; ++i) { auto* dim = config.add_dimensions(); - int before = xla::LiteralUtil::Get(pad_literal, {i, 0}); - int after = xla::LiteralUtil::Get(pad_literal, {i, 1}); + int before = pad_literal.Get({i, 0}); + int after = pad_literal.Get({i, 1}); OP_REQUIRES(ctx, before >= 0 && after >= 0, errors::InvalidArgument("Paddings must be non-negative: ", before, " ", after)); diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 518a9372c4f..dae2eb9d2a9 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -63,7 +63,7 @@ class MinOp : public XlaReductionOp { xla::ComputationBuilder* builder) override { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); - return builder->ConstantLiteral(xla::LiteralUtil::MaxValue(type)); + return builder->ConstantLiteral(xla::Literal::MaxValue(type)); } void BuildReducer(xla::ComputationBuilder* builder, @@ -83,7 +83,7 @@ class MaxOp : public XlaReductionOp { xla::ComputationBuilder* builder) override { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(input_type(0), &type)); - return builder->ConstantLiteral(xla::LiteralUtil::MinValue(type)); + return builder->ConstantLiteral(xla::Literal::MinValue(type)); } void BuildReducer(xla::ComputationBuilder* builder, diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc index 8798c80ad53..4b5d09eb9fd 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops_common.cc @@ -66,13 +66,13 @@ void XlaReductionOp::Compile(XlaOpKernelContext* ctx) { 1, {axes_tensor_shape.num_elements()}, &axes_literal)); VLOG(1) << "data shape: " << data_shape.DebugString(); - VLOG(1) << "axes : " << xla::LiteralUtil::ToString(axes_literal); + VLOG(1) << "axes : " << axes_literal.ToString(); gtl::InlinedVector bitmap(data_shape.dims(), false); std::vector xla_axes; int64 num_elements_reduced = 1LL; for (int64 i = 0; i < axes_tensor_shape.num_elements(); ++i) { - int32 index = xla::LiteralUtil::Get(axes_literal, {i}); + int32 index = axes_literal.Get({i}); OP_REQUIRES(ctx, !(index < -data_shape.dims() || index >= data_shape.dims()), errors::InvalidArgument("Invalid reduction dimension (", index, diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index df542350b44..5952e752724 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -50,7 +50,7 @@ class ReshapeOp : public XlaOpKernel { int64 product = 1; int unknown_index = -1; for (int d = 0; d < num_dims; ++d) { - const int32 size = xla::LiteralUtil::Get(literal, {d}); + const int32 size = literal.Get({d}); if (size == -1) { OP_REQUIRES( ctx, unknown_index == -1, diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index 5b6fa64fa82..c2b0e1bb4c1 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -32,7 +32,7 @@ template Status GetValue(int index, XlaOpKernelContext* ctx, T* value) { xla::Literal literal; TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal)); - *value = xla::LiteralUtil::Get(literal, {}); + *value = literal.Get({}); return Status::OK(); } @@ -41,10 +41,10 @@ Status GetIntValue(int index, XlaOpKernelContext* ctx, int64* value) { TF_RETURN_IF_ERROR(ctx->ConstantInput(index, &literal)); switch (literal.shape().element_type()) { case xla::S32: - *value = xla::LiteralUtil::Get(literal, {}); + *value = literal.Get({}); break; case xla::S64: - *value = xla::LiteralUtil::Get(literal, {}); + *value = literal.Get({}); break; default: return errors::InvalidArgument("Invalid argument type for argument", @@ -58,9 +58,9 @@ template Status CreateRangeTensor(const xla::Literal& start_literal, const xla::Literal& limit_literal, const xla::Literal& delta_literal, Tensor* output) { - T start = xla::LiteralUtil::Get(start_literal, {}); - T limit = xla::LiteralUtil::Get(limit_literal, {}); - T delta = xla::LiteralUtil::Get(delta_literal, {}); + T start = start_literal.Get({}); + T limit = limit_literal.Get({}); + T delta = delta_literal.Get({}); if (delta == 0) { return errors::InvalidArgument("Requires delta != 0: ", delta); diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index f15b354cb26..83a87f19a71 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -56,8 +56,8 @@ void SpaceToBatch(XlaOpKernelContext* ctx, padding_config.add_dimensions(); // Don't pad the batch dimension. for (int i = 0; i < block_rank; ++i) { auto* dim = padding_config.add_dimensions(); - int64 pad_start = xla::LiteralUtil::Get(paddings, {i, 0}); - int64 pad_end = xla::LiteralUtil::Get(paddings, {i, 1}); + int64 pad_start = paddings.Get({i, 0}); + int64 pad_end = paddings.Get({i, 1}); OP_REQUIRES(ctx, pad_start >= 0 && pad_end >= 0, errors::InvalidArgument("Paddings must be non-negative")); dim->set_edge_padding_low(pad_start); diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 42bde900422..44ee81461e5 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -39,7 +39,7 @@ class SplitOp : public XlaOpKernel { int32 split_dim; if (index_shape.dims() == 0) { - split_dim = xla::LiteralUtil::Get(literal_index, {}); + split_dim = literal_index.Get({}); } else { OP_REQUIRES( ctx, index_shape.dims() == 1, @@ -49,7 +49,7 @@ class SplitOp : public XlaOpKernel { ctx, index_shape.dim_size(0) == 1, errors::InvalidArgument("split_index input to Split Op must be a " "scalar or a vector with 1 element")); - split_dim = xla::LiteralUtil::Get(literal_index, {0}); + split_dim = literal_index.Get({0}); } const int32 num_split = num_outputs(); const TensorShape input_shape = ctx->InputShape(1); @@ -115,7 +115,7 @@ class SplitVOp : public XlaOpKernel { OP_REQUIRES(ctx, index_shape.dims() == 0, errors::InvalidArgument("split_dim input to Split Op must be a " "scalar")); - split_dim = xla::LiteralUtil::Get(literal_index, {}); + split_dim = literal_index.Get({}); xla::ComputationDataHandle input = ctx->Input(0); const TensorShape input_shape = ctx->InputShape(0); @@ -152,7 +152,7 @@ class SplitVOp : public XlaOpKernel { for (int i = 0; i < num_split; ++i) { int slice_size; - slice_size = xla::LiteralUtil::Get(split_size_literal, {i}); + slice_size = split_size_literal.Get({i}); if (slice_size == -1) { OP_REQUIRES( ctx, neg_one_dim == -1, diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index deee7dd44db..9367c1ef22c 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -41,32 +41,36 @@ namespace { // Since the element shape is not always provided to the TensorArrayV3 operator, // we must support lazily initialization of the TensorArray at the time of the // first write. -// If a TensorArray `var` has not been initialized, constructs storage for the -// TensorArray with elements of `elem_shape`. For both initialized and +// If a TensorArray `resource` has not been initialized, constructs storage for +// the TensorArray with elements of `elem_shape`. For both initialized and // uninitialized TensorArrays, checks that the tensor has a type compatible with // 'dtype' and shape compatible with 'elem_shape'. Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, - XlaVariable* var, DataType dtype, + XlaResource* resource, DataType dtype, const TensorShape& elem_shape) { - if (var->type != dtype) { + if (resource->kind != XlaResource::kTensorArray) { + return errors::InvalidArgument("Unexpected non-TensorArray resource"); + } + + if (resource->type != dtype) { return errors::InvalidArgument( - "TensorArray dtype is ", DataTypeString(var->type), + "TensorArray dtype is ", DataTypeString(resource->type), " but op has dtype ", DataTypeString(dtype), "."); } - TF_RET_CHECK(var->tensor_array_size >= 0) - << var->name << " size " << var->tensor_array_size; + TF_RET_CHECK(resource->tensor_array_size >= 0) + << resource->name << " size " << resource->tensor_array_size; TensorShape ta_shape; - ta_shape.AddDim(var->tensor_array_size); + ta_shape.AddDim(resource->tensor_array_size); ta_shape.AppendShape(elem_shape); - if (var->value.handle() == 0) { + if (resource->value.handle() == 0) { // TensorArray has not been initialized. - xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, var->type); - var->value = builder->Broadcast(zero, ta_shape.dim_sizes()); + xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type); + resource->value = builder->Broadcast(zero, ta_shape.dim_sizes()); } else { // Checks the elem_shape matches the TensorArray shape. - auto shape_or_status = builder->GetShape(var->value); + auto shape_or_status = builder->GetShape(resource->value); if (!shape_or_status.ok()) { return shape_or_status.status(); } @@ -80,6 +84,44 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, return Status::OK(); } +// Checks that the TensorArray 'resource' has been initialized, and has type +// 'dtype'. Sets 'shape' to the shape +Status CheckTensorArrayIsInitialized(const string& op_name, + const XlaResource* resource, + DataType dtype) { + if (resource->kind != XlaResource::kTensorArray) { + return errors::InvalidArgument( + "Unexpected non-TensorArray resource passed " + "to ", + op_name); + } + if (resource->value.handle() == 0) { + return errors::InvalidArgument("Uninitialized TensorArray passed to ", + op_name); + } + if (resource->type != dtype) { + return errors::InvalidArgument( + "TensorArray dtype is ", DataTypeString(resource->type), + " but op has dtype ", DataTypeString(dtype), "."); + } + + return Status::OK(); +} + +Status GetTensorArrayShape(const XlaResource* resource, + xla::ComputationBuilder* builder, + TensorShape* shape) { + auto shape_or_status = builder->GetShape(resource->value); + if (!shape_or_status.ok()) { + return shape_or_status.status(); + } + *shape = XLAShapeToTensorShape(*shape_or_status.ValueOrDie()); + if (shape->dims() < 1) { + return errors::InvalidArgument("TensorArray rank must be >= 1"); + } + return Status::OK(); +} + // Pads 'x' with 'count' zero indices. 'x' must have 1 element. xla::ComputationDataHandle PadIndexWithZeros( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, @@ -125,7 +167,6 @@ class TensorArrayOp : public XlaOpKernel { errors::InvalidArgument("TensorArray size must be >= 0")); xla::ComputationBuilder* b = ctx->builder(); - b->set_die_immediately_on_error(true); // Initializes the TensorArray value if we know the element shape. // Otherwise, defer initialization to the first write. @@ -141,12 +182,13 @@ class TensorArrayOp : public XlaOpKernel { } XlaContext& xc = XlaContext::Get(ctx); - XlaVariable* var; + XlaResource* var; string name = strings::StrCat("TensorArray: ", tensor_array_name_); - OP_REQUIRES_OK(ctx, - xc.CreateVariable(-1, std::move(name), dtype_, value, &var)); + OP_REQUIRES_OK( + ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), + dtype_, value, &var)); var->tensor_array_size = size; - ctx->SetVariableOutput(0, var); + ctx->SetResourceOutput(0, var); ctx->SetConstantOutput(1, Tensor(DT_FLOAT)); } @@ -173,11 +215,12 @@ class TensorArrayWriteOp : public XlaOpKernel { // Initializes the TensorArray, if the element shape was not known at // construction time. - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + OP_REQUIRES_OK(ctx, + MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = var->value; + xla::ComputationDataHandle ta = resource->value; xla::ComputationDataHandle index = ctx->Input(1); xla::ComputationDataHandle value = ctx->Input(2); @@ -191,7 +234,7 @@ class TensorArrayWriteOp : public XlaOpKernel { xla::ComputationDataHandle written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, written)); + resource->value = written; ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); } @@ -210,20 +253,17 @@ class TensorArrayReadOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - DataType ta_type; - TensorShape ta_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); - OP_REQUIRES(ctx, ta_type == dtype_, - errors::InvalidArgument( - "TensorArray dtype is ", DataTypeString(ta_type), - " but Op requested dtype ", DataTypeString(dtype_), ".")); - OP_REQUIRES(ctx, ta_shape.dims() >= 1, - errors::InvalidArgument("TensorArray rank must be >= 1")); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle ta; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + + OP_REQUIRES_OK(ctx, + CheckTensorArrayIsInitialized(name(), resource, dtype_)); + TensorShape ta_shape; + OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); + + xla::ComputationDataHandle ta = resource->value; xla::ComputationDataHandle index = ctx->Input(1); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. @@ -255,13 +295,15 @@ class TensorArrayGatherOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - DataType ta_type; + xla::ComputationBuilder* b = ctx->builder(); + + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + + OP_REQUIRES_OK(ctx, + CheckTensorArrayIsInitialized(name(), resource, dtype_)); TensorShape ta_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); - OP_REQUIRES(ctx, ta_type == dtype_, - errors::InvalidArgument("TensorArray type mismatch")); - OP_REQUIRES(ctx, ta_shape.dims() >= 1, - errors::InvalidArgument("TensorArray rank must be >= 1")); + OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); const TensorShape indices_shape = ctx->InputShape(1); OP_REQUIRES(ctx, indices_shape.dims() >= 1, @@ -269,10 +311,7 @@ class TensorArrayGatherOp : public XlaOpKernel { const int num_indices = indices_shape.dim_size(0); auto indices = ctx->Input(1); - xla::ComputationBuilder* b = ctx->builder(); - - xla::ComputationDataHandle ta; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + xla::ComputationDataHandle ta = resource->value; // For each index in `indices`, add the corresponding slice to `slices`. std::vector slices(num_indices); @@ -320,11 +359,12 @@ class TensorArrayScatterOp : public XlaOpKernel { const TensorShape value_shape = ctx->InputShape(2); - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); TensorShape elem_shape = value_shape; elem_shape.RemoveDim(0); - OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); + OP_REQUIRES_OK(ctx, + MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); const TensorShape indices_shape = ctx->InputShape(1); OP_REQUIRES(ctx, indices_shape.dims() >= 1, @@ -332,7 +372,7 @@ class TensorArrayScatterOp : public XlaOpKernel { const int num_indices = indices_shape.dim_size(0); const xla::ComputationDataHandle indices = ctx->Input(1); - xla::ComputationDataHandle ta = var->value; + xla::ComputationDataHandle ta = resource->value; const xla::ComputationDataHandle value = ctx->Input(2); auto slice_dims = value_shape.dim_sizes(); @@ -357,7 +397,7 @@ class TensorArrayScatterOp : public XlaOpKernel { ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); } - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta)); + resource->value = ta; ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); } @@ -376,18 +416,17 @@ class TensorArrayConcatOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - DataType ta_type; - TensorShape ta_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); - OP_REQUIRES(ctx, ta_type == dtype_, - errors::InvalidArgument("TensorArray type mismatch")); - OP_REQUIRES(ctx, ta_shape.dims() >= 1, - errors::InvalidArgument("TensorArray rank must be >= 1")); - xla::ComputationBuilder* b = ctx->builder(); - xla::ComputationDataHandle ta; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &ta)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + + OP_REQUIRES_OK(ctx, + CheckTensorArrayIsInitialized(name(), resource, dtype_)); + TensorShape ta_shape; + OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); + + xla::ComputationDataHandle ta = resource->value; auto ta_dims = ta_shape.dim_sizes(); std::vector shape(ta_dims.begin() + 1, ta_dims.end()); @@ -438,19 +477,20 @@ class TensorArraySplitOp : public XlaOpKernel { elem_shape.set_dim(0, length); xla::ComputationBuilder* b = ctx->builder(); - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, var, dtype_, elem_shape)); - xla::ComputationDataHandle ta = var->value; + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); + OP_REQUIRES_OK(ctx, + MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); + xla::ComputationDataHandle ta = resource->value; TensorShape ta_shape; - ta_shape.AddDim(var->tensor_array_size); + ta_shape.AddDim(resource->tensor_array_size); ta_shape.AppendShape(elem_shape); - OP_REQUIRES(ctx, lengths.size() == var->tensor_array_size, + OP_REQUIRES(ctx, lengths.size() == resource->tensor_array_size, errors::InvalidArgument( "TensorArray's size is not equal to the size of lengths (", - lengths.size(), " vs. ", var->tensor_array_size, ")")); + lengths.size(), " vs. ", resource->tensor_array_size, ")")); const xla::ComputationDataHandle value = ctx->Input(1); @@ -459,8 +499,7 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - ta = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, ta)); + resource->value = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())); ctx->SetConstantOutput(0, Tensor(DT_FLOAT)); } @@ -478,8 +517,8 @@ class TensorArraySizeOp : public XlaOpKernel { explicit TensorArraySizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + XlaResource* var; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var)); Tensor size_tensor(DT_INT32, {}); size_tensor.scalar()() = static_cast(var->tensor_array_size); ctx->SetConstantOutput(0, size_tensor); @@ -500,31 +539,31 @@ class TensorArrayGradOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* b = ctx->builder(); - XlaVariable* var; - OP_REQUIRES_OK(ctx, ctx->GetVariableInput(0, &var)); + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); - DataType ta_type; + OP_REQUIRES_OK( + ctx, CheckTensorArrayIsInitialized(name(), resource, resource->type)); TensorShape ta_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &ta_type, &ta_shape)); - OP_REQUIRES(ctx, ta_shape.dims() >= 1, - errors::InvalidArgument("TensorArray rank must be >= 1")); + OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); // Finds or looks up the corresponding gradient TensorArray, which stores // gradients computed during backpropagation. - XlaVariable*& gradient = var->tensor_array_gradient[source_]; + XlaResource*& gradient = resource->tensor_array_gradient[source_]; if (!gradient) { - xla::ComputationDataHandle zero = XlaHelpers::Zero(b, ta_type); + xla::ComputationDataHandle zero = XlaHelpers::Zero(b, resource->type); xla::ComputationDataHandle value = b->Broadcast(zero, ta_shape.dim_sizes()); XlaContext& xc = XlaContext::Get(ctx); - string name = strings::StrCat("TensorArrayGrad: ", var->name); - OP_REQUIRES_OK(ctx, xc.CreateVariable(-1, std::move(name), var->type, - value, &gradient)); - gradient->tensor_array_size = var->tensor_array_size; + string name = strings::StrCat("TensorArrayGrad: ", resource->name); + OP_REQUIRES_OK( + ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), + resource->type, value, &gradient)); + gradient->tensor_array_size = resource->tensor_array_size; } - ctx->SetVariableOutput(0, gradient); + ctx->SetResourceOutput(0, gradient); ctx->SetConstantOutput(1, Tensor(DT_FLOAT)); } @@ -536,5 +575,19 @@ class TensorArrayGradOp : public XlaOpKernel { REGISTER_XLA_OP(Name("TensorArrayGradV3"), TensorArrayGradOp); +class TensorArrayCloseOp : public XlaOpKernel { + public: + explicit TensorArrayCloseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + // Do nothing; XLA handles resource management. + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayCloseOp); +}; + +REGISTER_XLA_OP(Name("TensorArrayCloseV3"), TensorArrayCloseOp); + } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 4cc2eb8f877..9ee6bd89250 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -68,7 +68,7 @@ class TileOp : public XlaOpKernel { bool all_multiples_are_one = true; bool one_dimension_is_broadcasted_without_multiple = true; for (int i = 0; i < input_dims; ++i) { - int multiple = xla::LiteralUtil::Get(literal, {i}); + int multiple = literal.Get({i}); OP_REQUIRES(ctx, multiple, errors::InvalidArgument("Expected multiples[", i, "] >= 0, but got ", multiple)); diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index abe4949f5db..07ca5961504 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -44,6 +44,7 @@ namespace { // Return x if x>0, otherwise -x. XLAJIT_MAKE_UNARY(Abs, b->Abs(x)); XLAJIT_MAKE_UNARY(Ceil, b->Ceil(x)); +XLAJIT_MAKE_UNARY(Cos, b->Cos(x)); XLAJIT_MAKE_UNARY(Exp, b->Exp(x)); XLAJIT_MAKE_UNARY(Floor, b->Floor(x)); // Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0. diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc new file mode 100644 index 00000000000..0caa9c5f378 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -0,0 +1,265 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/while_op.h" + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +namespace { + +// Builds XlaCompiler argument descriptions `args` from `ctx`. +Status MakeXlaCompilerArgumentsFromInputs( + XlaOpKernelContext* ctx, std::vector* args, + bool* has_uninitialized_vars) { + VLOG(2) << "Num inputs " << ctx->num_inputs(); + args->resize(ctx->num_inputs()); + *has_uninitialized_vars = false; + for (int i = 0; i < ctx->num_inputs(); ++i) { + VLOG(2) << " Input " << i + << " type: " << DataTypeString(ctx->input_type(i)) + << " shape: " << ctx->InputShape(i).DebugString(); + XlaCompiler::Argument& arg = (*args)[i]; + DataType type = ctx->input_type(i); + // When reading a resource input, use the type and shape of the resource's + // current value. + if (type == DT_RESOURCE) { + XlaResource* resource; + TF_RETURN_IF_ERROR(ctx->GetResourceInput(i, &resource)); + + arg.initialized = resource->value.handle() > 0; + switch (resource->kind) { + case XlaResource::kVariable: + arg.kind = XlaCompiler::Argument::kVariable; + break; + case XlaResource::kTensorArray: + arg.kind = XlaCompiler::Argument::kTensorArray; + break; + case XlaResource::kInvalid: + CHECK(false); + } + arg.type = resource->type; + if (arg.initialized) { + auto shape = ctx->builder()->GetShape(resource->value); + TF_RETURN_IF_ERROR(shape.status()); + arg.shape = XLAShapeToTensorShape(*shape.ValueOrDie()); + } else { + *has_uninitialized_vars = true; + } + arg.tensor_array_size = resource->tensor_array_size; + arg.name = resource->name; + // TODO(phawkins): propagate TensorArray gradients into loops. + VLOG(2) << " resource " << resource->name + << " type: " << DataTypeString(arg.type) + << " shape: " << arg.shape.DebugString() + << " initialized: " << arg.initialized; + + } else { + arg.kind = XlaCompiler::Argument::kParameter; + arg.type = ctx->input_type(i); + arg.shape = ctx->InputShape(i); + } + } + return Status::OK(); +} + +} // anonymous namespace + +XlaWhileOp::XlaWhileOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + const NameAttrList* name_attr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("cond", &name_attr)); + cond_name_attr_ = *name_attr; + OP_REQUIRES_OK(ctx, ctx->GetAttr("body", &name_attr)); + body_name_attr_ = *name_attr; +} + +void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { + VLOG(1) << "WhileOp::Compile"; + + std::vector arguments; + bool has_uninitialized_vars; + OP_REQUIRES_OK(ctx, MakeXlaCompilerArgumentsFromInputs( + ctx, &arguments, &has_uninitialized_vars)); + + const bool use_tuple_arg = (arguments.size() != 1); + + xla::ComputationBuilder* builder = ctx->builder(); + XlaCompiler* compiler = ctx->compiler(); + + VLOG(1) << "Compiling body"; + + // All resource that are inputs to the loop's body must also be + // present as loop body outputs; the signature of the loop's input and + // output must match. We ensure this by asking the compiler to include the + // current values of all resources, even if they haven't been updated by the + // computation. + // TODO(phawkins): consider adding loop-invariant inputs to XLA's While() + // operator. + XlaCompiler::CompileOptions body_options; + body_options.use_tuple_arg = use_tuple_arg; + body_options.return_updated_values_for_all_resources = true; + XlaCompiler::CompilationResult body; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_, + arguments, &body)); + + // We must use a static shape for parameters to an XLA compilation. However, + // we may not know the shape of a TensorArray if it is first written inside + // the loop. Ideally we would require the user to provide a static shape, + // but this is not always easy. + // So if uninitialized resource are used by the loop body, we compile the + // body function twice: + // 1) once with uninitialized resource inputs. We discard the computation + // but we assume resource shapes reach a fixpoint after one iteration. + // So we can use the output shapes of the resource as the "true" shapes. + // 2) again with the "correct" input shapes determined by (1). + if (has_uninitialized_vars) { + // Initializes any uninitialized resource with zero values of the + // shape determined by the first compilation. + for (int i = 0; i < body.resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& update = body.resource_updates[i]; + XlaCompiler::Argument& arg = arguments[update.input_index]; + if (!arg.initialized) { + arg.initialized = true; + arg.shape = update.shape; + + XlaResource* resource; + OP_REQUIRES_OK(ctx, + ctx->GetResourceInput(update.input_index, &resource)); + + xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, arg.type); + resource->value = builder->Broadcast(zero, update.shape.dim_sizes()); + } + } + // Recompile the body with the "correct" shapes. + body = {}; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_, + arguments, &body)); + } + + VLOG(1) << "Compiling condition"; + + XlaCompiler::CompileOptions cond_options; + cond_options.use_tuple_arg = use_tuple_arg; + XlaCompiler::CompilationResult cond; + OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_, + arguments, &cond)); + + xla::Shape body_input_shape, cond_input_shape; + if (use_tuple_arg) { + body_input_shape = xla::ShapeUtil::MakeTupleShape(body.xla_input_shapes); + cond_input_shape = xla::ShapeUtil::MakeTupleShape(cond.xla_input_shapes); + } else { + CHECK(!body.xla_input_shapes.empty()); + body_input_shape = body.xla_input_shapes[0]; + CHECK(!body.xla_input_shapes.empty()); + cond_input_shape = cond.xla_input_shapes[0]; + } + + VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape) + << " -> " << xla::ShapeUtil::HumanString(body.xla_output_shape); + VLOG(2) << "Cond shape: " << xla::ShapeUtil::HumanString(cond_input_shape) + << " -> " << xla::ShapeUtil::HumanString(cond.xla_output_shape); + + OP_REQUIRES(ctx, + xla::ShapeUtil::Compatible(body_input_shape, cond_input_shape), + errors::InvalidArgument( + "Input shapes of loop body and condition do not match: ", + xla::ShapeUtil::HumanString(body_input_shape), " vs. ", + xla::ShapeUtil::HumanString(cond_input_shape))); + OP_REQUIRES( + ctx, xla::ShapeUtil::Compatible(body_input_shape, body.xla_output_shape), + errors::InvalidArgument( + "Input and output shapes of loop body do not match: ", + xla::ShapeUtil::HumanString(body_input_shape), " vs. ", + xla::ShapeUtil::HumanString(body.xla_output_shape))); + + xla::ComputationDataHandle data; + + int num_inputs = body.input_mapping.size(); + + std::vector inputs(num_inputs); + for (int i = 0; i < num_inputs; ++i) { + int input_num = body.input_mapping[i]; + if (ctx->input_type(input_num) == DT_RESOURCE) { + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource)); + inputs[i] = resource->value; + } else { + inputs[i] = ctx->Input(i); + } + } + + xla::ComputationDataHandle init; + if (use_tuple_arg) { + init = builder->Tuple(inputs); + } else { + init = inputs[0]; + } + + VLOG(1) << "Building while loop"; + + xla::ComputationDataHandle while_result = + builder->While(*cond.computation, *body.computation, init); + + auto get_loop_output = [&](int i) { + if (use_tuple_arg) { + return builder->GetTupleElement(while_result, i); + } else { + return while_result; + } + }; + + // Sets non-variable outputs. + for (int i = 0; i < ctx->num_outputs(); ++i) { + if (ctx->input_type(i) != DT_RESOURCE) { + ctx->SetOutput(body.input_mapping[i], get_loop_output(i)); + } + } + + // Updates the values of any resource variables modified by the loop. + for (int i = 0; i < body.resource_updates.size(); ++i) { + const XlaCompiler::ResourceUpdate& update = body.resource_updates[i]; + XlaResource* resource; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(update.input_index, &resource)); + if (update.modified) { + int pos = body.outputs.size() + i; + resource->value = get_loop_output(pos); + } + VLOG(2) << "Loop-carried variable: pos: " << update.input_index + << " name: " << resource->name << " modified: " << update.modified + << " type: " << DataTypeString(update.type) + << " shape: " << update.shape.DebugString(); + // Copies the identity of the resource variable from input to output + // unchanged, even if the variable was not modified. + ctx->op_kernel_context()->set_output( + update.input_index, + ctx->op_kernel_context()->input(update.input_index)); + } + + VLOG(1) << "Done building while loop"; +} + +REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.h b/tensorflow/compiler/tf2xla/kernels/while_op.h new file mode 100644 index 00000000000..67edebabf9f --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/while_op.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/attr_value.pb.h" + +namespace tensorflow { + +// This TensorFlow op provides a functional iteration primitive. +// +// The inputs and outputs of the loop body must agree on the number, types, and +// shapes of the Tensors carried around the loop body. +// +// Computations in while loops may read from and write to resource variables. +// Resource variables may be passed as arguments to a function's body and +// condition functions. The XlaCompiler converts resource variable arguments +// into parameters to the XLA computation and moves them to the end of the +// parameter list, and by using the `return_updated_values_for_all_variables` +// we ensure that all variables that appear in the input also appear at the +// end of the body's output. This ensures the loop body's input and output +// signatures match. +// +// It is the user's responsibility to ensure that each non-variable _Arg matches +// the corresponding _Retval. +// +// For example, suppose we have a loop body with arguments: +// DT_INT32, DT_RESOURCE (pointing to a DT_BOOL var), DT_FLOAT +// and return values +// DT_INT32, DT_FLOAT +// It is an error for the body to return DT_RESOURCE values. +// +// The body will be lowered into an XLA computation that takes and returns a +// tuple with XLA type (I32, F32, PRED). Note the resource variable appears at +// the end of both the loop body's input and output argument lists. +class XlaWhileOp : public XlaOpKernel { + public: + explicit XlaWhileOp(OpKernelConstruction* ctx); + + void Compile(XlaOpKernelContext* ctx) override; + + private: + NameAttrList cond_name_attr_; + NameAttrList body_name_attr_; + + TF_DISALLOW_COPY_AND_ASSIGN(XlaWhileOp); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_WHILE_OP_H_ diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 1f2bc01cf4a..e166e8a9b05 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -27,13 +27,13 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { TF_RETURN_IF_ERROR(TensorShapeToXLAShape( host_tensor.dtype(), host_tensor.shape(), literal->mutable_shape())); - xla::LiteralUtil::Reserve(host_tensor.NumElements(), literal); + literal->Reserve(host_tensor.NumElements()); // memcpy over the payload ... // TODO(phawkins): handle string types. size_t total_bytes = host_tensor.TotalBytes(); if (total_bytes > 0) { - void* dst_ptr = xla::LiteralUtil::MutableInternalData(literal); + void* dst_ptr = literal->MutableInternalData(); const void* src_ptr = DMAHelper::base(&host_tensor); memcpy(dst_ptr, src_ptr, total_bytes); } @@ -55,7 +55,7 @@ Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, *host_tensor = Tensor(target_type, shape); size_t total_bytes = host_tensor->TotalBytes(); if (total_bytes > 0) { - const void* src_ptr = xla::LiteralUtil::InternalData(literal); + const void* src_ptr = literal.InternalData(); void* dst_ptr = DMAHelper::base(host_tensor); memcpy(dst_ptr, src_ptr, total_bytes); } diff --git a/tensorflow/compiler/tf2xla/literal_util_test.cc b/tensorflow/compiler/tf2xla/literal_util_test.cc index 56993bc5853..f3d6787daaa 100644 --- a/tensorflow/compiler/tf2xla/literal_util_test.cc +++ b/tensorflow/compiler/tf2xla/literal_util_test.cc @@ -27,7 +27,7 @@ TEST(LiteralUtil, LiteralToHostTensor) { { std::vector int64_values = {1, 2, 3}; std::unique_ptr int64_values_literal = - xla::LiteralUtil::CreateR1(gtl::ArraySlice(int64_values)); + xla::Literal::CreateR1(gtl::ArraySlice(int64_values)); Tensor host_tensor; EXPECT_EQ("Cannot convert literal of type S64 to tensor of type int32", LiteralToHostTensor(*int64_values_literal, DT_INT32, &host_tensor) @@ -48,7 +48,7 @@ TEST(LiteralUtil, LiteralToHostTensor) { Tensor host_tensor; std::vector int32_values = {10, 11}; std::unique_ptr int32_values_literal = - xla::LiteralUtil::CreateR1(gtl::ArraySlice(int32_values)); + xla::Literal::CreateR1(gtl::ArraySlice(int32_values)); EXPECT_TRUE( LiteralToHostTensor(*int32_values_literal, DT_INT32, &host_tensor) .ok()); diff --git a/tensorflow/compiler/tf2xla/ops/BUILD b/tensorflow/compiler/tf2xla/ops/BUILD new file mode 100644 index 00000000000..a2bd06861d5 --- /dev/null +++ b/tensorflow/compiler/tf2xla/ops/BUILD @@ -0,0 +1,38 @@ +package( + default_visibility = ["//tensorflow/compiler/tf2xla:internal"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") + +cc_library( + name = "functional_ops", + srcs = ["functional_ops.cc"], + deps = [ + "//tensorflow/core:framework", + ], + alwayslink = 1, +) + +tf_gen_op_wrapper_py( + name = "gen_functional_ops", + out = "gen_functional_ops.py", + deps = [ + ":functional_ops", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/tf2xla/ops/functional_ops.cc b/tensorflow/compiler/tf2xla/ops/functional_ops.cc new file mode 100644 index 00000000000..38bcaa32278 --- /dev/null +++ b/tensorflow/compiler/tf2xla/ops/functional_ops.cc @@ -0,0 +1,45 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +// TODO(b/37549631) setting the While Op to always be stateful is too +// conservative. +REGISTER_OP("XlaWhile") + .Input("input: T") + .Output("output: T") + .Attr("T: list(type) >= 0") + .Attr("cond: func") + .Attr("body: func") + .SetIsStateful() + .Doc(R"doc( +output = input; While (Cond(output)) { output = Body(output) } + +input: A list of input tensors whose types are T. +output: A list of output tensors whose types are T. +cond: A function takes 'input' and returns a tensor. If the tensor is + a scalar of non-boolean, the scalar is converted to a boolean + according to the following rule: if the scalar is a numerical + value, non-zero means True and zero means False; if the scalar is + a string, non-empty means True and empty means False. If the + tensor is not a scalar, non-emptiness means True and False + otherwise. +body: A function that takes a list of tensors and returns another + list of tensors. Both lists have the same types as specified by T. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/test_util.cc b/tensorflow/compiler/tf2xla/test_util.cc new file mode 100644 index 00000000000..3c34b8788d5 --- /dev/null +++ b/tensorflow/compiler/tf2xla/test_util.cc @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/test_util.h" + +#include "tensorflow/compiler/xla/status_macros.h" + +namespace tensorflow { + +Status InstantiateFunctionForTest(const string& name, + const FunctionLibraryDefinition& library, + InstantiationResultForTest* result) { + const FunctionDef* fdef = library.Find(name); + TF_RET_CHECK(fdef != nullptr); + + auto get_func_sig = [&library](const string& op, const OpDef** sig) { + return library.LookUpOpDef(op, sig); + }; + InstantiationResult inst; + TF_RETURN_IF_ERROR( + InstantiateFunction(*fdef, AttrSlice(), get_func_sig, &inst)); + result->arg_types = inst.arg_types; + result->ret_types = inst.ret_types; + for (NodeDef& n : inst.nodes) { + *result->gdef.add_node() = std::move(n); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/test_util.h b/tensorflow/compiler/tf2xla/test_util.h new file mode 100644 index 00000000000..362558bcfc0 --- /dev/null +++ b/tensorflow/compiler/tf2xla/test_util.h @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helper functions for tests. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_ + +#include +#include +#include + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Same as InstantiationResult, but has a GraphDef instead of just nodes. +struct InstantiationResultForTest { + DataTypeVector arg_types; + DataTypeVector ret_types; + GraphDef gdef; +}; + +// Instantiates a function, producing a GraphDef to compare against the +// expected graph. +Status InstantiateFunctionForTest(const string& name, + const FunctionLibraryDefinition& library, + InstantiationResultForTest* result); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_TEST_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index 75630bee396..e4f43f1950d 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -64,26 +64,35 @@ class XlaCompilationDevice : public LocalDevice { std::unique_ptr allocator_; }; -struct XlaVariable { - // If this variable is visible externally, what was its argument number? +// Represents a resource, such as a Variable or TensorArray. +struct XlaResource { + enum Kind { + kInvalid, + kVariable, + kTensorArray, + }; + + Kind kind = kInvalid; + + // If this resource is visible externally, what was its argument number? int arg_num = -1; - // A descriptive name for the variable, used in error messages. + // A descriptive name for the resource, used in error messages. string name; - // Current type and value of the variable. Uninitialized variables are + // Current type and value of the resource. Uninitialized resources are // represented by a default (zero) handle and type DT_INVALID. - // While the type of a variable is notionally fixed during execution, when - // a variable is first initialized we do not yet know its type, so we keep + // While the type of a resource is notionally fixed during execution, when + // a resource is first initialized we do not yet know its type, so we keep // track of its type dynamically. DataType type = DT_INVALID; xla::ComputationDataHandle value; - // Value of the variable at computation entry. Used to detect which + // Value of the resource at computation entry. Used to detect which // variables have new values that need to be written back. xla::ComputationDataHandle initial_value; - // We treat TensorArrays as a Variable with some extra metadata. + // TensorArray-specific fields // 'tensor_array_size' stores the expected size of the TensorArray. We need // to store this since sometimes TensorArrays must be initialized lazily since @@ -91,10 +100,10 @@ struct XlaVariable { int64 tensor_array_size = -1; // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes - // to an XlaVariable containing the gradient TensorArrays. We store a pointer + // to an XlaResource containing the gradient TensorArrays. We store a pointer // here since there should only be one gradient TensorArray per 'source' // string, irrespective of the number of calls to TensorArrayGrad. - std::unordered_map tensor_array_gradient; + std::unordered_map tensor_array_gradient; }; // A XlaExpression wraps an XLA computation. Each Tensor on an @@ -115,8 +124,8 @@ class XlaExpression { bool has_constant_value() const { return has_constant_value_; } const Tensor& constant_value() const { return constant_value_; } - void set_variable(XlaVariable* variable) { variable_ = variable; } - XlaVariable* variable() const { return variable_; } + void set_resource(XlaResource* resource) { resource_ = resource; } + XlaResource* resource() const { return resource_; } private: // The XLA handle of the expression's computation. @@ -128,7 +137,7 @@ class XlaExpression { bool has_constant_value_ = false; Tensor constant_value_; - XlaVariable* variable_ = nullptr; // Not owned. + XlaResource* resource_ = nullptr; // Not owned. TF_DISALLOW_COPY_AND_ASSIGN(XlaExpression); }; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 580ce3d802e..50b384997a7 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -85,9 +86,10 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) (*options_.populate_resource_manager)(device_->resource_manager()); } + flib_def_.reset(new FunctionLibraryDefinition(*options.flib_def)); flib_runtime_.reset(NewFunctionLibraryRuntime( &device_mgr_, Env::Default(), device_, options.graph_def_version, - options.flib_def, OptimizerOptions(), + flib_def_.get(), OptimizerOptions(), nullptr /* custom_kernel_creator */)); } @@ -249,35 +251,36 @@ Status BuildArguments(const std::vector& args, std::vector* input_shapes) { context_args->resize(args.size()); - // Argument numbers of arguments and variables that are to be passed to the + // Argument numbers of arguments and resources that are to be passed to the // XLA computation as runtime parameters. - std::vector parameters, variables; + std::vector parameters, resources; parameters.reserve(args.size()); - variables.reserve(args.size()); + resources.reserve(args.size()); for (std::vector::size_type i = 0; i < args.size(); ++i) { XlaContext::Argument& context_arg = (*context_args)[i]; + context_arg.kind = args[i].kind; context_arg.name = args[i].name; context_arg.value.constant_value = args[i].constant_value; context_arg.value.type = args[i].type; switch (args[i].kind) { case XlaCompiler::Argument::kVariable: - variables.push_back(i); - context_arg.is_variable = true; - context_arg.value.is_constant = false; + case XlaCompiler::Argument::kTensorArray: + context_arg.is_resource = true; + if (args[i].initialized) { + resources.push_back(i); + context_arg.value.is_constant = false; + } else { + context_arg.value.is_constant = true; + } context_arg.tensor_array_size = args[i].tensor_array_size; break; case XlaCompiler::Argument::kParameter: parameters.push_back(i); context_arg.value.is_constant = false; break; - case XlaCompiler::Argument::kUninitializedVariable: - context_arg.is_variable = true; - context_arg.value.is_constant = true; - context_arg.tensor_array_size = args[i].tensor_array_size; - break; case XlaCompiler::Argument::kConstant: context_arg.value.is_constant = true; break; @@ -288,7 +291,7 @@ Status BuildArguments(const std::vector& args, // Append parameters containing variable values after the other runtime // parameters. - parameters.insert(parameters.end(), variables.begin(), variables.end()); + parameters.insert(parameters.end(), resources.begin(), resources.end()); if (parameters.empty()) { return Status::OK(); } @@ -329,22 +332,22 @@ Status BuildArguments(const std::vector& args, // variable states, generated by the symbolic evaluation. // If `has_side_effects` is true, the computation has side effects and should be // built even if it has no outputs. -// If `return_updated_values_for_all_variables` is true, all variables will be -// included in `variable_updates`, regardless of whether their value changed. +// If `return_updated_values_for_all_resources` is true, all resources will be +// included in `resource_updates`, regardless of whether their value changed. // Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. -// Sets `*variable_updates` to a description of variables whose values are +// Sets `*resource_updates` to a description of resources whose values are // written by the computation; the variable writes are the last -// `variable_updates.size()` return values from the computation. Each entry in -// `variable_updates` is a (input_index, type) pair, where `input_index` is the +// `resource_updates.size()` return values from the computation. Each entry in +// `resource_updates` is a (input_index, type) pair, where `input_index` is the // index of a resource variable argument to the computation, and `type` is the // type of the final output. Status BuildComputation( const std::vector& retvals, - const std::vector>& variables, - bool has_side_effects, bool return_updated_values_for_all_variables, + const std::vector>& resources, + bool has_side_effects, bool return_updated_values_for_all_resources, xla::ComputationBuilder* builder, xla::Computation* computation, int* num_nonconst_outputs, - std::vector* variable_updates) { + std::vector* resource_updates) { std::vector elems; elems.reserve(retvals.size()); for (const XlaContext::HandleOrConstant& retval : retvals) { @@ -354,24 +357,24 @@ Status BuildComputation( } *num_nonconst_outputs = elems.size(); - // Add return values for variables whose values have changed. - std::vector arg_vars; - arg_vars.reserve(variables.size()); - for (const auto& var : variables) { + // Add return values for resources whose values have changed. + std::vector arg_vars; + arg_vars.reserve(resources.size()); + for (const auto& var : resources) { if (var->arg_num >= 0) { arg_vars.push_back(var.get()); } } std::sort(arg_vars.begin(), arg_vars.end(), - [](const XlaVariable* a, const XlaVariable* b) { + [](const XlaResource* a, const XlaResource* b) { return a->arg_num < b->arg_num; }); - for (const XlaVariable* var : arg_vars) { + for (const XlaResource* var : arg_vars) { bool modified = var->value.handle() != var->initial_value.handle(); - if (return_updated_values_for_all_variables || modified) { - variable_updates->emplace_back(); - XlaCompiler::VariableUpdate& update = variable_updates->back(); + if (return_updated_values_for_all_resources || modified) { + resource_updates->emplace_back(); + XlaCompiler::ResourceUpdate& update = resource_updates->back(); update.input_index = var->arg_num; update.type = var->type; update.modified = modified; @@ -413,6 +416,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, // Report the error here if initialization failed. TF_RETURN_IF_ERROR(initialization_status_); + // Converts Tensorflow's graph control-flow constructs into functional + // control-flow that can be compiled into XLA code. + TF_RETURN_IF_ERROR(FunctionalizeControlFlow(graph.get(), flib_def_.get())); + xla::ComputationBuilder builder(client(), name); XlaContext* context = new XlaContext(this, &builder, options_.allow_cpu_custom_calls, @@ -433,10 +440,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, int num_nonconst_outputs; result->computation = std::make_shared(); TF_RETURN_IF_ERROR(BuildComputation( - context->retvals(), context->variables(), context->has_side_effects(), - options.return_updated_values_for_all_variables, &builder, + context->retvals(), context->resources(), context->has_side_effects(), + options.return_updated_values_for_all_resources, &builder, result->computation.get(), &num_nonconst_outputs, - &result->variable_updates)); + &result->resource_updates)); result->requires_runtime_context = context->has_context_parameter(); @@ -511,15 +518,15 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, } } - for (std::vector::size_type i = 0; - i < result->variable_updates.size(); ++i) { + for (std::vector::size_type i = 0; + i < result->resource_updates.size(); ++i) { if (num_computation_outputs > 1) { - result->variable_updates[i].shape = + result->resource_updates[i].shape = XLAShapeToTensorShape(xla::ShapeUtil::GetTupleElementShape( result->xla_output_shape, computation_output)); } else { CHECK_EQ(0, computation_output); - result->variable_updates[i].shape = + result->resource_updates[i].shape = XLAShapeToTensorShape(result->xla_output_shape); } ++computation_output; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 13143055325..58e42c34749 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -85,14 +85,14 @@ class XlaCompiler { // Argument is a compile-time constant. No associated runtime parameter. kConstant, - // Argument is a variable that has not been initialized yet. No associated - // runtime parameter. - kUninitializedVariable, - - // Argument is a variable that already has a value set. Expects a runtime - // parameter containing the current value. + // Argument is a variable resource. Has an associated runtime parameter + // iff `initialized` is true. kVariable, + // Argument is a TensorArray resource. Has an associated runtime parameter + // iff `initialized` is true. + kTensorArray, + // Argument is a run-time parameter. kParameter, }; @@ -114,8 +114,11 @@ class XlaCompiler { // The name of this argument, used for debugging. string name; - // For a kVariable or kUninitializedVariable corresponding to a TensorArray, - // what is the tensor array's declared size? + // For a kVariable or kTensorArray, has this resource been initialized? + bool initialized = false; + + // For a kTensorArray, what is the array's declared size? (Used for lazy + // initialization.) int64 tensor_array_size = -1; bool operator==(const Argument& other) const; @@ -133,7 +136,7 @@ class XlaCompiler { }; // Describes a variable write side effect of the computation. - struct VariableUpdate { + struct ResourceUpdate { // Index of the input that contains the variable resource to write to. int input_index; @@ -142,14 +145,14 @@ class XlaCompiler { TensorShape shape; // Was the value of the variable modified by the computation? - // (Always true, unless `return_updated_values_for_all_variables` is true.) + // (Always true, unless `return_updated_values_for_all_resources` is true.) bool modified; }; struct CompilationResult { // Vector that maps from the parameters of the XLA computation to their // original argument positions. To handle compile-time constant inputs and - // variables, the parameters to the XLA computation may be a subset of the + // resources, the parameters to the XLA computation may be a subset of the // original arguments, and are not necessarily in the same order.) std::vector input_mapping; @@ -172,10 +175,10 @@ class XlaCompiler { // containing both constant and non-constant results. std::vector outputs; - // Variables whose values were updated by the computation, ordered - // by return value position. Variable updates follow the non-constant + // Resources whose values were updated by the computation, ordered + // by return value position. Resource updates follow the non-constant // results in the outputs of XLA computation. - std::vector variable_updates; + std::vector resource_updates; // The XLA computation built from the tensorflow subgraph. May be null // if the output consists solely of compile-time constants. @@ -229,12 +232,12 @@ class XlaCompiler { // arguments; if false, each argument gets its own parameter. bool use_tuple_arg = false; - // If 'return_updated_values_for_all_variables' is true, then updated - // values of all resource variables arguments will be included in the - // 'variable_updates' of the computation, even if the variable was not + // If 'return_updated_values_for_all_resources' is true, then updated + // values of all resource resources arguments will be included in the + // 'resource_updates' of the computation, even if the resource was not // modified by the computation. Used when compiling loop bodies to ensure // the input and output signatures match. - bool return_updated_values_for_all_variables = false; + bool return_updated_values_for_all_resources = false; }; // Compiles a Tensorflow function `fn_name_attrs` into an XLA computation. @@ -294,6 +297,7 @@ class XlaCompiler { XlaCompilationDevice* device_; // Owned by device_mgr_ DeviceMgr device_mgr_; + std::unique_ptr flib_def_; std::unique_ptr flib_runtime_; struct SignatureHash { diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 58d74057d10..427b14534fd 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -163,9 +163,9 @@ TEST_F(XlaCompilerTest, Simple) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal::CreateR1({7, 42}); std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR1({-3, 101}); + xla::Literal::CreateR1({-3, 101}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -179,7 +179,7 @@ TEST_F(XlaCompilerTest, Simple) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected_literal = - xla::LiteralUtil::CreateR1({4, 143}); + xla::Literal::CreateR1({4, 143}); xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); } @@ -225,7 +225,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal::CreateR1({7, 42}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -236,7 +236,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { client_->Transfer(*actual).ConsumeValueOrDie(); std::unique_ptr expected_literal = - xla::LiteralUtil::CreateR1({-7, -42}); + xla::Literal::CreateR1({-7, -42}); xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); } @@ -260,7 +260,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { // Tests that the generated computation works. std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({7, 42}); + xla::Literal::CreateR1({7, 42}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -270,12 +270,11 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { std::unique_ptr actual_literal = client_->Transfer(*actual).ConsumeValueOrDie(); - std::unique_ptr expected0 = - xla::LiteralUtil::CreateR0(7); + std::unique_ptr expected0 = xla::Literal::CreateR0(7); std::unique_ptr expected1 = - xla::LiteralUtil::CreateR1({-7, -42}); + xla::Literal::CreateR1({-7, -42}); std::unique_ptr expected = - xla::LiteralUtil::MakeTuple({expected0.get(), expected1.get()}); + xla::Literal::MakeTuple({expected0.get(), expected1.get()}); xla::LiteralTestUtil::ExpectEqual(*expected, *actual_literal); } } diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 4440b530696..1a37d61944a 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -129,16 +129,18 @@ void XlaContext::AddSideEffects() { xla::ComputationBuilder* XlaContext::builder() { return builder_; } -Status XlaContext::CreateVariable(int arg_num, string name, DataType type, +Status XlaContext::CreateResource(XlaResource::Kind kind, int arg_num, + string name, DataType type, const xla::ComputationDataHandle& handle, - XlaVariable** variable) { - variables_.emplace_back(new XlaVariable); - *variable = variables_.back().get(); - XlaVariable& var = **variable; - var.arg_num = arg_num; - var.name = std::move(name); - var.type = type; - var.initial_value = var.value = handle; + XlaResource** resource) { + resources_.emplace_back(new XlaResource); + *resource = resources_.back().get(); + XlaResource& r = **resource; + r.kind = kind; + r.arg_num = arg_num; + r.name = std::move(name); + r.type = type; + r.initial_value = r.value = handle; return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 3978baaf637..dbede52b5d3 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -52,11 +52,13 @@ class XlaContext : public ResourceBase { }; struct Argument { - // Descriptive name for the variable, for use in error messages. + XlaCompiler::Argument::Kind kind; + + // Descriptive name for the resource, for use in error messages. string name; - // Is this a variable? - bool is_variable = false; + // Is this a resource? + bool is_resource = false; HandleOrConstant value; @@ -106,15 +108,15 @@ class XlaContext : public ResourceBase { bool has_side_effects() const { return has_side_effects_; } - // Creates a variable with variable `variable_id` and initial type `type` and + // Creates a resource with resource `kind` and initial type `type` and // value `handle`. `name` is a descriptive name for use in error messages. - // Fails if the variable already exists. - Status CreateVariable(int arg_num, string name, DataType type, - const xla::ComputationDataHandle& handle, - XlaVariable** variable); + // Fails if the resource already exists. + Status CreateResource(XlaResource::Kind kind, int arg_num, string name, + DataType type, const xla::ComputationDataHandle& handle, + XlaResource** resource); - const std::vector>& variables() { - return variables_; + const std::vector>& resources() { + return resources_; } // Get an XLA lambda to compute Max. This is cached in the @@ -166,8 +168,8 @@ class XlaContext : public ResourceBase { // Does the computation have side effects, i.e., Send() calls? bool has_side_effects_ = false; - // Holds ownership of variables. The variables are not ordered. - std::vector> variables_; + // Holds ownership of resources. The resources are not ordered. + std::vector> resources_; // Cache of prebuilt computations indexed by their type. using ComputationMap = std::map; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index f060f8f2f17..2366c02dd2b 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -30,28 +30,28 @@ xla::ComputationDataHandle XlaHelpers::MinValue(xla::ComputationBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::LiteralUtil::MinValue(type)); + return b->ConstantLiteral(xla::Literal::MinValue(type)); } xla::ComputationDataHandle XlaHelpers::MaxValue(xla::ComputationBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::LiteralUtil::MaxValue(type)); + return b->ConstantLiteral(xla::Literal::MaxValue(type)); } xla::ComputationDataHandle XlaHelpers::Zero(xla::ComputationBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::LiteralUtil::Zero(type)); + return b->ConstantLiteral(xla::Literal::Zero(type)); } xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b, DataType data_type) { xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - return b->ConstantLiteral(xla::LiteralUtil::One(type)); + return b->ConstantLiteral(xla::Literal::One(type)); } xla::ComputationDataHandle XlaHelpers::IntegerLiteral( @@ -61,28 +61,28 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); switch (type) { case xla::U8: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::U32: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::U64: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::S8: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::S32: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::S64: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::F32: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::F64: - literal = *xla::LiteralUtil::CreateR0(value); + literal = *xla::Literal::CreateR0(value); break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; @@ -91,7 +91,7 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( LOG(FATAL) << "u16/s16 literals not yet implemented"; case xla::F16: literal = - *xla::LiteralUtil::CreateR0(static_cast(value)); + *xla::Literal::CreateR0(static_cast(value)); break; case xla::TUPLE: LOG(FATAL) << "tuple element type is not integral"; diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 3272b1efa15..d606b32931e 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -39,7 +39,7 @@ static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { const XlaExpression* expression = reinterpret_cast(tensor.tensor_data().data()); CHECK(expression->handle().handle() != 0 || - expression->variable() != nullptr); + expression->resource() != nullptr); VLOG(1) << "Fetched T" << expression->handle().handle(); return expression; } @@ -144,9 +144,9 @@ static Status LiteralToInt64Scalar(const xla::Literal& literal, int64* out) { return errors::InvalidArgument("value is not a scalar"); } if (literal.shape().element_type() == xla::S32) { - *out = xla::LiteralUtil::Get(literal, {}); + *out = literal.Get({}); } else if (literal.shape().element_type() == xla::S64) { - *out = xla::LiteralUtil::Get(literal, {}); + *out = literal.Get({}); } else { return errors::InvalidArgument("value must be either int32 or int64"); } @@ -168,11 +168,11 @@ static Status LiteralToInt64Vector(const xla::Literal& literal, int64 size = xla::ShapeUtil::ElementsIn(literal.shape()); if (literal.shape().element_type() == xla::S32) { for (int64 i = 0; i < size; ++i) { - out->push_back(xla::LiteralUtil::Get(literal, {i})); + out->push_back(literal.Get({i})); } } else if (literal.shape().element_type() == xla::S64) { for (int64 i = 0; i < size; ++i) { - out->push_back(xla::LiteralUtil::Get(literal, {i})); + out->push_back(literal.Get({i})); } } else { return errors::InvalidArgument("value must be either int32 or int64"); @@ -252,8 +252,9 @@ Status XlaOpKernelContext::ReadVariableInput( int index, xla::ComputationDataHandle* value) { const Tensor& tensor = context_->input(index); const XlaExpression* expression = CastExpressionFromTensor(tensor); - XlaVariable* variable = expression->variable(); + XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); + TF_RET_CHECK(variable->kind == XlaResource::kVariable); if (variable->value.handle() == 0) { return errors::InvalidArgument("Read of uninitialized variable ", variable->name); @@ -262,22 +263,13 @@ Status XlaOpKernelContext::ReadVariableInput( return Status::OK(); } -string XlaOpKernelContext::VariableDebugString(int index) { - const Tensor& tensor = context_->input(index); - const XlaExpression* expression = CastExpressionFromTensor(tensor); - XlaVariable* variable = expression->variable(); - if (!variable) { - return ""; - } - return variable->name; -} - Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, TensorShape* shape) const { const Tensor& tensor = context_->input(index); const XlaExpression* expression = CastExpressionFromTensor(tensor); - XlaVariable* variable = expression->variable(); + XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); + TF_RET_CHECK(variable->kind == XlaResource::kVariable); if (variable->value.handle() == 0) { return errors::InvalidArgument("Read of uninitialized variable ", variable->name); @@ -337,33 +329,34 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { expression->set_constant_value(constant); } -void XlaOpKernelContext::SetVariableOutput(int index, XlaVariable* variable) { +void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { Tensor* output = nullptr; - // The shape of the output tensor is the shape of the variable resource - // (i.e., a scalar), not the shape of the variable's value. + // The shape of the output tensor is the shape of the resource itself + // (i.e., a scalar), not the shape of the resource's value. OP_REQUIRES_OK(context_, context_->allocate_output(index, TensorShape(), &output)); XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - expression->set_variable(variable); + expression->set_resource(resource); } -Status XlaOpKernelContext::GetVariableInput(int index, XlaVariable** variable) { +Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { const XlaExpression* expression = CastExpressionFromTensor(context_->input(index)); - TF_RET_CHECK(expression->variable() != nullptr); - *variable = expression->variable(); + TF_RET_CHECK(expression->resource() != nullptr); + *resource = expression->resource(); return Status::OK(); } Status XlaOpKernelContext::AssignVariable( - int index, DataType type, const xla::ComputationDataHandle& handle) { + int input_index, DataType type, const xla::ComputationDataHandle& handle) { TF_RET_CHECK(handle.handle() != 0); SetOpHasSideEffects(); const XlaExpression* expression = - CastExpressionFromTensor(context_->input(index)); - XlaVariable* variable = expression->variable(); + CastExpressionFromTensor(context_->input(input_index)); + XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); + TF_RET_CHECK(variable->kind == XlaResource::kVariable); if (!((variable->type == DT_INVALID && type != DT_INVALID) || (variable->type == type))) { return errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index a25774c3a6a..b151286217a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -148,6 +148,12 @@ class XlaOpKernelContext { // Variables + // Sets '*resource' to the resource associated with input `index`. + Status GetResourceInput(int index, XlaResource** resource); + + // Sets output 'index' to be a reference to resource 'resource'. + void SetResourceOutput(int index, XlaResource* resource); + // Sets `*type` and `*shape` to the current type and shape of a variable's // value. Status GetVariableTypeAndShape(int index, DataType* type, @@ -158,20 +164,10 @@ class XlaOpKernelContext { Status ReadVariableInput(int index, xla::ComputationDataHandle* value); // Assigns the value `handle` to the variable referenced by input - // `variable_index`. Marks the operator as having side effects. - Status AssignVariable(int variable_index, DataType type, + // `input_index`. Marks the operator as having side effects. + Status AssignVariable(int input_index, DataType type, const xla::ComputationDataHandle& handle); - // Sets '*variable' to the variable associated with input `index`. - Status GetVariableInput(int index, XlaVariable** variable); - - // Sets output 'index' to be a reference to variable 'variable'. Used - // to propagate resource variables through the compilation. - void SetVariableOutput(int index, XlaVariable* variable); - - // Returns a human-readable debug string describing 'variable_index'. - string VariableDebugString(int variable_index); - // Helper routines for the OP_REQUIRES macros void CtxFailure(Status s); void CtxFailureWithWarning(Status s); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 1bb0d852899..aaef27f16de 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -34,11 +34,18 @@ const char* const DEVICE_GPU_XLA_JIT = "XLA_GPU_JIT"; const char* const DEVICE_XLA_CPU = "XLA_CPU"; const char* const DEVICE_XLA_GPU = "XLA_GPU"; -// Is platform 'id' supported by XLA? -static bool IsPlatformSupported(perftools::gputools::Platform::Id id) { - auto platform = perftools::gputools::MultiPlatformManager::PlatformWithId(id); - if (!platform.ok()) return false; - return xla::ClientLibrary::GetOrCreateLocalClient(platform.ValueOrDie()).ok(); +static Status LaunchOpHasKernelForDevice(const DeviceType& device_type) { + const OpDef* op_def; + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef("_XlaLaunch", &op_def)); + NodeDef node_def; + node_def.set_name("_XlaLaunch-op"); + node_def.set_op("_XlaLaunch"); + string kernel_class_name; + TF_RETURN_IF_ERROR(FindKernelDef(device_type, node_def, /*KernelDef*/ nullptr, + &kernel_class_name)); + VLOG(1) << "LaunchOpHasKernelForDevice" + << " kernel_class_name: " << kernel_class_name; + return Status::OK(); } XlaOpRegistry::XlaOpRegistry() = default; @@ -75,7 +82,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; // GetCompilationDevice is called. static void* registration_init = [®istry]() { mutex_lock lock(registry.mutex_); - if (IsPlatformSupported(perftools::gputools::host::kHostPlatformId)) { + if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_CPU)).ok()) { DeviceRegistration& registration = registry.compilation_devices_[DEVICE_CPU]; registration.compilation_device_name = DEVICE_CPU_XLA_JIT; @@ -83,7 +90,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; registration.enable_jit_by_default = false; registration.compile_resource_ops = false; } - if (IsPlatformSupported(perftools::gputools::cuda::kCudaPlatformId)) { + if (LaunchOpHasKernelForDevice(DeviceType(DEVICE_GPU)).ok()) { DeviceRegistration& registration = registry.compilation_devices_[DEVICE_GPU]; registration.compilation_device_name = DEVICE_GPU_XLA_JIT; diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 2491cc3f7a2..c508071f8c1 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -46,21 +46,18 @@ xla_proto_library( ], ) -# This is a headers target that extra XLA devices can use to prevent -# circular dependencies. Devices that are compiled as separate shared -# objects can also use it to prevent linking of library code. -cc_header_only_library( - name = "xla_headers_lib", - visibility = ["//visibility:public"], +cc_library( + name = "execution_options_util", + srcs = [ + "execution_options_util.cc", + ], + hdrs = [ + "execution_options_util.h", + ], + visibility = [":friends"], deps = [ - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla:xla_proto", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_evaluator", - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:stream_executor_headers_lib", + ":xla_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", ], ) @@ -602,3 +599,18 @@ filegroup( ), visibility = ["//tensorflow:__subpackages__"], ) + +# This is a headers target that extra XLA devices can use to prevent circular dependencies. Devices that are compiled as separate shared objects can also use it to prevent linking of library code. +cc_header_only_library( + name = "xla_headers_lib", + visibility = ["//visibility:public"], + deps = [ + ":xla_data_proto", + ":xla_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:stream_executor_headers_lib", + ], +) diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 63c6d9ddaca..f50dc934f22 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -114,7 +114,6 @@ cc_library( "//tensorflow/compiler/xla/service:compile_only_service", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@llvm//:support", ], diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 37bf697683b..dcc313707b9 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -971,6 +971,11 @@ ComputationDataHandle ComputationBuilder::Sign( return UnaryOp(UNOP_SIGN, operand); } +ComputationDataHandle ComputationBuilder::Cos( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_COS, operand); +} + ComputationDataHandle ComputationBuilder::Tanh( const ComputationDataHandle& operand) { return UnaryOp(UNOP_TANH, operand); @@ -1411,6 +1416,52 @@ ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding( return ParseOpResponse(s, &response); } +ComputationDataHandle ComputationBuilder::BatchNormTraining( + const ComputationDataHandle& operand, const ComputationDataHandle& scale, + const ComputationDataHandle& offset, float epsilon, int64 feature_index) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + BatchNormTrainingRequest request; + *request.mutable_operand() = operand; + *request.mutable_scale() = scale; + *request.mutable_offset() = offset; + request.set_epsilon(epsilon); + request.set_feature_index(feature_index); + + OpRequest op_request; + *op_request.mutable_batch_norm_training_request() = request; + *op_request.mutable_computation() = computation_.handle(); + AddOpMetadata(&op_request); + + OpResponse response; + + VLOG(2) << "making BatchNormTraining request"; + + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + +ComputationDataHandle ComputationBuilder::BatchNormInference( + const ComputationDataHandle& operand, const ComputationDataHandle& scale, + const ComputationDataHandle& offset, const ComputationDataHandle& mean, + const ComputationDataHandle& variance, float epsilon, int64 feature_index) { + // TODO(b/62843645): Implement BatchNormInference. + NoteError(Unimplemented("BatchNormInference is not implemented yet.")); + return ComputationDataHandle(); +} + +ComputationDataHandle ComputationBuilder::BatchNormGrad( + const ComputationDataHandle& operand, const ComputationDataHandle& scale, + const ComputationDataHandle& batch_mean, + const ComputationDataHandle& batch_var, + const ComputationDataHandle& grad_output, float epsilon, + int64 feature_index) { + // TODO(b/62843645): Implement BatchNormGrad. + NoteError(Unimplemented("BatchNormGrad is not implemented yet.")); + return ComputationDataHandle(); +} + ComputationDataHandle ComputationBuilder::CrossReplicaSum( const ComputationDataHandle& operand) { if (!first_error_.ok() || !PrepareComputation().ok()) { @@ -1487,6 +1538,28 @@ ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding( return ParseOpResponse(s, &response); } +ComputationDataHandle ComputationBuilder::ReducePrecision( + const ComputationDataHandle& operand, const int exponent_bits, + const int mantissa_bits) { + if (!first_error_.ok() || !PrepareComputation().ok()) { + return ComputationDataHandle(); + } + + ReducePrecisionRequest request; + *request.mutable_operand() = operand; + request.set_exponent_bits(exponent_bits); + request.set_mantissa_bits(mantissa_bits); + OpRequest op_request; + *op_request.mutable_computation() = computation_.handle(); + *op_request.mutable_reduce_precision_request() = request; + AddOpMetadata(&op_request); + OpResponse response; + + VLOG(2) << "making reduce-precision request"; + Status s = client_->stub()->Op(&op_request, &response); + return ParseOpResponse(s, &response); +} + void ComputationBuilder::Send(const ComputationDataHandle& operand, const ChannelHandle& handle) { if (!first_error_.ok() || !PrepareComputation().ok()) { diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 5cc73c28d03..b411346459e 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -510,6 +510,9 @@ class ComputationBuilder { // Enqueues a sign instruction onto the computation. ComputationDataHandle Sign(const ComputationDataHandle& operand); + // Enqueues a cosine instruction onto the computation. + ComputationDataHandle Cos(const ComputationDataHandle& operand); + // Enqueues a tanh instruction onto the computation. ComputationDataHandle Tanh(const ComputationDataHandle& operand); @@ -597,6 +600,11 @@ class ComputationBuilder { const Computation& body, const ComputationDataHandle& init); + // Enqueues a ReducePrecision node onto the computation. + ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand, + const int exponent_bits, + const int mantissa_bits); + // Enqueues a Send node onto the computation, to send the given operand to // a Recv instruction that shares the same channel handle. void Send(const ComputationDataHandle& operand, const ChannelHandle& handle); @@ -820,87 +828,80 @@ class ComputationBuilder { template ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) { - return ConstantOp( - [value](Literal* literal) { LiteralUtil::PopulateR0(value, literal); }); + return ConstantOp([value](Literal* literal) { literal->PopulateR0(value); }); } template ComputationDataHandle ComputationBuilder::ConstantR1( tensorflow::gtl::ArraySlice values) { - return ConstantOp([&values](Literal* literal) { - LiteralUtil::PopulateR1(values, literal); - }); + return ConstantOp( + [&values](Literal* literal) { literal->PopulateR1(values); }); } template ComputationDataHandle ComputationBuilder::ConstantR1(int64 length, NativeT value) { return ConstantOp([length, value](Literal* literal) { - LiteralUtil::PopulateWithValue(value, {length}, literal); + literal->PopulateWithValue(value, {length}); }); } inline ComputationDataHandle ComputationBuilder::ConstantR1( const tensorflow::core::Bitmap& values) { - return ConstantOp([&values](Literal* literal) { - LiteralUtil::PopulateR1(values, literal); - }); + return ConstantOp( + [&values](Literal* literal) { literal->PopulateR1(values); }); } template ComputationDataHandle ComputationBuilder::ConstantR2( std::initializer_list> values) { - return ConstantOp([&values](Literal* literal) { - LiteralUtil::PopulateR2(values, literal); - }); + return ConstantOp( + [&values](Literal* literal) { literal->PopulateR2(values); }); } template ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( const Array2D& values, const Layout& layout) { return ConstantOp([&values, &layout](Literal* literal) { - LiteralUtil::PopulateR2FromArray2DWithLayout(values, layout, literal); + literal->PopulateR2FromArray2DWithLayout(values, layout); }); } template ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D( const Array2D& values) { - return ConstantOp([&values](Literal* literal) { - LiteralUtil::PopulateR2FromArray2D(values, literal); - }); + return ConstantOp( + [&values](Literal* literal) { literal->PopulateR2FromArray2D(values); }); } template ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { return ConstantOp([&values, &layout](Literal* literal) { - LiteralUtil::PopulateR3FromArray3DWithLayout(values, layout, literal); + literal->PopulateR3FromArray3DWithLayout(values, layout); }); } template ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D( const Array3D& values) { - return ConstantOp([&values](Literal* literal) { - LiteralUtil::PopulateR3FromArray3D(values, literal); - }); + return ConstantOp( + [&values](Literal* literal) { literal->PopulateR3FromArray3D(values); }); } template ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout( const Array4D& values, const Layout& layout) { return ConstantOp([&values, &layout](Literal* literal) { - LiteralUtil::PopulateR4FromArray4DWithLayout(values, layout, literal); + literal->PopulateR4FromArray4DWithLayout(values, layout); }); } template ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D( const Array4D& values) { - return ConstantOp([&values](Literal* literal) { - LiteralUtil::PopulateR4FromArray4D(values, literal); - }); + return ConstantOp( + [&values](Literal* literal) { literal->PopulateR4FromArray4D(values); }); } } // namespace xla diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index 86b16be62f0..edd971e114c 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -32,6 +32,7 @@ cc_library( srcs = ["testing.cc"], hdrs = ["testing.h"], deps = [ + "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index daa1557df0b..d8bfc945807 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,11 +35,11 @@ std::unique_ptr MakeFakeDataOrDie(const Shape& shape, client, tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape))); // TODO(b/26811613): Replace this when RNG is supported on all backends. - b.Broadcast(b.ConstantLiteral(LiteralUtil::One(shape.element_type())), + b.Broadcast(b.ConstantLiteral(Literal::One(shape.element_type())), AsInt64Slice(shape.dimensions())); Computation computation = b.Build().ConsumeValueOrDie(); - ExecutionOptions execution_options; + auto execution_options = CreateDefaultExecutionOptions(); *execution_options.mutable_shape_with_output_layout() = shape; return client->Execute(computation, /*arguments=*/{}, &execution_options) .ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 67f3a6c1df4..33d5b6f1d4d 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -77,4 +77,14 @@ ExecutionProfile* ExecutableRunOptions::execution_profile() const { return execution_profile_; } +ExecutableRunOptions& ExecutableRunOptions::set_device_assignment( + DeviceAssignment* device_assignment) { + device_assignment_ = device_assignment; + return *this; +} + +DeviceAssignment* ExecutableRunOptions::device_assignment() const { + return device_assignment_; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index 03f2d016ad0..deb3ddb203d 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -40,6 +40,7 @@ struct ThreadPoolDevice; namespace xla { class DeviceMemoryAllocator; +class DeviceAssignment; class ExecutionProfile; // Class containing options for running a LocalExecutable. @@ -79,9 +80,14 @@ class ExecutableRunOptions { ExecutionProfile* execution_profile() const; ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile); + ExecutableRunOptions& set_device_assignment( + DeviceAssignment* device_assignment); + DeviceAssignment* device_assignment() const; + private: DeviceMemoryAllocator* allocator_ = nullptr; int device_ordinal_ = -1; + DeviceAssignment* device_assignment_ = nullptr; perftools::gputools::Stream* stream_ = nullptr; tensorflow::thread::ThreadPool* inter_op_thread_pool_ = nullptr; const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; diff --git a/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html b/tensorflow/compiler/xla/execution_options_util.cc similarity index 50% rename from tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html rename to tensorflow/compiler/xla/execution_options_util.cc index a325f0a04cd..e83ff7cddd6 100644 --- a/tensorflow/tensorboard/components/tf_color_scale/tf-color-scale.html +++ b/tensorflow/compiler/xla/execution_options_util.cc @@ -1,6 +1,4 @@ - +==============================================================================*/ +#include "tensorflow/compiler/xla/execution_options_util.h" +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" - - +namespace xla { - - - - - +} // namespace xla diff --git a/tensorflow/compiler/xla/execution_options_util.h b/tensorflow/compiler/xla/execution_options_util.h new file mode 100644 index 00000000000..562da78e837 --- /dev/null +++ b/tensorflow/compiler/xla/execution_options_util.h @@ -0,0 +1,29 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ + +#include "tensorflow/compiler/xla/xla.pb.h" + +namespace xla { + +// Create a default ExecutionOptions proto; this proto has its debug options +// popupated to the default values taken from flags. +ExecutionOptions CreateDefaultExecutionOptions(); + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/BUILD b/tensorflow/compiler/xla/legacy_flags/BUILD index a147ce67a28..fafd5f591b1 100644 --- a/tensorflow/compiler/xla/legacy_flags/BUILD +++ b/tensorflow/compiler/xla/legacy_flags/BUILD @@ -73,26 +73,12 @@ cc_library( deps = [ ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", ], ) -cc_library( - name = "cpu_compiler_flags", - srcs = ["cpu_compiler_flags.cc"], - hdrs = ["cpu_compiler_flags.h"], - deps = - [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - cc_library( name = "cpu_runtime_flags", srcs = ["cpu_runtime_flags.cc"], @@ -128,30 +114,6 @@ cc_library( ], ) -cc_library( - name = "gpu_compiler_flags", - srcs = ["gpu_compiler_flags.cc"], - hdrs = ["gpu_compiler_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "gpu_backend_lib_flags", - srcs = ["gpu_backend_lib_flags.cc"], - hdrs = ["gpu_backend_lib_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/compiler/xla:types", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - cc_library( name = "stream_assignment_flags", srcs = ["stream_assignment_flags.cc"], @@ -175,28 +137,6 @@ cc_library( ], ) -cc_library( - name = "alias_analysis_flags", - srcs = ["alias_analysis_flags.cc"], - hdrs = ["alias_analysis_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - -cc_library( - name = "llvm_util_flags", - srcs = ["llvm_util_flags.cc"], - hdrs = ["llvm_util_flags.h"], - deps = [ - ":parse_flags_from_env", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - ], -) - cc_library( name = "service_flags", srcs = ["service_flags.cc"], diff --git a/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.cc b/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.cc deleted file mode 100644 index 474753c10ad..00000000000 --- a/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.cc +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's alias_analysis module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static AliasAnalysisFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new AliasAnalysisFlags; - flags->xla_emit_alias_scope = true; - flag_list = new std::vector({ - tensorflow::Flag("xla_emit_alias_scope", &flags->xla_emit_alias_scope, - "Use buffer analysis to refine alias-analysis."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's alias_analysis -// module. -void AppendAliasAnalysisFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the AliasAnalysisFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -AliasAnalysisFlags* GetAliasAnalysisFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h b/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h deleted file mode 100644 index 369f8cd7caa..00000000000 --- a/tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_ - -// Legacy flags for XLA's alias_analysis module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's alias_analysis -// module. -void AppendAliasAnalysisFlags(std::vector* flag_list); - -// The values of flags associated with XLA's alias_analysis module. -typedef struct { - bool xla_emit_alias_scope; // Use buffer analysis to refine alias-analysis. -} AliasAnalysisFlags; - -// Return a pointer to the AliasAnalysisFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -AliasAnalysisFlags* GetAliasAnalysisFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_ALIAS_ANALYSIS_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc b/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc deleted file mode 100644 index 13d41a8636b..00000000000 --- a/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.cc +++ /dev/null @@ -1,68 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's cpu_compiler module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static CpuCompilerFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new CpuCompilerFlags; - flags->xla_cpu_embed_ir = false; - flags->xla_cpu_dump_debug_json_to = ""; - flag_list = new std::vector({ - tensorflow::Flag( - "xla_cpu_embed_ir", &flags->xla_cpu_embed_ir, - "Embed the LLVM IR module string in the resultant CpuExecutable."), - tensorflow::Flag("xla_cpu_dump_debug_json_to", - &flags->xla_cpu_dump_debug_json_to, - "Dump debug JSON to this directory."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's cpu_compiler -// module. -void AppendCpuCompilerFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the CpuCompilerFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -CpuCompilerFlags* GetCpuCompilerFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h b/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h deleted file mode 100644 index bac498e18eb..00000000000 --- a/tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_ - -// Legacy flags for the XLA's cpu_compiler module. - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's cpu_compiler -// module. -void AppendCpuCompilerFlags(std::vector* flag_list); - -// The values of flags associated with XLA's cpu_compiler module. -typedef struct { - bool xla_cpu_embed_ir; // Embed the LLVM IR module string in the resultant - // CpuExecutable - string xla_cpu_dump_debug_json_to; // Dump debug JSON to this directory. -} CpuCompilerFlags; - -// Return a pointer to the CpuCompilerFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -CpuCompilerFlags* GetCpuCompilerFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_CPU_COMPILER_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index 5e3c4f912bf..5f029a5f532 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -28,6 +28,12 @@ struct DebugOptionsFlags { string xla_disable_hlo_passes; bool xla_enable_fast_math; int32 xla_backend_optimization_level; + bool xla_embed_ir_in_executable; + string xla_dump_debug_json_to; + + string xla_gpu_cuda_data_dir; + bool xla_gpu_ftz; + string xla_backend_extra_options; }; @@ -44,7 +50,11 @@ void AllocateFlags() { flag_values->xla_generate_hlo_graph = ""; flag_values->xla_disable_hlo_passes = ""; flag_values->xla_enable_fast_math = true; - flag_values->xla_backend_optimization_level = 2; + flag_values->xla_backend_optimization_level = 3; + flag_values->xla_embed_ir_in_executable = false; + flag_values->xla_dump_debug_json_to = ""; + flag_values->xla_gpu_cuda_data_dir = "./cuda_sdk_lib"; + flag_values->xla_gpu_ftz = false; flag_values->xla_backend_extra_options = ""; flag_objects = new std::vector( @@ -52,7 +62,6 @@ void AllocateFlags() { "xla_generate_hlo_graph", &flag_values->xla_generate_hlo_graph, "HLO modules matching this regex will be dumped to a .dot file " "throughout various stages in compilation."), - tensorflow::Flag( "xla_enable_fast_math", &flag_values->xla_enable_fast_math, "Enable unsafe fast-math optimizations in the compiler; " @@ -61,18 +70,31 @@ void AllocateFlags() { "xla_backend_optimization_level", &flag_values->xla_backend_optimization_level, "Numerical optimization level for the XLA compiler backend."), - + tensorflow::Flag( + "xla_disable_hlo_passes", &flag_values->xla_disable_hlo_passes, + "Comma-separated list of hlo passes to be disabled. These names " + "must exactly match the passes' names; no whitespace around " + "commas."), + tensorflow::Flag("xla_embed_ir_in_executable", + &flag_values->xla_embed_ir_in_executable, + "Embed the compiler IR as a string in the executable."), + tensorflow::Flag("xla_gpu_cuda_data_dir", + &flag_values->xla_gpu_cuda_data_dir, + "If non-empty, speficies a local directory containing " + "ptxas and nvvm libdevice files; otherwise we use " + "those from runfile directories."), + tensorflow::Flag("xla_gpu_ftz", &flag_values->xla_gpu_ftz, + "If true, flush-to-zero semantics are enabled in the " + "code generated for GPUs."), + tensorflow::Flag( + "xla_dump_debug_json_to", &flag_values->xla_dump_debug_json_to, + "Dump compilation artifacts as JSON into this directory."), tensorflow::Flag("xla_backend_extra_options", &flag_values->xla_backend_extra_options, "Extra options to pass to a backend; " "comma-separated list of 'key=val' strings (=val " - "may be omitted); no whitespace around commas."), + "may be omitted); no whitespace around commas.")}); - tensorflow::Flag( - "xla_disable_hlo_passes", &flag_values->xla_disable_hlo_passes, - "Comma-separated list of HLO passes to be disabled. These names " - "must exactly match the passes' names; " - "no whitespace around commas.")}); ParseFlagsFromEnv(*flag_objects); } @@ -99,6 +121,11 @@ xla::DebugOptions GetDebugOptionsFromFlags() { options.set_xla_enable_fast_math(flag_values->xla_enable_fast_math); options.set_xla_backend_optimization_level( flag_values->xla_backend_optimization_level); + options.set_xla_embed_ir_in_executable( + flag_values->xla_embed_ir_in_executable); + options.set_xla_dump_debug_json_to(flag_values->xla_dump_debug_json_to); + options.set_xla_gpu_cuda_data_dir(flag_values->xla_gpu_cuda_data_dir); + options.set_xla_gpu_ftz(flag_values->xla_gpu_ftz); std::vector extra_options_parts = tensorflow::str_util::Split(flag_values->xla_backend_extra_options, ','); diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc b/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc deleted file mode 100644 index f8f6ea26b1d..00000000000 --- a/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.cc +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's gpu_backend_lib module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static GpuBackendLibFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new GpuBackendLibFlags; - flags->dump_temp_products_to = ""; - flags->ftz = false; - flags->fma = true; - flags->verbose_ptx_asm = false; - flags->kernel = ""; - flags->llvm_dump_passes = false; - flags->llvm_cl_opts = ""; - flags->dump_ir_before_passes = false; - flags->opt_level = 3; - flag_list = new std::vector({ - tensorflow::Flag("dump_temp_products_to", &flags->dump_temp_products_to, - "dump temporary compilation products to this directory. " - "If empty, no dump is produced"), - tensorflow::Flag("ftz", &flags->ftz, "flush to zero semantics"), - tensorflow::Flag("fma", &flags->fma, "use FMA synthesis"), - tensorflow::Flag("verbose_ptx_asm", &flags->verbose_ptx_asm, - "emit PTX assembly with extra comments"), - tensorflow::Flag("kernel", &flags->kernel, - "only emit the IR and PTX for this kernel"), - tensorflow::Flag("llvm_dump_passes", &flags->llvm_dump_passes, - "dump the passes LLVM runs to stderr"), - tensorflow::Flag( - "llvm_cl_opts", &flags->llvm_cl_opts, - "comma-separated list of command line options to pass to " - "LLVM. For example, --llvm_cl_opts=--print-before=loop-unroll"), - tensorflow::Flag("dump_ir_before_passes", &flags->dump_ir_before_passes, - "dump the IR before each optimization pass in " - "sequentially-named files."), - tensorflow::Flag("opt_level", &flags->opt_level, - "optimization level (default to 3)"), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's gpu_backend_lib -// module. -void AppendGpuBackendLibFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the GpuBackendLibFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -GpuBackendLibFlags* GetGpuBackendLibFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h b/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h deleted file mode 100644 index 31cb50e9da9..00000000000 --- a/tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_ - -// Legacy flags for XLA's gpu_backend_lib module. - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's gpu_backend_lib -// module. -void AppendGpuBackendLibFlags(std::vector* flag_list); - -// The values of flags associated with XLA's gpu_backend_lib module. -typedef struct { - string dump_temp_products_to; // temporary compilation products dir - bool ftz; // flush to zero semantics - bool fma; // use FMA synthesis - bool verbose_ptx_asm; // emit PTX assembly with extra comments - string kernel; // only emit the IR and PTX for this kernel - bool llvm_dump_passes; // dump the passes LLVM runs to stderr - string llvm_cl_opts; // comma-separated list of LLVM options - bool dump_ir_before_passes; // dump IR before each pass - int32 opt_level; // optimization level -} GpuBackendLibFlags; - -// Return a pointer to the GpuBackendLibFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -GpuBackendLibFlags* GetGpuBackendLibFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_BACKEND_LIB_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc deleted file mode 100644 index 131e3ce70ac..00000000000 --- a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.cc +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's gpu_compiler module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static GpuCompilerFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new GpuCompilerFlags; - flags->xla_gpu_embed_ir = false; - flags->xla_cuda_data_dir = "./cuda_sdk_lib"; - flags->xla_gpu_dump_debug_json_to = ""; - flag_list = new std::vector({ - tensorflow::Flag( - "xla_gpu_embed_ir", &flags->xla_gpu_embed_ir, - "Embed the LLVM IR module string in the resultant GpuExecutable."), - tensorflow::Flag( - "xla_cuda_data_dir", &flags->xla_cuda_data_dir, - "If non-empty, specifies a local directory containing ptxas and " - "nvvm libdevice files. Otherwise, by default, we use those from " - "runfile directories."), - tensorflow::Flag("xla_ptxas_path", &flags->xla_ptxas_path, - "The path to ptxas. Required to log stats of the ptx."), - tensorflow::Flag("xla_gpu_dump_debug_json_to", - &flags->xla_gpu_dump_debug_json_to, - "Dump debug JSON to this directory."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's gpu_compiler -// module. -void AppendGpuCompilerFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the GpuCompilerFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -GpuCompilerFlags* GetGpuCompilerFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h b/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h deleted file mode 100644 index 0cf39e0ab35..00000000000 --- a/tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_ - -// Legacy flags for XLA's gpu_compiler module. - -#include - -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's gpu_compiler -// module. -void AppendGpuCompilerFlags(std::vector* flag_list); - -// The values of flags associated with XLA's gpu_compiler module. -typedef struct { - bool xla_gpu_embed_ir; // Embed the LLVM IR module string in the resultant - // GpuExecutable. - string xla_cuda_data_dir; // If non-empty, specifies a local directory - // containing ptxas and nvvm libdevice files. - // Otherwise, by default, we use those from runfile - // directories. - string xla_ptxas_path; // The path to ptxas. Required to log stats of - // the ptx. - string xla_gpu_dump_debug_json_to; // Dump debug JSON to this directory. -} GpuCompilerFlags; - -// Return a pointer to the GpuCompilerFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -GpuCompilerFlags* GetGpuCompilerFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_GPU_COMPILER_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.cc b/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.cc deleted file mode 100644 index 3c53729a670..00000000000 --- a/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.cc +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Legacy flags for XLA's llvm_util module. - -#include // NOLINT(build/c++11): only using std::call_once, not mutex. -#include - -#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h" -#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h" -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Pointers to the parsed value of the flags and flag descriptors, initialized -// via flags_init. -static LlvmUtilFlags* flags; -static std::vector* flag_list; -static std::once_flag flags_init; - -// Allocate *flags. Called via call_once(&flags_init,...). -static void AllocateFlags() { - flags = new LlvmUtilFlags; - flags->xla_emit_tbaa = true; - flag_list = new std::vector({ - tensorflow::Flag("xla_emit_tbaa", &flags->xla_emit_tbaa, - "Perform type-based alias analysis optimizations for " - "LLVM-based backends."), - }); - ParseFlagsFromEnv(*flag_list); -} - -// Append to *append_to flag definitions associated with XLA's llvm_util -// module. -void AppendLlvmUtilFlags(std::vector* append_to) { - std::call_once(flags_init, &AllocateFlags); - append_to->insert(append_to->end(), flag_list->begin(), flag_list->end()); -} - -// Return a pointer to the LlvmUtilFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -LlvmUtilFlags* GetLlvmUtilFlags() { - std::call_once(flags_init, &AllocateFlags); - return flags; -} - -} // namespace legacy_flags -} // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h b/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h deleted file mode 100644 index 98da26b4b80..00000000000 --- a/tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_ -#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_ - -// Legacy flags for XLA's llvm_util module. - -#include - -#include "tensorflow/core/platform/types.h" -#include "tensorflow/core/util/command_line_flags.h" - -namespace xla { -namespace legacy_flags { - -// Append to *flag_list flag definitions associated with XLA's llvm_util module. -void AppendLlvmUtilFlags(std::vector* flag_list); - -// The values of flags associated with XLA's llvm_util module. -typedef struct { - bool xla_emit_tbaa; // Perform type-based alias analysis optimizations for - // LLVM-based backends. -} LlvmUtilFlags; - -// Return a pointer to the LlvmUtilFlags struct; -// repeated calls return the same pointer. -// This should be called only after Flags::Parse() has returned. -LlvmUtilFlags* GetLlvmUtilFlags(); - -} // namespace legacy_flags -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_LLVM_UTIL_FLAGS_H_ diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index caef3a3869f..b6bd1158d23 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -321,6 +321,7 @@ Status Literal::Copy(const Literal& src_literal, } std::unique_ptr Literal::Relayout(const Layout& layout) const { + CHECK(ShapeUtil::IsArray(shape())); std::unique_ptr result = CloneToUnique(); *result->mutable_shape()->mutable_layout() = layout; @@ -754,10 +755,30 @@ void Literal::EachCellAsString( } namespace { +template +std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { + auto result_literal = MakeUnique(); + Shape* result_shape = result_literal->mutable_shape(); + *result_shape = src_literal.shape(); + result_shape->set_element_type( + primitive_util::NativeToPrimitiveType()); + result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape)); + tensorflow::gtl::ArraySlice src_data = + src_literal.GetArraySlice(); + tensorflow::gtl::MutableArraySlice dest_data = + result_literal->GetMutableArraySlice(); + int64 num_elements = ShapeUtil::ElementsIn(src_literal.shape()); + + for (int64 i = 0; i < num_elements; ++i) { + dest_data[i] = static_cast(src_data[i]); + } + return result_literal; +} + template std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); - return LiteralUtil::Convert< + return ConvertBetweenNativeTypes< typename primitive_util::PrimitiveTypeToNative::type, typename primitive_util::PrimitiveTypeToNative< primitive_dest_type>::type>(src_literal); @@ -782,19 +803,20 @@ StatusOr> ConvertIfDestTypeMatches( #undef CONVERT_IF_TYPES_MATCH // Other types are not yet supported. default: - return tensorflow::errors::InvalidArgument( - "Unimplemented: ConvertIfDestTypeMatches for type " + - PrimitiveType_Name(src_literal.shape().element_type())); + return InvalidArgument( + "Unimplemented: Convert from type %s to type %s", + PrimitiveType_Name(src_literal.shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str()); } } -} +} // namespace -StatusOr> LiteralUtil::ConvertIfSrcTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type) { - switch (src_literal.shape().element_type()) { +StatusOr> Literal::Convert( + PrimitiveType primitive_dest_type) const { + switch (shape().element_type()) { #define CONVERT_IF_DEST_TYPE_MATCHES(type) \ case (type): \ - return ConvertIfDestTypeMatches<(type)>(src_literal, primitive_dest_type); + return ConvertIfDestTypeMatches<(type)>(*this, primitive_dest_type); CONVERT_IF_DEST_TYPE_MATCHES(PRED) CONVERT_IF_DEST_TYPE_MATCHES(S8) CONVERT_IF_DEST_TYPE_MATCHES(S32) @@ -807,9 +829,9 @@ StatusOr> LiteralUtil::ConvertIfSrcTypeMatches( #undef CONVERT_IF_DEST_TYPE_MATCHES // Other types are not yet supported. default: - return tensorflow::errors::InvalidArgument( - "Unimplemented: ConvertIfSrcTypeMatches for type " + - PrimitiveType_Name(src_literal.shape().element_type())); + return InvalidArgument("Unimplemented: Convert from type %s to type %s", + PrimitiveType_Name(shape().element_type()).c_str(), + PrimitiveType_Name(primitive_dest_type).c_str()); } } diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index 42c8b61acec..ce4a00fa551 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -251,7 +251,7 @@ class Literal { *other = temp; } - // CreatesCreate new literal of a given rank. To minimize ambiguity (for users + // Creates a new literal of a given rank. To minimize ambiguity (for users // and the compiler) these CreateR[0-2] methods should explicitly specify the // native type. For example: // @@ -362,10 +362,10 @@ class Literal { template std::unique_ptr Replicate(int64 times) const; - // Creates a literal by converting each element in this literal to a new - // type. - template - std::unique_ptr Convert() const; + // Converts this literal to another primitive type. Returns an error if the + // conversion is not possible. + StatusOr> Convert( + PrimitiveType primitive_dest_type) const; // Creates a literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); @@ -444,10 +444,21 @@ class Literal { template void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); - // Retrieves the mutable array slice interface which can be used to manipulate - // pre-allocated literal values. + // Returns a (Mutable)ArraySlice view of the array for this literal for the + // given NativeT (e.g., float). These functions map native type to XLA + // PrimitiveType via template specialization. The unspecialized forms below + // aborts to handle the error case where the given native type does not map to + // an XLA primitive type. template - tensorflow::gtl::MutableArraySlice GetMutableArraySlice(); + tensorflow::gtl::ArraySlice GetArraySlice() const { + static_assert(!std::is_same::value, + "Cannot map native type to primitive type."); + } + template + tensorflow::gtl::MutableArraySlice GetMutableArraySlice() { + static_assert(!std::is_same::value, + "Cannot map native type to primitive type."); + } // Returns the element value at index (0, ..., 0), however many zeroes are // required for that index. @@ -588,17 +599,6 @@ class Literal { bool IsZero(tensorflow::gtl::ArraySlice indices) const; private: - // Returns an ArraySlice view of the array for this literal for the given - // NativeT (e.g., float). These functions map native type to XLA PrimitiveType - // via template specialization. The unspecialized forms below aborts to handle - // the error case where the given native type does not map to an XLA primitive - // type. - template - tensorflow::gtl::ArraySlice GetArraySlice() const { - static_assert(!std::is_same::value, - "Cannot map native type to primitive type."); - } - // Copy from a LiteralProto instance. void CopyFromProto(const LiteralProto& literal_proto); @@ -646,544 +646,6 @@ class Literal { std::vector tuple_literals_; }; -// Utility class for dealing with XLA literal values. Most methods are -// templated by native (host) type which corresponds to a unique XLA -// PrimitiveType. See ComputationBuilder for details. Not all primitive types -// defined in xla_data.proto have a corresponding native type or even have a -// storage location in the Literal proto yet (for example, primitive type F16). -// -// TODO(dnovillo) - All functions in this class simply redirect to the -// corresponding function in class Literal. Remove this class after converting -// all user code to use Literal directly. -class LiteralUtil { - public: - // Creates new literal of a given rank. To minimize ambiguity (for users and - // the compiler) these CreateR[0-2] methods should explicitly specify the - // native type. For example: - // - // CreateR1({1.0, 42.0}); - // CreateR2({{1, 2}, {3, 4}}); - // - // The variants not ending with WithLayout use the default XLA layout for the - // literal's linear representation in memory. - template - static std::unique_ptr CreateR0(NativeT value) { - return Literal::CreateR0(value); - } - - template - static std::unique_ptr CreateR1( - tensorflow::gtl::ArraySlice values) { - return Literal::CreateR1(values); - } - - static std::unique_ptr CreateR1( - const tensorflow::core::Bitmap& values) { - return Literal::CreateR1(values); - } - - template - static std::unique_ptr CreateR2( - std::initializer_list> values) { - return Literal::CreateR2(values); - } - - template - static std::unique_ptr CreateR2WithLayout( - std::initializer_list> values, - const Layout& layout) { - return Literal::CreateR2WithLayout(values, layout); - } - - template - static std::unique_ptr CreateR3( - std::initializer_list< - std::initializer_list>> - values) { - return Literal::CreateR3(values); - } - - template - static std::unique_ptr CreateR3WithLayout( - std::initializer_list< - std::initializer_list>> - values, - const Layout& layout) { - return Literal::CreateR3WithLayout(values, layout); - } - - template - static std::unique_ptr CreateR4( - std::initializer_list>>> - values) { - return Literal::CreateR4(values); - } - - template - static std::unique_ptr CreateR4WithLayout( - std::initializer_list>>> - values, - const Layout& layout) { - return Literal::CreateR4WithLayout(values, layout); - } - - // Creates a new Literal object with the shape specified as parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromShape(const Shape& shape) { - return Literal::CreateFromShape(shape); - } - - // Creates a new Literal object with its values havings the primitive_type - // type, and with dimensions defined by the dimensions parameter. - // The content of the literal values is the default value of the primitive - // type of literal itself (0 for numeric types, and false for predicates). - static std::unique_ptr CreateFromDimensions( - PrimitiveType primitive_type, - tensorflow::gtl::ArraySlice dimensions) { - return Literal::CreateFromDimensions(primitive_type, dimensions); - } - - // Copies the values from src_literal, starting at src_base shape indexes, - // to dest_literal, starting at dest_base, where the copy size in each - // dimension is specified by copy_size. - // - // The src_literal and dest_literal must have the same primitive type, - // src_base+copy_size must fit the source literal dimensions, as well as - // dest_base+copy_size must fit the destination literal dimensions. - static Status Copy(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - Literal* dest_literal, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { - return dest_literal->Copy(src_literal, src_base, dest_base, copy_size); - } - - // Creates a new value that has the equivalent value as literal, but conforms - // to new_layout; e.g. a literal matrix that was in {0, 1} minor-to-major - // dimension layout can be re-laid-out as {1, 0} minor-to-major dimension - // layout and the value in the cell at any given logical index (i0, i1) will - // be the same. - // - // Note: this is useful when the client wants to ensure that a value placed in - // the XLA allocation tracker has a particular layout; for efficiency - // purposes or avoiding unimplemented operation/layout combinations. - static std::unique_ptr Relayout(const Literal& literal, - const Layout& new_layout) { - return literal.Relayout(new_layout); - } - - // Reshapes literal 'input' to have 'shape'. Both the original shape and - // 'shape' must contain the same number of elements. The implementation - // currently only supports monotonic dim0-major layouts. - static StatusOr> Reshape( - const xla::Literal& input, tensorflow::gtl::ArraySlice shape) { - return input.Reshape(shape); - } - - // Creates a new literal by reordering the dimensions of the original literal. - // The given `permutation` must be a permutation of the dimension numbers - // in the original literal, and it specifies the order of the new dimensions - // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). - // For example, a transpose call on a literal of shape [3 x 8 x 4] and - // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. - static std::unique_ptr Transpose( - const Literal& literal, tensorflow::gtl::ArraySlice permutation) { - return literal.Transpose(permutation); - } - - // Creates a sub-array from the given literal by extracting the indices - // [start_index, limit_index) of each dimension. The result literal has the - // same rank and layout as for the given literal. The number of indices in - // start_indices and limit_indices must be the rank of the literal, and the - // indices follow the order of the dimensions. - static std::unique_ptr Slice( - const Literal& literal, tensorflow::gtl::ArraySlice start_indices, - tensorflow::gtl::ArraySlice limit_indices) { - return literal.Slice(start_indices, limit_indices); - } - - // Creates a literal with a prepended dimension with bound "times"; e.g. a - // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from the input - // literal replicated four times. - template - static std::unique_ptr Replicate(const Literal& input, int64 times) { - return input.Replicate(times); - } - - // Creates a literal by converting each element in an original literal to a - // new type. - template - static std::unique_ptr Convert(const Literal& literal) { - return literal.Convert(); - } - - // Convert a literal to another primitive type, but only if the literal - // type is connvertable into the destination type - static StatusOr> ConvertIfSrcTypeMatches( - const Literal& src_literal, PrimitiveType primitive_dest_type); - - // Creates a literal value zero of the given primitive type. - static Literal Zero(PrimitiveType primitive_type) { - return Literal::Zero(primitive_type); - } - - // Creates a literal value one of the given primitive type. - static Literal One(PrimitiveType primitive_type) { - return Literal::One(primitive_type); - } - - // Creates a literal value containing the minimum value of the given - // primitive type. For floating-point types, returns -inf. - static Literal MinValue(PrimitiveType primitive_type) { - return Literal::MinValue(primitive_type); - } - - // Creates a literal value containing the maximum value of the given - // primitive type. For floating-point types, returns inf. - static Literal MaxValue(PrimitiveType primitive_type) { - return Literal::MaxValue(primitive_type); - } - - // Creates a literal of the given shape where each element is `value`. - template - static std::unique_ptr CreateFullWithMonotonicDim0MajorLayout( - tensorflow::gtl::ArraySlice dimensions, NativeT value) { - return Literal::CreateFullWithMonotonicDim0MajorLayout(dimensions, value); - } - - // Creates a new literal from an array. The variants not ending with - // WithLayout use the default XLA layout for the literal's linear - // representation in memory. - template - static std::unique_ptr CreateR2FromArray2D( - const Array2D& values) { - return Literal::CreateR2FromArray2D(values); - } - - template - static std::unique_ptr CreateR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { - return Literal::CreateR2FromArray2DWithLayout(values, layout); - } - - template - static std::unique_ptr CreateR3FromArray3D( - const Array3D& values) { - return Literal::CreateR3FromArray3D(values); - } - - template - static std::unique_ptr CreateR3FromArray3DWithLayout( - const Array3D& values, const Layout& layout) { - return Literal::CreateR3FromArray3DWithLayout(values, layout); - } - - template - static std::unique_ptr CreateR4FromArray4D( - const Array4D& values) { - return Literal::CreateR4FromArray4D(values); - } - - template - static std::unique_ptr CreateR4FromArray4DWithLayout( - const Array4D& values, const Layout& layout) { - return Literal::CreateR4FromArray4DWithLayout(values, layout); - } - - // Creates a new vector of U8s literal value from a string. - static std::unique_ptr CreateR1U8(tensorflow::StringPiece value) { - return Literal::CreateR1U8(value); - } - - // Creates a linspace-populated literal with the given number of rows and - // columns. - static std::unique_ptr CreateR2F32Linspace(float from, float to, - int64 rows, int64 cols) { - return Literal::CreateR2F32Linspace(from, to, rows, cols); - } - - // Creates a literal that projects the (x, y) dimensions given in values into - // the z dimension given by "projection". - template - static std::unique_ptr CreateR3Projected( - std::initializer_list> values, - int64 projection) { - return Literal::CreateR3Projected(values, projection); - } - - // Creates a literal that projects the (x, y) dimensions given in values into - // the z and p dimensions given. - template - static std::unique_ptr CreateR4Projected( - std::initializer_list> values, - int64 projection_p, int64 projection_z) { - return Literal::CreateR4Projected(values, projection_p, projection_z); - } - - // Clones literal into an owned unique_ptr version. - static std::unique_ptr CloneToUnique(const Literal& literal) { - return literal.CloneToUnique(); - } - - // Returns the linear index of the given index within the literal's - // element_type repeated field. - static int64 LinearIndex(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index) { - return literal.LinearIndex(multi_index); - } - - // Gets or sets an element in the literal at the given index. The index is - // CHECKed against the dimension sizes. - template - static NativeT Get(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index) { - return literal.Get(multi_index); - } - - template - static void Set(Literal* literal, - tensorflow::gtl::ArraySlice multi_index, - NativeT value) { - literal->Set(multi_index, value); - } - - // Retrieves the mutable array slice interface which can be used to manipulate - // pre-allocated literal values. - template - static tensorflow::gtl::MutableArraySlice GetMutableArraySlice( - Literal* literal) { - return literal->GetMutableArraySlice(); - } - - // Returns the element value at index (0, ..., 0), however many zeroes are - // required for that index. - template - static NativeT GetFirstElement(const Literal& literal) { - return literal.GetFirstElement(); - } - - // As Get(), but determines the correct type and converts the value - // into text. - static string GetAsString(const Literal& literal, - tensorflow::gtl::ArraySlice multi_index) { - return literal.GetAsString(multi_index); - } - - // Returns an identity matrix (rank 2) with the given row and column count. - template - static std::unique_ptr MakeIdentityR2(int64 size) { - return Literal::MakeIdentityR2(size); - } - - // Returns a tuple literal composed of given literals. - static std::unique_ptr MakeTuple( - tensorflow::gtl::ArraySlice elements) { - return Literal::MakeTuple(elements); - } - - // Validates that the data payload of the literal matches the literal shape; - // if it does not, an appropriate status is returned. - static tensorflow::Status ValidateLiteral(const Literal& literal) { - return literal.ValidateLiteral(); - } - - // Returns a string representation of the literal value. - static string ToString(const Literal& literal) { return literal.ToString(); } - - // Invokes the "per cell" callback for each element in the provided - // literal with the element's indices and a string representation of - // the element's value. - // - // This function is useful if you want a polymorphic representation - // of the tensor's elements (turning it to a string for something - // like representation in a protobuf). - static void EachCellAsString( - const Literal& literal, - const std::function indices, - const string& value)>& per_cell) { - literal.EachCellAsString(per_cell); - } - - template - static void EachCell( - const Literal& literal, - std::function indices, - NativeT value)> - per_cell) { - literal.EachCell(per_cell); - } - - // Templated methods which populate the given repeated field in the Literal - // proto with the given value(s). The Shape field of the Literal proto is set - // to match the array dimensions and type. Examples: - // - // // Populate with floats. - // Array2D float_values = ... - // PopulateR2FromArray2D(values, literal); - // - // // Populate with int32s. - // PopulateR2({{1, 2}, {3, 4}}, literal); - // - template - static void PopulateR0(NativeT values, Literal* literal) { - literal->PopulateR0(values); - } - - template - static void PopulateR1(tensorflow::gtl::ArraySlice values, - Literal* literal) { - literal->PopulateR1(values); - } - - static void PopulateR1(const tensorflow::core::Bitmap& values, - Literal* literal) { - literal->PopulateR1(values); - } - - template - static void PopulateR2( - std::initializer_list> values, - Literal* literal) { - literal->PopulateR2(values); - } - - template - static void PopulateR2WithLayout( - std::initializer_list> values, - const Layout& layout, Literal* literal) { - literal->PopulateR2WithLayout(values, layout); - } - - template - static void PopulateR2FromArray2D(const Array2D& values, - Literal* literal) { - literal->PopulateR2FromArray2D(values); - } - - template - static void PopulateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout, - Literal* literal) { - literal->PopulateR2FromArray2DWithLayout(values, layout); - } - - template - static void PopulateR3FromArray3D(const Array3D& values, - Literal* literal) { - literal->PopulateR3FromArray3D(values); - } - - template - static void PopulateR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout, - Literal* literal) { - literal->PopulateR3FromArray3DWithLayout(values, layout); - } - - template - static void PopulateR4FromArray4D(const Array4D& values, - Literal* literal) { - literal->PopulateR4FromArray4D(values); - } - - template - static void PopulateR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout, - Literal* literal) { - literal->PopulateR4FromArray4DWithLayout(values, layout); - } - - // Populates literal values by calling the generator function for every cell - // in the literal object. - template - static Status Populate( - Literal* literal, - const std::function indexes)>& - generator) { - return literal->Populate(generator); - } - - // Creates a Literal of the given dimensions with all elements set to the - // given value. - template - static void PopulateWithValue(NativeT value, - tensorflow::gtl::ArraySlice dimensions, - Literal* literal) { - return literal->PopulateWithValue(value, dimensions); - } - - // Returns a pointer to the underlying vector containing the array data. Use - // with care. - static const void* InternalData(const Literal& literal) { - return literal.InternalData(); - } - - static void* MutableInternalData(Literal* literal) { - return literal->MutableInternalData(); - } - - // Allocates space in the underlying vector of the literal sufficient to hold - // num_elements of the literal's primitive type. Values in the vector are set - // to zero. num_elements must equal the number of elements in the literals - // shape. - static void Reserve(int64 num_elements, Literal* literal) { - literal->Reserve(num_elements); - } - - // Allocates space in the underlying vector of the literal sufficient to hold - // num_elements of the literal's primitive type and sets each element in the - // literal to the given value. num_elements must equal the number of elements - // in the literals shape. - template - static void Resize(int64 num_elements, NativeT value, Literal* literal) { - literal->Resize(num_elements, value); - } - - // Returns true if the two given literals have the same shape and - // values. Layout is not considered in the comparison. - static bool Equal(const Literal& literal1, const Literal& literal2) { - return literal1.Equal(literal2); - } - - // Returns whether every element in the given literal is equal to value. - // - // value is an int8 because we expect this to be called with small - // compile-time constants (0, -1, etc.) and so that whatever value you pass - // can be represented exactly by floating-point types as small as 16 bits. - // - // If value doesn't fit in literal's type, returns false. Values of 1/0 are - // considered equal to true/false; other values are not considered equal to - // true. - static bool IsAll(const Literal& literal, int8 value) { - return literal.IsAll(value); - } - - // Like IsAll(const Literal&, int8), except we check whether the literal is - // equal to a particular floating-point number. - // - // If the literal is not a floating-point value, this always returns false. - // - // This casts value to the type of literal, then compares using ==. The usual - // admonishments about floating-point equality checks apply. We expect you to - // use this to check for values that can be expressed precisely as a float, - // e.g. -0.5. - static bool IsAllFloat(const Literal& literal, float value) { - return literal.IsAllFloat(value); - } - - // Returns whether the literal is zero at the specified index. The literal - // must be an array. - static bool IsZero(const Literal& literal, - tensorflow::gtl::ArraySlice indices) { - return literal.IsZero(indices); - } - - TF_DISALLOW_COPY_AND_ASSIGN(LiteralUtil); -}; - // Declarations of template specializations for GetArraySlice and // GetMutableArraySlice. The specializations map native type to XLA primitive // type. @@ -1759,27 +1221,6 @@ void Literal::PopulateWithValue(NativeT value, Resize(ShapeUtil::ElementsIn(shape()), value); } -template -std::unique_ptr Literal::Convert() const { - const Shape& this_shape = shape(); - auto result_literal = MakeUnique(); - Shape* result_shape = result_literal->mutable_shape(); - *result_shape = this_shape; - result_shape->set_element_type( - primitive_util::NativeToPrimitiveType()); - result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape)); - tensorflow::gtl::ArraySlice src_data = - GetArraySlice(); - tensorflow::gtl::MutableArraySlice dest_data = - result_literal->GetMutableArraySlice(); - int64 num_elements = ShapeUtil::ElementsIn(this_shape); - - for (int64 i = 0; i < num_elements; ++i) { - dest_data[i] = static_cast(src_data[i]); - } - return result_literal; -} - template /* static */ std::unique_ptr Literal::CreateFullWithMonotonicDim0MajorLayout( diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 8d4a75d7aff..af84abd550f 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -72,11 +72,11 @@ class LiteralUtilTest : public ::testing::Test { layout_r4_dim0minor_ = LayoutUtil::MakeLayout({0, 1, 2, 3}); literal_r4_2x2x3x3_dim0major_ = - LiteralUtil::CreateR4FromArray4DWithLayout(arr4d, - layout_r4_dim0major_); + Literal::CreateR4FromArray4DWithLayout(arr4d, + layout_r4_dim0major_); literal_r4_2x2x3x3_dim0minor_ = - LiteralUtil::CreateR4FromArray4DWithLayout(arr4d, - layout_r4_dim0minor_); + Literal::CreateR4FromArray4DWithLayout(arr4d, + layout_r4_dim0minor_); } Layout layout_r2_dim0major_; @@ -90,43 +90,42 @@ class LiteralUtilTest : public ::testing::Test { }; TEST_F(LiteralUtilTest, LiteralScalarToString) { - auto true_lit = LiteralUtil::CreateR0(true); - ASSERT_EQ("true", LiteralUtil::ToString(*true_lit)); + auto true_lit = Literal::CreateR0(true); + ASSERT_EQ("true", true_lit->ToString()); - auto false_lit = LiteralUtil::CreateR0(false); - ASSERT_EQ("false", LiteralUtil::ToString(*false_lit)); + auto false_lit = Literal::CreateR0(false); + ASSERT_EQ("false", false_lit->ToString()); - auto u32_lit = LiteralUtil::CreateR0(42); - ASSERT_EQ("42", LiteralUtil::ToString(*u32_lit)); + auto u32_lit = Literal::CreateR0(42); + ASSERT_EQ("42", u32_lit->ToString()); - auto s32_lit = LiteralUtil::CreateR0(-999); - ASSERT_EQ("-999", LiteralUtil::ToString(*s32_lit)); + auto s32_lit = Literal::CreateR0(-999); + ASSERT_EQ("-999", s32_lit->ToString()); - auto f32_lit = LiteralUtil::CreateR0(3.14f); - ASSERT_EQ("3.14", LiteralUtil::ToString(*f32_lit)); + auto f32_lit = Literal::CreateR0(3.14f); + ASSERT_EQ("3.14", f32_lit->ToString()); - auto f16_lit = LiteralUtil::CreateR0(static_cast(0.5f)); - ASSERT_EQ("0.5", LiteralUtil::ToString(*f16_lit)); + auto f16_lit = Literal::CreateR0(static_cast(0.5f)); + ASSERT_EQ("0.5", f16_lit->ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { - auto pred_vec = LiteralUtil::CreateR1({true, false, true}); - ASSERT_EQ("{101}", LiteralUtil::ToString(*pred_vec)); + auto pred_vec = Literal::CreateR1({true, false, true}); + ASSERT_EQ("{101}", pred_vec->ToString()); } TEST_F(LiteralUtilTest, R2ToString) { - const auto literal = LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}); + const auto literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); const string expected = R"(s32[3,2] { { 1, 2 }, { 3, 4 }, { 5, 6 }, })"; - ASSERT_EQ(expected, LiteralUtil::ToString(*literal)); + ASSERT_EQ(expected, literal->ToString()); } TEST_F(LiteralUtilTest, R3ToString) { - const auto literal = - LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}}); + const auto literal = Literal::CreateR3({{{1}, {2}}, {{3}, {4}}, {{5}, {6}}}); const string expected = R"(s32[3,2,1] { { { 1 }, { 2 } }, @@ -135,13 +134,13 @@ TEST_F(LiteralUtilTest, R3ToString) { { { 5 }, { 6 } } })"; - ASSERT_EQ(expected, LiteralUtil::ToString(*literal)); + ASSERT_EQ(expected, literal->ToString()); } TEST_F(LiteralUtilTest, TupleToString) { - auto scalar = LiteralUtil::CreateR0(1.0); - auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + auto scalar = Literal::CreateR0(1.0); + auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); const string expected = R"((f32[], f32[2,2]) ( 1, f32[2,2] { @@ -149,7 +148,7 @@ f32[2,2] { { 3, 4 }, }, ))"; - ASSERT_EQ(expected, LiteralUtil::ToString(*tuple)); + ASSERT_EQ(expected, tuple->ToString()); } TEST_F(LiteralUtilTest, CreateR3FromArray3d) { @@ -164,9 +163,9 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { }); // clang-format on - auto literal = LiteralUtil::CreateR3FromArray3D(array_3d); + auto literal = Literal::CreateR3FromArray3D(array_3d); EXPECT_THAT(literal->shape().dimensions(), ElementsAre(2, 3, 2)); - string result = LiteralUtil::ToString(*literal); + string result = literal->ToString(); const string expected = R"(f32[2,3,2] { { { 1, 2 }, { 3, 4 }, @@ -180,14 +179,14 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { // clang-format off - auto literal = LiteralUtil::CreateR4Projected({ + auto literal = Literal::CreateR4Projected({ {1, 2}, {1001, 1002}, {2001, 2002}, }, /*projection_p=*/1, /*projection_z=*/2); // clang-format on EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2)); - string result = LiteralUtil::ToString(*literal); + string result = literal->ToString(); const string expected = R"(f32[1,2,3,2] { { // i0=0 { // i1=0 @@ -208,7 +207,7 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { EXPECT_THAT(literal_r4_2x2x3x3_dim0major_->shape().dimensions(), ElementsAre(2, 2, 3, 3)); - string result = LiteralUtil::ToString(*literal_r4_2x2x3x3_dim0major_); + string result = literal_r4_2x2x3x3_dim0major_->ToString(); const string expected = R"(f32[2,2,3,3] { { // i0=0 { // i1=0 @@ -240,14 +239,13 @@ TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { TEST_F(LiteralUtilTest, EachCellR2F32) { // clang-format off - auto literal = LiteralUtil::CreateR2({ + auto literal = Literal::CreateR2({ {3.1f, 4.2f}, {9.3f, 12.4f}, }); // clang-format on std::vector> seen; - LiteralUtil::EachCellAsString( - *literal, + literal->EachCellAsString( [&seen](tensorflow::gtl::ArraySlice indices, const string& value) { seen.emplace_back(indices[0], indices[1], value); }); @@ -259,176 +257,161 @@ TEST_F(LiteralUtilTest, EachCellR2F32) { } TEST_F(LiteralUtilTest, ScalarEquality) { - // Test LiteralUtil::Equal with scalars. - auto f32_42 = LiteralUtil::CreateR0(42.0); - auto f32_42_clone = LiteralUtil::CreateR0(42.0); + // Test Literal::Equal with scalars. + auto f32_42 = Literal::CreateR0(42.0); + auto f32_42_clone = Literal::CreateR0(42.0); - EXPECT_TRUE(LiteralUtil::Equal(*f32_42, *f32_42)); - EXPECT_TRUE(LiteralUtil::Equal(*f32_42, *f32_42_clone)); + EXPECT_TRUE(f32_42->Equal(*f32_42)); + EXPECT_TRUE(f32_42->Equal(*f32_42_clone)); - auto f32_123 = LiteralUtil::CreateR0(123.0); - EXPECT_FALSE(LiteralUtil::Equal(*f32_42, *f32_123)); + auto f32_123 = Literal::CreateR0(123.0); + EXPECT_FALSE(f32_42->Equal(*f32_123)); - auto f64_42 = LiteralUtil::CreateR0(42.0); - EXPECT_FALSE(LiteralUtil::Equal(*f32_42, *f64_42)); + auto f64_42 = Literal::CreateR0(42.0); + EXPECT_FALSE(f32_42->Equal(*f64_42)); } TEST_F(LiteralUtilTest, NonScalarEquality) { - // Test LiteralUtil::Equal with nonscalars. - auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto matrix_clone = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto matrix_different = - LiteralUtil::CreateR2({{4.0, 3.0}, {1.0, 2.0}}); - auto vector_literal = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0}); - auto scalar = LiteralUtil::CreateR0(1.0); + // Test Literal::Equal with nonscalars. + auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix_clone = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix_different = Literal::CreateR2({{4.0, 3.0}, {1.0, 2.0}}); + auto vector_literal = Literal::CreateR1({1.0, 2.0, 3.0, 4.0}); + auto scalar = Literal::CreateR0(1.0); - EXPECT_TRUE(LiteralUtil::Equal(*matrix, *matrix)); - EXPECT_TRUE(LiteralUtil::Equal(*matrix, *matrix_clone)); - EXPECT_FALSE(LiteralUtil::Equal(*matrix, *matrix_different)); - EXPECT_FALSE(LiteralUtil::Equal(*matrix, *vector_literal)); - EXPECT_FALSE(LiteralUtil::Equal(*matrix, *scalar)); + EXPECT_TRUE(matrix->Equal(*matrix)); + EXPECT_TRUE(matrix->Equal(*matrix_clone)); + EXPECT_FALSE(matrix->Equal(*matrix_different)); + EXPECT_FALSE(matrix->Equal(*vector_literal)); + EXPECT_FALSE(matrix->Equal(*scalar)); } TEST_F(LiteralUtilTest, DifferentLayoutEquality) { - // Test LiteralUtil::Equal with literals which have different layouts. + // Test Literal::Equal with literals which have different layouts. auto colmajor = MakeUnique(); *colmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2}); *colmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); - LiteralUtil::Reserve(4, colmajor.get()); - LiteralUtil::Set(colmajor.get(), {0, 0}, 1.0); - LiteralUtil::Set(colmajor.get(), {0, 1}, 2.0); - LiteralUtil::Set(colmajor.get(), {1, 0}, 3.0); - LiteralUtil::Set(colmajor.get(), {1, 1}, 4.0); + colmajor.get()->Reserve(4); + colmajor.get()->Set({0, 0}, 1.0); + colmajor.get()->Set({0, 1}, 2.0); + colmajor.get()->Set({1, 0}, 3.0); + colmajor.get()->Set({1, 1}, 4.0); auto rowmajor = MakeUnique(); *rowmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2}); *rowmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); - LiteralUtil::Reserve(4, rowmajor.get()); - LiteralUtil::Set(rowmajor.get(), {0, 0}, 1.0); - LiteralUtil::Set(rowmajor.get(), {0, 1}, 2.0); - LiteralUtil::Set(rowmajor.get(), {1, 0}, 3.0); - LiteralUtil::Set(rowmajor.get(), {1, 1}, 4.0); + rowmajor.get()->Reserve(4); + rowmajor.get()->Set({0, 0}, 1.0); + rowmajor.get()->Set({0, 1}, 2.0); + rowmajor.get()->Set({1, 0}, 3.0); + rowmajor.get()->Set({1, 1}, 4.0); - EXPECT_TRUE(LiteralUtil::Equal(*rowmajor, *colmajor)); + EXPECT_TRUE(rowmajor->Equal(*colmajor)); } TEST_F(LiteralUtilTest, TupleEquality) { - // Test LiteralUtil::Equal with tuples. - auto scalar = LiteralUtil::CreateR0(1.0); - auto matrix = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); - auto tuple1 = LiteralUtil::MakeTuple({scalar.get(), matrix.get()}); + // Test Literal::Equal with tuples. + auto scalar = Literal::CreateR0(1.0); + auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple1 = Literal::MakeTuple({scalar.get(), matrix.get()}); // Tuple with the same elements. One element is shared with the original // tuple, the other is a clone of the element in the original tuple. - auto scalar_clone = LiteralUtil::CreateR0(1.0); - auto tuple2 = LiteralUtil::MakeTuple({scalar_clone.get(), matrix.get()}); - EXPECT_TRUE(LiteralUtil::Equal(*tuple1, *tuple2)); + auto scalar_clone = Literal::CreateR0(1.0); + auto tuple2 = Literal::MakeTuple({scalar_clone.get(), matrix.get()}); + EXPECT_TRUE(tuple1->Equal(*tuple2)); // Tuple with elements reversed. - auto reversed_tuple = LiteralUtil::MakeTuple({matrix.get(), scalar.get()}); - EXPECT_FALSE(LiteralUtil::Equal(*tuple1, *reversed_tuple)); + auto reversed_tuple = Literal::MakeTuple({matrix.get(), scalar.get()}); + EXPECT_FALSE(tuple1->Equal(*reversed_tuple)); // Tuple with different value. - auto scalar_42 = LiteralUtil::CreateR0(42.0); - auto different_tuple = - LiteralUtil::MakeTuple({scalar_42.get(), matrix.get()}); - EXPECT_FALSE(LiteralUtil::Equal(*tuple1, *different_tuple)); + auto scalar_42 = Literal::CreateR0(42.0); + auto different_tuple = Literal::MakeTuple({scalar_42.get(), matrix.get()}); + EXPECT_FALSE(tuple1->Equal(*different_tuple)); } TEST_F(LiteralUtilTest, IsAllTuple) { - auto element1 = LiteralUtil::CreateR0(0.0); - auto element2 = LiteralUtil::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); - auto tuple = LiteralUtil::MakeTuple({element1.get(), element1.get()}); + auto element1 = Literal::CreateR0(0.0); + auto element2 = Literal::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); + auto tuple = Literal::MakeTuple({element1.get(), element1.get()}); // Tuples should always return false for IsAll. - EXPECT_FALSE(LiteralUtil::IsAll(*tuple, 0)); - EXPECT_FALSE(LiteralUtil::IsAll(*tuple, 1)); + EXPECT_FALSE(tuple->IsAll(0)); + EXPECT_FALSE(tuple->IsAll(1)); } TEST_F(LiteralUtilTest, IsAll) { - EXPECT_TRUE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(false), 0)); - EXPECT_TRUE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(true), 1)); - EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(false), 1)); - EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(false), 2)); - EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(true), 0)); - EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(true), 2)); - EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(true), -1)); + EXPECT_TRUE(Literal::CreateR0(false)->IsAll(0)); + EXPECT_TRUE(Literal::CreateR0(true)->IsAll(1)); + EXPECT_FALSE(Literal::CreateR0(false)->IsAll(1)); + EXPECT_FALSE(Literal::CreateR0(false)->IsAll(2)); + EXPECT_FALSE(Literal::CreateR0(true)->IsAll(0)); + EXPECT_FALSE(Literal::CreateR0(true)->IsAll(2)); + EXPECT_FALSE(Literal::CreateR0(true)->IsAll(-1)); // We shouldn't reinterpret int8_min as an unsigned type and then decide that // it is equal to 255. auto int8_min = std::numeric_limits::min(); - EXPECT_FALSE( - LiteralUtil::IsAll(*LiteralUtil::CreateR0(255), int8_min)); + EXPECT_FALSE(Literal::CreateR0(255)->IsAll(int8_min)); - EXPECT_TRUE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(42.0), 42)); - EXPECT_FALSE(LiteralUtil::IsAll(*LiteralUtil::CreateR0(42.0001), 42)); + EXPECT_TRUE(Literal::CreateR0(42.0)->IsAll(42)); + EXPECT_FALSE(Literal::CreateR0(42.0001)->IsAll(42)); - EXPECT_TRUE( - LiteralUtil::IsAll(*LiteralUtil::CreateR1({100, 100, 100}), 100)); - EXPECT_FALSE(LiteralUtil::IsAll( - *LiteralUtil::CreateR1({100, 100, 100.001}), 100)); + EXPECT_TRUE(Literal::CreateR1({100, 100, 100})->IsAll(100)); + EXPECT_FALSE(Literal::CreateR1({100, 100, 100.001})->IsAll(100)); - EXPECT_TRUE( - LiteralUtil::IsAll(*LiteralUtil::CreateR2({{8, 8}, {8, 8}}), 8)); - EXPECT_FALSE( - LiteralUtil::IsAll(*LiteralUtil::CreateR2({{8, 8}, {8, 9}}), 8)); - EXPECT_FALSE( - LiteralUtil::IsAll(*LiteralUtil::CreateR2({{9, 8}, {8, 8}}), 8)); + EXPECT_TRUE(Literal::CreateR2({{8, 8}, {8, 8}})->IsAll(8)); + EXPECT_FALSE(Literal::CreateR2({{8, 8}, {8, 9}})->IsAll(8)); + EXPECT_FALSE(Literal::CreateR2({{9, 8}, {8, 8}})->IsAll(8)); half h8(8.0f); half h9(9.0f); - EXPECT_TRUE( - LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h8}, {h8}}), 8)); - EXPECT_FALSE( - LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h8}, {h9}}), 8)); - EXPECT_FALSE( - LiteralUtil::IsAll(*LiteralUtil::CreateR2({{h9}, {h8}}), 8)); + EXPECT_TRUE(Literal::CreateR2({{h8}, {h8}})->IsAll(8)); + EXPECT_FALSE(Literal::CreateR2({{h8}, {h9}})->IsAll(8)); + EXPECT_FALSE(Literal::CreateR2({{h9}, {h8}})->IsAll(8)); auto uint64_max = std::numeric_limits::max(); - EXPECT_FALSE(LiteralUtil::IsAll( - *LiteralUtil::CreateR2( - {{uint64_max, uint64_max}, {uint64_max, uint64_max}}), - -1)); + EXPECT_FALSE(Literal::CreateR2( + {{uint64_max, uint64_max}, {uint64_max, uint64_max}}) + ->IsAll(-1)); } TEST_F(LiteralUtilTest, IsAllFloat) { // IsAllFloat always returns false when the literal is not floating-point. - EXPECT_FALSE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(false), 0)); - EXPECT_FALSE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(0), 0)); - EXPECT_FALSE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(0), 0)); - EXPECT_FALSE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(0), 0)); + EXPECT_FALSE(Literal::CreateR0(false)->IsAllFloat(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllFloat(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllFloat(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllFloat(0)); - EXPECT_TRUE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(0), 0)); - EXPECT_TRUE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(.5), .5)); - EXPECT_TRUE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(-.5), -.5)); + EXPECT_TRUE(Literal::CreateR0(0)->IsAllFloat(0)); + EXPECT_TRUE(Literal::CreateR0(.5)->IsAllFloat(.5)); + EXPECT_TRUE(Literal::CreateR0(-.5)->IsAllFloat(-.5)); + EXPECT_FALSE(Literal::CreateR0(-.5)->IsAllFloat(-.49)); EXPECT_FALSE( - LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(-.5), -.49)); - EXPECT_FALSE(LiteralUtil::IsAllFloat( - *LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}}), 0)); - EXPECT_TRUE(LiteralUtil::IsAllFloat( - *LiteralUtil::CreateR2({{.5, .5, .5}, {.5, .5, .5}}), .5)); - - EXPECT_TRUE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(0), 0)); - EXPECT_TRUE(LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(.5), .5)); + Literal::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); EXPECT_TRUE( - LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(-.5), -.5)); + Literal::CreateR2({{.5, .5, .5}, {.5, .5, .5}})->IsAllFloat(.5)); + + EXPECT_TRUE(Literal::CreateR0(0)->IsAllFloat(0)); + EXPECT_TRUE(Literal::CreateR0(.5)->IsAllFloat(.5)); + EXPECT_TRUE(Literal::CreateR0(-.5)->IsAllFloat(-.5)); + EXPECT_FALSE(Literal::CreateR0(-.5)->IsAllFloat(-.49)); EXPECT_FALSE( - LiteralUtil::IsAllFloat(*LiteralUtil::CreateR0(-.5), -.49)); - EXPECT_FALSE(LiteralUtil::IsAllFloat( - *LiteralUtil::CreateR2({{0, 0, 0}, {0, .1, 0}}), 0)); + Literal::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); } TEST_F(LiteralUtilTest, IsZero) { - auto scalar_zero = LiteralUtil::CreateR0(0.0f); - auto scalar_one = LiteralUtil::CreateR0(1.0f); - EXPECT_TRUE(LiteralUtil::IsZero(*scalar_zero, {})); - EXPECT_FALSE(LiteralUtil::IsZero(*scalar_one, {})); + auto scalar_zero = Literal::CreateR0(0.0f); + auto scalar_one = Literal::CreateR0(1.0f); + EXPECT_TRUE(scalar_zero->IsZero({})); + EXPECT_FALSE(scalar_one->IsZero({})); - auto array = LiteralUtil::CreateR2({{1, 2, 0, 3}, {1, 0, 1, 2}}); - EXPECT_FALSE(LiteralUtil::IsZero(*array, {0, 1})); - EXPECT_TRUE(LiteralUtil::IsZero(*array, {0, 2})); - EXPECT_TRUE(LiteralUtil::IsZero(*array, {1, 1})); - EXPECT_FALSE(LiteralUtil::IsZero(*array, {1, 2})); + auto array = Literal::CreateR2({{1, 2, 0, 3}, {1, 0, 1, 2}}); + EXPECT_FALSE(array->IsZero({0, 1})); + EXPECT_TRUE(array->IsZero({0, 2})); + EXPECT_TRUE(array->IsZero({1, 1})); + EXPECT_FALSE(array->IsZero({1, 2})); } template @@ -440,127 +423,122 @@ TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes); TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { // Make a non-integer for floating point types. TypeParam half = TypeParam(1) / TypeParam(2); - auto data = LiteralUtil::CreateR2({{half, 2}, {3, 4}}); + auto data = Literal::CreateR2({{half, 2}, {3, 4}}); const Layout layout01 = LayoutUtil::MakeLayout({0, 1}); const Layout layout10 = LayoutUtil::MakeLayout({1, 0}); - auto data01 = LiteralUtil::Relayout(*data, layout01); + auto data01 = data->Relayout(layout01); EXPECT_TRUE(LayoutUtil::Equal(data01->shape().layout(), layout01)); - EXPECT_TRUE(LiteralUtil::Equal(*data, *data01)); + EXPECT_TRUE(data->Equal(*data01)); - auto data10 = LiteralUtil::Relayout(*data, layout10); + auto data10 = data->Relayout(layout10); EXPECT_TRUE(LayoutUtil::Equal(data10->shape().layout(), layout10)); - EXPECT_TRUE(LiteralUtil::Equal(*data, *data10)); + EXPECT_TRUE(data->Equal(*data10)); } TEST_F(LiteralUtilTest, ReshapeR0) { - auto original = LiteralUtil::CreateR0(1.7f); - auto reshape = - LiteralUtil::Reshape(*original, /*shape=*/{}).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*original, *reshape)); + auto original = Literal::CreateR0(1.7f); + auto reshape = original->Reshape(/*shape=*/{}).ConsumeValueOrDie(); + EXPECT_TRUE(original->Equal(*reshape)); } TEST_F(LiteralUtilTest, ReshapeR4) { // clang-format off // F32[1x3x2x4] - auto original = LiteralUtil::CreateR4WithLayout({{ + auto original = Literal::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); // F32[1x3x4x2] - auto expected = LiteralUtil::CreateR3WithLayout({ + auto expected = Literal::CreateR3WithLayout({ {{10, 11}, {12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}, {24, 25}}, {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, }, layout_r3_dim0major_); // clang-format on - auto reshape = LiteralUtil::Reshape(*original, {3, 4, 2}).ConsumeValueOrDie(); + auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*expected, *reshape)); + EXPECT_TRUE(expected->Equal(*reshape)); } TEST_F(LiteralUtilTest, ReshapeR4Dim0Minor) { // clang-format off // F32[1x3x2x4] - auto original = LiteralUtil::CreateR4WithLayout({{ + auto original = Literal::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0minor_); // F32[1x3x4x2] - auto expected = LiteralUtil::CreateR3WithLayout({ + auto expected = Literal::CreateR3WithLayout({ {{10, 11}, {12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}, {24, 25}}, {{26, 27}, {28, 29}, {30, 31}, {32, 33}}, }, layout_r3_dim0major_); // clang-format on - auto reshape = LiteralUtil::Reshape(*original, {3, 4, 2}).ConsumeValueOrDie(); + auto reshape = original->Reshape({3, 4, 2}).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*expected, *reshape)); + EXPECT_TRUE(expected->Equal(*reshape)); } TEST_F(LiteralUtilTest, TransposeR0) { - auto original = LiteralUtil::CreateR0(1.7f); - auto reshape = LiteralUtil::Transpose(*original, /*permutation=*/{}); - EXPECT_TRUE(LiteralUtil::Equal(*original, *reshape)); + auto original = Literal::CreateR0(1.7f); + auto reshape = original->Transpose(/*permutation=*/{}); + EXPECT_TRUE(original->Equal(*reshape)); } TEST_F(LiteralUtilTest, TransposeR4) { // clang-format off // F32[1x3x2x4] - auto original = LiteralUtil::CreateR4({{ + auto original = Literal::CreateR4({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}); // clang-format on - auto reshape = - LiteralUtil::Transpose(*original, /*permutation=*/{2, 3, 0, 1}); + auto reshape = original->Transpose(/*permutation=*/{2, 3, 0, 1}); - LiteralUtil::EachCell( - *reshape, [&](tensorflow::gtl::ArraySlice indices, float value) { - EXPECT_EQ(value, - LiteralUtil::Get(*original, {indices[2], indices[3], - indices[0], indices[1]})); + reshape->EachCell( + [&](tensorflow::gtl::ArraySlice indices, float value) { + EXPECT_EQ(value, original->Get( + {indices[2], indices[3], indices[0], indices[1]})); }); } TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { // Tests that using Relayout on an array is equivalent to creating it in the // target layout in the first place. - auto dim0minor_relaid_to_dim0major = LiteralUtil::Relayout( - *literal_r4_2x2x3x3_dim0minor_, layout_r4_dim0major_); - EXPECT_TRUE(LiteralUtil::Equal(*literal_r4_2x2x3x3_dim0major_, - *dim0minor_relaid_to_dim0major)); + auto dim0minor_relaid_to_dim0major = + literal_r4_2x2x3x3_dim0minor_->Relayout(layout_r4_dim0major_); + EXPECT_TRUE( + literal_r4_2x2x3x3_dim0major_->Equal(*dim0minor_relaid_to_dim0major)); - auto dim0major_relaid_to_dim0minor = LiteralUtil::Relayout( - *literal_r4_2x2x3x3_dim0major_, layout_r4_dim0minor_); - EXPECT_TRUE(LiteralUtil::Equal(*literal_r4_2x2x3x3_dim0minor_, - *dim0major_relaid_to_dim0minor)); + auto dim0major_relaid_to_dim0minor = + literal_r4_2x2x3x3_dim0major_->Relayout(layout_r4_dim0minor_); + EXPECT_TRUE( + literal_r4_2x2x3x3_dim0minor_->Equal(*dim0major_relaid_to_dim0minor)); } TEST_F(LiteralUtilTest, TestR2LinearLayout) { // Test expected memory layout of R2 dim0-minor (column-major) literal. - auto mat_dim0minor = LiteralUtil::CreateR2WithLayout( - {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_); + auto mat_dim0minor = Literal::CreateR2WithLayout({{1, 2, 3}, {4, 5, 6}}, + layout_r2_dim0minor_); EXPECT_EQ(mat_dim0minor->s32s_size(), 6); EXPECT_THAT(mat_dim0minor->s32s(), ElementsAre(1, 4, 2, 5, 3, 6)); // Test expected memory layout when using Relayout to row major. - auto relaid_mat_to_dim0major = - LiteralUtil::Relayout(*mat_dim0minor, layout_r2_dim0major_); + auto relaid_mat_to_dim0major = mat_dim0minor->Relayout(layout_r2_dim0major_); EXPECT_THAT(relaid_mat_to_dim0major->s32s(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout of R2 created with dim0-major (row-major). - auto mat_dim0major = LiteralUtil::CreateR2WithLayout( - {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_); + auto mat_dim0major = Literal::CreateR2WithLayout({{1, 2, 3}, {4, 5, 6}}, + layout_r2_dim0major_); EXPECT_EQ(mat_dim0major->s32s_size(), 6); EXPECT_THAT(mat_dim0major->s32s(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout when using Relayout to column major. - auto relaid_mat_to_dim0minor = - LiteralUtil::Relayout(*mat_dim0major, layout_r2_dim0minor_); + auto relaid_mat_to_dim0minor = mat_dim0major->Relayout(layout_r2_dim0minor_); EXPECT_THAT(relaid_mat_to_dim0minor->s32s(), ElementsAre(1, 4, 2, 5, 3, 6)); } @@ -578,8 +556,8 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { {10, 11, 12}, }, }); // clang-format on - auto lit_dim0minor = LiteralUtil::CreateR3FromArray3DWithLayout( - arr3d, layout_r3_dim0minor_); + auto lit_dim0minor = + Literal::CreateR3FromArray3DWithLayout(arr3d, layout_r3_dim0minor_); EXPECT_EQ(lit_dim0minor->s32s_size(), 12); std::vector expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}; @@ -587,122 +565,120 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { testing::ElementsAreArray(expected_dim0minor)); // Test expected memory layout when using Relayout to row major. - auto relaid_lit_to_dim0major = - LiteralUtil::Relayout(*lit_dim0minor, layout_r3_dim0major_); + auto relaid_lit_to_dim0major = lit_dim0minor->Relayout(layout_r3_dim0major_); std::vector expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; EXPECT_THAT(relaid_lit_to_dim0major->s32s(), testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout of R3 created with dim0-major (row-major). - auto lit_dim0major = LiteralUtil::CreateR3FromArray3DWithLayout( - arr3d, layout_r3_dim0major_); + auto lit_dim0major = + Literal::CreateR3FromArray3DWithLayout(arr3d, layout_r3_dim0major_); EXPECT_EQ(lit_dim0major->s32s_size(), 12); EXPECT_THAT(lit_dim0major->s32s(), testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout when using Relayout to column major. - auto relaid_lit_to_dim0minor = - LiteralUtil::Relayout(*lit_dim0major, layout_r3_dim0minor_); + auto relaid_lit_to_dim0minor = lit_dim0major->Relayout(layout_r3_dim0minor_); EXPECT_THAT(relaid_lit_to_dim0minor->s32s(), testing::ElementsAreArray(expected_dim0minor)); } TEST_F(LiteralUtilTest, SliceR0S32) { - auto input = LiteralUtil::CreateR0(1); - auto result = LiteralUtil::Slice(*input, {}, {}); - EXPECT_TRUE(LiteralUtil::Equal(*input, *result)); + auto input = Literal::CreateR0(1); + auto result = input->Slice({}, {}); + EXPECT_TRUE(input->Equal(*result)); } TEST_F(LiteralUtilTest, SliceR1F32) { - auto input = LiteralUtil::CreateR1({1.0, 2.0, 3.0, 4.0, 5.0}); - auto result = LiteralUtil::Slice(*input, {3}, {4}); - auto expected = LiteralUtil::CreateR1({4.0}); - EXPECT_TRUE(LiteralUtil::Equal(*expected, *result)); + auto input = Literal::CreateR1({1.0, 2.0, 3.0, 4.0, 5.0}); + auto result = input->Slice({3}, {4}); + auto expected = Literal::CreateR1({4.0}); + EXPECT_TRUE(expected->Equal(*result)); } TEST_F(LiteralUtilTest, SliceR2U32) { - auto input_3x4 = LiteralUtil::CreateR2( - {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); - auto result = LiteralUtil::Slice(*input_3x4, {0, 2}, {2, 4}); - auto expected = LiteralUtil::CreateR2({{3, 4}, {7, 8}}); - EXPECT_TRUE(LiteralUtil::Equal(*expected, *result)); + auto input_3x4 = + Literal::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); + auto result = input_3x4->Slice({0, 2}, {2, 4}); + auto expected = Literal::CreateR2({{3, 4}, {7, 8}}); + EXPECT_TRUE(expected->Equal(*result)); } TEST_F(LiteralUtilTest, SliceR3U32Full) { - auto input_2x3x2 = LiteralUtil::CreateR3( + auto input_2x3x2 = Literal::CreateR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}); - auto result = LiteralUtil::Slice(*input_2x3x2, {0, 0, 0}, {2, 3, 2}); - EXPECT_TRUE(LiteralUtil::Equal(*input_2x3x2, *result)); + auto result = input_2x3x2->Slice({0, 0, 0}, {2, 3, 2}); + EXPECT_TRUE(input_2x3x2->Equal(*result)); } TEST_F(LiteralUtilTest, PopulateR1S64) { Literal output; - LiteralUtil::PopulateR1({77}, &output); - auto expected = LiteralUtil::CreateR1({77}); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateR1({77}); + auto expected = Literal::CreateR1({77}); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateR2U64) { Literal output; - LiteralUtil::PopulateR1({{77, 88}}, &output); - auto expected = LiteralUtil::CreateR1({{77, 88}}); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateR1({{77, 88}}); + auto expected = Literal::CreateR1({{77, 88}}); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output; - LiteralUtil::PopulateWithValue(2.5f, {}, &output); - auto expected = LiteralUtil::CreateR0(2.5f); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateWithValue(2.5f, {}); + auto expected = Literal::CreateR0(2.5f); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateWithValueR1S64) { Literal output; - LiteralUtil::PopulateWithValue(-7, {3}, &output); - auto expected = LiteralUtil::CreateR1({-7, -7, -7}); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateWithValue(-7, {3}); + auto expected = Literal::CreateR1({-7, -7, -7}); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { Literal output; - LiteralUtil::PopulateWithValue(42, {2, 2}, &output); - auto expected = LiteralUtil::CreateR2({{42, 42}, {42, 42}}); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateWithValue(42, {2, 2}); + auto expected = Literal::CreateR2({{42, 42}, {42, 42}}); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { Literal output; half h(0.25f); - LiteralUtil::PopulateWithValue(h, {}, &output); - auto expected = LiteralUtil::CreateR0(h); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateWithValue(h, {}); + auto expected = Literal::CreateR0(h); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { Literal output; half h(0.5f); - LiteralUtil::PopulateWithValue(h, {3}, &output); - auto expected = LiteralUtil::CreateR1({h, h, h}); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateWithValue(h, {3}); + auto expected = Literal::CreateR1({h, h, h}); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { Literal output; half h(2.0f); - LiteralUtil::PopulateWithValue(h, {2, 2}, &output); - auto expected = LiteralUtil::CreateR2({{h, h}, {h, h}}); - EXPECT_TRUE(LiteralUtil::Equal(output, *expected)); + output.PopulateWithValue(h, {2, 2}); + auto expected = Literal::CreateR2({{h, h}, {h, h}}); + EXPECT_TRUE(output.Equal(*expected)); } TEST_F(LiteralUtilTest, ReplicateR2U32) { - auto input = LiteralUtil::CreateR2( - {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); - auto output = LiteralUtil::Replicate(*input, 3); - auto expected = LiteralUtil::CreateR3( + auto input = + Literal::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}); + auto output = input->Replicate(3); + auto expected = Literal::CreateR3( {{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}}); - EXPECT_TRUE(LiteralUtil::Equal(*output, *expected)); + EXPECT_TRUE(output->Equal(*expected)); } TEST_F(LiteralUtilTest, Copy) { @@ -712,13 +688,13 @@ TEST_F(LiteralUtilTest, Copy) { for (const auto& layout : layouts) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), dimensions, layout); - auto blank = LiteralUtil::CreateFromShape(shape); - auto source = LiteralUtil::CreateFromShape(shape); + auto blank = Literal::CreateFromShape(shape); + auto source = Literal::CreateFromShape(shape); const int64 zero_base[] = {0, 0, 0, 0}; const int64 step[] = {1, 1, 1, 1}; uint32 seqnr = 0; auto init_proc = [&](const std::vector& indexes) { - LiteralUtil::Set(source.get(), indexes, ++seqnr); + source.get()->Set(indexes, ++seqnr); return true; }; @@ -729,8 +705,7 @@ TEST_F(LiteralUtilTest, Copy) { const int64 dest_base[] = {6, 4, 12, 2}; const int64 copy_size[] = {7, 8, 11, 9}; - TF_EXPECT_OK(LiteralUtil::Copy(*source, src_base, blank.get(), dest_base, - copy_size)); + TF_EXPECT_OK(blank.get()->Copy(*source, src_base, dest_base, copy_size)); std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); bool matched = true; @@ -741,9 +716,8 @@ TEST_F(LiteralUtilTest, Copy) { std::copy(indexes.begin(), indexes.end(), blank_indexes.begin()); std::transform(blank_indexes.begin(), blank_indexes.end(), dest_base, blank_indexes.begin(), std::plus()); - auto bval = LiteralUtil::Get(*blank, blank_indexes); - matched = (bval != 0 && - bval == LiteralUtil::Get(*source, source_indexes)); + auto bval = blank->Get(blank_indexes); + matched = (bval != 0 && bval == source->Get(source_indexes)); return matched; }; ShapeUtil::ForEachIndex(source->shape(), zero_base, copy_size, step, @@ -753,25 +727,25 @@ TEST_F(LiteralUtilTest, Copy) { } TEST_F(LiteralUtilTest, CopyScalars) { - auto zero = LiteralUtil::CreateR0(0); - auto nine = LiteralUtil::CreateR0(9); - TF_EXPECT_OK(LiteralUtil::Copy(*nine, {}, zero.get(), {}, {})); - EXPECT_TRUE(LiteralUtil::Equal(*zero, *nine)); + auto zero = Literal::CreateR0(0); + auto nine = Literal::CreateR0(9); + TF_EXPECT_OK(zero.get()->Copy(*nine, {}, {}, {})); + EXPECT_TRUE(zero->Equal(*nine)); - auto vect = LiteralUtil::CreateR1({3, 4, 9, 12, 5, 17, 21}); - TF_EXPECT_OK(LiteralUtil::Copy(*vect, {5}, zero.get(), {}, {})); - EXPECT_EQ(LiteralUtil::Get(*zero, {}), 17); - TF_EXPECT_OK(LiteralUtil::Copy(*zero, {}, vect.get(), {4}, {})); - EXPECT_EQ(LiteralUtil::Get(*vect, {4}), 17); + auto vect = Literal::CreateR1({3, 4, 9, 12, 5, 17, 21}); + TF_EXPECT_OK(zero.get()->Copy(*vect, {5}, {}, {})); + EXPECT_EQ(zero->Get({}), 17); + TF_EXPECT_OK(vect.get()->Copy(*zero, {}, {4}, {})); + EXPECT_EQ(vect->Get({4}), 17); } TEST_F(LiteralUtilTest, F16) { // Verify that the internal data views are consistent and that they // are in little endian format // TODO - modify if we make the data format machine endianess dependent - auto m1 = LiteralUtil::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); + auto m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); Literal* l1 = m1.get(); - const char* d1 = static_cast(LiteralUtil::InternalData(*l1)); + const char* d1 = static_cast(l1->InternalData()); EXPECT_EQ(d1[0], 0); EXPECT_EQ(d1[1], 0); EXPECT_EQ(d1[2], 0); @@ -780,14 +754,13 @@ TEST_F(LiteralUtilTest, F16) { EXPECT_EQ(d1[5], 0); EXPECT_EQ(d1[6], 0); EXPECT_EQ(d1[7], 0); - EXPECT_EQ(LiteralUtil::InternalData(*l1), - LiteralUtil::MutableInternalData(l1)); + EXPECT_EQ(l1->InternalData(), l1->MutableInternalData()); half h1(1.0f); half h2(2.0f); - auto m2 = LiteralUtil::CreateR2({{h1, h2}, {h2, h1}}); + auto m2 = Literal::CreateR2({{h1, h2}, {h2, h1}}); Literal* l2 = m2.get(); - const char* d2 = static_cast(LiteralUtil::InternalData(*l2)); + const char* d2 = static_cast(l2->InternalData()); EXPECT_EQ(d2[0], 0); EXPECT_EQ(d2[1], 0x3C); EXPECT_EQ(d2[2], 0); @@ -796,8 +769,7 @@ TEST_F(LiteralUtilTest, F16) { EXPECT_EQ(d2[5], 0x40); EXPECT_EQ(d2[6], 0); EXPECT_EQ(d2[7], 0x3C); - EXPECT_EQ(LiteralUtil::InternalData(*l2), - LiteralUtil::MutableInternalData(l2)); + EXPECT_EQ(l2->InternalData(), l2->MutableInternalData()); } TEST_F(LiteralUtilTest, Populate) { @@ -818,19 +790,19 @@ TEST_F(LiteralUtilTest, Populate) { Shape shape = ShapeUtil::MakeShapeWithLayout( primitive_util::NativeToPrimitiveType(), data.dimensions, data.layout); - auto literal = LiteralUtil::CreateFromShape(shape); + auto literal = Literal::CreateFromShape(shape); auto generator = [&](tensorflow::gtl::ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. - return LiteralUtil::LinearIndex(*literal, indexes) + 17; + return literal->LinearIndex(indexes) + 17; }; - TF_EXPECT_OK(LiteralUtil::Populate(literal.get(), generator)); + TF_EXPECT_OK(literal.get()->Populate(generator)); std::vector zero_base(data.dimensions.size(), 0); std::vector step(data.dimensions.size(), 1); bool matched = true; auto check_function = [&](const std::vector& indexes) { - auto value = LiteralUtil::Get(*literal, indexes); + auto value = literal->Get(indexes); matched = matched && (value == generator(indexes)); return matched; }; @@ -842,65 +814,66 @@ TEST_F(LiteralUtilTest, Populate) { TEST_F(LiteralUtilTest, ConvertR4) { // clang-format off - auto original = LiteralUtil::CreateR4WithLayout({{ + auto original = Literal::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); - auto expected = LiteralUtil::CreateR4WithLayout({{ + auto expected = Literal::CreateR4WithLayout({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, }}, layout_r4_dim0major_); // clang-format on - auto converted = LiteralUtil::Convert(*original); + TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr converted, + original->Convert(U32)); - EXPECT_TRUE(LiteralUtil::Equal(*expected, *converted)); + EXPECT_TRUE(expected->Equal(*converted)); } TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { // clang-format off - auto s8 = LiteralUtil::CreateR4WithLayout({{ + auto s8 = Literal::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto s32 = LiteralUtil::CreateR4WithLayout({{ + auto s32 = Literal::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto u32 = LiteralUtil::CreateR4WithLayout({{ + auto u32 = Literal::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto s64 = LiteralUtil::CreateR4WithLayout({{ + auto s64 = Literal::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto u64 = LiteralUtil::CreateR4WithLayout({{ + auto u64 = Literal::CreateR4WithLayout({{ {{10, 0, 12, 0}, {0, 15, 0, 17}}, {{0, 19, 0, 21}, {22, 0, 24, 0}}, {{26, 0, 28, 0}, {0, 31, 0, 33}}, }}, layout_r4_dim0major_); - auto pred = LiteralUtil::CreateR4WithLayout({{ + auto pred = Literal::CreateR4WithLayout({{ {{true, false, true, false}, {false, true, false, true}}, {{false, true, false, true}, {true, false, true, false}}, {{true, false, true, false}, {false, true, false, true}}, }}, layout_r4_dim0major_); - auto int32_pred = LiteralUtil::CreateR4WithLayout({{ + auto int32_pred = Literal::CreateR4WithLayout({{ {{1, 0, 1, 0}, {0, 1, 0, 1}}, {{0, 1, 0, 1}, {1, 0, 1, 0}}, {{1, 0, 1, 0}, {0, 1, 0, 1}}, }}, layout_r4_dim0major_); - auto f32 = LiteralUtil::CreateR4WithLayout({{ + auto f32 = Literal::CreateR4WithLayout({{ {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, }}, layout_r4_dim0major_); - auto f64 = LiteralUtil::CreateR4WithLayout({{ + auto f64 = Literal::CreateR4WithLayout({{ {{10.0, 0.0, 12.0, 0.0}, {0.0, 15.0, 0.0, 17.0}}, {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}}, {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}}, @@ -908,40 +881,40 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { // clang-format on std::unique_ptr conv; - conv = LiteralUtil::ConvertIfSrcTypeMatches(*s8, U32).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *u32)); + conv = s8->Convert(U32).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*u32)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*s8, S32).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *s32)); + conv = s8->Convert(S32).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*s32)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*s8, U64).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *u64)); + conv = s8->Convert(U64).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*u64)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*s8, S64).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *s64)); + conv = s8->Convert(S64).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*s64)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*s8, PRED).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *pred)); + conv = s8->Convert(PRED).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*pred)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*pred, S32).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *int32_pred)); + conv = pred->Convert(S32).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*int32_pred)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*f32, S32).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *s32)); + conv = f32->Convert(S32).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*s32)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*f64, S32).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *s32)); + conv = f64->Convert(S32).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*s32)); - conv = LiteralUtil::ConvertIfSrcTypeMatches(*s32, F32).ConsumeValueOrDie(); - EXPECT_TRUE(LiteralUtil::Equal(*conv, *f32)); + conv = s32->Convert(F32).ConsumeValueOrDie(); + EXPECT_TRUE(conv->Equal(*f32)); - EXPECT_EQ(LiteralUtil::ConvertIfSrcTypeMatches(*s32, TUPLE).status().code(), + EXPECT_EQ(s32->Convert(TUPLE).status().code(), tensorflow::error::INVALID_ARGUMENT); - EXPECT_EQ(LiteralUtil::ConvertIfSrcTypeMatches(*s32, F16).status().code(), + EXPECT_EQ(s32->Convert(F16).status().code(), tensorflow::error::INVALID_ARGUMENT); - EXPECT_EQ(LiteralUtil::ConvertIfSrcTypeMatches(*s32, S16).status().code(), + EXPECT_EQ(s32->Convert(S16).status().code(), tensorflow::error::INVALID_ARGUMENT); - EXPECT_EQ(LiteralUtil::ConvertIfSrcTypeMatches(*s32, U16).status().code(), + EXPECT_EQ(s32->Convert(U16).status().code(), tensorflow::error::INVALID_ARGUMENT); } diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index d488830a6cd..11870799066 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -58,8 +58,7 @@ StatusOr> PackedLiteralReader::Read( } int64 elements = ShapeUtil::ElementsIn(shape); - LiteralUtil::Resize(elements, std::numeric_limits::quiet_NaN(), - result.get()); + result.get()->Resize(elements, std::numeric_limits::quiet_NaN()); std::vector* field = result->mutable_f32s(); char* data = tensorflow::bit_cast(field->data()); uint64 bytes = elements * sizeof(float); diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index f839ac019df..215f2202589 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -52,7 +52,7 @@ class ReferenceUtilTest : public ::testing::Test { TEST_F(ReferenceUtilTest, TransposeArray2D) { auto result = ReferenceUtil::TransposeArray2D(*matrix_); - auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = Literal::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 4.f}, {2.f, 5.f}, {3.f, 6.f}}, *actual_literal, ErrorSpec(0.0001)); } @@ -62,7 +62,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { {7.f, 8.f}, {9.f, 10.f}, {11.f, 12.f}, }); auto result = ReferenceUtil::MatmulArray2D(*matrix_, rhs); - auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = Literal::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{58.f, 64.f}, {139.f, 154.f}}, *actual_literal, ErrorSpec(0.0001)); } @@ -70,7 +70,7 @@ TEST_F(ReferenceUtilTest, MatmulArray2D) { TEST_F(ReferenceUtilTest, ReduceToColArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToColArray2D(*matrix_, 0.0f, add); - auto actual_literal = LiteralUtil::CreateR1(*result); + auto actual_literal = Literal::CreateR1(*result); LiteralTestUtil::ExpectR1Near({6.f, 15.f}, *actual_literal, ErrorSpec(0.0001)); } @@ -78,7 +78,7 @@ TEST_F(ReferenceUtilTest, ReduceToColArray2D) { TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { auto add = [](float lhs, float rhs) { return lhs + rhs; }; auto result = ReferenceUtil::ReduceToRowArray2D(*matrix_, 0.0f, add); - auto actual_literal = LiteralUtil::CreateR1(*result); + auto actual_literal = Literal::CreateR1(*result); LiteralTestUtil::ExpectR1Near({5.f, 7.f, 9.f}, *actual_literal, ErrorSpec(0.0001)); } @@ -86,7 +86,7 @@ TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { TEST_F(ReferenceUtilTest, MapArray2D) { auto identity = [](float value) { return log(exp(value)); }; auto result = ReferenceUtil::MapArray2D(*matrix_, identity); - auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = Literal::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2NearArray2D(*matrix_, *actual_literal, ErrorSpec(0.0001)); } @@ -96,7 +96,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray2D) { return value + row + col; }; auto result = ReferenceUtil::MapWithIndexArray2D(*matrix_, add_index); - auto actual_literal = LiteralUtil::CreateR2FromArray2D(*result); + auto actual_literal = Literal::CreateR2FromArray2D(*result); LiteralTestUtil::ExpectR2Near({{1.f, 3.f, 5.f}, {5.f, 7.f, 9.f}}, *actual_literal, ErrorSpec(0.0001)); } @@ -107,7 +107,7 @@ TEST_F(ReferenceUtilTest, MapArray4D) { input->FillWithMultiples(1.0f); auto multiply_by_two = [](float value) { return 2 * value; }; auto result = ReferenceUtil::MapArray4D(*input, multiply_by_two); - auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); + auto actual_literal = Literal::CreateR4FromArray4D(*result); Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.FillWithMultiples(2.0f); @@ -124,7 +124,7 @@ TEST_F(ReferenceUtilTest, MapWithIndexArray4D) { return value - (3 * 4 * 5 * plane + 4 * 5 * depth + 5 * height + width); }; auto result = ReferenceUtil::MapWithIndexArray4D(*input, subtract_index); - auto actual_literal = LiteralUtil::CreateR4FromArray4D(*result); + auto actual_literal = Literal::CreateR4FromArray4D(*result); Array4D expected(/*planes=*/2, /*depth=*/3, /*height=*/4, /*width=*/5); expected.Fill(0.0f); @@ -161,7 +161,7 @@ TEST_F(ReferenceUtilTest, ConvWithSamePadding) { })); // clang-format on - auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); + auto actual_literal = Literal::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -195,7 +195,7 @@ TEST_F(ReferenceUtilTest, ConvWithValidPadding) { })); // clang-format on - auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); + auto actual_literal = Literal::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -247,7 +247,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) { }}); // clang-format on - auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); + auto actual_literal = Literal::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -296,7 +296,7 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { Array4D expected({{{{2514, 2685}}}}); // clang-format on - auto actual_literal = LiteralUtil::CreateR4FromArray4D(*actual); + auto actual_literal = Literal::CreateR4FromArray4D(*actual); LiteralTestUtil::ExpectR4NearArray4D(expected, *actual_literal, ErrorSpec(0.0001)); @@ -309,7 +309,7 @@ TEST_F(ReferenceUtilTest, ApplyElementwise2D) { auto actual = ReferenceUtil::ApplyElementwise2D( [](float x, float y, float z) { return 100 * x + 10 * y + z; }, a, b, c); - auto actual_literal = LiteralUtil::CreateR2FromArray2D(*actual); + auto actual_literal = Literal::CreateR2FromArray2D(*actual); LiteralTestUtil::ExpectR2Near({{300.f, 600.f}, {900.f, 1200.f}}, *actual_literal, ErrorSpec(0.0001)); } diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 0687368b83d..eaf89f13191 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -112,6 +112,7 @@ cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", "//tensorflow/core:test_main", @@ -330,6 +331,7 @@ cc_library( hdrs = ["backend.h"], deps = [ ":compiler", + ":computation_placer", ":device_memory_allocator", ":platform_util", ":pool", @@ -338,7 +340,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/legacy_flags:backend_flags", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -382,6 +383,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/legacy_flags:backend_flags", "//tensorflow/compiler/xla/legacy_flags:service_flags", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", "//tensorflow/core:lib", @@ -416,6 +418,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/legacy_flags:service_flags", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", @@ -948,6 +951,26 @@ cc_test( ], ) +cc_library( + name = "computation_placer", + srcs = ["computation_placer.cc"], + hdrs = ["computation_placer.h"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], + alwayslink = True, # Contains per-platform computation placer registration +) + cc_library( name = "generic_transfer_manager", srcs = ["generic_transfer_manager.cc"], @@ -1165,6 +1188,7 @@ cc_library( deps = [ ":call_graph", ":hlo", + ":hlo_ordering", ":liveness_util", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", @@ -1398,7 +1422,6 @@ cc_library( ":call_graph", ":flatten_call_graph", ":hlo", - ":hlo_cost_analysis", ":hlo_dce", ":hlo_ordering", ":liveness_util", @@ -1572,10 +1595,8 @@ cc_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/core:lib", ], ) @@ -1777,7 +1798,6 @@ cc_library( ":hlo", ":hlo_proto", "//tensorflow/compiler/xla:status", - "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 754ac0c68dc..5709ac3067f 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -48,7 +48,7 @@ namespace { // Returns whether operand is a literal with the given value. bool IsLiteralWithValue(const HloInstruction* operand, int8 value) { return operand->opcode() == HloOpcode::kConstant && - LiteralUtil::IsAll(operand->literal(), value); + operand->literal().IsAll(value); } bool IsAll(const HloInstruction* op, int8 value) { @@ -126,10 +126,9 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { HloInstruction* concatenate, tensorflow::gtl::ArraySlice operands) override; - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleCopy(HloInstruction* copy) override; - Status HandleConvert(HloInstruction* convert, - HloInstruction* operand) override; + Status HandleConvert(HloInstruction* convert) override; Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, const Window& window) override; @@ -179,11 +178,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleSubtract(HloInstruction* sub, HloInstruction* lhs, HloInstruction* rhs) override; - Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, - HloInstruction* rhs) override; - - Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs, - HloInstruction* rhs) override; + Status HandleMaximum(HloInstruction* maximum) override; + Status HandleMinimum(HloInstruction* minimum) override; // Returns whether algebraic simplification has occurred. const bool changed() const { return changed_; } @@ -334,16 +330,16 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add, return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy, - HloInstruction* operand) { +Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { // If a copy feeds a copy, make it a single copy. - if (operand->opcode() == HloOpcode::kCopy) { + if (copy->operand(0)->opcode() == HloOpcode::kCopy) { return ReplaceWithNewInstruction( - copy, HloInstruction::CreateUnary(copy->shape(), HloOpcode::kCopy, - operand->operands()[0])); + copy, HloInstruction::CreateUnary( + copy->shape(), HloOpcode::kCopy, + copy->mutable_operand(0)->mutable_operand(0))); } // All copies can be eliminated (assuming layout constraints are satisified). - ReplaceInstructionIfSameShape(copy, operand); + ReplaceInstructionIfSameShape(copy, copy->mutable_operand(0)); return Status::OK(); } @@ -469,7 +465,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, ShapeUtil::HasZeroElements(lhs->shape()) || ShapeUtil::HasZeroElements(rhs->shape())) { auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); return ReplaceWithNewInstruction( dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); } @@ -507,7 +503,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, HloComputation* add_reduce_computation = CreateScalarBinaryComputation( computation_->parent(), F32, HloOpcode::kAdd); auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero, {0}, add_reduce_computation)); @@ -531,7 +527,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, HloComputation* add_reduce_computation = CreateScalarBinaryComputation( computation_->parent(), F32, HloOpcode::kAdd); auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); HloInstruction* reduce; if (ShapeUtil::Rank(rhs->shape()) == 1) { auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( @@ -571,7 +567,7 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, HloComputation* add_reduce_computation = CreateScalarBinaryComputation( computation_->parent(), F32, HloOpcode::kAdd); auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( ShapeUtil::MakeShape(dot->shape().element_type(), {lhs->shape().dimensions(0)}), @@ -792,12 +788,11 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { // A conversion to the same element type as the operand is a nop and can be // removed. A conversion of a constant can be simplified by making a new // constant. -Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert, - HloInstruction* operand) { - PrimitiveType src_type = operand->shape().element_type(); +Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) { + PrimitiveType src_type = convert->operand(0)->shape().element_type(); PrimitiveType dest_type = convert->shape().element_type(); if (src_type == dest_type) { - return ReplaceInstruction(convert, operand); + return ReplaceInstruction(convert, convert->mutable_operand(0)); } return Status::OK(); } @@ -897,8 +892,8 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power, HloInstruction* rhs) { VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString(); if (IsAll(rhs, 0)) { - auto one = HloInstruction::CreateConstant(LiteralUtil::CloneToUnique( - LiteralUtil::One(power->shape().element_type()))); + auto one = HloInstruction::CreateConstant( + Literal::One(power->shape().element_type()).CloneToUnique()); std::unique_ptr ones; if (ShapeUtil::IsScalar(power->shape())) { ones = std::move(one); @@ -923,9 +918,8 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power, VLOG(10) << "trying transform [pow(A, -1) => 1/A]: " << power->ToString(); if (IsAll(rhs, -1)) { - auto* one = computation_->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CloneToUnique( - LiteralUtil::One(rhs->shape().element_type())))); + auto* one = computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::One(rhs->shape().element_type()).CloneToUnique())); return ReplaceWithNewInstruction( power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide, one, lhs)); @@ -1008,7 +1002,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { // dimension. if (ShapeUtil::HasZeroElements(reshape->shape())) { auto empty_constant = HloInstruction::CreateConstant( - LiteralUtil::CreateFromShape(reshape->shape())); + Literal::CreateFromShape(reshape->shape())); return ReplaceWithNewInstruction(reshape, std::move(empty_constant)); } @@ -1208,8 +1202,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( // try to get more fancy about proving equivalence in cases beyond that. if (pad_value->opcode() != HloOpcode::kConstant || reduce_init_value->opcode() != HloOpcode::kConstant || - !LiteralUtil::Equal(pad_value->literal(), - reduce_init_value->literal())) { + !pad_value->literal().Equal(reduce_init_value->literal())) { VLOG(10) << "Not folding pad into reduce-window due to different pad " "values."; return Status::OK(); @@ -1396,9 +1389,7 @@ bool AlgebraicSimplifierVisitor::TransformToClampIfSameShape( return true; } -Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum, - HloInstruction* lhs, - HloInstruction* rhs) { +Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum) { // Match the following tree: // min_operand operand // \ / @@ -1429,9 +1420,7 @@ Status AlgebraicSimplifierVisitor::HandleMaximum(HloInstruction* maximum, return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum, - HloInstruction* lhs, - HloInstruction* rhs) { +Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum) { // Match the following tree: // max_operand operand // \ / diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index e4368a7bb25..7e52c8fb0c3 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -55,7 +55,7 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero)); @@ -76,7 +76,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); HloInstruction* bcast = builder.AddInstruction( HloInstruction::CreateBroadcast(r2f32, zero, {0, 1})); builder.AddInstruction( @@ -99,7 +99,7 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0, 0, 0}))); + HloInstruction::CreateConstant(Literal::CreateR1({0, 0, 0}))); HloInstruction* bcast = builder.AddInstruction(HloInstruction::CreateBroadcast(r2f32, zero, {1})); builder.AddInstruction( @@ -123,7 +123,7 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero)); @@ -145,7 +145,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one)); @@ -167,7 +167,7 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r2f32, "param0")); HloInstruction* one = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 1.0}, {1.0, 1.0}}))); + Literal::CreateR2({{1.0, 1.0}, {1.0, 1.0}}))); HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one)); @@ -300,7 +300,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero)); @@ -315,7 +315,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); - EXPECT_EQ(LiteralUtil::GetFirstElement(root->literal()), 1); + EXPECT_EQ(root->literal().GetFirstElement(), 1); } // Test that pow(A, 0) where A is not a scalar is simplified to broadcast(1). @@ -325,7 +325,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero)); @@ -344,8 +344,7 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { << ShapeUtil::HumanString(root->shape()); EXPECT_EQ(root->dimensions().size(), 0); EXPECT_TRUE(ShapeUtil::IsScalar(root->operand(0)->shape())); - EXPECT_EQ(LiteralUtil::GetFirstElement(root->operand(0)->literal()), - 1); + EXPECT_EQ(root->operand(0)->literal().GetFirstElement(), 1); } // Test that pow(A, 1) is simplified to A. @@ -355,7 +354,7 @@ TEST_F(AlgebraicSimplifierTest, Pow1) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one)); @@ -378,7 +377,7 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* two = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2))); + HloInstruction::CreateConstant(Literal::CreateR0(2))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two)); @@ -401,7 +400,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* negative_one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(-1))); + HloInstruction::CreateConstant(Literal::CreateR0(-1))); builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, negative_one)); @@ -416,8 +415,7 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Divide(op::Constant(), param0)); - EXPECT_EQ(LiteralUtil::GetFirstElement(root->operand(0)->literal()), - 1); + EXPECT_EQ(root->operand(0)->literal().GetFirstElement(), 1); } TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { @@ -451,7 +449,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); @@ -519,7 +517,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { HloInstruction* param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, r1f32, "param1")); HloInstruction* empty_literal = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); + HloInstruction::CreateConstant(Literal::CreateR1({}))); HloInstruction* empty_slice = builder.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(F32, {0}), param1, {42}, {42}, {1})); @@ -550,7 +548,7 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* empty_literal = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({}))); + HloInstruction::CreateConstant(Literal::CreateR1({}))); HloInstruction* empty_slice = builder.AddInstruction(HloInstruction::CreateSlice( ShapeUtil::MakeShape(F32, {0}), param0, {42}, {42}, {1})); @@ -735,7 +733,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), param)); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), HloOpcode::kMaximum, movable_reshape, zero)); @@ -1035,7 +1033,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveNoopPad) { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {2, 2}), "param")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); PaddingConfig no_padding; for (int i = 0; i < 2; ++i) { auto dimension = no_padding.add_dimensions(); @@ -1066,7 +1064,7 @@ TEST_F(AlgebraicSimplifierTest, NegativePadding) { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {10, 10}), "param")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); PaddingConfig padding; int64 low_padding[2] = {-1, -2}; int64 high_padding[2] = {2, -3}; @@ -1376,9 +1374,9 @@ TEST_F(AlgebraicSimplifierTest, MaxMinToClamp) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); HloInstruction* min = builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMinimum, param0, min_value)); builder.AddInstruction( @@ -1406,9 +1404,9 @@ TEST_F(AlgebraicSimplifierTest, MinMaxToClamp) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMaximum, param0, max_value)); builder.AddInstruction( @@ -1437,9 +1435,9 @@ TEST_F(AlgebraicSimplifierTest, MinMaxWithBroadcastToClamp) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r1f32, "param0")); HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( r1f32, HloOpcode::kMaximum, param0, max_value)); builder.AddInstruction( @@ -1497,9 +1495,9 @@ TEST_F(AlgebraicSimplifierTest, MinEquationWithMaxNotToClamp) { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param0")); HloInstruction* min_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); HloInstruction* max_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); HloInstruction* max = builder.AddInstruction(HloInstruction::CreateBinary( r0f32, HloOpcode::kMaximum, param0, max_value)); HloInstruction* fmax = builder.AddInstruction( @@ -1566,7 +1564,7 @@ TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToSlice) { TEST_F(AlgebraicSimplifierTest, ScalarBroadcastToTransposeReshape) { HloComputation::Builder builder(TestName()); HloInstruction* forty_two = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); Shape broadcast_shape = ShapeUtil::MakeShape(F32, {4, 5, 6}); HloInstruction* broadcast = @@ -1614,7 +1612,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { padding.mutable_dimensions(3)->set_edge_padding_high(2); HloInstruction* pad_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( ShapeUtil::MakeShape(F32, {1, 3, 3, 5}), operand, pad_value, padding)); @@ -1645,7 +1643,7 @@ TEST_F(AlgebraicSimplifierTest, FoldPadIntoReduceWindow) { const Shape reduce_window_shape = ShapeUtil::MakeShape(F32, {111, 113, 113, 115}); HloInstruction* reduce_init_value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(5.0f))); HloInstruction* reduce_window = builder.AddInstruction(HloInstruction::CreateReduceWindow( reduce_window_shape, pad, reduce_init_value, window, @@ -1714,9 +1712,9 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { HloComputation::Builder call_builder(TestName() + ".Call"); HloInstruction* zero = call_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({0.0f}))); + HloInstruction::CreateConstant(Literal::CreateR1({0.0f}))); HloInstruction* one = call_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({1.0f}))); + HloInstruction::CreateConstant(Literal::CreateR1({1.0f}))); builder.AddInstruction( HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get())); diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc index 66d54ad3802..9abe30e3f37 100644 --- a/tensorflow/compiler/xla/service/backend.cc +++ b/tensorflow/compiler/xla/service/backend.cc @@ -22,7 +22,6 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/compiler/xla/legacy_flags/backend_flags.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -51,13 +50,6 @@ perftools::gputools::Platform* BackendOptions::platform() const { return platform_; } -BackendOptions& BackendOptions::set_number_of_replicas(int number_of_replicas) { - number_of_replicas_ = number_of_replicas; - return *this; -} - -int BackendOptions::number_of_replicas() const { return number_of_replicas_; } - BackendOptions& BackendOptions::set_intra_op_parallelism_threads( int num_threads) { intra_op_parallelism_threads_ = num_threads; @@ -85,20 +77,17 @@ struct Backend::EigenThreadPoolWrapper { /* static */ StatusOr> Backend::CreateBackend( const BackendOptions& options) { - int64 replica_count = options.number_of_replicas(); - if (replica_count == -1) { - legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags(); - replica_count = flags->xla_replicas; - } perftools::gputools::Platform* platform = options.platform(); TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform)); TF_ASSIGN_OR_RETURN(auto stream_executors, PlatformUtil::GetStreamExecutors(platform)); TF_ASSIGN_OR_RETURN(auto transfer_manager, TransferManager::GetForPlatform(platform)); + TF_ASSIGN_OR_RETURN(auto computation_placer, + ComputationPlacer::GetForPlatform(platform)); std::unique_ptr backend( - new Backend(replica_count, platform, compiler, stream_executors, - transfer_manager, options.intra_op_parallelism_threads())); + new Backend(platform, compiler, stream_executors, transfer_manager, + computation_placer, options.intra_op_parallelism_threads())); return std::move(backend); } @@ -132,34 +121,25 @@ StatusOr Backend::BorrowStream( } Backend::Backend( - int64 replica_count, perftools::gputools::Platform* platform, - Compiler* compiler, + perftools::gputools::Platform* platform, Compiler* compiler, tensorflow::gtl::ArraySlice stream_executors, - TransferManager* transfer_manager, int intra_op_parallelism_threads) + TransferManager* transfer_manager, ComputationPlacer* computation_placer, + int intra_op_parallelism_threads) : platform_(platform), compiler_(compiler), transfer_manager_(transfer_manager), - replica_count_(replica_count) { + computation_placer_(computation_placer) { // The given set of stream executors set may include invalid executors. for (se::StreamExecutor* exec : stream_executors) { if (exec != nullptr) { stream_executors_.push_back(exec); } } - CHECK_GE(replica_count, 1) << "Must request at least 1 replica."; - // Create a memory allocator for the valid stream executors. memory_allocator_ = MakeUnique(platform, stream_executors); - - // First check that there are some non-null stream executors to avoid issuing - // an error mentioning replicas in the common case of requesting just 1 - // replica, which means no replication. CHECK(!stream_executors_.empty()) << "Service found no devices for backend " << platform_->Name() << '.'; - CHECK_GE(stream_executors_.size(), replica_count) - << "Requested more replicas than there are devices for backend " - << platform_->Name() << '.'; if (platform->id() == se::host::kHostPlatformId) { inter_op_thread_pool_.reset(new tensorflow::thread::ThreadPool( @@ -179,36 +159,6 @@ int Backend::default_device_ordinal() const { return default_stream_executor()->device_ordinal(); } -StatusOr> Backend::Replicas( - int device_ordinal) const { - if (stream_executors_[device_ordinal] == nullptr) { - return InvalidArgument("device %s not supported by XLA service", - device_name(device_ordinal).c_str()); - } - - // Find replica_count_ stream executors starting from the given device - // ordinal. - std::vector replicas; - for (se::StreamExecutor* exec : stream_executors_) { - CHECK(exec != nullptr); - if (exec->device_ordinal() >= device_ordinal) { - replicas.push_back(exec); - if (replicas.size() >= replica_count_) { - return replicas; - } - } - } - - return InvalidArgument( - "Not enough devices for replicas for the device ordinal %d", - device_ordinal); -} - -std::vector Backend::Replicas() const { - CHECK_GE(stream_executors_.size(), replica_count_); - return Replicas(default_device_ordinal()).ValueOrDie(); -} - tensorflow::thread::ThreadPool* Backend::inter_op_thread_pool() const { return inter_op_thread_pool_.get(); } diff --git a/tensorflow/compiler/xla/service/backend.h b/tensorflow/compiler/xla/service/backend.h index e0b15dc43f2..b5ca483b727 100644 --- a/tensorflow/compiler/xla/service/backend.h +++ b/tensorflow/compiler/xla/service/backend.h @@ -22,6 +22,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" #include "tensorflow/compiler/xla/service/pool.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -46,12 +47,6 @@ class BackendOptions { BackendOptions& set_platform(perftools::gputools::Platform* platform); perftools::gputools::Platform* platform() const; - // Set the number of replicas to use when compiling replicated - // programs. The default is -1 meaning that the value is read from - // the xla_replicas flag. - BackendOptions& set_number_of_replicas(int number_of_replicas); - int number_of_replicas() const; - // Sets the thread pool size for parallel execution of an individual operator. // The default value of -1 will result in initializing the thread pool with // the number of threads equal to the number of cores in the system. @@ -60,7 +55,6 @@ class BackendOptions { private: perftools::gputools::Platform* platform_ = nullptr; - int number_of_replicas_ = -1; int intra_op_parallelism_threads_ = -1; }; @@ -74,8 +68,7 @@ class Backend { public: using StreamPtr = Pool::SmartPtr; - // Creates a new backend for the given platform with the given number of - // replicas. + // Creates a new backend. static StatusOr> CreateBackend( const BackendOptions& options); @@ -92,6 +85,7 @@ class Backend { return memory_allocator_.get(); } TransferManager* transfer_manager() const { return transfer_manager_; } + ComputationPlacer* computation_placer() const { return computation_placer_; } // Returns the number of devices of the platform type which are visible. Not // all of these devices may be usable by XLA. @@ -107,24 +101,13 @@ class Backend { return stream_executors_; } - // Returns the replicas for the default stream executor. - // - // When the number of replicas is R, the first R stream executors are assigned - // to the replicas of the default stream executor. - std::vector Replicas() const; - - // Returns the replicas for the given device_ordinal. The given device ordinal - // is considered to be the first device ordinal among the replicas. Returns an - // error status if the stream executor for the given given device ordinal does - // not exist or if there are not enough stream executors for the replicas. - StatusOr> Replicas( - int device_ordinal) const; - - // Return the stream executor for the given device ordinal. + // Returns the stream executor for the given device ordinal. StatusOr stream_executor( int device_ordinal) const; - // Return the stream executor for the default device ordinal. + // Returns the stream executor for the default device ordinal. This stream + // executor can only be used when the number of computations is 1 (replication + // can be > 1). perftools::gputools::StreamExecutor* default_stream_executor() const { CHECK(!stream_executors_.empty()); return stream_executors_[0]; @@ -174,18 +157,19 @@ class Backend { private: struct EigenThreadPoolWrapper; - Backend(int64 replica_count, perftools::gputools::Platform* platform, - Compiler* compiler, + Backend(perftools::gputools::Platform* platform, Compiler* compiler, tensorflow::gtl::ArraySlice stream_executors, - TransferManager* transfer_manager, int intra_op_parallelism_threads); + TransferManager* transfer_manager, + ComputationPlacer* computation_placer, + int intra_op_parallelism_threads); Backend(const Backend&) = delete; Backend& operator=(const Backend&) = delete; perftools::gputools::Platform* platform_; Compiler* compiler_; TransferManager* transfer_manager_; - int64 replica_count_ = -1; + ComputationPlacer* computation_placer_; // Vector of stream executors. stream_executors_[0] is the default executor. std::vector stream_executors_; diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index f91eb0207a2..44b4f4e3d8d 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -1074,7 +1074,8 @@ void BufferAssigner::AddSetToColocatedBufferSets( // different while instructions. void BufferAssigner::AddWhileSetToColocatedBufferSets( const std::vector& colocated_set, - const LogicalBuffer* while_init_buffer, const HloInstruction* while_hlo, + const LogicalBuffer* while_init_buffer, + const LogicalBuffer* while_result_buffer, const HloInstruction* while_hlo, const HloComputation& computation, const BufferLiveness& buffer_liveness, const LogicalBuffer::SizeFunction& buffer_size, std::vector* colocated_buffer_sets) { @@ -1137,16 +1138,30 @@ void BufferAssigner::AddWhileSetToColocatedBufferSets( continue; } - // Skip predecessor set if the live range of any predecessor buffers - // overlaps with 'while_init_buffer'. Note that tuple element buffer - // forwarding can cause the same buffer to appear on both sides of the - // interference comparison below. - if (std::any_of( - predecessor_while_buffers.begin(), predecessor_while_buffers.end(), - [while_init_buffer, &buffer_liveness](const LogicalBuffer* buffer) { - return while_init_buffer->id() != buffer->id() && - buffer_liveness.MayInterfere(*while_init_buffer, *buffer); - })) { + // Skip predecessor set if the live range of any predecessor + // buffers overlaps with 'while_init_buffer' or + // 'while_result_buffer' (we need to check both since they're + // aliased together, but the points-to analysis is unaware of this + // aliasing). Note that tuple element buffer forwarding can cause + // the same buffer to appear on both sides of the interference + // comparison below. + auto may_interfere_with_init_or_result = [&](const LogicalBuffer* buffer) { + if (while_init_buffer->id() != buffer->id() && + buffer_liveness.MayInterfere(*while_init_buffer, *buffer)) { + return true; + } + + if (while_result_buffer->id() != buffer->id() && + buffer_liveness.MayInterfere(*while_result_buffer, *buffer)) { + return true; + } + + return false; + }; + + if (std::any_of(predecessor_while_buffers.begin(), + predecessor_while_buffers.end(), + may_interfere_with_init_or_result)) { continue; } @@ -1209,8 +1224,8 @@ void BufferAssigner::BuildColocatedBufferSets( AddBufferToColocatedSet(while_hlo->operand(0), index, points_to_analysis, &colocated_set); // Add while.result. - AddBufferToColocatedSet(while_hlo, index, points_to_analysis, - &colocated_set); + auto* result_buffer = AddBufferToColocatedSet( + while_hlo, index, points_to_analysis, &colocated_set); // Add while.cond.parameter. AddBufferToColocatedSet( while_hlo->while_condition()->parameter_instruction(0), index, @@ -1224,8 +1239,9 @@ void BufferAssigner::BuildColocatedBufferSets( while_hlo->while_body()->root_instruction(), index, points_to_analysis, &colocated_set); AddWhileSetToColocatedBufferSets( - colocated_set, init_buffer, while_hlo, *computation, - buffer_liveness, buffer_size, colocated_buffer_sets); + colocated_set, init_buffer, result_buffer, while_hlo, + *computation, buffer_liveness, buffer_size, + colocated_buffer_sets); }); } else if (opcode == HloOpcode::kCall) { const HloInstruction* call_hlo = instruction; diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index b3933f11c1e..dd84b06b779 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -511,7 +511,8 @@ class BufferAssigner { // colocated buffers for while instructions. void AddWhileSetToColocatedBufferSets( const std::vector& colocated_set, - const LogicalBuffer* while_init_buffer, const HloInstruction* while_hlo, + const LogicalBuffer* while_init_buffer, + const LogicalBuffer* while_result_buffer, const HloInstruction* while_hlo, const HloComputation& computation, const BufferLiveness& buffer_liveness, const LogicalBuffer::SizeFunction& buffer_size, std::vector* colocated_buffer_sets); diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 892f67a8812..d69a78729ae 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -105,7 +105,7 @@ class BufferAssignmentTest : public HloTestBase { auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); auto value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param, value)); return builder.Build(); @@ -122,7 +122,7 @@ class BufferAssignmentTest : public HloTestBase { const string& name) { auto builder = HloComputation::Builder(name); auto const4 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(4))); + HloInstruction::CreateConstant(Literal::CreateR0(4))); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); auto index = builder.AddInstruction( @@ -147,9 +147,9 @@ class BufferAssignmentTest : public HloTestBase { const string& name) { auto builder = HloComputation::Builder(name); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); auto constv = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, t_s32_f32v4_, "x")); auto indexc = builder.AddInstruction( @@ -264,7 +264,7 @@ static bool BuffersDistinct(const std::vector& a, TEST_F(BufferAssignmentTest, ScalarConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -278,9 +278,9 @@ TEST_F(BufferAssignmentTest, BufferForConst) { // no buffers assigned, and their consumer has a buffer. auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({4.1f, 4.2f, 4.3f, 4.4f}))); + Literal::CreateR1({4.1f, 4.2f, 4.3f, 4.4f}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, const0, const1)); auto module = CreateNewModule(); @@ -298,7 +298,7 @@ TEST_F(BufferAssignmentTest, BufferForOutputConst) { // This computation copies a constant to output. auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto copy = builder.AddInstruction( HloInstruction::CreateUnary(const0->shape(), HloOpcode::kCopy, const0)); auto module = CreateNewModule(); @@ -586,7 +586,7 @@ TEST_F(BufferAssignmentTest, CannotReuseInputBufferOfReduce) { auto exp2 = builder.AddInstruction( HloInstruction::CreateUnary(f32a100x10_, HloOpcode::kExp, exp1)); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); auto reduce = builder.AddInstruction(HloInstruction::CreateReduce( /*shape=*/f32vec10_, /*operand=*/exp2, @@ -634,9 +634,9 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { // Creates the main kernel and verifies instruction counts. auto builder = HloComputation::Builder(TestName()); auto const3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); auto const4 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); + Literal::CreateR1({1.1f, 2.2f, 3.3f, 4.4f}))); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({const3, const4})); auto while_op = builder.AddInstruction(HloInstruction::CreateWhile( @@ -1075,9 +1075,8 @@ TEST_F(BufferAssignmentTest, DISABLED_TupleConstantAsOutput) { // Test that a tuple constant which is forwarded to the computation output is // properly handled. auto builder = HloComputation::Builder(TestName()); - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), - LiteralUtil::CreateR0(1).get()}))); + builder.AddInstruction(HloInstruction::CreateConstant(Literal::MakeTuple( + {Literal::CreateR0(0).get(), Literal::CreateR0(1).get()}))); auto module = CreateNewModule(); module->AddEntryComputation(builder.Build()); @@ -1369,9 +1368,9 @@ class WhileBufferAssignmentTest : public HloTestBase { builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); auto ten = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))); + HloInstruction::CreateConstant(Literal::CreateR0(10))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, zero, ten)); return builder.Build(); @@ -1429,7 +1428,7 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { HloInstruction::CreateParameter(2, data_shape_, "weights1")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); auto output1 = builder.AddInstruction( @@ -1484,7 +1483,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { HloInstruction::CreateParameter(1, data_shape_, "weights0")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); auto output1 = builder.AddInstruction( @@ -1532,16 +1531,16 @@ TEST_F(BufferAssignmentTest, TwoCalls) { auto param = builder.AddInstruction( HloInstruction::CreateParameter(0, r0f32, "param")); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, constant1)); sub_computation = module->AddEmbeddedComputation(builder.Build(add)); } auto builder = HloComputation::Builder(TestName()); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto call1 = builder.AddInstruction( HloInstruction::CreateCall(r0f32, {constant2}, sub_computation)); auto call2 = builder.AddInstruction( @@ -1565,6 +1564,104 @@ TEST_F(BufferAssignmentTest, TwoCalls) { EXPECT_TRUE(BuffersDistinct({call1}, {call2}, *assignment)); } +static bool IsPostOrderTraversal( + const std::vector& sequence) { + tensorflow::gtl::FlatSet seen_so_far; + auto has_not_been_seen_yet = [&](const HloInstruction* instruction) { + return seen_so_far.count(instruction) == 0; + }; + + for (auto instruction : sequence) { + if (std::any_of(instruction->operands().begin(), + instruction->operands().end(), has_not_been_seen_yet) || + std::any_of(instruction->control_predecessors().begin(), + instruction->control_predecessors().end(), + has_not_been_seen_yet)) { + return false; // Not a post order. + } + if (!seen_so_far.insert(instruction).second) { + return false; // Not a "traversal". + } + } + + return true; +} + +TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { + auto module = MakeUnique(TestName()); + auto builder = HloComputation::Builder(TestName()); + + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0))); + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + + auto input0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape_, "input0")); + auto weights0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "weights0")); + auto output0 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + + auto input1 = builder.AddInstruction( + HloInstruction::CreateParameter(2, data_shape_, "input1")); + auto weights1 = builder.AddInstruction( + HloInstruction::CreateParameter(3, data_shape_, "weights1")); + auto output1 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, one, {1})); + + auto cond = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body = module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + + auto tuple0 = builder.AddInstruction( + HloInstruction::CreateTuple({input0, weights0, output0})); + auto tuple1 = builder.AddInstruction( + HloInstruction::CreateTuple({input1, weights1, output1})); + + auto while0 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple0)); + auto while1 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond, body, tuple1)); + + auto root_add = builder.AddInstruction(HloInstruction::CreateBinary( + while0->shape(), HloOpcode::kAdd, while0, while1)); + module->AddEntryComputation(builder.Build()); + + RunCopyInsertion(module.get()); + + { + FlattenCallGraph flatten; + TF_ASSIGN_OR_ASSERT_OK(bool result, flatten.Run(module.get())); + EXPECT_TRUE(result); + } + + auto sequence = + CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie(); + + // To trigger b/38494731, we want a specific Hlo sequence for the + // root computation, so we overwrite that entry with a manually + // crafted sequence. + std::vector sequence_for_buffer_assigment = { + input1, weights1, one, output1, tuple1, while1, input0, + weights0, zero, output0, tuple0, while0, root_add}; + + // If this ASSERT_TRUE fails, we constructed a bogus sequence above + // and this test itself is buggy. + ASSERT_TRUE(IsPostOrderTraversal(sequence_for_buffer_assigment)); + + sequence[module->entry_computation()] = + std::move(sequence_for_buffer_assigment); + + auto assignment = BufferAssigner::Run(module.get(), + MakeUnique( + module.get(), sequence), + ByteSizeOf, 1) + .ConsumeValueOrDie(); + + EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); +} + // Test buffer assignment for while nodes with multiple uses. // TODO(b/37245345): Fix buffer assignment for this case. TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) { @@ -1577,7 +1674,7 @@ TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) { HloInstruction::CreateParameter(1, data_shape_, "weights0")); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0))); auto output0 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 1b14c26340f..56b8b5e80b7 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -122,7 +122,7 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, if (b.instruction()->IsUserOf(alias.instruction()) && !CanShareOperandBufferWithUser(alias.instruction(), alias.index(), b.instruction(), b.index(), - points_to_analysis())) { + &points_to_analysis())) { return false; } } diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index fda44ff4d2d..a5f7cc0aebe 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -397,13 +397,11 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) { // computation. The buffer containing {0, 1} is copied by GetTupleElement, and // the buffers containing {3} and 3 are dead. auto builder = HloComputation::Builder(TestName()); - auto inner_tuple0 = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(0).get(), - LiteralUtil::CreateR0(1).get()}); - auto inner_tuple1 = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(3).get()}); + auto inner_tuple0 = Literal::MakeTuple( + {Literal::CreateR0(0).get(), Literal::CreateR0(1).get()}); + auto inner_tuple1 = Literal::MakeTuple({Literal::CreateR0(3).get()}); auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::MakeTuple({inner_tuple0.get(), inner_tuple1.get()}))); + Literal::MakeTuple({inner_tuple0.get(), inner_tuple1.get()}))); builder.AddInstruction(HloInstruction::CreateGetTupleElement( inner_tuple0->shape(), tuple_constant, 0)); @@ -450,7 +448,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple_element0_shape, tuple_param0, 0)); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0)); @@ -462,7 +460,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple_element1_shape, tuple_param0, 1)); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}))); + Literal::CreateR1({2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f}))); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( tuple_element1_shape, HloOpcode::kAdd, tuple_element1, const1)); @@ -513,7 +511,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) { builder.AddInstruction(HloInstruction::CreateGetTupleElement( tuple_element0_shape, tuple_param0, 0)); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( tuple_element0_shape, HloOpcode::kAdd, tuple_element0, const0)); @@ -585,7 +583,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1)); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + Literal::CreateR1({2.f, 2.f, 2.f}))); HloInstruction* slice = nullptr; if (update_uses_tuple_element1) { // Create a slice instruction as an additional user of 'gte1'. @@ -596,7 +594,7 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { } // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(Literal::CreateR1({2}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -715,7 +713,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { HloInstruction::CreateGetTupleElement(data_shape, tuple_param0, 1)); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + Literal::CreateR1({2.f, 2.f, 2.f}))); if (tuple_element1_has_two_uses) { // Add 'gte0' and 'gte1' to create another user of 'gte1'. @@ -724,7 +722,7 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { } // Create a DynamicUpdateSlice instruction of tuple element 1 with 'update'. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(Literal::CreateR1({2}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index fa7b2a30952..b450e0c4007 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -133,6 +133,37 @@ CallGraphNode& CallGraph::GetNode(const HloComputation* computation) { return nodes_[it->second]; } +bool CallGraph::DominatesHelper( + const HloComputation* a, const HloComputation* b, + tensorflow::gtl::FlatSet* visited) const { + if (a == b || ContainsKey(*visited, b)) { + // The call graph is guaranteed to be acyclic so any previously visited node + // we encounter was already determined to be dominated. + return true; + } + + const CallGraphNode& b_node = GetNode(b); + if (b_node.callers().empty()) { + // We reached a root node without hitting 'a'. 'a' does not dominate 'b'. + return false; + } + + // Walk up the callers of 'b' until we hit 'a' or a root node (no callers). + visited->insert(b); + for (const HloComputation* b_caller : b_node.callers()) { + if (!DominatesHelper(a, b_caller, visited)) { + return false; + } + } + return true; +} + +bool CallGraph::Dominates(const HloComputation* a, + const HloComputation* b) const { + tensorflow::gtl::FlatSet visited; + return DominatesHelper(a, b, &visited); +} + namespace { // Returns the call context of a computation which is called from contexts 'a' diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h index 7f9990f06d4..a3297ff534f 100644 --- a/tensorflow/compiler/xla/service/call_graph.h +++ b/tensorflow/compiler/xla/service/call_graph.h @@ -189,6 +189,20 @@ class CallGraph { Status VisitNodes(const VisitorFunction& visitor_func, bool visit_unreachable_nodes = true) const; + // Returns true if 'a' dominates 'b' in the call graph. Computation 'a' + // dominates computation 'b' iff all callgraph paths in the caller-to-callee + // direction from a root computation to 'b' pass through computation + // 'a'. Trivially, a computation dominates itself. + bool Dominates(const HloComputation* a, const HloComputation* b) const; + + // Returns whether 'instruction' is contained in 'computation' either directly + // ('instruction->parent' is 'computation') or indirectly ('computation' + // dominates 'instruction->parent' in the call graph). + bool InstructionIsNestedIn(const HloInstruction* instruction, + const HloComputation* computation) const { + return Dominates(computation, instruction->parent()); + } + string ToString() const; private: @@ -205,6 +219,13 @@ class CallGraph { const VisitorFunction& visitor_func, const CallGraphNode& node, tensorflow::gtl::FlatSet* visited) const; + // Recursive helper for computing whether 'a' dominates 'b' in the call + // graph. 'b_ancestor' is the currently visited node (which starts at 'b'), + // and 'visited' is the set of computations which have been visited. + bool DominatesHelper( + const HloComputation* a, const HloComputation* b, + tensorflow::gtl::FlatSet* visited) const; + // The HLO module represented by this call graph. const HloModule* module_ = nullptr; diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index e276473c90a..3c22871b3bf 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -81,7 +81,7 @@ class CallGraphTest : public HloTestBase { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, kScalarShape, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); return builder.Build(); @@ -314,6 +314,37 @@ TEST_F(CallGraphTest, ComplexGraph) { EXPECT_LT(index_of(cond_computation), index_of(a_computation)); EXPECT_LT(index_of(c_computation), index_of(b_computation)); EXPECT_LT(index_of(b_computation), index_of(a_computation)); + + // Verify dominance relations between computation in the graph. + + // Entry dominates everybody, and is dominated by no one except itself. + EXPECT_TRUE(call_graph->Dominates(entry_computation, entry_computation)); + EXPECT_TRUE(call_graph->Dominates(entry_computation, a_computation)); + EXPECT_TRUE(call_graph->Dominates(entry_computation, b_computation)); + EXPECT_TRUE(call_graph->Dominates(entry_computation, c_computation)); + EXPECT_TRUE(call_graph->Dominates(entry_computation, cond_computation)); + EXPECT_FALSE(call_graph->Dominates(a_computation, entry_computation)); + EXPECT_FALSE(call_graph->Dominates(b_computation, entry_computation)); + EXPECT_FALSE(call_graph->Dominates(c_computation, entry_computation)); + EXPECT_FALSE(call_graph->Dominates(cond_computation, entry_computation)); + + // 'a' only dominates 'b' and 'c'. + EXPECT_TRUE(call_graph->Dominates(a_computation, a_computation)); + EXPECT_TRUE(call_graph->Dominates(a_computation, b_computation)); + EXPECT_TRUE(call_graph->Dominates(a_computation, c_computation)); + EXPECT_FALSE(call_graph->Dominates(b_computation, a_computation)); + EXPECT_FALSE(call_graph->Dominates(c_computation, a_computation)); + EXPECT_FALSE(call_graph->Dominates(a_computation, cond_computation)); + + EXPECT_TRUE(call_graph->Dominates(b_computation, b_computation)); + EXPECT_FALSE(call_graph->Dominates(b_computation, c_computation)); + EXPECT_FALSE(call_graph->Dominates(b_computation, cond_computation)); + + EXPECT_TRUE(call_graph->Dominates(c_computation, c_computation)); + EXPECT_FALSE(call_graph->Dominates(c_computation, cond_computation)); + EXPECT_FALSE(call_graph->Dominates(cond_computation, c_computation)); + + EXPECT_TRUE(call_graph->Dominates(cond_computation, cond_computation)); } TEST_F(CallGraphTest, VisitSingletonComputation) { diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 0d1a439724a..cfcf2744fba 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -52,14 +52,16 @@ CompileOnlyService::NewService(const ServiceOptions& options) { TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, CreateComputeConstantBackend()); - std::unique_ptr service( - new CompileOnlyService(compiler, std::move(compute_constant_backend))); + std::unique_ptr service(new CompileOnlyService( + options, compiler, std::move(compute_constant_backend))); return std::move(service); } CompileOnlyService::CompileOnlyService( - Compiler* compiler, std::unique_ptr compute_constant_backend) - : Service(/*backend=*/nullptr, std::move(compute_constant_backend)), + const ServiceOptions& options, Compiler* compiler, + std::unique_ptr compute_constant_backend) + : Service(options, /*backend=*/nullptr, + std::move(compute_constant_backend)), compiler_(compiler) { runs_in_client_process_ = true; } diff --git a/tensorflow/compiler/xla/service/compile_only_service.h b/tensorflow/compiler/xla/service/compile_only_service.h index 3358305c03c..0a1911cbd15 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.h +++ b/tensorflow/compiler/xla/service/compile_only_service.h @@ -103,7 +103,8 @@ class CompileOnlyService : public Service { private: explicit CompileOnlyService( - Compiler* compiler, std::unique_ptr compute_constant_backend); + const ServiceOptions& options, Compiler* compiler, + std::unique_ptr compute_constant_backend); CompileOnlyService(const CompileOnlyService&) = delete; void operator=(const CompileOnlyService&) = delete; diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc new file mode 100644 index 00000000000..cdf277581f4 --- /dev/null +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -0,0 +1,151 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/computation_placer.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +Status DeviceAssignment::Serialize(DeviceAssignmentProto* proto) const { + proto->set_replica_count(replica_count()); + proto->set_computation_count(computation_count()); + for (int computation = 0; computation < computation_count(); ++computation) { + DeviceAssignmentProto::ComputationDevice* computation_device = + proto->add_computation_devices(); + for (int replica = 0; replica < replica_count(); ++replica) { + computation_device->add_replica_device_ids((*this)(replica, computation)); + } + } + return Status::OK(); +} + +/* static */ StatusOr DeviceAssignment::Deserialize( + const DeviceAssignmentProto& proto) { + TF_RET_CHECK(proto.computation_devices_size() == proto.computation_count()); + DeviceAssignment assignment(proto.replica_count(), proto.computation_count()); + for (int computation = 0; computation < proto.computation_count(); + ++computation) { + const auto& computation_device = proto.computation_devices(computation); + TF_RET_CHECK(computation_device.replica_device_ids_size() == + proto.replica_count()); + for (int replica = 0; replica < proto.replica_count(); ++replica) { + assignment(replica, computation) = + computation_device.replica_device_ids(replica); + } + } + return std::move(assignment); +} + +StatusOr ComputationPlacer::DeviceId(int replica, int computation, + int replica_count, + int computation_count) { + TF_RET_CHECK(replica < replica_count); + TF_RET_CHECK(computation < computation_count); + + return computation * replica_count + replica; +} + +StatusOr ComputationPlacer::AssignDevices( + int replica_count, int computation_count) { + DeviceAssignment assignment(replica_count, computation_count); + for (int replica = 0; replica < replica_count; ++replica) { + for (int computation = 0; computation < computation_count; ++computation) { + TF_ASSIGN_OR_RETURN( + int device_id, + DeviceId(replica, computation, replica_count, computation_count)); + assignment(replica, computation) = device_id; + } + } + return std::move(assignment); +} + +/* static */ void ComputationPlacer::RegisterComputationPlacer( + se::Platform::Id platform_id, + ComputationPlacerCreationFunction creation_function) { + tensorflow::mutex_lock lock( + *ComputationPlacer::platform_computation_placer_mutex()); + auto* computation_placers = GetPlatformComputationPlacers(); + CHECK(computation_placers->find(platform_id) == computation_placers->end()); + (*computation_placers)[platform_id].creation_function = creation_function; +} + +/* static */ StatusOr ComputationPlacer::GetForPlatform( + const se::Platform* platform) { + tensorflow::mutex_lock lock( + *ComputationPlacer::platform_computation_placer_mutex()); + auto* computation_placers = GetPlatformComputationPlacers(); + + auto it = computation_placers->find(platform->id()); + if (it == computation_placers->end()) { + return NotFound( + "could not find registered computation placer for platform %s -- check " + "target linkage", + platform->Name().c_str()); + } + + if (it->second.placer == nullptr) { + // Lazily create the computation placer the first time it is needed. + it->second.placer = (*it->second.creation_function)(); + } + + return it->second.placer.get(); +} + +/* static */ tensorflow::mutex* +ComputationPlacer::platform_computation_placer_mutex() { + static tensorflow::mutex* m = new tensorflow::mutex; + return m; +} + +/* static */ std::map* +ComputationPlacer::GetPlatformComputationPlacers() { + static auto* r = + new std::map; + return r; +} + +} // namespace xla + +static std::unique_ptr CreateComputationPlacer() { + return xla::MakeUnique(); +} + +static bool InitModule() { + xla::ComputationPlacer::RegisterComputationPlacer(se::host::kHostPlatformId, + &CreateComputationPlacer); + xla::ComputationPlacer::RegisterComputationPlacer(se::cuda::kCudaPlatformId, + &CreateComputationPlacer); + return true; +} +static bool module_initialized = InitModule(); diff --git a/tensorflow/compiler/xla/service/computation_placer.h b/tensorflow/compiler/xla/service/computation_placer.h new file mode 100644 index 00000000000..4d26d6bb85f --- /dev/null +++ b/tensorflow/compiler/xla/service/computation_placer.h @@ -0,0 +1,113 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Class that represents the device assignment for a set of XLA replicated +// computations. For R replicas and C computations, R * C devices are required +// execute the computation in parallel. The assigned device ids can be accessed +// by assignment(replica, computation). +class DeviceAssignment : public Array2D { + public: + DeviceAssignment() {} + DeviceAssignment(int replica_count, int computation_count) + : Array2D(replica_count, computation_count, -1) { + CHECK_GT(replica_count, 0); + CHECK_GT(computation_count, 0); + } + + int replica_count() const { return height(); } + int computation_count() const { return width(); } + + // Protocol buffer serialization and deserialization. + Status Serialize(DeviceAssignmentProto* proto) const; + static StatusOr Deserialize( + const DeviceAssignmentProto& proto); +}; + +// A generic implementation of the XLA computation placer, which assigns device +// ids to a set of replicated computations. +class ComputationPlacer { + public: + ComputationPlacer() {} + virtual ~ComputationPlacer() {} + + // Returns the device id assigned to the given replica and computation + // instance for [replica_count x computation_count] setup. The returned device + // id must match the assignement from PlaceReplicatedComputation(). + virtual StatusOr DeviceId(int replica, int computation, + int replica_count, int computation_count); + + // Returns the device ids assigned to a set of replicated computations, given + // the number of replicas and the number of computations. + virtual StatusOr AssignDevices(int replica_count, + int computation_count); + + using ComputationPlacerCreationFunction = + std::unique_ptr (*)(); + + // Registers a computation placer creation function for a particular platform. + static void RegisterComputationPlacer( + perftools::gputools::Platform::Id platform_id, + ComputationPlacerCreationFunction creation_function); + + // Returns the computation placer singleton pointer if it is available for the + // given platform, or an error status if it is not. + static StatusOr GetForPlatform( + const perftools::gputools::Platform* platform); + + private: + // Routine that returns the mutex that guards the platform-to-computation + // placer map. Done as a routine to ensure correct initialization ordering, + // since RegisterComputationPlacer can be called during program initialization + // time. + static tensorflow::mutex* platform_computation_placer_mutex(); + + // State kept for each kind of ComputationPlacer. Registration functions set + // up creation_function, and then we use that to lazily create "placer" the + // first time GetForPlatform is invoked for a particular id. + struct State { + std::unique_ptr placer; + ComputationPlacerCreationFunction creation_function = nullptr; + }; + + // Map from platform kind to computation placer singleton. + static std::map* + GetPlatformComputationPlacers(); + + perftools::gputools::Platform::Id platform_id_; + + TF_DISALLOW_COPY_AND_ASSIGN(ComputationPlacer); +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COMPUTATION_PLACER_H_ diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index cc77339bb63..026be75757a 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -87,7 +87,7 @@ TEST_F(CopyInsertionTest, SingleParameter) { TEST_F(CopyInsertionTest, SingleConstant) { auto builder = HloComputation::Builder(TestName()); HloInstruction* constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); HloInstruction* tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant})); @@ -110,9 +110,9 @@ TEST_F(CopyInsertionTest, MultipleConstantsAndParameters) { auto builder = HloComputation::Builder(TestName()); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); HloInstruction* x = builder.AddInstruction( HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "x")); @@ -140,11 +140,11 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { // the computation result. Verify that copies are added properly. auto builder = HloComputation::Builder(TestName()); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); HloInstruction* constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); HloInstruction* tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -152,7 +152,7 @@ TEST_F(CopyInsertionTest, AmbiguousPointsToSet) { HloInstruction::CreateTuple({constant3, constant2})); HloInstruction* pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); @@ -196,9 +196,8 @@ TEST_F(CopyInsertionTest, BitcastConstant) { // The output of a bitcast is its operand (same buffer), so a bitcast // constant feeding the result must have a copy added. auto builder = HloComputation::Builder(TestName()); - HloInstruction* constant = - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.0, 42.0}))); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({1.0, 42.0}))); HloInstruction* bitcast = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 2}), HloOpcode::kBitcast, constant)); @@ -308,9 +307,9 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { // copy is added. auto builder = HloComputation::Builder(TestName()); HloInstruction* constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); HloInstruction* constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); HloInstruction* tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -318,7 +317,7 @@ TEST_F(CopyInsertionTest, AmbiguousTopLevelRoot) { HloInstruction::CreateTuple({constant2, constant1})); HloInstruction* pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloInstruction* select = builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); HloInstruction* gte = @@ -350,7 +349,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { bool nested = false) { auto builder = HloComputation::Builder(TestName() + ".Condition"); auto limit_const = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(10))); + HloInstruction::CreateConstant(Literal::CreateR0(10))); const Shape& loop_state_shape = nested ? nested_loop_state_shape_ : loop_state_shape_; auto loop_state = builder.AddInstruction( @@ -381,7 +380,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); // Update data GTE(1). @@ -419,7 +418,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); // add0 = Add(in0, 1) auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -488,7 +487,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); // add0 = Add(in0, 1) auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); @@ -503,9 +502,8 @@ class WhileCopyInsertionTest : public CopyInsertionTest { data = builder.AddInstruction( HloInstruction::CreateGetTupleElement(data_shape_, loop_state, 1)); } - auto update = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1( - {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto update = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); // add1 = Add(in1, {1, 1, 1, 1, 1, 1, 1, 1}) auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data, update)); @@ -538,7 +536,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto gte0 = builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, 0)); auto inc = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( gte0->shape(), HloOpcode::kAdd, gte0, inc)); @@ -548,9 +546,8 @@ class WhileCopyInsertionTest : public CopyInsertionTest { // GTE(GTE(loop_state, 1), 0) -> Add auto gte10 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(data_shape_, gte1, 0)); - auto update10 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1( - {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto update10 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); auto add10 = builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, gte10, update10)); @@ -574,11 +571,10 @@ class WhileCopyInsertionTest : public CopyInsertionTest { bool nested = false) { auto builder = HloComputation::Builder(TestName() + ".While"); auto induction_var_init = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); - auto data_init = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1( - {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); if (nested) { auto inner_init = builder.AddInstruction( @@ -601,9 +597,8 @@ class WhileCopyInsertionTest : public CopyInsertionTest { HloInstruction* BuildWhileInstruction_InitPointsToConstant() { auto builder = HloComputation::Builder(TestName() + ".While"); - auto data_init = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1( - {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); return BuildWhileInstructionWithCustomInit(loop_state_shape_, data_init, &builder); } @@ -620,11 +615,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto builder = HloComputation::Builder(TestName() + ".While"); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto v1 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, one, {1})); auto zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto v2 = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, zero, {1})); @@ -632,7 +627,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto tuple2 = builder.AddInstruction(HloInstruction::CreateTuple({v2, v1})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto data_init = builder.AddInstruction(HloInstruction::CreateTernary( nested_tuple_shape_, HloOpcode::kSelect, pred, tuple1, tuple2)); @@ -644,7 +639,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { auto builder = HloComputation::Builder(TestName() + ".While"); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto one_vec = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, one, {1})); auto data_init = @@ -657,12 +652,11 @@ class WhileCopyInsertionTest : public CopyInsertionTest { HloInstruction* BuildWhileInstruction_InitPointsToInterfering() { auto builder = HloComputation::Builder(TestName() + ".While"); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto data_init = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape_, one, {1})); - auto one_vec = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1( - {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); + auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}))); // Take a reference to 'data_init' to make it interfere with while result. builder.AddInstruction(HloInstruction::CreateBinary( data_shape_, HloOpcode::kAdd, data_init, one_vec)); @@ -677,7 +671,7 @@ class WhileCopyInsertionTest : public CopyInsertionTest { const bool nested = ShapeUtil::Equal(loop_state_shape, nested_loop_state_shape_); auto induction_var_init = builder->AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); auto condition = module_->AddEmbeddedComputation(BuildConditionComputation(nested)); auto body = module_->AddEmbeddedComputation( diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 51ecbccd494..de6660e3b5b 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -52,7 +52,6 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", @@ -151,9 +150,12 @@ cc_library( cc_library( name = "parallel_cpu_executable", srcs = ["parallel_cpu_executable.cc"], - hdrs = ["parallel_cpu_executable.h"], + hdrs = [ + "parallel_cpu_executable.h", + ], deps = [ ":cpu_runtime", + ":shape_partition", ":simple_orc_jit", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -177,12 +179,15 @@ cc_library( cc_library( name = "ir_emitter", srcs = ["ir_emitter.cc"], - hdrs = ["ir_emitter.h"], + hdrs = [ + "ir_emitter.h", + ], deps = [ ":cpu_runtime", ":dot_op_emitter", ":elemental_ir_emitter", ":ir_emission_utils", + ":shape_partition", ":simple_orc_jit", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -334,6 +339,7 @@ cc_library( copts = runtime_copts(), deps = [ "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", ], ) @@ -405,6 +411,7 @@ cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -437,10 +444,15 @@ cc_library( cc_library( name = "cpu_parallelization_preparation", srcs = ["cpu_parallelization_preparation.cc"], - hdrs = ["cpu_parallelization_preparation.h"], + hdrs = [ + "cpu_parallelization_preparation.h", + ], deps = [ + ":ir_emission_utils", + ":shape_partition", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_cost_analysis", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:logical_buffer", "//tensorflow/compiler/xla/service:tuple_points_to_analysis", @@ -502,6 +514,7 @@ cc_library( "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", ], ) @@ -511,6 +524,7 @@ cc_test( srcs = ["conv_canonicalization_test.cc"], deps = [ ":conv_canonicalization", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:hlo", @@ -518,6 +532,27 @@ cc_test( ], ) +cc_library( + name = "shape_partition", + srcs = ["shape_partition.cc"], + hdrs = ["shape_partition.h"], + deps = [ + "//tensorflow/compiler/xla/service:hlo", + ], +) + +cc_test( + name = "shape_partition_test", + srcs = ["shape_partition_test.cc"], + deps = [ + ":shape_partition", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test_main", + ], +) # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index f5ad431277d..ec992f15e63 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -59,11 +59,11 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { auto builder = HloComputation::Builder(TestName()); // The input dimensions are in CNHW order. auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR4FromArray4D(Array4D( + Literal::CreateR4FromArray4D(Array4D( kInputFeatureCount, kBatchSize, kInputSize, kInputSize)))); // The kernel dimensions are in OIHW order. auto kernel = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR4FromArray4D(Array4D( + Literal::CreateR4FromArray4D(Array4D( kOutputFeatureCount, kInputFeatureCount, kWindowSize, kWindowSize)))); ConvolutionDimensionNumbers dnums; @@ -113,11 +113,11 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { auto builder = HloComputation::Builder(TestName()); // The input dimensions are in NHWC order. auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR4FromArray4D(Array4D( + Literal::CreateR4FromArray4D(Array4D( kBatchSize, kInputSize, kInputSize, kInputFeatureCount)))); // The kernel dimensions are in HWIO order. auto kernel = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR4FromArray4D(Array4D( + Literal::CreateR4FromArray4D(Array4D( kWindowSize, kWindowSize, kInputFeatureCount, kOutputFeatureCount)))); ConvolutionDimensionNumbers dnums; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 34b99f2440b..4786e75fa7a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -37,7 +37,6 @@ limitations under the License. #include "external/llvm/include/llvm/Support/TargetSelect.h" #include "external/llvm/include/llvm/Target/TargetMachine.h" #include "external/llvm/include/llvm/Target/TargetOptions.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" @@ -285,8 +284,13 @@ Status CpuCompiler::RunHloPasses(HloModule* module, HloDumper dump_hlo) { /*enable_dot_simplification=*/false); pipeline.AddPass(/*is_layout_sensitive=*/true); // Outline ops in the entry computation into calls to subcomputations. + const int max_parallelism = + module->config().intra_op_parallelism_threads() > 0 + ? module->config().intra_op_parallelism_threads() + : tensorflow::port::NumSchedulableCPUs(); if (CpuParallelBackendRequested(module->config())) { - pipeline.AddPass(); + pipeline.AddPass(max_parallelism, + ShapeSizeBytesFunction()); } // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -299,7 +303,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module, HloDumper dump_hlo) { if (CpuParallelBackendRequested(module->config())) { // Re-run the outlining, in case any copies were inserted into the entry // computation. - pipeline.AddPass(); + pipeline.AddPass(max_parallelism, + ShapeSizeBytesFunction()); } pipeline.AddPass(); pipeline.AddPass(); @@ -367,7 +372,14 @@ StatusOr> CpuCompiler::Compile( } std::unique_ptr cpu_executable; - legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags(); + + // Cache this flag here since we'll want to access it after the module's + // ownership is std::moved. + const bool embed_ir_in_executable = + module->config().debug_options().xla_embed_ir_in_executable(); + const string dump_debug_json_to = + module->config().debug_options().xla_dump_debug_json_to(); + if (CpuParallelBackendRequested(module->config())) { // Run buffer analysis on the HLO graph. This analysis figures out which // temporary buffers are required to run the computation. @@ -381,10 +393,10 @@ StatusOr> CpuCompiler::Compile( MakeUnique(module.get()), BufferSizeBytesFunction(), kMemoryAlignment)); - if (!flags->xla_cpu_dump_debug_json_to.empty()) { + if (!dump_debug_json_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, flags->xla_cpu_dump_debug_json_to, module->name())); + proto, dump_debug_json_to, module->name())); } // If we are using the parallel CPU backend, we need to create map from @@ -400,7 +412,7 @@ StatusOr> CpuCompiler::Compile( if (instruction->opcode() == HloOpcode::kConstant) { // Copy the constant out of the ProtocolBuffer so that we can give it a // higher alignment. - const void* data = LiteralUtil::InternalData(instruction->literal()); + const void* data = instruction->literal().InternalData(); int64 size = CpuExecutable::ShapeSizeBytes(instruction->shape()); auto iter = aligned_constants.emplace( instruction, MakeUnique(size)); @@ -419,6 +431,7 @@ StatusOr> CpuCompiler::Compile( IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), &hlo_to_profile_idx); + std::unique_ptr> function_names( new std::map()); for (auto embedded_computation : @@ -446,7 +459,7 @@ StatusOr> CpuCompiler::Compile( } string ir_module_string; - if (flags->xla_cpu_embed_ir) { + if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); } @@ -457,7 +470,7 @@ StatusOr> CpuCompiler::Compile( std::move(function_names), std::move(hlo_to_profile_idx), std::move(aligned_constants))); - if (flags->xla_cpu_embed_ir) { + if (embed_ir_in_executable) { static_cast(*cpu_executable) .set_ir_module_string(ir_module_string); } @@ -478,10 +491,10 @@ StatusOr> CpuCompiler::Compile( MakeUnique(module.get(), module_sequence), BufferSizeBytesFunction(), kMemoryAlignment)); - if (!flags->xla_cpu_dump_debug_json_to.empty()) { + if (!dump_debug_json_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, flags->xla_cpu_dump_debug_json_to, module->name())); + proto, dump_debug_json_to, module->name())); } // Each computation is a single function. Emit all embedded computations @@ -490,6 +503,7 @@ StatusOr> CpuCompiler::Compile( // before a caller computation. IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), &hlo_to_profile_idx); + for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { TF_RETURN_IF_ERROR( @@ -510,7 +524,7 @@ StatusOr> CpuCompiler::Compile( string function_name = llvm_ir::AsString(entry_function->getName()); string ir_module_string; - if (flags->xla_cpu_embed_ir) { + if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); } @@ -520,7 +534,7 @@ StatusOr> CpuCompiler::Compile( std::move(jit), std::move(assignment), std::move(module), function_name, std::move(hlo_to_profile_idx))); - if (flags->xla_cpu_embed_ir) { + if (embed_ir_in_executable) { static_cast(*cpu_executable) .set_ir_module_string(ir_module_string); } @@ -642,11 +656,12 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, module, MakeUnique(module, module_sequence), BufferSizeBytesFunction(), kMemoryAlignment)); - legacy_flags::CpuCompilerFlags* flags = legacy_flags::GetCpuCompilerFlags(); - if (!flags->xla_cpu_dump_debug_json_to.empty()) { + const string dump_debug_json_to = + module->config().debug_options().xla_dump_debug_json_to(); + if (!dump_debug_json_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, flags->xla_cpu_dump_debug_json_to, module->name())); + proto, dump_debug_json_to, module->name())); } IrEmitter ir_emitter(*module, *assignment, &llvm_module, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index 29fa4eac61b..3b130a614e9 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -132,7 +132,7 @@ class CpuCompiler : public Compiler { // Runs the HLO passes which are necessary for both optimizations and // correctness. - Status RunHloPasses(HloModule* hlo_module, HloDumper dump_hlo); + Status RunHloPasses(HloModule* module, HloDumper dump_hlo); TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc index f6b1dcae75a..af931f7b013 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc @@ -15,19 +15,28 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" namespace xla { namespace cpu { StatusOr ParallelizationPreparation::Run(HloModule* module) { + XLA_VLOG_LINES(2, "ParallelizationPreparation ENTRY"); + XLA_VLOG_LINES(2, module->ToString()); + bool changed = false; + TF_ASSIGN_OR_RETURN(changed, RunParallelTaskAssignment(module)); + HloComputation* entry_computation = module->entry_computation(); std::unordered_set outlined; std::vector instructions_to_outline; @@ -44,13 +53,21 @@ StatusOr ParallelizationPreparation::Run(HloModule* module) { instruction->opcode() == HloOpcode::kConstant) { continue; } + + // Outline 'instruction' in isolation if it was assigned parallel tasks. + if (OutlineParallelizableInstruction(instruction)) { + outlined.insert(instruction); + changed = true; + continue; + } + instructions_to_outline.clear(); HloInstruction* outline_candidate = instruction; instructions_to_outline.push_back(outline_candidate); bool all_bitcasts = outline_candidate->opcode() == HloOpcode::kBitcast; // Outline sole users with the current instruction. - while (outline_candidate->users().size() == 1) { + while (CanOutlineWithUser(outline_candidate)) { HloInstruction* prior_candidate = outline_candidate; outline_candidate = *outline_candidate->users().begin(); all_bitcasts |= outline_candidate->opcode() == HloOpcode::kBitcast; @@ -120,8 +137,136 @@ StatusOr ParallelizationPreparation::Run(HloModule* module) { changed = true; } } + + XLA_VLOG_LINES(2, "ParallelizationPreparation EXIT"); + XLA_VLOG_LINES(2, module->ToString()); return changed; } +StatusOr ParallelizationPreparation::RunParallelTaskAssignment( + HloModule* module) { + VLOG(1) << "RunParallelTaskAssignment max_parallelism_: " << max_parallelism_; + bool changed = false; + // Run cost analysis on entry computation. + HloCostAnalysis cost_analysis(shape_size_); + HloComputation* computation = module->entry_computation(); + Status cost_status = computation->root_instruction()->Accept(&cost_analysis); + for (auto& instruction : computation->instructions()) { + // Currently, we do not assign parallel tasks to instructions with at least + // one of the following properties: + // *) Internal threading (library calls to kConv, kDot, and kCustomCall). + // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). + // *) Tuple-shaped. + // TODO(b/27458679) Parallelize instructions which are skipped here. + if (instruction->opcode() == HloOpcode::kParameter || + instruction->opcode() == HloOpcode::kConstant || + instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kCustomCall || + instruction->opcode() == HloOpcode::kSelectAndScatter || + (instruction->opcode() == HloOpcode::kConvolution && + PotentiallyImplementedAsEigenConvolution(*instruction)) || + PotentiallyImplementedAsEigenDot(*instruction) || + (instruction->opcode() == HloOpcode::kFusion && + instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || + ShapeUtil::IsTuple(instruction->shape())) { + continue; + } + + // Calculate target parallel task count in [1, max_parallelism_]. + const int64 target_parallel_task_count = GetTargetParallelTaskCount( + cost_status.ok() ? &cost_analysis : nullptr, instruction.get()); + if (target_parallel_task_count == 1) { + continue; + } + + // Assign feasible dimension partitions (based on actual dimension sizes). + auto dim_partition_counts = ShapePartitionAssigner(instruction->shape()) + .Run(target_parallel_task_count); + const int64 total_partition_count = + ShapePartitionAssigner::GetTotalPartitionCount(dim_partition_counts); + if (total_partition_count <= 1) { + // Feasible partition calculation resulting in no partitioning, so skip. + continue; + } + VLOG(2) << "Assigning parallel task count: " << total_partition_count + << " to instruction: " << instruction->name(); + // Map 'instruction' to assigned dimension partitioning. + instruction->set_outer_dimension_partitions(dim_partition_counts); + } + + return changed; +} + +int64 ParallelizationPreparation::GetTargetParallelTaskCount( + const HloCostAnalysis* cost_analysis, HloInstruction* instruction) { + // Default to a simple cost model based on hlo size and typical L2 cache size. + // Note that 'cost_analysis' can be 'nullptr' if HloCostAnalysis returns an + // error status (likely because HLOs like CustomCall are not yet implemented + // in the HloCostAnalysis). + int64 instruction_cost = shape_size_(instruction->shape()); + int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size. + if (cost_analysis != nullptr) { + // Calculate the instruction cost in cycles. + // TODO(29630486) Improve on this linear cost model. + // Consider making 'min_cost_per_thread' be a function of the target + // bandwidth limit for instructions with low arithmetic complexity. + instruction_cost = 1 * cost_analysis->flop_count(*instruction) + + 2 * cost_analysis->transcendental_count(*instruction) + + 10 * cost_analysis->bytes_accessed(*instruction); + // Minimum per-thread cost is 100us of work on a 2GHz core. + min_cost_per_thread = 100000; + } + // Return target parallel task count in [1, max_parallelism_]. + return std::min(max_parallelism_, + std::max(1LL, instruction_cost / min_cost_per_thread)); +} + +bool ParallelizationPreparation::OutlineParallelizableInstruction( + HloInstruction* instruction) { + if (instruction->outer_dimension_partitions().empty()) { + return false; + } + // Store dimension partition counts before outlining (which clones + // 'instruction'). + std::vector dim_partition_counts = + instruction->outer_dimension_partitions(); + // Outline 'instruction' in its own sub-computation. + HloModule* module = instruction->parent()->parent(); + auto* call = module->OutlineExpressionFromComputation( + {instruction}, tensorflow::strings::StrCat("pp_", instruction->name()), + module->entry_computation()); + // Map previously assigned 'dim_partition_counts' to cloned root instruction. + VLOG(1) << "Outlining parallelizable" + << " caller: " << call->name() + << " callee: " << call->to_apply()->root_instruction()->name(); + call->to_apply()->root_instruction()->set_outer_dimension_partitions( + dim_partition_counts); + return true; +} + +bool ParallelizationPreparation::CanOutlineWithUser( + HloInstruction* instruction) { + if (instruction->users().size() != 1) { + // Do not outline 'instruction' with multiple users. + return false; + } + if (AssignedParallelTasks(instruction) || + AssignedParallelTasks(*instruction->users().begin())) { + // Do not outline if 'instruction' (or user) were assigned parallel tasks. + return false; + } + return true; +} + +bool ParallelizationPreparation::AssignedParallelTasks( + HloInstruction* instruction) { + return !instruction->outer_dimension_partitions().empty() || + (instruction->opcode() == HloOpcode::kCall && + !instruction->to_apply() + ->root_instruction() + ->outer_dimension_partitions() + .empty()); +} + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h index 62999f5686d..70e34929342 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_ +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -32,14 +33,51 @@ namespace cpu { // handle While constructs. class ParallelizationPreparation : public HloPassInterface { public: + // 'max_parallelism': the maximum parallel task count per instruction. + // 'shape_size': shape size function used by HloCostAnalysis during parallel + // task assignment. + ParallelizationPreparation( + const int64 max_parallelism, + const HloCostAnalysis::ShapeSizeFunction& shape_size) + : max_parallelism_(max_parallelism), shape_size_(shape_size) {} ~ParallelizationPreparation() override {} + tensorflow::StringPiece name() const override { return "cpu-parallel-prepare"; } - // Run instruction fusion on the given computation. Returns whether the + // Run parallel preparation on the given computation. Returns whether the // computation was changed. StatusOr Run(HloModule* module) override; + + private: + // Assigns parallel task partitions to conformant instructions in 'module'. + // Returns true on success or error status otherwise. + StatusOr RunParallelTaskAssignment(HloModule* module); + + // Returns the target parallel task count for 'instruction'. + // Utilizes 'cost_analysis' if non-null. + // Otherwise defaults to a simple HLO output size-based cost model. + int64 GetTargetParallelTaskCount(const HloCostAnalysis* cost_analysis, + HloInstruction* instruction); + + // Outlines 'instruction' from entry computation, if it had + // been assigned parallel tasks in an earlier pass through the computation. + // Returns true if 'instruction' was succesfully outlined, false otherwise. + bool OutlineParallelizableInstruction(HloInstruction* instruction); + + // Returns true if 'instruction' can be outlined into the same sub-computation + // with its single user (parallelizable instructions are not outlined with + // each other). Returns false otherwise. + bool CanOutlineWithUser(HloInstruction* instruction); + + // Returns true if 'instruction' (or the root of the sub-computation that + // 'instruction' calls) has had parallel tasks assigned in earlier pass. + // Returns false otherwise. + bool AssignedParallelTasks(HloInstruction* instruction); + + const int64 max_parallelism_; + const HloCostAnalysis::ShapeSizeFunction shape_size_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 7ad497ff1a2..fee5fd88305 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -55,6 +55,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" namespace xla { @@ -84,6 +85,12 @@ StatusOr IrEmitter::EmitComputation( std::vector* instruction_order) { string function_name = name_uniquer_.GetUniqueName(function_name_prefix); VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix << "]"; + num_dynamic_loop_bounds_ = 0; + if (!computation->root_instruction()->outer_dimension_partitions().empty()) { + num_dynamic_loop_bounds_ = + computation->root_instruction()->outer_dimension_partitions().size(); + } + InitializeIrFunction(function_name, is_entry_computation); // The rdtscp instruction is x86 specific. We will fallback to LLVM's generic // readcyclecounter if it is unavailable. @@ -112,7 +119,7 @@ void IrEmitter::InitializeIrFunction(const string& function_name, bool is_entry_computation) { // The function signature is: // void function(i8* retval, i8* run_options, i8** params, i8** temps, - // i64* prof_counters) + // i64* dynamic_loop_bounds, i64* prof_counters) // // retval: points to the returned value. // params: address of an array with pointers to parameters. @@ -152,6 +159,10 @@ void IrEmitter::InitializeIrFunction(const string& function_name, // | temp 0 | | temp 1 | | temp N-1 | // \---------/ \---------/ \-----------/ // + // /--------------------------------------------\ + // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....| + // (elided for aot) \--------------------------------------------/ + // // /---------------------------------------------\ // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | // (elided for aot) \---------------------------------------------/ @@ -164,6 +175,9 @@ void IrEmitter::InitializeIrFunction(const string& function_name, llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext()); std::vector compute_function_params( {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); + if (num_dynamic_loop_bounds_ > 0) { + compute_function_params.push_back(i64_ptr_type); + } if (hlo_to_profile_idx_) { compute_function_params.push_back(i64_ptr_type); } @@ -190,6 +204,9 @@ void IrEmitter::InitializeIrFunction(const string& function_name, (++arg_iter)->setName("run_options"); (++arg_iter)->setName("params"); (++arg_iter)->setName("temps"); + if (num_dynamic_loop_bounds_ > 0) { + (++arg_iter)->setName("dynamic_loop_bounds"); + } if (hlo_to_profile_idx_) { (++arg_iter)->setName("prof_counters"); } @@ -242,12 +259,12 @@ Status IrEmitter::HandleConstant(HloInstruction* constant, return Status::OK(); } -Status IrEmitter::HandleCopy(HloInstruction* copy, HloInstruction* operand) { +Status IrEmitter::HandleCopy(HloInstruction* copy) { if (ShapeUtil::IsTuple(copy->shape())) { // kCopy shallow copies a tuple so just memcpy the top-level buffer. TF_ASSIGN_OR_RETURN(llvm::Value * copy_value, EmitTargetAddressForOp(copy)); emitted_value_[copy] = copy_value; - return EmitMemcpy(*operand, *copy); + return EmitMemcpy(*(copy->operand(0)), *copy); } else { // Use the elemental emitter for non-tuple shapes. return DefaultAction(copy); @@ -1039,6 +1056,231 @@ Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { "Cross replica sum not implemented on CPU. See b/33011107."); } +// Fills up the free variables in 'index_with_free_var' with values from +// 'filler_index'. The size of free variables must be the same as the +// size of 'filler_index'. +// +// This is often used after dimension reduction, where +// 'index_with_free_var' has one or more dimensions reduced, which serves as +// free variables (represented as nullptr). For example, if we have a 4 +// dimensional input and index for the dimension being reduced is +// 2 (third dimension), we will have an index like [i, j, NULL, k] +// after reduced dimension. +// +// Here we fill up that free variable by 'filler_index', which contains +// the value in the reduced dimension. +static llvm_ir::IrArray::Index FillReducedDimensionIndex( + llvm_ir::IrArray::Index index_with_free_var, + llvm_ir::IrArray::Index filler_index) { + llvm_ir::IrArray::Index::const_iterator it = filler_index.begin(); + + for (size_t i = 0; i < index_with_free_var.size(); ++i) { + if (index_with_free_var[i] == nullptr) { + index_with_free_var[i] = *it++; + } + } + CHECK(filler_index.end() == it); + return index_with_free_var; +} + +Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { + // The output of BatchNormTraining is a tuple of three element: + // - An N-dimensional array containing normalized values. + // - A 1 dimensional array containing the mean value for each feature. + // - A 1 dimensional array containing the variance value for each feature. + HloInstruction* operand = batch_norm_training->operands()[0]; + HloInstruction* scale = batch_norm_training->operands()[1]; + HloInstruction* offset = batch_norm_training->operands()[2]; + float epsilon = batch_norm_training->epsilon(); + int64 feature_index = batch_norm_training->feature_index(); + TF_RET_CHECK(ShapeUtil::IsTuple(batch_norm_training->shape()) && + ShapeUtil::TupleElementCount(batch_norm_training->shape()) == 3); + + const Shape& output_shape = + ShapeUtil::GetTupleElementShape(batch_norm_training->shape(), 0); + const Shape& feature_shape = + ShapeUtil::GetTupleElementShape(batch_norm_training->shape(), 1); + + // Reduce vector of the non-feature dimensions. + std::vector dimensions_to_reduce; + + for (int64 i = 0; i < operand->shape().dimensions_size(); ++i) { + if (i != feature_index) { + dimensions_to_reduce.push_back(i); + } + } + + // Get the second and third allocations in the output tuple, which should be + // used to store the result of mean and variance value calculation. + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice_mean, + assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{1})); + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice_var, + assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{2})); + const int feature_count = output_shape.dimensions(feature_index); + const int size_in_elements = ShapeUtil::ElementsIn(output_shape); + TF_RET_CHECK(ShapeUtil::ElementsIn(operand->shape()) == size_in_elements); + const int elements_per_feature = size_in_elements / feature_count; + + llvm::Value* mean = EmitTempBufferPointer(slice_mean, feature_shape); + llvm_ir::IrArray mean_array(mean, feature_shape); + + llvm::Value* var = EmitTempBufferPointer(slice_var, feature_shape); + llvm_ir::IrArray var_array(var, feature_shape); + + // This loop calculates mean and variance for each feature. + // + // In theory this could be swapped by multi-output fusion. We will evaluate + // this when it's ready. + // + // For variance calculation, we use a simplified formula so we can fuse the + // computation into the same loop to calculate mean: Var=E(X^2) - E(X)^2. + TF_RETURN_IF_ERROR( + llvm_ir::LoopEmitter( + [this, operand, dimensions_to_reduce, feature_shape, var_array, + elements_per_feature](const llvm_ir::IrArray::Index& index) { + PrimitiveType element_type = operand->shape().element_type(); + // Used to calculate E(X). + llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_), + "sum_address", &ir_builder_, + MinimumAlignmentForPrimitiveType(element_type)); + + // Used to calculate E(X^2). + llvm::Value* sum_square_address = + llvm_ir::EmitAllocaAtFunctionEntry( + llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_), + "sum_square_address", &ir_builder_, + MinimumAlignmentForPrimitiveType(element_type)); + + ir_builder_.CreateStore( + llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), + sum_address); + + ir_builder_.CreateStore( + llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), + sum_square_address); + + llvm_ir::ForLoopNest loops(&ir_builder_); + + const llvm_ir::IrArray::Index reduced_dims_index = + loops.AddLoopsForShapeOnDimensions( + operand->shape(), dimensions_to_reduce, "reduction_dim"); + + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), + &ir_builder_); + + llvm_ir::IrArray operand_array(GetIrArrayForOp(operand)); + llvm_ir::IrArray::Index input_index = + FillReducedDimensionIndex(reduced_dims_index, index); + llvm::Value* new_value = + operand_array.EmitReadArrayElement(input_index, &ir_builder_); + + llvm::Value* new_value_square = + ir_builder_.CreateFMul(new_value, new_value); + + llvm::Value* current_sum = ir_builder_.CreateLoad(sum_address); + llvm::Value* current_sum_square = + ir_builder_.CreateLoad(sum_square_address); + // Update sum. + ir_builder_.CreateStore( + ir_builder_.CreateFAdd(current_sum, new_value), sum_address); + + // Update sum square. + ir_builder_.CreateStore( + ir_builder_.CreateFAdd(current_sum_square, new_value_square), + sum_square_address); + + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), + &ir_builder_); + + llvm::Value* sum = ir_builder_.CreateLoad(sum_address); + llvm::Value* elements_per_feature_value = llvm::ConstantFP::get( + ir_builder_.getFloatTy(), elements_per_feature); + llvm::Value* mean = + ir_builder_.CreateFDiv(sum, elements_per_feature_value); + llvm::Value* mean_square = ir_builder_.CreateFMul(mean, mean); + llvm::Value* sum_square = + ir_builder_.CreateLoad(sum_square_address); + + // Var=E(X^2) - E(X)^2. + llvm::Value* var = ir_builder_.CreateFSub( + ir_builder_.CreateFDiv(sum_square, elements_per_feature_value), + mean_square); + + var_array.EmitWriteArrayElement(index, var, &ir_builder_); + return mean; + }, + mean_array, &ir_builder_) + .EmitLoop()); + + TF_ASSIGN_OR_RETURN(llvm::Value * target_address, + EmitTargetAddressForOp(batch_norm_training)); + + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice, + assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{0})); + + llvm::Value* normalized = EmitTempBufferPointer(slice, output_shape); + + llvm_ir::IrArray target_array(normalized, output_shape); + + AddAliasingInformationToIrArray(*batch_norm_training, &target_array); + + TF_RETURN_IF_ERROR( + llvm_ir::LoopEmitter( + [this, mean_array, var_array, epsilon, operand, dimensions_to_reduce, + feature_index, offset, scale](const llvm_ir::IrArray::Index& index) { + // The following logic normalizes the input value, scales and shifts + // it: + // + // normalized = (input - mean) / sqrt(variance + epsilon) + // result = normalized * scale + offset + + // Current index in the feature dimension. + llvm_ir::IrArray::Index feature_index_value(1, + index[feature_index]); + + llvm::Value* mean = mean_array.EmitReadArrayElement( + feature_index_value, &ir_builder_); + llvm::Value* var = var_array.EmitReadArrayElement( + feature_index_value, &ir_builder_); + + llvm_ir::IrArray operand_array(GetIrArrayForOp(operand)); + llvm::Value* input = + operand_array.EmitReadArrayElement(index, &ir_builder_); + + llvm::Value* variance_with_epsilon = ir_builder_.CreateFAdd( + var, llvm::ConstantFP::get(ir_builder_.getFloatTy(), epsilon)); + llvm::Function* func_llvm_sqrt = llvm::Intrinsic::getDeclaration( + module_, llvm::Intrinsic::sqrt, {ir_builder_.getFloatTy()}); + llvm::Value* variance_sqrt = + ir_builder_.CreateCall(func_llvm_sqrt, {variance_with_epsilon}); + llvm::Value* normalized = ir_builder_.CreateFDiv( + ir_builder_.CreateFSub(input, mean), variance_sqrt); + llvm_ir::IrArray offset_array(GetIrArrayForOp(offset)); + llvm::Value* offset = offset_array.EmitReadArrayElement( + feature_index_value, &ir_builder_); + llvm_ir::IrArray scale_array(GetIrArrayForOp(scale)); + llvm::Value* scale = scale_array.EmitReadArrayElement( + feature_index_value, &ir_builder_); + llvm::Value* result = ir_builder_.CreateFAdd( + ir_builder_.CreateFMul(normalized, scale), offset); + + return result; + }, + target_array, &ir_builder_) + .EmitLoop()); + + llvm_ir::EmitTuple( + llvm_ir::IrArray(target_address, batch_norm_training->shape()), + {normalized, mean, var}, &ir_builder_); + emitted_value_[batch_norm_training] = target_address; + + return Status::OK(); +} + Status IrEmitter::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); auto param_number = parameter->parameter_number(); @@ -1606,13 +1848,24 @@ llvm::Argument* IrEmitter::GetResultArgument() { } llvm::Argument* IrEmitter::GetProfileCountersArgument() { - return hlo_to_profile_idx_ ? GetArg(compute_function_, 4) : nullptr; + const int64 arg_index = num_dynamic_loop_bounds_ > 0 ? 5 : 4; + return hlo_to_profile_idx_ ? GetArg(compute_function_, arg_index) : nullptr; } llvm::Value* IrEmitter::GetTempBuffersArgument() { return GetArg(compute_function_, 3); } +llvm::Value* IrEmitter::GetDynamicLoopBound(const int64 offset) { + CHECK_GT(num_dynamic_loop_bounds_, 0); + CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); + llvm::Argument* loop_bounds_arg = GetArg(compute_function_, 4); + string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); + return ir_builder_.CreateLoad( + ir_builder_.CreateGEP(loop_bounds_arg, ir_builder_.getInt64(offset), + llvm_ir::AsStringRef(name))); +} + llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { return GetArg(compute_function_, 1); } @@ -1745,7 +1998,7 @@ StatusOr IrEmitter::EmitTargetAddressForOp( // For the root node, we write directly to the output buffer of the // function. llvm::Argument* retval = GetResultArgument(); - if (!ShapeUtil::HasZeroElements(target_shape)) { + if (!ShapeUtil::IsNil(target_shape)) { llvm::AttrBuilder attr_builder; attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); @@ -1776,13 +2029,76 @@ Status IrEmitter::EmitTargetElementLoop( llvm_ir::IrArray target_array(target_address, target_shape); AddAliasingInformationToIrArray(*target_op, &target_array); - TF_RETURN_IF_ERROR( - llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_) - .EmitLoop()); + if (num_dynamic_loop_bounds_ > 0 && + target_op == target_op->parent()->root_instruction()) { + // Emit parallel loop for root instruction if dynamic outer-dimension loop + // bounds were specified. + TF_RETURN_IF_ERROR(EmitParallelTargetElementLoop( + target_shape, element_generator, &target_array)); + } else { + TF_RETURN_IF_ERROR( + llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_) + .EmitLoop()); + } + emitted_value_[target_op] = target_address; return Status::OK(); } +Status IrEmitter::EmitParallelTargetElementLoop( + const Shape& target_shape, + const llvm_ir::ElementGenerator& element_generator, + llvm_ir::IrArray* target_array) { + CHECK(!ShapeUtil::IsTuple(target_shape)); + CHECK(!ShapeUtil::IsScalar(target_shape)); + + // Emit code to read dynamic loop bounds from function argument 4. + std::vector dynamic_loop_bounds(2 * num_dynamic_loop_bounds_); + for (int i = 0; i < 2 * num_dynamic_loop_bounds_; ++i) { + dynamic_loop_bounds[i] = GetDynamicLoopBound(i); + } + + llvm_ir::ForLoopNest loop_nest(&ir_builder_); + const int64 num_dims = target_shape.dimensions_size(); + llvm_ir::IrArray::Index array_index(num_dims); + + // Add loops from outer-most to inner-most dimensions. + for (int i = target_shape.layout().minor_to_major_size() - 1; i >= 0; --i) { + const int64 dimension = target_shape.layout().minor_to_major(i); + const int bounds_index = num_dims - 1 - i; + if (bounds_index < num_dynamic_loop_bounds_) { + // Emit dynamic loop bounds for this dimension. Dynamic loop bounds + // are read from ir function dynamic loop bounds argument. + llvm::Value* start_index = dynamic_loop_bounds[bounds_index * 2 + 0]; + llvm::Value* end_index = dynamic_loop_bounds[bounds_index * 2 + 1]; + + std::unique_ptr loop = loop_nest.AddLoop( + /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension), + start_index, end_index); + array_index[dimension] = loop->GetIndVarValue(); + } else { + // Emit static loop bounds for this dimension. + std::unique_ptr loop = loop_nest.AddLoop( + /*start_index=*/0, + /*end_index=*/target_shape.dimensions(dimension), + /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); + array_index[dimension] = loop->GetIndVarValue(); + } + } + // Point IR builder at inner loop BB. + SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), &ir_builder_); + + // Emit loop body. + TF_ASSIGN_OR_RETURN(llvm::Value * target_element, + element_generator(array_index)); + target_array->EmitWriteArrayElement(array_index, target_element, + &ir_builder_); + // Point IR builder at outer loop exit BB. + SetToFirstInsertPoint(loop_nest.GetOuterLoopExitBasicBlock(), &ir_builder_); + + return Status::OK(); +} + Status IrEmitter::EmitMemcpy(const HloInstruction& source, const HloInstruction& destination) { llvm::Value* source_value = GetEmittedValueFor(&source); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index ebb7296a075..a1b7bd9e6dc 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -96,7 +96,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleBitcast(HloInstruction* bitcast) override; Status HandleConstant(HloInstruction* constant, const Literal& literal) override; - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleCopy(HloInstruction* copy) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element, HloInstruction* operand) override; Status HandleSelect(HloInstruction* select, HloInstruction* pred, @@ -106,6 +106,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* rhs) override; Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, const Window& window) override; + Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleOutfeed(HloInstruction* infeed) override; @@ -192,6 +193,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { // computation function being emitted by this emitter. llvm::Value* GetTempBuffersArgument(); + // Emit ir to read and return the ir value for the dynamic loop bound at + // 'offset' from the "dynamic_loop_bounds" argument of the computation + // function being emitted by this emitter. + llvm::Value* GetDynamicLoopBound(const int64 offset); + // Emits code that computes the address of the given temporary buffer to the // function. target_shape is the shape of this temporary buffer. // The returned Value's type is a pointer to element_type. @@ -262,6 +268,15 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* target_op, const llvm_ir::ElementGenerator& element_generator); + // Emit IR to perform a computation for every element in a partition/slice of + // 'target_shape'. The loop bounds for the outer-dimension partitions are + // passed into the compute function as a runtime argument (accessible from + // GetDynamicLoopBound). + Status EmitParallelTargetElementLoop( + const Shape& target_shape, + const llvm_ir::ElementGenerator& element_generator, + llvm_ir::IrArray* target_array); + // Emits a memcpy from the source instruction's result value to the // destination's. Both source and destination must have an entry in the // emitted_value_ table. @@ -319,6 +334,10 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm_ir::AliasAnalysis alias_analysis_; + // The number of root instruction outer dimensions used in parallel loop + // emission (EmitParallelTargetElementLoop). + int64 num_dynamic_loop_bounds_ = 0; + // This struct contains all the state needed to emit instructions for // profiling a computation. class ProfilingState { diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc index bdddca99c2f..19909f4bed8 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -70,7 +71,7 @@ ParallelCpuExecutable::ParallelCpuExecutable( // Type of the computation function we expect in the JIT. using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, - uint64*); + int64*, uint64*); // Given a pointer to an output buffer (following the CPU JIT calling // conventions), mark addresses that are "live". The initial pointer itself is @@ -95,6 +96,232 @@ static void MarkLiveAddressesInOutput( } } +namespace { + +// Executor manages the concurrent execution of 'functions' for instructions +// in 'pending' on 'thread_pool' (storing resulting data in 'results'). +class Executor { + public: + Executor(const std::map& functions, + const ServiceExecutableRunOptions* run_options, + std::list* pending, + std::map* results, void** temps_array, + uint64* profile_counters_array, BufferAssignment* assignment) + : functions_(functions), + run_options_(run_options), + pending_(pending), + results_(results), + temps_array_(temps_array), + profile_counters_array_(profile_counters_array), + thread_pool_(CHECK_NOTNULL(run_options_->xla_intra_op_thread_pool())), + assignment_(assignment) {} + + // Executes pending list of instructions on thread pool. + // Returns OK status on success, error status otherwise. + Status Run(); + + private: + // Schedules a parallel invocation of compute function for 'instruction' on + // 'thread_pool_', storing result in 'result_buffer'. + // If 'partition_buffers' is non-null, parallel task will be invoked on + // per-dimension partition [start, limit) values stored in + // 'partition_buffers'. + void Schedule(HloInstruction* instruction, int64* partition_buffers, + void* result_buffer); + + // Returns true if 'instruction' has been assigned parallel tasks (returns + // false otherwise). + bool HasParallelTasks(HloInstruction* instruction); + + // Returns in 'partition_buffers' the partition [size, limit) for each + // dimension. + int64* GetPartitionBuffers( + const std::vector>& partition); + + // Returns array of result buffers for all operands in 'instruction'. + const void** GetOperandBuffers(HloInstruction* instruction); + + // Arguments passed into Executor. + const std::map& functions_; + const ServiceExecutableRunOptions* run_options_; + std::list* pending_; + std::map* results_; + void** temps_array_; + uint64* profile_counters_array_; + tensorflow::thread::ThreadPool* thread_pool_; + BufferAssignment* assignment_; + + // Members used to manage instruction execution. + tensorflow::mutex completion_queue_lock_; + tensorflow::condition_variable completion_queue_cv_; + std::deque completion_queue_; + int64 instructions_in_flight_ = 0; + std::unordered_map tasks_in_flight_; +}; + +Status Executor::Run() { + while (!pending_->empty() || instructions_in_flight_ > 0) { + auto pending_it = pending_->begin(); + while (pending_it != pending_->end()) { + HloInstruction* instruction = *pending_it; + // Skip pending instructions whose operands aren't ready. + if (std::any_of(instruction->operands().begin(), + instruction->operands().end(), + [&](HloInstruction* operand) { + return !ContainsKey(*results_, operand); + })) { + ++pending_it; + continue; + } + + // Get 'result_buffer' reference to result buffer for 'instruction'. + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, + assignment_->GetUniqueTopLevelSlice(instruction)); + void* result_buffer = + static_cast(temps_array_[result_slice.index()]) + + result_slice.offset(); + + if (HasParallelTasks(instruction)) { + // 'instruction' has been assigned parallel task partitions. + CHECK_EQ(HloOpcode::kCall, instruction->opcode()); + HloInstruction* root = instruction->to_apply()->root_instruction(); + + // Create ShapePartitionIterator to iterate through all outer dimension + // partitions of 'instruction'. + ShapePartitionIterator partition_iterator( + root->shape(), root->outer_dimension_partitions()); + + const int64 partition_count = + partition_iterator.GetTotalPartitionCount(); + + // Record total parallel task count for 'instruction' before dispatch. + { + tensorflow::mutex_lock l(completion_queue_lock_); + tasks_in_flight_.insert(std::make_pair(instruction, partition_count)); + VLOG(2) << "Schedule PARALLEL" + << " instruction: " << instruction->name() + << " instruction.callee: " + << instruction->to_apply()->root_instruction()->name() + << " partition_count: " << partition_count; + } + + for (int64 i = 0; i < partition_count; ++i) { + // Get partition [start, limit) for each dimension. + auto partition_buffers = + GetPartitionBuffers(partition_iterator.GetPartition(i)); + Schedule(instruction, partition_buffers, result_buffer); + } + + } else { + // Set tasks in-flight to '1' for sequential instruction execution. + { + tensorflow::mutex_lock l(completion_queue_lock_); + tasks_in_flight_.insert(std::make_pair(instruction, 1)); + VLOG(2) << "Schedule SEQUENTIAL" + << " instruction: " << instruction->name() + << " instruction.callee: " + << instruction->to_apply()->root_instruction()->name(); + } + Schedule(instruction, nullptr, result_buffer); + } + + ++instructions_in_flight_; + pending_it = pending_->erase(pending_it); + } + // Wait for a completed HLO instruction to be present in the queue. We will + // pop it out of the queue and make the result available to its users. + HloInstruction* instruction; + do { + tensorflow::mutex_lock l(completion_queue_lock_); + if (completion_queue_.empty()) { + completion_queue_cv_.wait(l); + } + if (!completion_queue_.empty()) { + instruction = completion_queue_.front(); + completion_queue_.pop_front(); + break; + } + } while (1); + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, + assignment_->GetUniqueTopLevelSlice(instruction)); + void* result_buffer = + static_cast(temps_array_[result_slice.index()]) + + result_slice.offset(); + InsertOrDie(results_, instruction, result_buffer); + --instructions_in_flight_; + } + return Status::OK(); +} + +void Executor::Schedule(HloInstruction* instruction, int64* partition_buffers, + void* result_buffer) { + // The thread pool entry takes ownership of |operand_buffers|. + auto operand_buffers = GetOperandBuffers(instruction); + + auto function = FindOrDie(functions_, instruction); + const auto* exec_run_options = &run_options_->run_options(); + thread_pool_->Schedule([this, instruction, result_buffer, operand_buffers, + partition_buffers, exec_run_options, function]() { + function(result_buffer, exec_run_options, operand_buffers, temps_array_, + partition_buffers, profile_counters_array_); + + delete[] operand_buffers; + delete[] partition_buffers; + // Push the completed HLO instruction on the queue, the main + // thread will pop it off and potentially launch more work which + // uses the result. + // TODO(b/27458679) Consider alternative task scheduling and synchronization + // schemes. For example, we could avoid the overhead associate with the + // condvar here if the thread just dequed the next instruction to execute + // on completion. + { + tensorflow::mutex_lock l(completion_queue_lock_); + // Decrement in-flight task count for this completion. + if (--FindOrDie(tasks_in_flight_, instruction) == 0) { + completion_queue_.push_back(instruction); + completion_queue_cv_.notify_all(); + tasks_in_flight_.erase(instruction); + } + } + }); +} + +int64* Executor::GetPartitionBuffers( + const std::vector>& partition) { + // Return in 'partition_buffers' partition [size, limit) for each dimension. + auto partition_buffers = new int64[partition.size() * 2]; + for (int i = 0; i < partition.size(); ++i) { + partition_buffers[2 * i + 0] = partition[i].first; + partition_buffers[2 * i + 1] = partition[i].first + partition[i].second; + } + return partition_buffers; +} + +bool Executor::HasParallelTasks(HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kCall && + !instruction->to_apply() + ->root_instruction() + ->outer_dimension_partitions() + .empty(); +} + +const void** Executor::GetOperandBuffers(HloInstruction* instruction) { + // We cannot use a move-only RAII type like std::unique_ptr because the + // list of operands is allocated on the main thread and transferred to the + // worker via the lambda passed to enqueue_function. In order for the + // lambda to take ownership, we would need to use generalized lambda + // capture which is a feature new to C++14. + // TODO(b/27458679) Avoid dynamic allocations in Executor. + auto operand_buffers = new const void*[instruction->operand_count()]; + std::transform(instruction->operands().begin(), instruction->operands().end(), + operand_buffers, [this](HloInstruction* operand) { + return FindOrDie(*results_, operand); + }); + return operand_buffers; +} + +} // namespace + Status ParallelCpuExecutable::AllocateBuffers( DeviceMemoryAllocator* memory_allocator, int device_ordinal, std::vector* buffers) { @@ -210,88 +437,16 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( } } - void** temps_array = buffer_pointers.data(); - uint64* profile_counters_array = profile_counters.data(); - auto* thread_pool = CHECK_NOTNULL(run_options->xla_intra_op_thread_pool()); - tensorflow::mutex completion_queue_lock; - tensorflow::condition_variable completion_queue_cv; - std::deque completion_queue; - int64 instructions_in_flight = 0; - while (!pending.empty() || instructions_in_flight > 0) { - auto pending_it = pending.begin(); - while (pending_it != pending.end()) { - HloInstruction* instruction = *pending_it; - // Skip pending instructions whose operands aren't ready. - if (std::any_of(instruction->operands().begin(), - instruction->operands().end(), - [&](HloInstruction* operand) { - return !ContainsKey(results, operand); - })) { - ++pending_it; - continue; - } + // TODO(b/27458679) Manage scheduling based on in-flight concurrency limits. + // For example, if we expect a library conv/matmul call to run at max + // concurrency, we should not dispatch runnable instructions until the + // libary call is finished (to avoid expensive cache invalidation). + Executor executor(functions, run_options, &pending, &results, + buffer_pointers.data(), profile_counters.data(), + assignment_.get()); - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelSlice(instruction)); - void* result_buffer = - static_cast(temps_array[result_slice.index()]) + - result_slice.offset(); - // We cannot use a move-only RAII type like std::unique_ptr because the - // list of operands is allocated on the main thread and transferred to the - // worker via the lambda passed to enqueue_function. In order for the - // lambda to take ownership, we would need to use generalized lambda - // capture which is a feature new to C++14. - auto operand_buffers = new const void*[instruction->operand_count()]; - std::transform(instruction->operands().begin(), - instruction->operands().end(), operand_buffers, - [&results](HloInstruction* operand) { - return FindOrDie(results, operand); - }); - auto function = FindOrDie(functions, instruction); - // The thread pool entry takes ownership of |operand_buffers|. - const auto* exec_run_options = &run_options->run_options(); - thread_pool->Schedule([instruction, &completion_queue, - &completion_queue_lock, &completion_queue_cv, - result_buffer, exec_run_options, operand_buffers, - temps_array, profile_counters_array, function] { - function(result_buffer, exec_run_options, operand_buffers, temps_array, - profile_counters_array); - delete[] operand_buffers; - // Push the completed HLO instruction on the queue, the main thread - // will pop it off and potentially launch more work which uses the - // result. - { - tensorflow::mutex_lock l(completion_queue_lock); - completion_queue.push_back(instruction); - completion_queue_cv.notify_all(); - } - }); + TF_RETURN_IF_ERROR(executor.Run()); - ++instructions_in_flight; - pending_it = pending.erase(pending_it); - } - // Wait for a completed HLO instruction to be present in the queue. We will - // pop it out of the queue and make the result available to its users. - HloInstruction* instruction; - do { - tensorflow::mutex_lock l(completion_queue_lock); - if (completion_queue.empty()) { - completion_queue_cv.wait(l); - } - if (!completion_queue.empty()) { - instruction = completion_queue.front(); - completion_queue.pop_front(); - break; - } - } while (1); - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelSlice(instruction)); - void* result_buffer = - static_cast(temps_array[result_slice.index()]) + - result_slice.offset(); - InsertOrDie(&results, instruction, result_buffer); - --instructions_in_flight; - } uint64 end_micros = tensorflow::Env::Default()->NowMicros(); { diff --git a/tensorflow/compiler/xla/service/cpu/sample_harness.cc b/tensorflow/compiler/xla/service/cpu/sample_harness.cc index 8f1ce82d49a..b3f4609d465 100644 --- a/tensorflow/compiler/xla/service/cpu/sample_harness.cc +++ b/tensorflow/compiler/xla/service/cpu/sample_harness.cc @@ -38,13 +38,12 @@ int main(int argc, char** argv) { // Transfer parameters. std::unique_ptr param0_literal = - xla::LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + xla::Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = client->TransferToServer(*param0_literal).ConsumeValueOrDie(); - std::unique_ptr param1_literal = - xla::LiteralUtil::CreateR2( - {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); + std::unique_ptr param1_literal = xla::Literal::CreateR2( + {{3.1f, 4.2f, 7.3f, 9.5f}, {1.1f, 2.2f, 3.3f, 4.4f}}); std::unique_ptr param1_data = client->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -69,7 +68,7 @@ int main(int argc, char** argv) { LOG(INFO) << tensorflow::strings::Printf("computation took %lldns", profile.compute_time_ns()); - LOG(INFO) << xla::LiteralUtil::ToString(*actual); + LOG(INFO) << actual->ToString(); return 0; } diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.cc b/tensorflow/compiler/xla/service/cpu/shape_partition.cc new file mode 100644 index 00000000000..e27ff13edd6 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/shape_partition.cc @@ -0,0 +1,160 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" + +namespace xla { +namespace cpu { + +std::vector ShapePartitionAssigner::Run(int64 target_partition_count) { + // Gather outer-most dims where dim_size >= 'target_partition_count'. + // Note: always leave inner-dim static for vectorization/optimzations. + std::vector outer_dims; + int64 outer_dim_size = 1; + // TODO(b/27458679) Consider reserving enough minor dimensions (based on + // target vector register width) to enable vector instructions. + for (int i = shape_.layout().minor_to_major_size() - 1; i >= 1; --i) { + const int64 dimension = shape_.layout().minor_to_major(i); + outer_dims.push_back(dimension); + outer_dim_size *= shape_.dimensions(dimension); + if (outer_dim_size >= target_partition_count) { + break; + } + } + + // Clip target partition count if outer dim size is insufficient to cover. + target_partition_count = std::min(outer_dim_size, target_partition_count); + + // Calculate the target number of partitions per-dimension, by factoring + // 'target_partition_count' into 'num_outer_dims' equal terms. + // EX: + // *) target_partition_count = 16 + // *) out_dim_count = 2 + // *) target_dim_partition_count = 16 ^ (1.0 / 2) == 4 + const int64 target_dim_partition_count = std::pow( + static_cast(target_partition_count), 1.0 / outer_dims.size()); + + // Assign feasible dimension partitions based on 'target_dim_partition_count' + // and actual dimension sizes from 'shape_'. + std::vector dimension_partition_counts(outer_dims.size()); + for (int64 i = 0; i < outer_dims.size(); ++i) { + dimension_partition_counts[i] = + std::min(static_cast(shape_.dimensions(outer_dims[i])), + target_dim_partition_count); + } + + // Check if total partition count is below 'target_partition_count'. + // This can occur if some dimensions in 'shape_' are below the + // 'target_dim_partition_count' threshold. + if (GetTotalPartitionCount(dimension_partition_counts) < + target_partition_count) { + // Assign additional partitions (greedily to outer dimensions), if doing + // so would keep the total number of partitions <= 'target_partition_count', + // using one pass over 'dimension_partition_counts'. + for (int64 i = 0; i < dimension_partition_counts.size(); ++i) { + const int64 current_dim_partition_count = dimension_partition_counts[i]; + const int64 other_dims_partition_count = + GetTotalPartitionCount(dimension_partition_counts) / + current_dim_partition_count; + // Constraint: (current + additional) * other <= target + // Calculate: additional = target / other - current + int64 additional_partition_count = + target_partition_count / other_dims_partition_count - + current_dim_partition_count; + // Clip 'additional_partition_count' by current dimension size. + additional_partition_count = std::min( + shape_.dimensions(outer_dims[i]) - dimension_partition_counts[i], + additional_partition_count); + if (additional_partition_count > 0) { + dimension_partition_counts[i] += additional_partition_count; + } + } + } + + return dimension_partition_counts; +} + +int64 ShapePartitionAssigner::GetTotalPartitionCount( + const std::vector& dimension_partition_counts) { + int64 total_partition_count = 1; + for (int64 dim_partition_count : dimension_partition_counts) { + total_partition_count *= dim_partition_count; + } + return total_partition_count; +} + +ShapePartitionIterator::ShapePartitionIterator( + const Shape& shape, const std::vector& dimension_partition_counts) + : shape_(shape), + dimension_partition_counts_(dimension_partition_counts), + dimensions_(dimension_partition_counts_.size()), + dimension_partition_sizes_(dimension_partition_counts_.size()), + dimension_partition_strides_(dimension_partition_counts_.size()) { + // Store partitioned outer dimensions from 'shape_'. + for (int i = 0; i < dimensions_.size(); ++i) { + dimensions_[i] = shape_.layout().minor_to_major( + shape_.layout().minor_to_major_size() - 1 - i); + } + + // Calculate partition size for each dimension (note that the size of + // the last partition in each dimension may be different if the dimension + // size is not a multiple of partition size). + for (int i = 0; i < dimension_partition_sizes_.size(); ++i) { + const int64 dim_size = shape_.dimensions(dimensions_[i]); + dimension_partition_sizes_[i] = + std::max(1LL, dim_size / dimension_partition_counts_[i]); + } + + // Calculate the partition strides for each dimension. + dimension_partition_strides_[dimension_partition_strides_.size() - 1] = 1; + for (int i = dimension_partition_strides_.size() - 2; i >= 0; --i) { + dimension_partition_strides_[i] = dimension_partition_strides_[i + 1] * + dimension_partition_counts_[i + 1]; + } +} + +std::vector> ShapePartitionIterator::GetPartition( + int64 index) const { + // Calculate and return the partition for 'index'. + // Returns for each dimension: (partition_start, partition_size). + std::vector> partition(dimensions_.size()); + for (int64 i = 0; i < partition.size(); ++i) { + // Calculate the index for dimension 'i'. + const int64 partition_index = index / dimension_partition_strides_[i]; + // Calculate dimension partition start at 'partition_index'. + partition[i].first = partition_index * dimension_partition_sizes_[i]; + // Calculate dimension partition size (note that the last partition size + // may be adjusted if dimension size is not a multiple of partition size). + if (partition_index == dimension_partition_counts_[i] - 1) { + // Last partition in this dimension. + partition[i].second = + shape_.dimensions(dimensions_[i]) - partition[i].first; + } else { + partition[i].second = dimension_partition_sizes_[i]; + } + CHECK_GT(partition[i].second, 0); + // Update index to remove conribution from current dimension. + index -= partition_index * dimension_partition_strides_[i]; + } + return partition; +} + +int64 ShapePartitionIterator::GetTotalPartitionCount() const { + return ShapePartitionAssigner::GetTotalPartitionCount( + dimension_partition_counts_); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.h b/tensorflow/compiler/xla/service/cpu/shape_partition.h new file mode 100644 index 00000000000..bdbcb874c1d --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/shape_partition.h @@ -0,0 +1,105 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ + +#include + +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { +namespace cpu { + +// ShapePartitionAssigner partitions the most-major dimensions of 'shape' such +// that the total partition count <= 'target_partition_count'. +// +// Example 1: +// +// Let 'shape' = [8, 16, 32] and 'target_partition_count' = 6. +// +// Because the most-major dimension size is <= 'target_partition_count', we +// can generate our target number of partitions by partition the most-major +// dimensions. +// +// This will result in the following partitions of the most-major dimension: +// +// [0, 1), [1, 2), [2, 3), [3, 4), [4, 5) [5, 8) +// +// Note that the last parition has residule because the dimension size is +// not a multiple of the partition count. +// +// +// Example 2: +// +// Let 'shape' = [8, 16, 32] and 'target_partition_count' = 16. +// +// Because the most-major dimension only has size 8, we must also partition +// the next most-major dimension to generate the target of 16 partitions. +// We factor 'target_partition_count' by the number of most-major dimensions +// we need to partition, to get a per-dimension target partition count: +// +// target_dimension_partition_count = 16 ^ (1 / 2) == 4 +// +// This will result in the following partitions of the most-major dimension: +// +// [0, 2), [2, 4), [4, 6), [6, 8) +// +// This will result in the following partitions of the second most-major +// dimension: +// +// [0, 4), [4, 8), [8, 12), [12, 16) +// +class ShapePartitionAssigner { + public: + ShapePartitionAssigner(const Shape& shape) : shape_(shape) {} + + // Returns dimension partition counts (starting at outer-most dimension). + std::vector Run(int64 target_partition_count); + + // Returns the total partition count based on 'dimension_partition_counts'. + static int64 GetTotalPartitionCount( + const std::vector& dimension_partition_counts); + + private: + const Shape& shape_; +}; + +// ShapePartitionIterator iterates through outer-dimension partitions of +// 'shape' as specified by 'dimension_partition_counts'. +class ShapePartitionIterator { + public: + ShapePartitionIterator(const Shape& shape, + const std::vector& dimension_partition_counts); + + // Returns a partition [start, size] for each dimension. + // Partitions are listed starting from outer-most dimension first. + std::vector> GetPartition(int64 index) const; + + int64 GetTotalPartitionCount() const; + + private: + const Shape& shape_; + const std::vector dimension_partition_counts_; + + std::vector dimensions_; + std::vector dimension_partition_sizes_; + std::vector dimension_partition_strides_; +}; + +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc new file mode 100644 index 00000000000..6cc6e3fe85b --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/shape_partition_test.cc @@ -0,0 +1,248 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" + +#include +#include + +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace cpu { +namespace { + +class ShapePartitionAssignerTest : public HloTestBase { + protected: + typedef std::vector Vec; + + void RunR2Test(const Shape& shape, const int64 expected_max_partition_count) { + ShapePartitionAssigner assigner(shape); + // Check all partitions of outer dimension. + for (int64 i = 1; i <= expected_max_partition_count; ++i) { + EXPECT_TRUE(ContainersEqual(Vec({i}), + assigner.Run(/*target_partition_count=*/i))); + } + // Check target_partition_count > outer dimension size. + EXPECT_TRUE(ContainersEqual( + Vec({expected_max_partition_count}), + assigner.Run( + /*target_partition_count=*/expected_max_partition_count + 1))); + } +}; + +TEST_F(ShapePartitionAssignerTest, Shape13WithLayout10) { + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {1, 3}, {1, 0}), 1); +} + +TEST_F(ShapePartitionAssignerTest, Shape31WithLayout01) { + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {3, 1}, {0, 1}), 1); +} + +TEST_F(ShapePartitionAssignerTest, Shape53WithLayout10) { + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {1, 0}), 5); +} + +TEST_F(ShapePartitionAssignerTest, Shape53WithLayout01) { + RunR2Test(ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {0, 1}), 3); +} + +TEST_F(ShapePartitionAssignerTest, Shape532WithLayout210) { + Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 1, 0}); + ShapePartitionAssigner assigner(shape); + + for (int64 i = 1; i <= 5; ++i) { + EXPECT_TRUE(ContainersEqual(Vec({i}), assigner.Run( + /*target_partition_count=*/i))); + } + + EXPECT_TRUE( + ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/6))); + EXPECT_TRUE( + ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/7))); + EXPECT_TRUE( + ContainersEqual(Vec({4, 2}), assigner.Run(/*target_partition_count=*/8))); + EXPECT_TRUE( + ContainersEqual(Vec({3, 3}), assigner.Run(/*target_partition_count=*/9))); + EXPECT_TRUE(ContainersEqual(Vec({3, 3}), + assigner.Run(/*target_partition_count=*/10))); + EXPECT_TRUE(ContainersEqual(Vec({3, 3}), + assigner.Run(/*target_partition_count=*/11))); + EXPECT_TRUE(ContainersEqual(Vec({4, 3}), + assigner.Run(/*target_partition_count=*/12))); + EXPECT_TRUE(ContainersEqual(Vec({4, 3}), + assigner.Run(/*target_partition_count=*/13))); + EXPECT_TRUE(ContainersEqual(Vec({4, 3}), + assigner.Run(/*target_partition_count=*/14))); + EXPECT_TRUE(ContainersEqual(Vec({5, 3}), + assigner.Run(/*target_partition_count=*/15))); + EXPECT_TRUE(ContainersEqual(Vec({5, 3}), + assigner.Run(/*target_partition_count=*/16))); +} + +TEST_F(ShapePartitionAssignerTest, Shape532WithLayout201) { + Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 0, 1}); + ShapePartitionAssigner assigner(shape); + + for (int64 i = 1; i <= 3; ++i) { + EXPECT_TRUE(ContainersEqual(Vec({i}), assigner.Run( + /*target_partition_count=*/i))); + } + + EXPECT_TRUE( + ContainersEqual(Vec({2, 2}), assigner.Run(/*target_partition_count=*/4))); + EXPECT_TRUE( + ContainersEqual(Vec({2, 2}), assigner.Run(/*target_partition_count=*/5))); + EXPECT_TRUE( + ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/6))); + EXPECT_TRUE( + ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/7))); + EXPECT_TRUE( + ContainersEqual(Vec({3, 2}), assigner.Run(/*target_partition_count=*/8))); + EXPECT_TRUE( + ContainersEqual(Vec({3, 3}), assigner.Run(/*target_partition_count=*/9))); + EXPECT_TRUE(ContainersEqual(Vec({3, 3}), + assigner.Run(/*target_partition_count=*/10))); + EXPECT_TRUE(ContainersEqual(Vec({3, 3}), + assigner.Run(/*target_partition_count=*/11))); + EXPECT_TRUE(ContainersEqual(Vec({3, 4}), + assigner.Run(/*target_partition_count=*/12))); + EXPECT_TRUE(ContainersEqual(Vec({3, 4}), + assigner.Run(/*target_partition_count=*/13))); + EXPECT_TRUE(ContainersEqual(Vec({3, 4}), + assigner.Run(/*target_partition_count=*/14))); + EXPECT_TRUE(ContainersEqual(Vec({3, 5}), + assigner.Run(/*target_partition_count=*/15))); + EXPECT_TRUE(ContainersEqual(Vec({3, 5}), + assigner.Run(/*target_partition_count=*/16))); +} + +class ShapePartitionIteratorTest : public HloTestBase { + protected: + typedef std::vector> Partition; +}; + +TEST_F(ShapePartitionIteratorTest, Shape53WithLayout10) { + Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3}, {1, 0}); + + { + ShapePartitionIterator iterator(shape, {1}); + EXPECT_EQ(1, iterator.GetTotalPartitionCount()); + EXPECT_TRUE(ContainersEqual(Partition({{0, 5}}), iterator.GetPartition(0))); + } + + { + ShapePartitionIterator iterator(shape, {2}); + EXPECT_EQ(2, iterator.GetTotalPartitionCount()); + EXPECT_TRUE(ContainersEqual(Partition({{0, 2}}), iterator.GetPartition(0))); + EXPECT_TRUE(ContainersEqual(Partition({{2, 3}}), iterator.GetPartition(1))); + } + + { + ShapePartitionIterator iterator(shape, {3}); + EXPECT_EQ(3, iterator.GetTotalPartitionCount()); + EXPECT_TRUE(ContainersEqual(Partition({{0, 1}}), iterator.GetPartition(0))); + EXPECT_TRUE(ContainersEqual(Partition({{1, 1}}), iterator.GetPartition(1))); + EXPECT_TRUE(ContainersEqual(Partition({{2, 3}}), iterator.GetPartition(2))); + } +} + +TEST_F(ShapePartitionIteratorTest, Shape532WithLayout210) { + Shape shape = ShapeUtil::MakeShapeWithLayout(F32, {5, 3, 2}, {2, 1, 0}); + + { + ShapePartitionIterator iterator(shape, {1, 1}); + EXPECT_EQ(1, iterator.GetTotalPartitionCount()); + EXPECT_TRUE( + ContainersEqual(Partition({{0, 5}, {0, 3}}), iterator.GetPartition(0))); + } + + { + ShapePartitionIterator iterator(shape, {2, 2}); + EXPECT_EQ(4, iterator.GetTotalPartitionCount()); + EXPECT_TRUE( + ContainersEqual(Partition({{0, 2}, {0, 1}}), iterator.GetPartition(0))); + EXPECT_TRUE( + ContainersEqual(Partition({{0, 2}, {1, 2}}), iterator.GetPartition(1))); + EXPECT_TRUE( + ContainersEqual(Partition({{2, 3}, {0, 1}}), iterator.GetPartition(2))); + EXPECT_TRUE( + ContainersEqual(Partition({{2, 3}, {1, 2}}), iterator.GetPartition(3))); + } +} + +class RandomShapePartitionIteratorTest : public HloTestBase { + protected: + typedef std::vector> Partition; + RandomShapePartitionIteratorTest() + : generator_(rd_()), distribution_(1, 10) {} + + std::vector RandR4Dims() { return {Rand(), Rand(), Rand(), Rand()}; } + + int64 Rand() { return distribution_(generator_); } + + std::random_device rd_; + std::mt19937 generator_; + std::uniform_int_distribution distribution_; +}; + +TEST_F(RandomShapePartitionIteratorTest, RandomShapeAndPartitions) { + // Choose random dimensions for R4 shape. + Shape shape = ShapeUtil::MakeShapeWithLayout(F32, RandR4Dims(), {3, 2, 1, 0}); + // Choose random number of outer dimensions to partition. + const int num_outer_dims_to_partiton = 1 + (Rand() % 3); + // Choose random outer dimension partiton counts. + std::vector dim_sizes(num_outer_dims_to_partiton); + std::vector dim_partition_counts(num_outer_dims_to_partiton); + int64 total_dim_size = 1; + for (int i = 0; i < num_outer_dims_to_partiton; ++i) { + const int64 dimension = shape.layout().minor_to_major( + shape.layout().minor_to_major_size() - 1 - i); + dim_sizes[i] = shape.dimensions(dimension); + total_dim_size *= dim_sizes[i]; + // Choose dimension partition count in [1, dim_size] + const int64 dim_partition_count = 1 + Rand() % dim_sizes[i]; + dim_partition_counts[i] = dim_partition_count; + } + // Iterate through all partition: for each partition record covered + // index ranges by dimension. + std::vector> ranges(num_outer_dims_to_partiton); + ShapePartitionIterator partition_iterator(shape, dim_partition_counts); + const int64 partition_count = partition_iterator.GetTotalPartitionCount(); + for (int64 i = 0; i < partition_count; ++i) { + const auto& dim_partition = partition_iterator.GetPartition(i); + for (int dim = 0; dim < dim_partition.size(); ++dim) { + ranges[dim].insert( + std::make_pair(dim_partition[dim].first, + dim_partition[dim].first + dim_partition[dim].second)); + } + } + // Check that partitions cover entire dimension size range (for each + // partitioned dimension). + for (int i = 0; i < ranges.size(); ++i) { + int64 expected_index = 0; + for (auto& r : ranges[i]) { + EXPECT_EQ(expected_index, r.first); + expected_index = r.second; + } + EXPECT_EQ(expected_index, dim_sizes[i]); + } +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 7c74912a7ab..04d4d8f075b 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -21,6 +21,8 @@ limitations under the License. #include #include +#include "external/llvm/include/llvm/ExecutionEngine/ExecutionEngine.h" +#include "external/llvm/include/llvm/ExecutionEngine/SectionMemoryManager.h" #include "external/llvm/include/llvm/IR/Mangler.h" #include "external/llvm/include/llvm/Support/CodeGen.h" #include "external/llvm/include/llvm/Support/Host.h" diff --git a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc index 2d9d9c7de62..1d553cab1ad 100644 --- a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu_transfer_manager.cc @@ -75,20 +75,26 @@ Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, ShapeUtil::HumanString(literal.shape()).c_str()); } - cpu::runtime::InfeedManager* infeed_manager = - cpu::runtime::GetInfeedManager(); - int64 size = GetByteSizeRequirement(shape); if (size > std::numeric_limits::max()) { return Unimplemented("Infeed shape is too large: %s needs %lld bytes", ShapeUtil::HumanString(literal.shape()).c_str(), size); } + + return TransferBufferToInfeed(executor, size, literal.InternalData()); +} + +Status CpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor, + int64 size, + const void* source) { int32 size_32 = static_cast(size); CpuInfeedBuffer* queued_buffer = new CpuInfeedBuffer(size_32); - TF_RETURN_IF_ERROR(TransferBufferToDevice( - executor, /*size=*/size, /*source=*/LiteralUtil::InternalData(literal), - queued_buffer->device_memory())); + TF_RETURN_IF_ERROR(TransferBufferToDevice(executor, /*size=*/size, + /*source=*/source, + queued_buffer->device_memory())); + cpu::runtime::InfeedManager* infeed_manager = + cpu::runtime::GetInfeedManager(); infeed_manager->EnqueueBuffer(queued_buffer); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu_transfer_manager.h index 727462252d7..5d10b62a178 100644 --- a/tensorflow/compiler/xla/service/cpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/cpu_transfer_manager.h @@ -37,6 +37,8 @@ class CpuTransferManager : public GenericTransferManager { Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, const Literal& literal) override; + Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, + int64 size, const void* source) override; private: TF_DISALLOW_COPY_AND_ASSIGN(CpuTransferManager); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc index 5b296861006..5121d368665 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc @@ -24,16 +24,13 @@ limitations under the License. namespace xla { Status DfsHloVisitor::HandleElementwiseUnary(HloInstruction* hlo, - HloOpcode opcode, - HloInstruction* operand) { + HloOpcode opcode) { return Unimplemented("DfsHloVisitor::HandleElementwiseUnary: %s", HloOpcodeString(opcode).c_str()); } Status DfsHloVisitor::HandleElementwiseBinary(HloInstruction* hlo, - HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs) { + HloOpcode opcode) { return Unimplemented("DfsHloVisitor::HandleElementwiseBinary: %s", HloOpcodeString(opcode).c_str()); } @@ -51,22 +48,18 @@ void DfsHloVisitor::SetVisited(const HloInstruction& instruction) { } bool DfsHloVisitor::IsVisiting(const HloInstruction& instruction) { - if (visit_state_.count(&instruction) == 0) { - return false; - } - return visit_state_[&instruction] == VisitState::kVisiting; + auto it = visit_state_.find(&instruction); + return it != visit_state_.end() && it->second == VisitState::kVisiting; } bool DfsHloVisitor::DidVisit(const HloInstruction& instruction) { - if (visit_state_.count(&instruction) == 0) { - return false; - } - return visit_state_[&instruction] == VisitState::kVisited; + auto it = visit_state_.find(&instruction); + return it != visit_state_.end() && it->second == VisitState::kVisited; } bool DfsHloVisitor::NotVisited(const HloInstruction& instruction) { - return visit_state_.count(&instruction) == 0 || - visit_state_[&instruction] == VisitState::kNotVisited; + auto it = visit_state_.find(&instruction); + return it == visit_state_.end() || it->second == VisitState::kNotVisited; } Status DfsHloVisitor::Preprocess(HloInstruction* hlo) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 78a398f8efa..3f9b71cf2b6 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -65,43 +65,37 @@ class DfsHloVisitor { // These routines are self-descriptive, see class comment for usage // information. - virtual Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* operand); - virtual Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs); + virtual Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode); + virtual Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode); virtual Status HandleClamp(HloInstruction* clamp, HloInstruction* min, HloInstruction* arg, HloInstruction* max) = 0; virtual Status HandleSelect(HloInstruction* select, HloInstruction* pred, HloInstruction* on_true, HloInstruction* on_false) = 0; - virtual Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, - HloInstruction* rhs) { - return HandleElementwiseBinary(maximum, HloOpcode::kMaximum, lhs, rhs); + virtual Status HandleMaximum(HloInstruction* maximum) { + return HandleElementwiseBinary(maximum, HloOpcode::kMaximum); } - virtual Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs, - HloInstruction* rhs) { - return HandleElementwiseBinary(minimum, HloOpcode::kMinimum, lhs, rhs); + virtual Status HandleMinimum(HloInstruction* minimum) { + return HandleElementwiseBinary(minimum, HloOpcode::kMinimum); } virtual Status HandleConcatenate( HloInstruction* concatenate, tensorflow::gtl::ArraySlice operands) = 0; - virtual Status HandleConvert(HloInstruction* convert, - HloInstruction* operand) { - return HandleElementwiseUnary(convert, HloOpcode::kConvert, operand); + virtual Status HandleConvert(HloInstruction* convert) { + return HandleElementwiseUnary(convert, HloOpcode::kConvert); } - virtual Status HandleCopy(HloInstruction* copy, HloInstruction* operand) { - return HandleElementwiseUnary(copy, HloOpcode::kCopy, operand); + virtual Status HandleCopy(HloInstruction* copy) { + return HandleElementwiseUnary(copy, HloOpcode::kCopy); } virtual Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(multiply, HloOpcode::kMultiply, lhs, rhs); + return HandleElementwiseBinary(multiply, HloOpcode::kMultiply); } virtual Status HandleDot(HloInstruction* dot, HloInstruction* lhs, HloInstruction* rhs) = 0; virtual Status HandlePower(HloInstruction* power, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(power, HloOpcode::kPower, lhs, rhs); + return HandleElementwiseBinary(power, HloOpcode::kPower); } virtual Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, @@ -109,64 +103,71 @@ class DfsHloVisitor { virtual Status HandleCrossReplicaSum(HloInstruction* crs) = 0; virtual Status HandleCompare(HloInstruction* compare, HloOpcode opcode, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(compare, opcode, lhs, rhs); + return HandleElementwiseBinary(compare, opcode); } virtual Status HandleAdd(HloInstruction* add, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(add, HloOpcode::kAdd, lhs, rhs); + return HandleElementwiseBinary(add, HloOpcode::kAdd); } virtual Status HandleDivide(HloInstruction* divide, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(divide, HloOpcode::kDivide, lhs, rhs); + return HandleElementwiseBinary(divide, HloOpcode::kDivide); } virtual Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(remainder, HloOpcode::kRemainder, lhs, rhs); + return HandleElementwiseBinary(remainder, HloOpcode::kRemainder); } virtual Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(subtract, HloOpcode::kSubtract, lhs, rhs); + return HandleElementwiseBinary(subtract, HloOpcode::kSubtract); } virtual Status HandleAbs(HloInstruction* abs, HloInstruction* operand) { - return HandleElementwiseUnary(abs, HloOpcode::kAbs, operand); + return HandleElementwiseUnary(abs, HloOpcode::kAbs); } virtual Status HandleSign(HloInstruction* sign, HloInstruction* operand) { - return HandleElementwiseUnary(sign, HloOpcode::kSign, operand); + return HandleElementwiseUnary(sign, HloOpcode::kSign); } virtual Status HandleNegate(HloInstruction* negate, HloInstruction* operand) { - return HandleElementwiseUnary(negate, HloOpcode::kNegate, operand); + return HandleElementwiseUnary(negate, HloOpcode::kNegate); } virtual Status HandleExp(HloInstruction* exp, HloInstruction* operand) { - return HandleElementwiseUnary(exp, HloOpcode::kExp, operand); + return HandleElementwiseUnary(exp, HloOpcode::kExp); } virtual Status HandleFloor(HloInstruction* floor, HloInstruction* operand) { - return HandleElementwiseUnary(floor, HloOpcode::kFloor, operand); + return HandleElementwiseUnary(floor, HloOpcode::kFloor); } virtual Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) { - return HandleElementwiseUnary(ceil, HloOpcode::kCeil, operand); + return HandleElementwiseUnary(ceil, HloOpcode::kCeil); } virtual Status HandleLog(HloInstruction* log, HloInstruction* operand) { - return HandleElementwiseUnary(log, HloOpcode::kLog, operand); + return HandleElementwiseUnary(log, HloOpcode::kLog); + } + virtual Status HandleCos(HloInstruction* cos, HloInstruction* operand) { + return HandleElementwiseUnary(cos, HloOpcode::kCos); } virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) { - return HandleElementwiseUnary(tanh, HloOpcode::kTanh, operand); + return HandleElementwiseUnary(tanh, HloOpcode::kTanh); } virtual Status HandleIsFinite(HloInstruction* is_finite, HloInstruction* operand) { - return HandleElementwiseUnary(is_finite, HloOpcode::kIsFinite, operand); + return HandleElementwiseUnary(is_finite, HloOpcode::kIsFinite); } virtual Status HandleLogicalAnd(HloInstruction* logical_and, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(logical_and, HloOpcode::kLogicalAnd, lhs, - rhs); + return HandleElementwiseBinary(logical_and, HloOpcode::kLogicalAnd); } virtual Status HandleLogicalNot(HloInstruction* logical_not, HloInstruction* operand) { - return HandleElementwiseUnary(logical_not, HloOpcode::kLogicalNot, operand); + return HandleElementwiseUnary(logical_not, HloOpcode::kLogicalNot); } virtual Status HandleLogicalOr(HloInstruction* logical_or, HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(logical_or, HloOpcode::kLogicalOr, lhs, rhs); + return HandleElementwiseBinary(logical_or, HloOpcode::kLogicalOr); + } + virtual Status HandleReducePrecision(HloInstruction* reduce_precision, + HloInstruction* operand) { + return HandleElementwiseUnary(reduce_precision, + HloOpcode::kReducePrecision); } virtual Status HandleInfeed(HloInstruction* infeed) = 0; @@ -225,6 +226,8 @@ class DfsHloVisitor { virtual Status HandleRecv(HloInstruction* recv) = 0; + virtual Status HandleBatchNormTraining(HloInstruction* batchNormTraining) = 0; + // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". virtual Status FinishVisit(HloInstruction* root) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 6557c3aa8e6..2970ba8cc41 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -41,15 +41,19 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor { // Default action performed on HloInstruction. virtual Status DefaultAction(HloInstruction* hlo_instruction) = 0; - Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* operand) override { + Status HandleElementwiseUnary(HloInstruction* hlo, + HloOpcode opcode) override { return DefaultAction(hlo); } - Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleElementwiseBinary(HloInstruction* hlo, + HloOpcode opcode) override { return DefaultAction(hlo); } + + Status HandleBatchNormTraining(HloInstruction* hlo) override { + return DefaultAction(hlo); + } + Status HandleClamp(HloInstruction* clamp, HloInstruction* /*min*/, HloInstruction* /*arg*/, HloInstruction* /*max*/) override { @@ -60,12 +64,10 @@ class DfsHloVisitorWithDefault : public DfsHloVisitor { tensorflow::gtl::ArraySlice /*operands*/) override { return DefaultAction(concatenate); } - Status HandleConvert(HloInstruction* convert, - HloInstruction* /*operand*/) override { + Status HandleConvert(HloInstruction* convert) override { return DefaultAction(convert); } - Status HandleCopy(HloInstruction* copy, - HloInstruction* /*operand*/) override { + Status HandleCopy(HloInstruction* copy) override { return DefaultAction(copy); } Status HandleSelect(HloInstruction* select, HloInstruction* /*pred*/, diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index be4aadb6522..f79b6826b59 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -172,6 +172,10 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {operand_value}, {operand_value->getType()}, ir_builder_); + case HloOpcode::kCos: + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {operand_value}, + {operand_value->getType()}, + ir_builder_); case HloOpcode::kFloor: return llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::floor, {operand_value}, {operand_value->getType()}, @@ -381,6 +385,24 @@ StatusOr ElementalIrEmitter::EmitErfcInv( return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value)); } +StatusOr ElementalIrEmitter::EmitReducePrecision( + const HloInstruction* hlo, llvm::Value* x) const { + if (hlo->operand(0)->shape().element_type() != F32) { + return Unimplemented("reduce-precision only implemented for F32"); + } + // As a preliminary implementation, we only implement this for the case + // where it is a no-op -- that is, where the exponent and mantissa bit + // counts are equal to the (IEEE f32) bit counts for the input values. + if (hlo->exponent_bits() != 8) { + return Unimplemented("reduce-precision requires 8 exponent bits"); + } + if (hlo->mantissa_bits() != 23) { + return Unimplemented("reduce-precision requires 23 mantissa bits"); + } + + return x; +} + StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, bool is_signed) const { @@ -588,20 +610,37 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(1)}, {param_ir_type}, ir_builder_); auto in_block = ir_builder_->GetInsertBlock(); - auto body_block = in_block->splitBasicBlock( - ir_builder_->GetInsertPoint(), "rng_body"); - SetToFirstInsertPoint(body_block, ir_builder_); - auto out_block = body_block->splitBasicBlock( - ir_builder_->GetInsertPoint(), "rng_out"); + + // A terminator should be present iff we're emitting code + // into the middle (as opposed to the end) of a basic block. + CHECK_EQ(ir_builder_->GetInsertPoint() == in_block->end(), + in_block->getTerminator() == nullptr); + + llvm::BasicBlock* body_block; + llvm::BasicBlock* out_block; + + if (ir_builder_->GetInsertPoint() == in_block->end()) { + body_block = + llvm_ir::CreateBasicBlock(nullptr, "rng_body", ir_builder_); + out_block = + llvm_ir::CreateBasicBlock(nullptr, "rng_out", ir_builder_); + llvm::BranchInst::Create(body_block, in_block); + } else { + body_block = in_block->splitBasicBlock( + ir_builder_->GetInsertPoint(), "rng_body"); + out_block = body_block->splitBasicBlock( + ir_builder_->GetInsertPoint(), "rng_out"); + body_block->getTerminator()->eraseFromParent(); + } + SetToFirstInsertPoint(body_block, ir_builder_); auto random = ir_builder_->CreateAnd( ir_builder_->CreateZExtOrTrunc(get_next_i64(), param_ir_type), ir_builder_->CreateLShr(llvm::ConstantInt::get(param_ir_type, ~0), leading_zeros)); - llvm::ReplaceInstWithInst( - body_block->getTerminator(), - llvm::BranchInst::Create(out_block, body_block, - ir_builder_->CreateICmpULT(random, r))); + llvm::BranchInst::Create(out_block, body_block, + ir_builder_->CreateICmpULT(random, r), + body_block); SetToFirstInsertPoint(out_block, ir_builder_); return ir_builder_->CreateAdd( p, ir_builder_->CreateSelect( @@ -647,6 +686,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kCeil: case HloOpcode::kConvert: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kFloor: case HloOpcode::kIsFinite: @@ -720,6 +760,14 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( ElementwiseSourceIndex(index, *hlo, 2))); return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); }; + case HloOpcode::kReducePrecision: + return [this, hlo, &operand_to_generator]( + const IrArray::Index& index) -> StatusOr { + TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, + operand_to_generator.at(hlo->operand(0))( + ElementwiseSourceIndex(index, *hlo, 0))); + return EmitReducePrecision(hlo, operand_value); + }; case HloOpcode::kConcatenate: return [this, hlo, &operand_to_generator]( const IrArray::Index target_index) -> StatusOr { diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 2576d3823e0..bb9117ca61e 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -84,6 +84,9 @@ class ElementalIrEmitter { virtual StatusOr EmitErfcInv(PrimitiveType prim_type, llvm::Value* value) const; + virtual StatusOr EmitReducePrecision(const HloInstruction* hlo, + llvm::Value* x) const; + // A helper method for MakeElementGenerator. Given an elementwise op `hlo` and // the target array index, computes the source array index of its // `operand_no`-th operand. diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index bb4712c86f6..a08506d84d1 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -80,7 +80,7 @@ class FlattenCallGraphTest : public HloTestBase { HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, kScalarShape, "param0")); HloInstruction* zero = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(PRED, {}), HloOpcode::kGt, param0, zero)); return builder.Build(); @@ -157,7 +157,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(PRED, {}), "param0")); HloInstruction* false_constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); builder.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, param0, false_constant)); @@ -168,7 +168,7 @@ TEST_F(FlattenCallGraphTest, SharedWhileConditionAndBody) { { HloComputation::Builder builder(TestName() + ".entry"); HloInstruction* false_constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); builder.AddInstruction(HloInstruction::CreateWhile( ShapeUtil::MakeShape(PRED, {}), cond_computation, cond_computation, false_constant)); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index eb8b93330fb..476b2b8d6f8 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -82,13 +82,12 @@ Status GenericTransferManager::TransferLiteralFromDevice( } *literal->mutable_shape() = device_shape; - LiteralUtil::Reserve(ShapeUtil::ElementsIn(device_shape), literal); + literal->Reserve(ShapeUtil::ElementsIn(device_shape)); TF_RETURN_IF_ERROR(TransferBufferFromDevice( executor, source, /*size=*/ShapeUtil::ByteSizeOf(device_shape), - /*destination=*/LiteralUtil::MutableInternalData(literal))); + /*destination=*/literal->MutableInternalData())); if (!ShapeUtil::Equal(literal_shape, device_shape)) { - literal->Swap( - LiteralUtil::Relayout(*literal, literal_shape.layout()).get()); + literal->Swap(literal->Relayout(literal_shape.layout()).get()); } TF_RET_CHECK(ShapeUtil::Equal(literal_shape, literal->shape())); return Status::OK(); @@ -152,14 +151,20 @@ Status GenericTransferManager::TransferLiteralToDevice( tuple_elements_on_device.data(), destination); } - return TransferBufferToDevice( - executor, /*size=*/GetByteSizeRequirement(shape), - /*source=*/LiteralUtil::InternalData(literal), destination); + return TransferBufferToDevice(executor, + /*size=*/GetByteSizeRequirement(shape), + /*source=*/literal.InternalData(), destination); } Status GenericTransferManager::TransferLiteralToInfeed( se::StreamExecutor* executor, const Literal& literal) { - return Unimplemented("Infeed is not supported on GPU (b/30467474)"); + return Unimplemented("Generic transfer to Infeed"); +} + +Status GenericTransferManager::TransferBufferToInfeed( + perftools::gputools::StreamExecutor* executor, int64 size, + const void* source) { + return Unimplemented("Generic transfer to Infeed"); } Status GenericTransferManager::TransferLiteralFromOutfeed( diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 2fbdb94f06f..48c061d28e5 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -54,6 +54,8 @@ class GenericTransferManager : public TransferManager { Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, const Literal& literal) override; + Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, + int64 size, const void* source) override; Status TransferLiteralFromOutfeed( perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 86986934117..52b4a13296f 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -267,7 +267,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/platform/default/build_config:cublas_plugin", "//tensorflow/core/platform/default/build_config:cudnn_plugin", - "//tensorflow/core/platform/default/build_config:stream_executor_cuda", + "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep ], ) @@ -376,7 +376,6 @@ cc_test( ":fusion_merger", ":instruction_fusion", "//tensorflow/compiler/xla:test_helpers", - "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", ], ) @@ -418,7 +417,6 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/legacy_flags:gpu_compiler_flags", "//tensorflow/compiler/xla/service:algebraic_simplifier", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 2987c8913d7..c2dec7ed6af 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -55,7 +55,7 @@ using tensorflow::strings::StrAppend; // Returns whether operand is a floating-point literal with the given value. bool IsFPLiteralWithValue(const HloInstruction* operand, float value) { return operand->opcode() == HloOpcode::kConstant && - LiteralUtil::IsAllFloat(operand->literal(), value); + operand->literal().IsAllFloat(value); } GpuElementalIrEmitter::GpuElementalIrEmitter( diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index 8afc32dea97..242c32936d3 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -59,7 +59,7 @@ class FusionMergerTest : public HloTestBase { // Create const vector of ones to be used in element-wise computations. auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f}))); + Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); // Create simple fusable computation for tuple element 0 (wont get merged). auto out0 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -138,7 +138,7 @@ class FusionMergerTest : public HloTestBase { // Create two sub-computations, both of which are users of 'mul0'. auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f}))); + Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); // First sub-computation: out0 = Mul(Add(mul0, one_vec), one_vec) auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -209,7 +209,7 @@ class FusionMergerTest : public HloTestBase { // Create two fusable sub-computations which are dependent on shared // computation 'reduce_out'. auto one_vec = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f}))); + Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); // First sub-computation: out0 = Mul(Add(reduce_out, one_vec), one_vec) auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 86137a569f9..7ce8b8290a4 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -23,7 +23,6 @@ limitations under the License. #include "external/llvm/include/llvm/IR/DiagnosticPrinter.h" #include "external/llvm/include/llvm/IR/LLVMContext.h" #include "external/llvm/include/llvm/IR/Module.h" -#include "tensorflow/compiler/xla/legacy_flags/gpu_compiler_flags.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" @@ -95,11 +94,9 @@ constexpr int64 kMemoryAlignment = 256; // called in GpuCompiler's constructor, so can't return an error. But // GpuCompiler::Compile will return an error when the wanted libdevice file // doesn't exist in the folder this function returns. -string GetLibdeviceDir() { +string GetLibdeviceDir(const HloModuleConfig& config) { std::vector potential_libdevice_dirs; - // Flag xla_cuda_data_dir specified by the user. - legacy_flags::GpuCompilerFlags* flags = legacy_flags::GetGpuCompilerFlags(); - const string datadir = flags->xla_cuda_data_dir; + const string datadir = config.debug_options().xla_gpu_cuda_data_dir(); if (!datadir.empty()) { potential_libdevice_dirs.push_back(datadir); } @@ -230,8 +227,7 @@ void DumpPtxasInfo(const string& ptx) { } // namespace GpuCompiler::GpuCompiler() - : libdevice_dir_(GetLibdeviceDir()), - pointer_size_(llvm::DataLayout(kDataLayout).getPointerSize()) {} + : pointer_size_(llvm::DataLayout(kDataLayout).getPointerSize()) {} StatusOr> GpuCompiler::Compile( std::unique_ptr module, HloDumper dump_hlo, @@ -273,11 +269,12 @@ StatusOr> GpuCompiler::Compile( BufferAssigner::Run(module.get(), hlo_schedule->ConsumeHloOrdering(), BufferSizeBytesFunction(), kMemoryAlignment)); - legacy_flags::GpuCompilerFlags* flags = legacy_flags::GetGpuCompilerFlags(); - if (!flags->xla_gpu_dump_debug_json_to.empty()) { + const string dump_debug_json_to = + module->config().debug_options().xla_dump_debug_json_to(); + if (!dump_debug_json_to.empty()) { HloProto proto = MakeHloProto(*module, *buffer_assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, flags->xla_gpu_dump_debug_json_to, module->name())); + proto, dump_debug_json_to, module->name())); } IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(), @@ -292,7 +289,9 @@ StatusOr> GpuCompiler::Compile( entry_computation->root_instruction()->Accept(&ir_emitter)); string ir_module_string_before_opt; - if (VLOG_IS_ON(2) || flags->xla_gpu_embed_ir) { + const bool embed_ir_in_executable = + module->config().debug_options().xla_embed_ir_in_executable(); + if (VLOG_IS_ON(2) || embed_ir_in_executable) { ir_module_string_before_opt = llvm_ir::DumpModuleToString(llvm_module); VLOG(2) << "LLVM module before optimizations:"; XLA_VLOG_LINES(2, ir_module_string_before_opt); @@ -313,6 +312,10 @@ StatusOr> GpuCompiler::Compile( cc_major = 2; cc_minor = 0; } + if (libdevice_dir_.empty()) { + // Compute libdevice_dir_ just once and cache it in this member. + libdevice_dir_ = GetLibdeviceDir(module->config()); + } TF_ASSIGN_OR_RETURN(*ptx, CompileToPtx(&llvm_module, {cc_major, cc_minor}, module->config(), libdevice_dir_)); @@ -333,7 +336,7 @@ StatusOr> GpuCompiler::Compile( auto* gpu_executable = new GpuExecutable(*ptx, std::move(thunk_schedule), std::move(module), std::move(buffer_assignment), ShapeSizeBytesFunction()); - if (flags->xla_gpu_embed_ir) { + if (embed_ir_in_executable) { DCHECK_NE("", ir_module_string_before_opt); gpu_executable->set_ir_module_string(ir_module_string_before_opt); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index da52f5ab1f8..68dddcdc198 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -65,7 +65,7 @@ class GpuCompiler : public Compiler { private: // The parent directory of libdevice IR libraries. - const string libdevice_dir_; + string libdevice_dir_; // The list of PTX strings generated by this GpuCompiler. We let GpuCompiler // to own them because they need to be alive across the life span of the diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc index 120a3f7fba2..8b948d89f5a 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.cc @@ -13,8 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/core/platform/logging.h" namespace se = ::perftools::gputools; @@ -22,23 +24,23 @@ namespace se = ::perftools::gputools; namespace xla { namespace gpu { -InfeedManager::InfeedManager() - : current_buffer_(nullptr), - host_to_device_executor_(nullptr) {} +InfeedManager::InfeedManager() : host_to_device_executor_(nullptr) {} void InfeedManager::Reset() { tensorflow::mutex_lock l(mu_); - CHECK(!current_buffer_); + CHECK(dequeued_buffer_.empty()); for (auto buffer : enqueued_buffer_) { buffer->Done(); } enqueued_buffer_.clear(); } -void InfeedManager::EnqueueBuffer(InfeedBuffer* buffer) { +void InfeedManager::EnqueueBuffers(std::vector buffers) { tensorflow::mutex_lock l(mu_); bool was_empty = enqueued_buffer_.empty(); - enqueued_buffer_.push_back(buffer); + for (gpu::InfeedBuffer* b : buffers) { + enqueued_buffer_.push_back(b); + } if (was_empty) { // This has the potential to suffer from the notified thread // immediately trying and failing to acquire mu_, but seems @@ -53,18 +55,23 @@ InfeedBuffer* InfeedManager::BlockingDequeueBuffer() { while (enqueued_buffer_.empty()) { cv_.wait(l); } - CHECK(!current_buffer_); - current_buffer_ = enqueued_buffer_.front(); + InfeedBuffer* current_buffer = enqueued_buffer_.front(); enqueued_buffer_.pop_front(); - return current_buffer_; + dequeued_buffer_.insert(current_buffer); + return current_buffer; } -void InfeedManager::ReleaseCurrentBuffer(se::DeviceMemoryBase* device_memory) { - tensorflow::mutex_lock l(mu_); - CHECK(current_buffer_); - CHECK(device_memory->IsSameAs(*current_buffer_->device_memory())); - current_buffer_->Done(); - current_buffer_ = nullptr; +void InfeedManager::ReleaseBuffers(std::vector buffers) { + { + tensorflow::mutex_lock l(mu_); + for (gpu::InfeedBuffer* b : buffers) { + CHECK(ContainsKey(dequeued_buffer_, b)); + dequeued_buffer_.erase(b); + } + } + for (gpu::InfeedBuffer* b : buffers) { + b->Done(); + } } se::Stream* InfeedManager::GetStream(se::StreamExecutor* executor) { diff --git a/tensorflow/compiler/xla/service/gpu/infeed_manager.h b/tensorflow/compiler/xla/service/gpu/infeed_manager.h index 50d0ce340f3..23fbc5bd083 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_manager.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_manager.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -81,25 +82,19 @@ class InfeedManager { // condition is to call Reset when no computation is taking place. void Reset(); - // Adds buffer to the infeed queue. buffer->Done will be called when - // the buffer will no longer be accessed by the InfeedManager, - // either as a result of a call to Reset or because the runtime has - // dequeued and used the buffer. - void EnqueueBuffer(InfeedBuffer* buffer); + // Adds a set of buffers to the infeed queue atomically. buffer->Done + // will be called when the buffer will no longer be accessed by the + // InfeedManager, either as a result of a call to Reset or because the + // runtime has dequeued and used the buffer. + void EnqueueBuffers(std::vector buffers); // Blocks until the infeed queue is non-empty, then returns the - // buffer at the head of the queue. Sets the current buffer to be - // the returned buffer. It is an error to call BlockingDequeueBuffer - // if there is an unreleased current buffer, i.e., - // ReleaseCurrentBuffer must be called between calls to - // BlockingDequeueBuffer. + // buffer at the head of the queue. Adds the current buffer to the + // to-be released set. InfeedBuffer* BlockingDequeueBuffer(); - // Releases the current buffer, which is the last buffer returned by - // BlockingDequeueBuffer and not yet released. device_memory must - // match that of the current buffer. - void ReleaseCurrentBuffer( - perftools::gputools::DeviceMemoryBase* device_memory); + // Releases a set of buffers from the to-be released set. + void ReleaseBuffers(std::vector buffers); // Returns a cached stream associated with an executor. Allocates a // new stream on the first invocation. On subsequent invocations, if @@ -109,18 +104,25 @@ class InfeedManager { perftools::gputools::StreamExecutor* executor); private: + // TODO(b/30467474): Revisit if this mutex becomes a point of + // contention. tensorflow::mutex mu_; + // Condition variable that is signaled every time a buffer is // enqueued to an empty queue. tensorflow::condition_variable cv_; + // InfeedBuffer* queue contents are not owned, but buffer->Done must // be called when the buffer is no longer needed by the runtime. std::deque enqueued_buffer_; - // If non-NULL, the buffer that is currently being processed by the + + // Buffers that are dequeued and currently being processed by the // runtime. Not owned. - InfeedBuffer* current_buffer_; + tensorflow::gtl::FlatSet dequeued_buffer_; + // Cached host to device stream for queuing infeed data. std::unique_ptr host_to_device_stream_; + // Executor that the host_to_device_stream belongs to. Not owned. perftools::gputools::StreamExecutor* host_to_device_executor_; }; diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index 6f144c7273e..e33e904692c 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -21,31 +21,59 @@ limitations under the License. namespace xla { namespace gpu { -InfeedThunk::InfeedThunk(const BufferAllocation::Slice& destination_buffer, - uint64 mem_size, const HloInstruction* hlo_instruction) +InfeedThunk::InfeedThunk( + tensorflow::gtl::ArraySlice tuple_element_buffers, + const BufferAllocation::Slice& destination_buffer, + const HloInstruction* hlo_instruction) : Thunk(Kind::kInfeed, hlo_instruction), - destination_buffer_(destination_buffer), - mem_size_(mem_size) {} + tuple_element_buffers_(tuple_element_buffers.begin(), + tuple_element_buffers.end()), + destination_buffer_(destination_buffer) {} tensorflow::Status InfeedThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, perftools::gputools::Stream* stream) { VLOG(2) << "Infeeding to GPU "; - perftools::gputools::DeviceMemoryBase destination_data = + + perftools::gputools::DeviceMemoryBase destination_address = buffer_allocations.GetDeviceAddress(destination_buffer_); InfeedManager* infeed_manager = GetOrCreateInfeedManager(); - InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); - CHECK_EQ(buffer->length(), mem_size_); - stream->ThenMemcpy(&destination_data, *(buffer->device_memory()), - buffer->length()); + std::vector infeed_buffers; + if (ShapeUtil::IsTuple(hlo_instruction()->shape())) { + CHECK(!ShapeUtil::IsNestedTuple(hlo_instruction()->shape())); + // Transfer the tuple elements first. + std::vector tuple_element_addresses; + for (BufferAllocation::Slice tuple_element_buffer : + tuple_element_buffers_) { + perftools::gputools::DeviceMemoryBase tuple_element_address = + buffer_allocations.GetDeviceAddress(tuple_element_buffer); + + InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); + infeed_buffers.push_back(buffer); + stream->ThenMemcpy(&tuple_element_address, *(buffer->device_memory()), + buffer->length()); + tuple_element_addresses.push_back(tuple_element_address.opaque()); + } + // Transfer the tuple outer buffer. + auto host_size = tuple_element_addresses.size() * sizeof(void*); + stream->ThenMemcpy(&destination_address, tuple_element_addresses.data(), + host_size); + } else { + InfeedBuffer* buffer = infeed_manager->BlockingDequeueBuffer(); + infeed_buffers.push_back(buffer); + stream->ThenMemcpy(&destination_address, *(buffer->device_memory()), + buffer->length()); + } + if (!stream->BlockHostUntilDone()) { return InternalError("Failed to complete data transfer on stream %p", stream); } - // Since Infeeds are totally ordered, no other infeed should sneak - // in and we should be able to release the same buffer we dequeued. - infeed_manager->ReleaseCurrentBuffer(buffer->device_memory()); + + infeed_manager->ReleaseBuffers(infeed_buffers); + + VLOG(2) << "Infeeding to GPU complete"; return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index 0a808186c21..371d71f9dbd 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -35,8 +35,10 @@ class InfeedThunk : public Thunk { // infeed queue to the device buffer // `destination_buffer`. `mem_size` is the size of the data in // bytes. - InfeedThunk(const BufferAllocation::Slice& destination_buffer, - uint64 mem_size, const HloInstruction* hlo_instruction); + InfeedThunk(tensorflow::gtl::ArraySlice + tuple_element_buffers, + const BufferAllocation::Slice& destination_buffer, + const HloInstruction* hlo_instruction); InfeedThunk(const InfeedThunk&) = delete; InfeedThunk& operator=(const InfeedThunk&) = delete; @@ -46,8 +48,8 @@ class InfeedThunk : public Thunk { perftools::gputools::Stream* stream) override; private: + const std::vector tuple_element_buffers_; const BufferAllocation::Slice destination_buffer_; - const uint64 mem_size_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 607a366ac67..de72ac738ea 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -231,7 +231,7 @@ class IrEmitterUnnested : public IrEmitter { // IrEmitterUnnested handles the following instructions differently from // IrEmitter. - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleCopy(HloInstruction* copy) override; Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, const Window& window) override; Status HandleDot(HloInstruction* dot, HloInstruction* lhs_instruction, diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 5fa2bfdd7e4..ea71d924179 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -722,8 +722,7 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, } // namespace -Status IrEmitterUnnested::HandleCopy(HloInstruction* copy, - HloInstruction* operand) { +Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { if (ImplementedAsMemcpy(*copy)) { thunk_sequence_->emplace_back(BuildCopyThunk(copy)); return Status::OK(); @@ -731,7 +730,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy, bool is_transpose_021; Shape reduced_input_shape, reduced_output_shape; std::tie(is_transpose_021, reduced_input_shape, reduced_output_shape) = - IsTranspose021(operand->shape(), copy->shape()); + IsTranspose021(copy->operand(0)->shape(), copy->shape()); if (is_transpose_021 && reduced_input_shape.dimensions(1) >= kMinDimensionToTransposeTiled && reduced_input_shape.dimensions(2) >= kMinDimensionToTransposeTiled) { @@ -739,7 +738,8 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy, VLOG(3) << "Emitting tiled 0-2-1 transposition"; constexpr int64 tile_size = 32; int64 num_tiles = EmitTranspose021Tiled( - GetIrArray(*operand).CastToShape(reduced_input_shape, &ir_builder_), + GetIrArray(*(copy->operand(0))) + .CastToShape(reduced_input_shape, &ir_builder_), GetIrArray(*copy).CastToShape(reduced_output_shape, &ir_builder_), tile_size, &ir_builder_); UpdateLaunchDimensions(LaunchDimensions(num_tiles, tile_size), LastThunk(), @@ -747,7 +747,7 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy, return Status::OK(); } - return IrEmitter::HandleCopy(copy, operand); + return IrEmitter::HandleCopy(copy); } Status IrEmitterUnnested::EmitColumnReduction( @@ -1648,7 +1648,7 @@ std::unique_ptr IrEmitterUnnested::BuildCopyThunk( const HloInstruction* operand = inst->operand(0); CHECK_EQ(HloOpcode::kConstant, operand->opcode()); return MakeUnique( - /*source_address=*/LiteralUtil::InternalData(operand->literal()), + /*source_address=*/operand->literal().InternalData(), /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ llvm_ir::ByteSizeOf(operand->shape(), @@ -1659,12 +1659,18 @@ std::unique_ptr IrEmitterUnnested::BuildCopyThunk( std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( const HloInstruction* inst) { CHECK_EQ(HloOpcode::kInfeed, inst->opcode()); + + std::vector tuple_element_buffers; + for (int64 i = 0; i < inst->shape().tuple_shapes_size(); ++i) { + BufferAllocation::Slice buffer = ir_emitter_context_->buffer_assignment() + .GetUniqueSlice(inst, {i}) + .ConsumeValueOrDie(); + tuple_element_buffers.push_back(buffer); + } + return MakeUnique( - /*destination_buffer=*/GetAllocationSlice(*inst), - /*mem_size=*/ - llvm_ir::ByteSizeOf(inst->shape(), - ir_emitter_context_->llvm_module()->getDataLayout()), - inst); + tuple_element_buffers, + /*destination_buffer=*/GetAllocationSlice(*inst), inst); } std::unique_ptr IrEmitterUnnested::BuildGemmThunk( diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD index 724549c0c4e..1d1e5bee542 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/BUILD @@ -28,10 +28,10 @@ cc_library( "utils.h", ], deps = [ + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla/legacy_flags:gpu_backend_lib_flags", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index e03571a9672..881522a0298 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/legacy_flags/gpu_backend_lib_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/dump_ir_pass.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/utils.h" @@ -134,13 +133,8 @@ static string GetSmName(std::pair compute_capability) { // from the input filename. string MakeNameForTempProduct(const std::string& input_filename, tensorflow::StringPiece extension) { - legacy_flags::GpuBackendLibFlags* flags = - legacy_flags::GetGpuBackendLibFlags(); - return tensorflow::io::JoinPath( - flags->dump_temp_products_to, - ReplaceFilenameExtension( - tensorflow::io::Basename(llvm_ir::AsString(input_filename)), - extension)); + return ReplaceFilenameExtension( + tensorflow::io::Basename(llvm_ir::AsString(input_filename)), extension); } // Initializes LLVM passes. Uses the PassRegistry mechanism. @@ -177,20 +171,16 @@ std::unique_ptr GetTargetMachine( .xla_enable_fast_math(), &target_options); - // Enable FMA synthesis if desired. - legacy_flags::GpuBackendLibFlags* flags = - legacy_flags::GetGpuBackendLibFlags(); - if (flags->fma) { - target_options.AllowFPOpFusion = FPOpFusion::Fast; - } + // Enable FMA synthesis. + target_options.AllowFPOpFusion = FPOpFusion::Fast; // Set the verbose assembly options. - target_options.MCOptions.AsmVerbose = flags->verbose_ptx_asm; + target_options.MCOptions.AsmVerbose = false; // The selection of codegen optimization level is copied from function // GetCodeGenOptLevel in //external/llvm/tools/opt/opt.cpp. CodeGenOpt::Level codegen_opt_level; - switch (flags->opt_level) { + switch (hlo_module_config.debug_options().xla_backend_optimization_level()) { case 1: codegen_opt_level = CodeGenOpt::Less; break; @@ -262,12 +252,10 @@ string EmitModuleToPTX(Module* module, llvm::TargetMachine* target_machine) { // The extension is stripped by IrDumpingPassManager, so we need to // get creative to add a suffix. string module_id(llvm_ir::AsString(module->getModuleIdentifier())); - legacy_flags::GpuBackendLibFlags* flags = - legacy_flags::GetGpuBackendLibFlags(); IrDumpingPassManager codegen_passes( ReplaceFilenameExtension(tensorflow::io::Basename(module_id), "-nvptx.dummy"), - flags->dump_temp_products_to, flags->dump_ir_before_passes); + "", false); codegen_passes.add(new llvm::TargetLibraryInfoWrapperPass( llvm::Triple(module->getTargetTriple()))); @@ -345,36 +333,19 @@ StatusOr CompileModuleToPtx(llvm::Module* module, TF_RETURN_IF_ERROR( LinkLibdeviceIfNecessary(module, compute_capability, libdevice_dir_path)); - legacy_flags::GpuBackendLibFlags* flags = - legacy_flags::GetGpuBackendLibFlags(); - if (!flags->dump_temp_products_to.empty()) { - string linked_filename = - MakeNameForTempProduct(module->getModuleIdentifier(), "linked.bc"); - LOG(INFO) << "dumping bitcode after linking libdevice to: " - << linked_filename; - EmitBitcodeToFile(*module, linked_filename); - } - // Set the flush-denormals-to-zero flag on the module so the NVVM reflect pass // can access it. - module->addModuleFlag(llvm::Module::Override, "nvvm-reflect-ftz", flags->ftz); + module->addModuleFlag(llvm::Module::Override, "nvvm-reflect-ftz", + hlo_module_config.debug_options().xla_gpu_ftz()); // If ftz is enabled, set it as an attribute on every function in the module. - if (flags->ftz) { + if (hlo_module_config.debug_options().xla_gpu_ftz()) { for (llvm::Function& fn : *module) { fn.addFnAttr("nvptx-f32ftz", "true"); } } - // Run IR-level optimizations. - if (flags->dump_ir_before_passes && flags->dump_temp_products_to.empty()) { - LOG(FATAL) << "--dump_ir_before_passes must be specified with " - "--dump_temp_products_to"; - } - - IrDumpingPassManager module_passes(module->getModuleIdentifier(), - flags->dump_temp_products_to, - flags->dump_ir_before_passes); + IrDumpingPassManager module_passes(module->getModuleIdentifier(), "", false); // Add an appropriate TargetLibraryInfo pass for the module's triple. llvm::TargetLibraryInfoWrapperPass* tliwp = @@ -406,8 +377,16 @@ StatusOr CompileModuleToPtx(llvm::Module* module, // too. llvm::legacy::FunctionPassManager function_passes(module); - AddOptimizationPasses(flags->opt_level, /*size_level=*/0, - target_machine.get(), &module_passes, &function_passes); + int32 opt_level = + hlo_module_config.debug_options().xla_backend_optimization_level(); + + CHECK_GE(opt_level, 2) + << "The XLA GPU backend doesn't support unoptimized code generation"; + + AddOptimizationPasses(opt_level, + /*size_level=*/0, target_machine.get(), &module_passes, + &function_passes); + // Loop unrolling exposes more opportunities for SROA. Therefore, we run SROA // again after the standard optimization passes [http://b/13329423]. // TODO(jingyue): SROA may further expose more optimization opportunities, such @@ -415,7 +394,7 @@ StatusOr CompileModuleToPtx(llvm::Module* module, // the inlining cost of a function). For now, running SROA already emits good // enough code for the evaluated benchmarks. We may want to run more // optimizations later. - if (flags->opt_level > 0) { + if (opt_level > 0) { // LLVM's optimizer turns on SROA when the optimization level is greater // than 0. We mimic this behavior here. module_passes.add(llvm::createSROAPass()); @@ -433,14 +412,6 @@ StatusOr CompileModuleToPtx(llvm::Module* module, function_passes.doFinalization(); module_passes.run(*module); - if (!flags->dump_temp_products_to.empty()) { - string optimized_filename = - MakeNameForTempProduct(module->getModuleIdentifier(), "optimized.bc"); - LOG(INFO) << "dumping bitcode after optimizations to: " - << optimized_filename; - EmitBitcodeToFile(*module, optimized_filename); - } - // Finally, produce PTX. return EmitModuleToPTX(module, target_machine.get()); } @@ -473,22 +444,6 @@ void GPUBackendInit() { // between those loads. FeedLLVMWithFlags({"-memdep-block-scan-limit=500"}); - legacy_flags::GpuBackendLibFlags* flags = - legacy_flags::GetGpuBackendLibFlags(); - if (!flags->llvm_cl_opts.empty()) { - std::vector opts = - tensorflow::str_util::Split(flags->llvm_cl_opts, ','); - FeedLLVMWithFlags(opts); - } - - if (flags->llvm_dump_passes) { - // Enable LLVM pass debugging dump. LLVM dumps this information when a pass - // manager is initialized for execution. It's done to stderr (this is - // hardcoded within LLVM to the dbgs() stream, we can't change it from the - // outside). - FeedLLVMWithFlags({"-debug-pass=Arguments"}); - } - // Initialize the NVPTX target; it's the only target we link with, so call its // specific initialization functions instead of the catch-all InitializeAll*. LLVMInitializeNVPTXTarget(); diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index a12a9a71682..b8c61620845 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -61,7 +61,7 @@ HloInstruction* MaybePaddedAndSlicedInput( PrimitiveType element_type = input->shape().element_type(); HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(LiteralUtil::Zero(element_type)))); + MakeUnique(Literal::Zero(element_type)))); input = computation->AddInstruction(HloInstruction::CreatePad( ShapeInference::InferPadShape( /*operand_shape=*/input->shape(), @@ -127,7 +127,7 @@ HloInstruction* MaybePaddedKernel(const Window& conv_window, PrimitiveType element_type = kernel->shape().element_type(); HloInstruction* padding = computation->AddInstruction(HloInstruction::CreateConstant( - MakeUnique(LiteralUtil::Zero(element_type)))); + MakeUnique(Literal::Zero(element_type)))); return computation->AddInstruction(HloInstruction::CreatePad( ShapeInference::InferPadShape( /*operand_shape=*/kernel->shape(), @@ -242,9 +242,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // Create a new backward convolution replacing the old one. HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(1); - HloInstruction* padding = computation->AddInstruction( - HloInstruction::CreateConstant(MakeUnique( - LiteralUtil::Zero(input->shape().element_type())))); + HloInstruction* padding = + computation->AddInstruction(HloInstruction::CreateConstant( + MakeUnique(Literal::Zero(input->shape().element_type())))); HloInstruction* padded_input = computation->AddInstruction(HloInstruction::CreatePad( ShapeInference::InferPadShape(input->shape(), padding->shape(), diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index 06b01d311da..3034ed06b7e 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -37,8 +37,8 @@ namespace { // patterns to match. // // Each ExprTree node is comprised of an HloOpcode, and a set of operands (each -// of type ExprTree). Operands can be added by specifying the index and HloOpcode -// of the operand. +// of type ExprTree). Operands can be added by specifying the index and +// HloOpcode of the operand. // // For example, the following computation: // @@ -197,10 +197,9 @@ class MatcherBase { return InvalidArgument("Must use S32 or S64 integral types."); } if (type == S32) { - *const_value = - static_cast(LiteralUtil::GetFirstElement(literal)); + *const_value = static_cast(literal.GetFirstElement()); } else if (type == S64) { - *const_value = LiteralUtil::GetFirstElement(literal); + *const_value = literal.GetFirstElement(); } return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index e82491fd6f9..51d38f84212 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -41,7 +41,7 @@ class WhileTransformerTest : public HloTestBase { const int64 tuple_index, const int64 limit) { auto builder = HloComputation::Builder(TestName() + ".Condition"); auto limit_const = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(limit))); + HloInstruction::CreateConstant(Literal::CreateR0(limit))); auto loop_state = builder.AddInstruction( HloInstruction::CreateParameter(0, loop_state_shape_, "loop_state")); auto induction_variable = @@ -64,8 +64,8 @@ class WhileTransformerTest : public HloTestBase { auto induction_variable = builder.AddInstruction(HloInstruction::CreateGetTupleElement( induction_variable_shape_, loop_state, ind_var_tuple_index)); - auto inc = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(increment))); + auto inc = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(increment))); auto add0 = builder.AddInstruction(HloInstruction::CreateBinary( induction_variable->shape(), HloOpcode::kAdd, induction_variable, inc)); // Update data GTE(data_tuple_index). @@ -88,12 +88,10 @@ class WhileTransformerTest : public HloTestBase { const int64 ind_var_tuple_index, const int64 ind_var_init) { auto builder = HloComputation::Builder(TestName() + ".While"); - auto induction_var_init = - builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR0(ind_var_init))); - auto data_init = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1( - {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); + auto induction_var_init = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(ind_var_init))); + auto data_init = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}))); auto loop_state_init = ind_var_tuple_index == 0 ? builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu_transfer_manager.cc index 4b8d190a463..cd1b182b222 100644 --- a/tensorflow/compiler/xla/service/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu_transfer_manager.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -44,16 +43,66 @@ GpuTransferManager::GpuTransferManager() Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, const Literal& literal) { const Shape& shape = literal.shape(); - VLOG(2) << "Transferring literal shape to infeed: " + VLOG(2) << "Transferring literal to infeed with shape: " << ShapeUtil::HumanString(shape); - // TODO(b/30467474) handle tuples. + std::vector buffers; + if (ShapeUtil::IsTuple(shape)) { - return Unimplemented("Infeed with a tuple shape is not supported: %s", - ShapeUtil::HumanString(literal.shape()).c_str()); + if (ShapeUtil::IsNestedTuple(shape)) { + return Unimplemented( + "Infeed with a nested tuple shape is not supported: %s", + ShapeUtil::HumanString(literal.shape()).c_str()); + } + + for (const auto& tuple_element : literal.tuple_literals()) { + TF_ASSIGN_OR_RETURN( + gpu::InfeedBuffer * buffer, + TransferLiteralToInfeedInternal(executor, tuple_element)); + buffers.push_back(buffer); + } + } else { + TF_ASSIGN_OR_RETURN(gpu::InfeedBuffer * buffer, + TransferLiteralToInfeedInternal(executor, literal)); + buffers.push_back(buffer); } + gpu::InfeedManager* infeed_manager = gpu::GetOrCreateInfeedManager(); + se::Stream* stream = infeed_manager->GetStream(executor); + + // TODO(b/30467474): Since this stream is shared across different + // infeed requests, blocking on the stream might be + // heavy-handed. Figure out if finer-grained acknowledgement is + // possible. + if (!stream->BlockHostUntilDone()) { + for (gpu::InfeedBuffer* b : buffers) { + b->Done(); + } + return InternalError("Failed to complete data transfer on stream %p", + stream); + } + + infeed_manager->EnqueueBuffers(buffers); + + VLOG(2) << "Infeed data transferred"; + + return Status::OK(); +} + +Status GpuTransferManager::TransferBufferToInfeed(se::StreamExecutor* executor, + int64 size, + const void* source) { + return TransferBufferToInfeedInternal(executor, size, source).status(); +} + +StatusOr +GpuTransferManager::TransferLiteralToInfeedInternal( + se::StreamExecutor* executor, const Literal& literal) { + const Shape& shape = literal.shape(); + CHECK(!ShapeUtil::IsTuple(shape)); + int64 size = GetByteSizeRequirement(shape); + if (size > std::numeric_limits::max()) { return Unimplemented("Infeed shape is too large: %s needs %lld bytes", ShapeUtil::HumanString(literal.shape()).c_str(), size); @@ -64,6 +113,11 @@ Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, ShapeUtil::HumanString(literal.shape()).c_str()); } + return TransferBufferToInfeedInternal(executor, size, literal.InternalData()); +} + +StatusOr GpuTransferManager::TransferBufferToInfeedInternal( + se::StreamExecutor* executor, int64 size, const void* source) { gpu::InfeedManager* infeed_manager = gpu::GetOrCreateInfeedManager(); se::Stream* stream = infeed_manager->GetStream(executor); if (stream == nullptr) { @@ -71,21 +125,11 @@ Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, } gpu::InfeedBuffer* buffer = new gpu::InfeedBuffer(executor, size); - stream->ThenMemcpy(buffer->device_memory(), - LiteralUtil::InternalData(literal), size); + stream->ThenMemcpy(buffer->device_memory(), source, size); VLOG(2) << "Queued infeed data on stream " << stream; - if (!stream->BlockHostUntilDone()) { - buffer->Done(); - return InternalError("Failed to complete data transfer on stream %p", - stream); - } - - infeed_manager->EnqueueBuffer(buffer); - - VLOG(2) << "Infeed data transferred"; - return Status::OK(); + return buffer; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu_transfer_manager.h index 6dfe7ba0295..4fc6c911a4b 100644 --- a/tensorflow/compiler/xla/service/gpu_transfer_manager.h +++ b/tensorflow/compiler/xla/service/gpu_transfer_manager.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/generic_transfer_manager.h" +#include "tensorflow/compiler/xla/service/gpu/infeed_manager.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -37,8 +38,20 @@ class GpuTransferManager : public GenericTransferManager { Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, const Literal& literal) override; + Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, + int64 size, const void* source) override; private: + // Internal helper function for TransferLiteralToInfeed(). Input + // literal cannot be a tuple. + StatusOr TransferLiteralToInfeedInternal( + perftools::gputools::StreamExecutor* executor, const Literal& literal); + + // Internal helper function for TransferLiteralToInfeed(). + StatusOr TransferBufferToInfeedInternal( + perftools::gputools::StreamExecutor* executor, int64 size, + const void* source); + TF_DISALLOW_COPY_AND_ASSIGN(GpuTransferManager); }; diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index cd00a41a037..d194b3a3102 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -47,7 +47,7 @@ HloComputation* AddScalarConstantComputation(int64 addend, HloModule* module) { auto x_value = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {}), "x_value")); auto half = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.5))); + HloInstruction::CreateConstant(Literal::CreateR0(0.5))); builder.AddInstruction(HloInstruction::CreateBinary( half->shape(), HloOpcode::kAdd, x_value, half)); return module->AddEmbeddedComputation(builder.Build()); @@ -118,7 +118,7 @@ std::unique_ptr MakeBigGraph() { auto rng = builder.AddInstruction( HloInstruction::CreateRng(vshape, RNG_UNIFORM, {param_m, param_m})); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto add_computation = ScalarSumComputation(module.get()); builder.AddInstruction( HloInstruction::CreateReduce(vshape, rng, one, {1}, add_computation)); diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 86f62accd3b..c662cec9c70 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -187,7 +187,7 @@ Status HeapSimulator::RunComputation( buffer->instruction()->opcode() != HloOpcode::kCopy && CanShareOperandBufferWithUser( operand_buffer->instruction(), operand_buffer->index(), - buffer->instruction(), buffer->index(), points_to_analysis)) { + buffer->instruction(), buffer->index(), &points_to_analysis)) { ShareBuffer(buffer, operand_buffer, instruction); shared = true; break; diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 60a0768a86b..fefc4c6a0f2 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -173,7 +173,7 @@ class HeapSimulatorTest : public HloTestBase { TEST_F(HeapSimulatorTest, ScalarConstant) { auto builder = HloComputation::Builder(TestName()); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); // Constants aren't assigned. See b/32248867 HeapSimulatorTracker tracker(TestName(), builder.Build(), {const0}); diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index 24c467d411b..442585772f1 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -37,12 +37,12 @@ using ::testing::UnorderedElementsAre; class HloAliasAnalysisTest : public HloTestBase { protected: - HloAliasAnalysisTest() : module_(TestName()) {} + HloAliasAnalysisTest() : module_(CreateNewModule()) {} // Run alias analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. const HloAliasAnalysis& RunAnalysis() { - analysis_ = HloAliasAnalysis::Run(&module_).ConsumeValueOrDie(); + analysis_ = HloAliasAnalysis::Run(module_.get()).ConsumeValueOrDie(); return *analysis_; } @@ -77,7 +77,31 @@ class HloAliasAnalysisTest : public HloTestBase { return analysis_->dataflow_analysis().GetValue(buffer.value_ids()[0]); } - HloModule module_; + // Returns true if any values held in the same buffer interfere. Generally, in + // the compiler pipeline copy-insertion will guarantee that this interference + // never occurs, but HLO graphs with interference can be explicitly + // constructed. + bool AnyValuesInSameBufferInterfere() { + DependencyHloOrdering ordering(module_.get()); + for (const HloBuffer* buffer : analysis_->buffers()) { + for (HloValue::Id value_id_a : buffer->value_ids()) { + for (HloValue::Id value_id_b : buffer->value_ids()) { + if (value_id_a != value_id_b && + analysis_->dataflow_analysis().MayInterfere( + value_id_a, value_id_b, ordering)) { + VLOG(1) << analysis_->dataflow_analysis().GetValue(value_id_a) + << " interferes with " + << analysis_->dataflow_analysis().GetValue(value_id_b) + << " in buffer: " << *buffer; + return true; + } + } + } + } + return false; + } + + std::unique_ptr module_; std::unique_ptr analysis_; const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); @@ -87,12 +111,12 @@ TEST_F(HloAliasAnalysisTest, BinaryOperation) { // Test the analysis on a single binary operation (Add). auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, constant1, constant2)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -107,6 +131,8 @@ TEST_F(HloAliasAnalysisTest, BinaryOperation) { EXPECT_FALSE(analysis.GetInstructionBufferSet(add).IsAmbiguous()); EXPECT_TRUE(analysis.GetInstructionBufferSet(add).IsDistinct()); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, TupleAndGtes) { @@ -124,7 +150,7 @@ TEST_F(HloAliasAnalysisTest, TupleAndGtes) { HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1)); builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -156,6 +182,8 @@ TEST_F(HloAliasAnalysisTest, TupleAndGtes) { EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsAmbiguous()); EXPECT_TRUE(analysis.GetInstructionBufferSet(tuple).IsDistinct()); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, NondistinctTuple) { @@ -168,7 +196,7 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) { // param0 is included twice in the tuple. auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({param0, param1, param0})); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -179,6 +207,8 @@ TEST_F(HloAliasAnalysisTest, NondistinctTuple) { EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsAmbiguous()); EXPECT_FALSE(analysis.GetInstructionBufferSet(tuple).IsDistinct()); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, SingleCall) { @@ -192,16 +222,16 @@ TEST_F(HloAliasAnalysisTest, SingleCall) { auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); HloComputation* called_computation = - module_.AddEmbeddedComputation(subbuilder.Build()); + module_->AddEmbeddedComputation(subbuilder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -217,6 +247,8 @@ TEST_F(HloAliasAnalysisTest, SingleCall) { EXPECT_THAT( analysis.GetUniqueBufferAt(add).locations(), UnorderedElementsAre(HloLocation{add, {}}, HloLocation{call, {}})); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) { @@ -229,18 +261,18 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) { auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); HloComputation* called_computation = - module_.AddEmbeddedComputation(subbuilder.Build()); + module_->AddEmbeddedComputation(subbuilder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto call1 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto call2 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {call1, constant2}, called_computation)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -269,6 +301,8 @@ TEST_F(HloAliasAnalysisTest, ComputationCalledTwice) { EXPECT_FALSE(analysis.GetInstructionBufferSet(subparam1).IsAmbiguous()); EXPECT_TRUE(analysis.GetInstructionBufferSet(subparam0).IsDistinct()); EXPECT_TRUE(analysis.GetInstructionBufferSet(subparam1).IsDistinct()); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, SingleWhile) { @@ -303,27 +337,27 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) { scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); auto body_tuple = body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_0, add})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); // Condition computation trivially returns a constant "false". auto cond_builder = HloComputation::Builder("condition"); auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -356,6 +390,8 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) { GetValueDefinedAt(body_param, {1}), GetValueDefinedAt(cond_param, {1}), GetValueDefinedAt(add))); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, SequentialWhiles) { @@ -392,21 +428,21 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) { scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_0, add})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while0 = builder.AddInstruction( @@ -415,7 +451,7 @@ TEST_F(HloAliasAnalysisTest, SequentialWhiles) { HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0)); auto xla_while2 = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -449,13 +485,21 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) { const Shape tuple_shape = ShapeUtil::MakeTupleShape({scalar_shape_, scalar_shape_}); - auto cond_builder = HloComputation::Builder("condition"); - cond_builder.AddInstruction( - HloInstruction::CreateParameter(0, tuple_shape, "param")); - cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); - HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + auto build_cond_computation = [&tuple_shape]() { + auto cond_builder = HloComputation::Builder("condition"); + cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + return cond_builder.Build(); + }; + // Build separate condition computations so the call graph is flat. The + // callgraph is always flattened in the compiler pipeline, and the flattened + // callgraph enables representive interference analysis. + HloComputation* condition1 = + module_->AddEmbeddedComputation(build_cond_computation()); + HloComputation* condition2 = + module_->AddEmbeddedComputation(build_cond_computation()); // Element 0 passes transparently through the body. auto inner_builder = HloComputation::Builder("inner_body"); @@ -470,7 +514,7 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) { inner_builder.AddInstruction( HloInstruction::CreateTuple({inner_element_0, add})); HloComputation* inner_body = - module_.AddEmbeddedComputation(inner_builder.Build()); + module_->AddEmbeddedComputation(inner_builder.Build()); // Element 1 passes transparently through the body. auto outer_builder = HloComputation::Builder("outer_body"); @@ -485,20 +529,20 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) { auto outer_tuple = outer_builder.AddInstruction( HloInstruction::CreateTuple({negate, outer_element_1})); auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile( - tuple_shape, condition, inner_body, outer_tuple)); + tuple_shape, condition1, inner_body, outer_tuple)); HloComputation* outer_body = - module_.AddEmbeddedComputation(outer_builder.Build()); + module_->AddEmbeddedComputation(outer_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto entry_while = builder.AddInstruction( - HloInstruction::CreateWhile(tuple_shape, condition, outer_body, tuple)); - module_.AddEntryComputation(builder.Build()); + HloInstruction::CreateWhile(tuple_shape, condition2, outer_body, tuple)); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -515,6 +559,8 @@ TEST_F(HloAliasAnalysisTest, NestedWhiles) { analysis.GetUniqueBufferAt(nested_while, /*index=*/{1})); EXPECT_EQ(analysis.GetUniqueBufferAt(constant2), analysis.GetUniqueBufferAt(inner_element_1)); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, SwizzlingWhile) { @@ -548,28 +594,28 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) { HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 2)); body_builder.AddInstruction(HloInstruction::CreateTuple( {body_element_1, body_element_2, body_element_0})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); auto cond_constant = cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2, constant3})); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -593,6 +639,10 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) { analysis.GetUniqueBufferAt(constant2)); EXPECT_EQ(analysis.GetUniqueBufferAt(constant1), analysis.GetUniqueBufferAt(constant3)); + + // All elements in of the loop state tuple are forced into the same buffer + // resulting liveness interference. + EXPECT_TRUE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, TupleSelect) { @@ -600,15 +650,15 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) { // instruction. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); + HloInstruction::CreateConstant(Literal::CreateR0(4.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = @@ -627,7 +677,7 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) { auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary( tuple_shape, HloOpcode::kSelect, pred, select12, select34)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -655,6 +705,8 @@ TEST_F(HloAliasAnalysisTest, TupleSelect) { EXPECT_TRUE(analysis.GetInstructionBufferSet(select12).IsDistinct()); EXPECT_TRUE(analysis.GetInstructionBufferSet(select34).IsDistinct()); EXPECT_TRUE(analysis.GetInstructionBufferSet(select1234).IsDistinct()); + + EXPECT_FALSE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { @@ -688,22 +740,22 @@ TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kNegate, body_element)); body_builder.AddInstruction(HloInstruction::CreateTuple({negate})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = @@ -713,7 +765,7 @@ TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, select)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); @@ -736,17 +788,21 @@ TEST_F(HloAliasAnalysisTest, TupleSelectToWhile) { EXPECT_TRUE(analysis.GetInstructionBufferSet(select).IsDistinct()); EXPECT_TRUE(analysis.GetInstructionBufferSet(xla_while).IsDistinct()); + + // The two operands of the select get flattened into the same buffer resulting + // in liveness interference. + EXPECT_TRUE(AnyValuesInSameBufferInterfere()); } TEST_F(HloAliasAnalysisTest, Bitcast) { // Bitcasting a value should not produce a new buffer. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kBitcast, constant)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); const HloAliasAnalysis& analysis = RunAnalysis(); diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index 5d49c83e2d0..057d1ce09bd 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -110,7 +110,7 @@ TEST_F(HloComputationTest, PostOrderSingleton) { // Test GetInstructionPostOrder for a computation with one instruction. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto computation = builder.Build(); EXPECT_THAT(computation->MakeInstructionPostOrder(), ElementsAre(constant)); @@ -121,7 +121,7 @@ TEST_F(HloComputationTest, PostOrderSimple) { // instructions. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto negate1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto negate2 = builder.AddInstruction( @@ -136,7 +136,7 @@ TEST_F(HloComputationTest, PostOrderTrace) { // Test GetInstructionPostOrder for a computation with a trace instruction. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto negate1 = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto trace = @@ -155,13 +155,13 @@ TEST_F(HloComputationTest, PostOrderDisconnectedInstructions) { // which are not connected. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto computation = builder.Build(); EXPECT_THAT(computation->MakeInstructionPostOrder(), @@ -173,11 +173,11 @@ TEST_F(HloComputationTest, PostOrderWithMultipleRoots) { // which are not connected. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto add1 = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( @@ -197,11 +197,11 @@ TEST_F(HloComputationTest, VisitWithMultipleRoots) { // computation has multiple roots (dead code). auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); // Add three disconnected add expressions. builder.AddInstruction(HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, constant1, constant2)); @@ -248,7 +248,7 @@ TEST_F(HloComputationTest, DeepCopyArray) { // Test that DeepCopyInstruction properly copies an array. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); + Literal::CreateR1({1.0, 2.0, 3.0}))); auto computation = builder.Build(); auto copy = computation->DeepCopyInstruction(constant).ValueOrDie(); @@ -260,9 +260,9 @@ TEST_F(HloComputationTest, DeepCopyTuple) { // Test that DeepCopyInstruction properly copies a tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); + Literal::CreateR1({1.0, 2.0, 3.0}))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -280,7 +280,7 @@ TEST_F(HloComputationTest, CycleDetection) { // Test whether the visitor can detect cycles in the graph. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto add = builder.AddInstruction( @@ -303,7 +303,7 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { // twice. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto dead_negate = builder.AddInstruction( HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, constant)); auto dead_add = builder.AddInstruction(HloInstruction::CreateBinary( @@ -326,9 +326,9 @@ TEST_F(HloComputationTest, RemoveInstructionWithDuplicateOperand) { TEST_F(HloComputationTest, CloneWithControlDependency) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0f))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( r0f32_, HloOpcode::kAdd, constant1, constant2)); diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 93f448e7018..804efdd906a 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -58,6 +58,13 @@ StatusOr HloConstantFolding::Run(HloModule* module) { continue; } + // Broadcasts dramatically increase the size of constants with is often + // detrimental to performance and memory capacity so do not fold + // broadcasts. + if (instruction->opcode() == HloOpcode::kBroadcast) { + continue; + } + std::unique_ptr result = evaluator->TryEvaluate(instruction); // Currently we skip unimplemented operations. // TODO(b/35975797): Fold constant computations for more operations. diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc index 31b81052cb2..1c60b06dddc 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding_test.cc @@ -41,7 +41,7 @@ using HloConstantFoldingTest = HloTestBase; TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {}), input)); @@ -55,15 +55,14 @@ TEST_F(HloConstantFoldingTest, ConvertF32ToS64) { EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); - EXPECT_EQ(LiteralUtil::GetFirstElement( - computation->root_instruction()->literal()), + EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement(), 42); } TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { HloComputation::Builder builder(TestName()); HloInstruction* input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); + HloInstruction::CreateConstant(Literal::CreateR0(42))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); @@ -77,15 +76,14 @@ TEST_F(HloConstantFoldingTest, ConvertS64ToF32) { EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); - EXPECT_EQ(LiteralUtil::GetFirstElement( - computation->root_instruction()->literal()), + EXPECT_EQ(computation->root_instruction()->literal().GetFirstElement(), 42.0f); } TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { HloComputation::Builder builder(TestName()); - HloInstruction* input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({42.0f, 19.0f}))); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({42.0f, 19.0f}))); builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(S64, {2}), input)); @@ -99,12 +97,8 @@ TEST_F(HloConstantFoldingTest, ConvertF32ArrayToS64Array) { EXPECT_TRUE(result); EXPECT_THAT(computation->root_instruction(), op::Constant()); - EXPECT_EQ( - LiteralUtil::Get(computation->root_instruction()->literal(), {0}), - 42); - EXPECT_EQ( - LiteralUtil::Get(computation->root_instruction()->literal(), {1}), - 19); + EXPECT_EQ(computation->root_instruction()->literal().Get({0}), 42); + EXPECT_EQ(computation->root_instruction()->literal().Get({1}), 19); } TEST_F(HloConstantFoldingTest, Concatenate) { @@ -126,7 +120,7 @@ TEST_F(HloConstantFoldingTest, Concatenate) { for (auto csize : test_config.concat_sizes) { dimensions[test_config.concat_dimension] = csize; concat_size += csize; - auto literal = LiteralUtil::CreateFromDimensions(F32, dimensions); + auto literal = Literal::CreateFromDimensions(F32, dimensions); HloInstruction* insn = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); operands.push_back(insn); @@ -180,7 +174,7 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { TF_ASSIGN_OR_ASSERT_OK(auto literal, LiteralTestUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); - auto literal_clone = LiteralUtil::CloneToUnique(*literal); + auto literal_clone = literal->Literal::CloneToUnique(); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); Shape shape = ShapeUtil::MakeShape(F32, {8, 7, 11, 9, 5}); @@ -200,12 +194,10 @@ TEST_F(HloConstantFoldingTest, TransposeConstantFold) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; bool matched = true; - LiteralUtil::EachCell( - root->literal(), + root->literal().EachCell( [&](tensorflow::gtl::ArraySlice indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); - matched = matched && (value == LiteralUtil::Get(*literal_clone, - rindexes)); + matched = matched && (value == literal_clone->Get(rindexes)); }); EXPECT_TRUE(matched); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 38cc74b0f1e..f3a6cd43c2a 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -75,15 +75,12 @@ Status HloCostAnalysis::HandleElementwiseOp(HloInstruction* hlo_instruction) { } Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo, - HloOpcode opcode, - HloInstruction* operand) { + HloOpcode opcode) { return HandleElementwiseOp(hlo); } Status HloCostAnalysis::HandleElementwiseBinary(HloInstruction* hlo, - HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs) { + HloOpcode opcode) { return HandleElementwiseOp(hlo); } @@ -100,6 +97,11 @@ Status HloCostAnalysis::HandleClamp(HloInstruction* clamp, return HandleElementwiseOp(clamp); } +Status HloCostAnalysis::HandleReducePrecision(HloInstruction* hlo, + HloInstruction* operand) { + return HandleElementwiseOp(hlo); +} + Status HloCostAnalysis::HandleParameter(HloInstruction* parameter) { current_bytes_accessed_ = 0; return Status::OK(); @@ -164,13 +166,11 @@ Status HloCostAnalysis::HandleConcatenate( return Status::OK(); } -Status HloCostAnalysis::HandleConvert(HloInstruction* convert, - HloInstruction* operand) { +Status HloCostAnalysis::HandleConvert(HloInstruction* convert) { return HandleElementwiseOp(convert); } -Status HloCostAnalysis::HandleCopy(HloInstruction* copy, - HloInstruction* operand) { +Status HloCostAnalysis::HandleCopy(HloInstruction* copy) { return Status::OK(); } @@ -314,6 +314,12 @@ Status HloCostAnalysis::HandleReshape(HloInstruction* reshape) { return Status::OK(); } +Status HloCostAnalysis::HandleBatchNormTraining( + HloInstruction* batchNormTraining) { + // TODO(b/62294698): Implement cost analysis for batch-norm-learning. + return Status::OK(); +} + Status HloCostAnalysis::HandleTranspose(HloInstruction* transpose) { return Status::OK(); } @@ -362,6 +368,19 @@ Status HloCostAnalysis::HandleFusion(HloInstruction* fusion) { HloCostAnalysis visitor([](const Shape&) { return 0; }); TF_RETURN_IF_ERROR(fused_expression_root->Accept(&visitor)); + // If a fusion node produces a tuple, it also produces the operands of that + // tuple. + current_bytes_accessed_ = 0; + ShapeUtil::ForEachSubshape( + fusion->shape(), + [this](const Shape& subshape, const ShapeIndex& /*shape_index*/) { + current_bytes_accessed_ += shape_size_(subshape); + }); + + for (const HloInstruction* operand : fusion->operands()) { + current_bytes_accessed_ += shape_size_(operand->shape()); + } + // Attribute the cost of the fused expression to the fusion node. current_transcendental_count_ = visitor.transcendental_count(); current_flop_count_ = visitor.flop_count(); diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index b2c40f75ca4..3f0dfcc619f 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -42,11 +42,9 @@ class HloCostAnalysis : public DfsHloVisitor { explicit HloCostAnalysis(const ShapeSizeFunction& shape_size) : shape_size_(shape_size) {} - Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* operand) override; - Status HandleElementwiseBinary(HloInstruction* hlo, HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs) override; + Status HandleElementwiseUnary(HloInstruction* hlo, HloOpcode opcode) override; + Status HandleElementwiseBinary(HloInstruction* hlo, + HloOpcode opcode) override; Status HandleConstant(HloInstruction* constant, const Literal& literal) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element, @@ -58,14 +56,15 @@ class HloCostAnalysis : public DfsHloVisitor { HloInstruction* lhs, HloInstruction* rhs) override; Status HandleClamp(HloInstruction* clamp, HloInstruction* min, HloInstruction* arg, HloInstruction* max) override; + Status HandleReducePrecision(HloInstruction* hlo, + HloInstruction* operand) override; Status HandleConcatenate( HloInstruction* concatenate, tensorflow::gtl::ArraySlice operands) override; Status HandleSend(HloInstruction* send) override; Status HandleRecv(HloInstruction* recv) override; - Status HandleConvert(HloInstruction* convert, - HloInstruction* operand) override; - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleConvert(HloInstruction* convert) override; + Status HandleCopy(HloInstruction* copy) override; Status HandleDot(HloInstruction* dot, HloInstruction* lhs, HloInstruction* rhs) override; Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, @@ -83,6 +82,7 @@ class HloCostAnalysis : public DfsHloVisitor { HloInstruction* init_value, tensorflow::gtl::ArraySlice dimensions, HloComputation* function_handle) override; + Status HandleBatchNormTraining(HloInstruction* batchNormTraining) override; Status HandleFusion(HloInstruction* fusion) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call, diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index b74c7eb4e07..5c71056bb5d 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -342,11 +342,11 @@ TEST_F(FusionCostAnalysis, LoopFusion) { // mul = Mul(exp, C3) // sub = Sub(mul, clamp) // tuple = Tuple({sub, sub, mul, C1}) - auto c1 = HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( + auto c1 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( /*from=*/0.0f, /*to=*/1.0f, /*rows=*/2, /*cols=*/2)); - auto c2 = HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( + auto c2 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( /*from=*/1.0f, /*to=*/2.0f, /*rows=*/2, /*cols=*/2)); - auto c3 = HloInstruction::CreateConstant(LiteralUtil::CreateR2F32Linspace( + auto c3 = HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( /*from=*/2.0f, /*to=*/3.0f, /*rows=*/2, /*cols=*/2)); auto add = @@ -383,9 +383,8 @@ TEST_F(FusionCostAnalysis, NoLayout) { shape_without_layout.clear_layout(); auto c1 = HloInstruction::CreateConstant( - LiteralUtil::CreateR4FromArray4D(Array4D(2, 3, 4, 5))); - auto c2 = - HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3})); + Literal::CreateR4FromArray4D(Array4D(2, 3, 4, 5))); + auto c2 = HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3})); auto broadcast = HloInstruction::CreateBroadcast(shape_without_layout, c2.get(), {1}); diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index 4c6af5c40fa..0fef89a06d0 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -68,7 +68,7 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { auto range = constants.equal_range(shape_string); HloInstruction* match = nullptr; for (auto it = range.first; it != range.second; ++it) { - if (LiteralUtil::Equal(instruction->literal(), it->second->literal())) { + if (instruction->literal().Equal(it->second->literal())) { match = it->second; break; } diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index cc39c3ac203..8b0b9c8bbd0 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -51,9 +51,9 @@ TEST_F(HloCseTest, CombineTwoConstants) { // Test that two identical constants are commoned. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -67,10 +67,10 @@ TEST_F(HloCseTest, CombineTwoConstants) { EXPECT_EQ(2, computation->instruction_count()); HloInstruction* constant = computation->instructions().begin()->get(); - EXPECT_EQ(42.0f, LiteralUtil::Get(constant->literal(), {})); + EXPECT_EQ(42.0f, constant->literal().Get({})); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = LiteralUtil::CreateR0(84.0); + auto expected = Literal::CreateR0(84.0); LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); } @@ -102,7 +102,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndInsensitive) { EXPECT_THAT(add, op::Add(first_operand, first_operand)); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); + auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); } @@ -132,7 +132,7 @@ TEST_F(HloCseTest, CombineTwoConstantsDifferentLayoutsAndSensitive) { EXPECT_THAT(add, op::Add(constant1, constant2)); auto result = ExecuteAndTransfer(std::move(module), {}); - auto expected = LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); + auto expected = Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}); LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(1e-4)); } @@ -141,20 +141,20 @@ TEST_F(HloCseTest, ConstantsSameValueDifferentType) { // commoned. auto builder = HloComputation::Builder(TestName()); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); + HloInstruction::CreateConstant(Literal::CreateR0(42))); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42))); + HloInstruction::CreateConstant(Literal::CreateR0(42))); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); // Duplicate the float constant to verify something happens. builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -171,13 +171,13 @@ TEST_F(HloCseTest, NonscalarConstants) { // Test that identical nonscalar constants are merged. auto builder = HloComputation::Builder(TestName()); auto common_constant1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto common_constant2 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); // Create a constant which has the same shape but a different value. auto uncommon_constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{2.0, 4.0}, {6.0, 8.0}}))); + Literal::CreateR2({{2.0, 4.0}, {6.0, 8.0}}))); // Tie the constants together with a tuple. This makes it easier to refer to // the constant instructions via their use. @@ -206,7 +206,7 @@ TEST_F(HloCseTest, IdenticalInstructions) { // Test that three identical instructions are commoned. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); auto exp2 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -236,7 +236,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsSensitive) { // commoned if the pass is layout sensitive. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); @@ -267,7 +267,7 @@ TEST_F(HloCseTest, IdenticalInstructionsDifferentLayoutsInsensitive) { // the pass is layout insensitive. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto exp1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kExp, constant)); @@ -311,7 +311,7 @@ TEST_F(HloCseTest, IdenticalExpressions) { // The *1 instructions should be merged with the *2 instructions. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); auto negate1 = builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kNegate, constant)); @@ -349,9 +349,9 @@ TEST_F(HloCseTest, DoNotCombineRng) { // Test that two RNG ops are not commoned. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); auto rng1 = builder.AddInstruction(HloInstruction::CreateRng( ShapeUtil::MakeShape(F32, {}), RandomDistribution::RNG_UNIFORM, {constant1, constant2})); @@ -392,9 +392,9 @@ TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); auto builder = HloComputation::Builder(TestName() + "_rng_fun"); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); auto rng = builder.AddInstruction(HloInstruction::CreateRng( scalar_shape, RandomDistribution::RNG_UNIFORM, {constant1, constant2})); auto param = builder.AddInstruction(HloInstruction::CreateParameter( @@ -409,7 +409,7 @@ TEST_F(HloCseTest, DISABLED_DoNotCombineCallsToImpureFunctions) { { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({5.0f}))); + HloInstruction::CreateConstant(Literal::CreateR1({5.0f}))); auto rng1 = builder.AddInstruction( HloInstruction::CreateMap(constant->shape(), {constant}, rng_function)); auto rng2 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index d1b87256445..7e951721ba2 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -76,7 +76,8 @@ HloValue::HloValue(HloValue::Id id, HloInstruction* instruction, } bool HloValue::operator==(const HloValue& other) const { - bool equal = instruction() == other.instruction() && index() == other.index(); + bool equal = defining_instruction() == other.defining_instruction() && + defining_index() == other.defining_index(); // If the values are equal they most both be phi (or non phi). CHECK(!(equal && is_phi() != other.is_phi())); return equal; @@ -87,10 +88,11 @@ bool HloValue::operator!=(const HloValue& other) const { } string HloValue::ToShortString() const { - string index_str = - ShapeUtil::IsTuple(instruction()->shape()) ? index().ToString() : ""; - return StrCat(is_phi_ ? "PHI " : "", instruction()->FullyQualifiedName(), - index_str); + string index_str = ShapeUtil::IsTuple(defining_instruction()->shape()) + ? defining_index().ToString() + : ""; + return StrCat(is_phi_ ? "PHI " : "", + defining_instruction()->FullyQualifiedName(), index_str); } string HloValue::ToString(int indent) const { @@ -106,6 +108,50 @@ string HloValue::ToString(int indent) const { return out; } +namespace { + +// Returns true if the instruction 'user' may use the value at the given +// ShapeIndex in the given operand. Generally, instruction which pass through +// values transparently without reading the value are not considered to use the +// value. +bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index, + const HloInstruction* user) { + switch (user->opcode()) { + case HloOpcode::kGetTupleElement: + case HloOpcode::kCopy: + // These instructions only access the top-level values of their + // operand. Non-top-level (nested) values are passed through + // transparently. + CHECK_EQ(operand_number, 0); + return index.empty(); + case HloOpcode::kSelect: + // Select does not use any nested elements of its selected-from operands + // (operand 1 and 2) + CHECK_GE(operand_number, 0); + CHECK_LE(operand_number, 2); + return operand_number == 0 || index.empty(); + + case HloOpcode::kCall: + case HloOpcode::kTuple: + // These instructions always pass through their operands transparently. + return false; + + case HloOpcode::kWhile: + // Though the while instructions passes through its operands, we return + // true because in SSA form there may be a Phi at the parameter of the + // while which is considered a use of its incoming value because the Phi + // input values are not passed through into the body computation. Because + // this function is used in both SSA and non-SSA forms of the analysis + // conservatively return true. + return true; + + default: + return true; + } +} + +} // namespace + void HloValue::AddLocation(HloInstruction* instruction, const ShapeIndex& index) { // The given location should not already exist in locations_. @@ -118,7 +164,7 @@ void HloValue::AddLocation(HloInstruction* instruction, // Update uses. for (HloInstruction* user : instruction->users()) { for (int64 operand_number : user->OperandIndices(instruction)) { - if (!DoesNotUseOperandBuffer(instruction, index, user)) { + if (MayUseOperandValue(operand_number, index, user)) { for (const HloUse& use : uses_) { // Verify that this use does not already exist. DCHECK(!(use.instruction == user && @@ -136,12 +182,16 @@ void HloValue::AddLocation(HloInstruction* instruction, if (instruction == module.entry_computation()->root_instruction()) { live_out_of_module_ = true; } + + if (instruction == instruction->parent()->root_instruction()) { + live_out_of_computation_ = true; + } } void HloValue::RemoveLocation(HloInstruction* instruction, const ShapeIndex& index) { // The defining location cannot be removed. - CHECK(!(instruction == this->instruction() && index == this->index())); + CHECK(!(instruction == defining_instruction() && index == defining_index())); int64 size_before = locations_.size(); locations_.erase( @@ -163,19 +213,27 @@ void HloValue::RemoveLocation(HloInstruction* instruction, }), uses_.end()); + // Returns whether this value is contained in the given instruction's output. + auto is_contained_in = [this](const HloInstruction* instruction) { + for (const HloLocation& location : locations()) { + if (location.instruction == instruction) { + return true; + } + } + return false; + }; + const HloModule& module = *instruction->parent()->parent(); if (instruction == module.entry_computation()->root_instruction()) { // Value has been removed from a location in the entry root instruction. - // Check if the value is still live out of the module by walking all - // remaining locations. - live_out_of_module_ = false; - for (const HloLocation& location : locations()) { - if (location.instruction == - module.entry_computation()->root_instruction()) { - live_out_of_module_ = true; - break; - } - } + live_out_of_module_ = + is_contained_in(module.entry_computation()->root_instruction()); + } + if (instruction == defining_instruction()->parent()->root_instruction()) { + // Value has been removed from the root of the computation the value has + // been defined in. + live_out_of_computation_ = + is_contained_in(defining_instruction()->parent()->root_instruction()); } } @@ -259,7 +317,8 @@ bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction, if (value_set.value_ids().size() != 1) { return false; } - return GetValue(value_set.GetUniqueValueId()).instruction() == instruction; + return GetValue(value_set.GetUniqueValueId()).defining_instruction() == + instruction; } const HloValue& HloDataflowAnalysis::GetValueDefinedAt( @@ -468,8 +527,8 @@ void HloDataflowAnalysis::UpdateLocationsOfValuesAt( } // Don't remove the defining location of the value. HloValue& value = GetValue(value_id); - if (instruction == value.instruction()) { - CHECK_EQ(index, value.index()); + if (instruction == value.defining_instruction()) { + CHECK_EQ(index, value.defining_index()); } else { value.RemoveLocation(instruction, index); } @@ -482,8 +541,8 @@ void HloDataflowAnalysis::UpdateLocationsOfValuesAt( const HloValueSet& value_set) { for (HloValue::Id value_id : value_set.value_ids()) { HloValue& value = GetValue(value_id); - if (instruction == value.instruction()) { - CHECK_EQ(index, value.index()); + if (instruction == value.defining_instruction()) { + CHECK_EQ(index, value.defining_index()); } else { value.AddLocation(instruction, index); } @@ -694,15 +753,24 @@ InstructionValueSet HloDataflowAnalysis::RecomputeParameterValueSet( std::vector inputs; bool called_from_while = false; for (const CallSite& callsite : call_graph_node.caller_callsites()) { - inputs.push_back(&GetInstructionValueSet( - callsite.instruction()->operand(parameter->parameter_number()))); - if (callsite.instruction()->opcode() == HloOpcode::kWhile) { - // In a while instruction, the backedge is also a dataflow input to the - // parameter instruction. This code covers the case where the parameter is - // in the while body or the parameter is in the while condition. + if (callsite.instruction()->opcode() == HloOpcode::kCall) { + // The operand values of a call instruction are forwarded to the + // respective parameter instruction of the subcomputation. + inputs.push_back(&GetInstructionValueSet( + callsite.instruction()->operand(parameter->parameter_number()))); + } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) { + // In a while instruction, the while operand (ie, the init value) and the + // backedge are dataflow inputs to the parameter instruction. This is the + // case for parameters of both the body and condition computations. + CHECK_EQ(parameter->parameter_number(), 0); + inputs.push_back( + &GetInstructionValueSet(callsite.instruction()->operand(0))); inputs.push_back(&GetInstructionValueSet( callsite.instruction()->while_body()->root_instruction())); called_from_while = true; + } else { + LOG(FATAL) << "CallContext::kSequential computations should only be " + "called from call or while instructions"; } } @@ -804,6 +872,149 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { return Status::OK(); } +bool HloDataflowAnalysis::IsDefinedBefore(const HloValue& a, const HloValue& b, + const HloOrdering& ordering) const { + // If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b' + // is live into the module. + if (b.defining_instruction()->parent() == module_->entry_computation() && + b.defining_instruction()->opcode() == HloOpcode::kParameter) { + return false; + } + + // Phi values require special handling. Because XLA does not have a phi + // instruction, the definition instruction of the phis values are + // placeholders: either the subcomputation parameter (body or condition) or + // the while instruction. However, the program point where these values are + // logically defined does not necessarily coincide exactly with program point + // of these place-holder instructions. So we explicitly define the following + // order for phi values: + // + // body/condition parameter phi: + // Defined before all values defined in its computation excepting other + // phis. + // + // while phi: + // defined after all values defined in the condition or body. + // + auto is_body_or_condition_phi = [](const HloValue& v) { + return v.is_phi() && + v.defining_instruction()->opcode() == HloOpcode::kParameter; + }; + if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) && + call_graph_->InstructionIsNestedIn(b.defining_instruction(), + a.defining_instruction()->parent())) { + return true; + } + if (is_body_or_condition_phi(b) && + call_graph_->InstructionIsNestedIn(a.defining_instruction(), + b.defining_instruction()->parent())) { + return false; + } + + // If 'b' is a while phi and 'a' is in the body or condition, then 'a' + // executes before 'b'. + if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile && + (call_graph_->InstructionIsNestedIn( + a.defining_instruction(), b.defining_instruction()->while_body()) || + call_graph_->InstructionIsNestedIn( + a.defining_instruction(), + b.defining_instruction()->while_condition()))) { + return true; + } + + return ordering.ExecutesBefore(a.defining_instruction(), + b.defining_instruction()); +} + +bool HloDataflowAnalysis::UseIsBeforeValueDefinition( + const HloUse& use, const HloValue& value, + const HloOrdering& ordering) const { + if (ordering.ExecutesBefore(use.instruction, value.defining_instruction())) { + return true; + } + + // If the use is at the instruction where the value is defined, then the use + // is before the def if the instruction allows buffer sharing (in place + // computation). + if (use.instruction == value.defining_instruction() && + CanShareOperandBufferWithUser( + use.instruction->mutable_operand(use.operand_number), + use.operand_index, value.defining_instruction(), + value.defining_index())) { + return true; + } + + // The use at a while is an input to a phi, and logically occurs before values + // are defined in the body or condition computations. + if (use.instruction->opcode() == HloOpcode::kWhile) { + const HloInstruction* xla_while = use.instruction; + if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), + xla_while->while_body()) || + call_graph_->InstructionIsNestedIn(value.defining_instruction(), + xla_while->while_condition())) { + return true; + } + } + + // Similarly if the value is defined at a while, it logically occurs after any + // uses in the body or condition computations. + if (value.defining_instruction()->opcode() == HloOpcode::kWhile) { + CHECK(ssa_form_); + const HloInstruction* xla_while = value.defining_instruction(); + if (call_graph_->InstructionIsNestedIn(use.instruction, + xla_while->while_body()) || + call_graph_->InstructionIsNestedIn(use.instruction, + xla_while->while_condition())) { + return true; + } + } + return false; +} + +bool HloDataflowAnalysis::LiveRangeStrictlyBefore( + const HloValue& a, const HloValue& b, const HloOrdering& ordering) const { + VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString() + << ", b = " << b.ToShortString() << ")"; + if (!IsDefinedBefore(a, b, ordering)) { + VLOG(4) << "a not defined before b"; + return false; + } + + // Live-out values from the module can never have ranges strictly before any + // other value. + if (a.live_out_of_module()) { + VLOG(4) << "a is live out of module"; + return false; + } + + // Live-out values of computations can never have ranges strictly before any + // other value in the computation (including values nested in + // subcomputations). + if (a.live_out_of_computation() && + call_graph_->InstructionIsNestedIn(b.defining_instruction(), + a.defining_instruction()->parent())) { + VLOG(4) << "a is live out of computation containing b"; + return false; + } + + // All uses of 'a' must be before 'b' is defined. + for (const HloUse& use : a.uses()) { + if (!UseIsBeforeValueDefinition(use, b, ordering)) { + VLOG(4) << "use of a (" << use << ") not before b is defined"; + return false; + } + } + + return true; +} + +bool HloDataflowAnalysis::MayInterfere(const HloValue& a, const HloValue& b, + const HloOrdering& ordering) const { + // Buffers without disjoint liveness may interfere. + return !LiveRangeStrictlyBefore(a, b, ordering) && + !LiveRangeStrictlyBefore(b, a, ordering); +} + /* static */ StatusOr> HloDataflowAnalysis::Run( HloModule* module, bool ssa_form, bool bitcast_defines_value) { diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 2f9b0a64be5..7087bd978d9 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status.h" @@ -123,16 +124,16 @@ class HloValue { bool is_phi() const { return is_phi_; } // Return the location where this value is defined. - const HloLocation& DefinitionLocation() const { return locations_[0]; } + const HloLocation& defining_location() const { return locations_[0]; } // Return the instruction which defines this HloValue. - HloInstruction* instruction() const { - return DefinitionLocation().instruction; + HloInstruction* defining_instruction() const { + return defining_location().instruction; } // Return the shape index at which this HloValue is defined in the output of - // instruction(). - const ShapeIndex& index() const { return DefinitionLocation().index; } + // its defining instruction. + const ShapeIndex& defining_index() const { return defining_location().index; } // Add or remove a location at which the HloValue appears. The definition // location can not be removed. The uses of the HloValue are updated. @@ -145,9 +146,12 @@ class HloValue { // Return all uses of the HloValue. const std::vector& uses() const { return uses_; } - // Set/get whether this HloValue is live out of the module. + // Get whether this HloValue is live out of the module. bool live_out_of_module() const { return live_out_of_module_; } + // Get whether this HloValue is live out of the computation it is defined in. + bool live_out_of_computation() const { return live_out_of_computation_; } + bool operator==(const HloValue& other) const; bool operator!=(const HloValue& other) const; @@ -172,6 +176,9 @@ class HloValue { // Whether this value is live out of the HLO module. bool live_out_of_module_ = false; + + // Whether this value is live out of its computation. + bool live_out_of_computation_ = false; }; std::ostream& operator<<(std::ostream& out, const HloValue& hlo_value); @@ -309,6 +316,17 @@ class HloDataflowAnalysis { const HloValue& GetValue(HloValue::Id value_id) const; HloValue& GetValue(HloValue::Id value_id); + // Returns whether the given values interfere assuming the given HLO + // ordering. Two values interfere if they may both be simultaneously live. + bool MayInterfere(const HloValue& a, const HloValue& b, + const HloOrdering& ordering) const; + + // Overload which takes HloValue:Ids. + bool MayInterfere(HloValue::Id a, HloValue::Id b, + const HloOrdering& ordering) const { + return MayInterfere(GetValue(a), GetValue(b), ordering); + } + // Return the total number of HloValues. int64 value_count() const { return values_.size(); } @@ -374,6 +392,20 @@ class HloDataflowAnalysis { HloInstruction* instruction, const InstructionValueSet& new_value_set, const InstructionValueSet* prev_value_set = nullptr); + // Returns true if the live range of the given value 'a' is strictly before + // the live range of value 'b' using the given HLO ordering. + bool LiveRangeStrictlyBefore(const HloValue& a, const HloValue& b, + const HloOrdering& ordering) const; + + // Returns whether the value 'a' is defined before the value 'b' under the + // given ordering. + bool IsDefinedBefore(const HloValue& a, const HloValue& b, + const HloOrdering& ordering) const; + + // Returns whether the given use is before the given value definition. + bool UseIsBeforeValueDefinition(const HloUse& use, const HloValue& value, + const HloOrdering& ordering) const; + HloModule* const module_; const bool ssa_form_; const bool bitcast_defines_value_; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index 21344af5f22..79edd0fcb59 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -39,14 +39,14 @@ using ::testing::UnorderedElementsAre; class HloDataflowAnalysisTest : public HloTestBase, public ::testing::WithParamInterface { protected: - HloDataflowAnalysisTest() : module_(TestName()) {} + HloDataflowAnalysisTest() : module_(CreateNewModule()) {} // Run dataflow analysis on the member module. For convenience returns a // reference to the generated analysis stored in analysis_. const HloDataflowAnalysis& RunAnalysis(bool ssa_form, bool bitcast_defines_value = false) { analysis_ = - HloDataflowAnalysis::Run(&module_, ssa_form, bitcast_defines_value) + HloDataflowAnalysis::Run(module_.get(), ssa_form, bitcast_defines_value) .ConsumeValueOrDie(); return *analysis_; } @@ -63,22 +63,34 @@ class HloDataflowAnalysisTest : public HloTestBase, return values; } - HloModule module_; + // Returns true if the top-level values for instructions 'a' and 'b' may + // interfere. Precondition: 'a' and 'b' define array-shaped values. + bool InstructionsMayInterfere(const HloOrdering& ordering, + const HloInstruction* a, + const HloInstruction* b) { + EXPECT_FALSE(ShapeUtil::IsTuple(a->shape())); + EXPECT_FALSE(ShapeUtil::IsTuple(b->shape())); + return analysis_->MayInterfere(analysis_->GetValueDefinedAt(a), + analysis_->GetValueDefinedAt(b), ordering); + } + + std::unique_ptr module_; std::unique_ptr analysis_; const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); + const Shape vector_shape_ = ShapeUtil::MakeShape(F32, {42}); }; TEST_P(HloDataflowAnalysisTest, BinaryOperation) { // Test the dataflow for a simple binary operation (Add). auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto add = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, constant1, constant2)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -126,7 +138,7 @@ TEST_P(HloDataflowAnalysisTest, TupleAndGtes) { HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(scalar_shape_, HloOpcode::kAdd, gte0, gte1)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -158,27 +170,21 @@ TEST_P(HloDataflowAnalysisTest, TupleAndGtes) { // Verify uses. Of interest is that a GetTupleElement instruction is only a // use of the top-level value in the tuple operand. EXPECT_THAT(analysis.GetValueDefinedAt(param0).uses(), - UnorderedElementsAre(HloUse{tuple, 0, {}}, HloUse{add, 0, {}})); + UnorderedElementsAre(HloUse{add, 0, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(param1).uses(), - UnorderedElementsAre(HloUse{tuple, 1, {}}, HloUse{add, 1, {}})); + UnorderedElementsAre(HloUse{add, 1, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(), UnorderedElementsAre(HloUse{gte0, 0, {}}, HloUse{gte1, 0, {}})); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); } TEST_P(HloDataflowAnalysisTest, NestedTuple) { - // Verify the dataflow through a nested tuple of the following form for two - // constants %constant1 and %constant2: - // - // %nested_tuple = {{%constant1, %constant2}, - // {%constant1, %constant2}, - // %constant1} - // + // Verify the dataflow through a nested tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto nested_tuple = builder.AddInstruction( @@ -187,7 +193,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) { HloInstruction::CreateGetTupleElement(tuple->shape(), nested_tuple, 1)); auto gte_out = builder.AddInstruction( HloInstruction::CreateGetTupleElement(scalar_shape_, gte_tuple, 0)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -202,18 +208,15 @@ TEST_P(HloDataflowAnalysisTest, NestedTuple) { HloLocation{nested_tuple, {0, 0}}, HloLocation{nested_tuple, {1, 0}}, HloLocation{nested_tuple, {2}}, HloLocation{gte_tuple, {0}}, HloLocation{gte_out, {}})); - EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre( - HloUse{tuple, 0, {}}, HloUse{nested_tuple, 0, {0}}, - HloUse{nested_tuple, 1, {0}}, HloUse{nested_tuple, 2, {}})); - EXPECT_THAT( - analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{tuple, 1, {}}, HloUse{nested_tuple, 0, {1}}, - HloUse{nested_tuple, 1, {1}})); + // Constant values should have no uses though one is live out. The locations + // where they appear as operands are on instructions which do not use the + // values (eg, Tuple). + EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty()); + EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty()); + + // The top-level tuple values are used in GTE instructions. EXPECT_THAT(analysis.GetValueDefinedAt(tuple, /*index=*/{}).uses(), - UnorderedElementsAre(HloUse{nested_tuple, 0, {}}, - HloUse{nested_tuple, 1, {}}, - HloUse{gte_out, 0, {}})); + UnorderedElementsAre(HloUse{gte_out, 0, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(nested_tuple, /*index=*/{}).uses(), UnorderedElementsAre(HloUse{gte_tuple, 0, {}})); @@ -236,16 +239,16 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) { auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); HloComputation* called_computation = - module_.AddEmbeddedComputation(subbuilder.Build()); + module_->AddEmbeddedComputation(subbuilder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto call = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -268,11 +271,12 @@ TEST_P(HloDataflowAnalysisTest, SingleCall) { EXPECT_EQ(analysis.GetUniqueValueAt(call), analysis.GetValueDefinedAt(add)); EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{call, 0, {}})); + UnorderedElementsAre(HloUse{add, 0, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{add, 1, {}}, HloUse{call, 1, {}})); + UnorderedElementsAre(HloUse{add, 1, {}})); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); } TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { @@ -285,20 +289,20 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); HloComputation* called_computation = - module_.AddEmbeddedComputation(subbuilder.Build()); + module_->AddEmbeddedComputation(subbuilder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto call1 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto call2 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto sub = builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kSubtract, call1, call2)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -316,17 +320,18 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithSameArguments) { EXPECT_TRUE(analysis.ValueIsDefinedAt(sub)); EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{call1, 0, {}}, - HloUse{call2, 0, {}})); + UnorderedElementsAre(HloUse{add, 0, {}})); EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{add, 1, {}}, HloUse{call1, 1, {}}, - HloUse{call2, 1, {}})); + UnorderedElementsAre(HloUse{add, 1, {}})); // The Add from the subcomputation is used as both operands of the Subtract. EXPECT_THAT(analysis.GetValueDefinedAt(add).uses(), UnorderedElementsAre(HloUse{sub, 0, {}}, HloUse{sub, 1, {}})); EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); + EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(sub).live_out_of_computation()); } TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) { @@ -339,18 +344,18 @@ TEST_P(HloDataflowAnalysisTest, ComputationCalledTwiceWithDifferentArguments) { auto add = subbuilder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, subparam0, subparam1)); HloComputation* called_computation = - module_.AddEmbeddedComputation(subbuilder.Build()); + module_->AddEmbeddedComputation(subbuilder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto call1 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, called_computation)); auto call2 = builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {call1, constant2}, called_computation)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -392,7 +397,7 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { auto add = inner_builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, inner_param0, inner_param1)); HloComputation* inner_computation = - module_.AddEmbeddedComputation(inner_builder.Build()); + module_->AddEmbeddedComputation(inner_builder.Build()); auto outer_builder = HloComputation::Builder("OuterComputation"); auto outer_param0 = outer_builder.AddInstruction( @@ -400,19 +405,19 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { auto outer_param1 = outer_builder.AddInstruction( HloInstruction::CreateParameter(1, scalar_shape_, "param1")); // Swizzle parameters. - auto nested_call = outer_builder.AddInstruction(HloInstruction::CreateCall( + outer_builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {outer_param1, outer_param0}, inner_computation)); HloComputation* outer_computation = - module_.AddEmbeddedComputation(outer_builder.Build()); + module_->AddEmbeddedComputation(outer_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); - auto call = builder.AddInstruction(HloInstruction::CreateCall( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + builder.AddInstruction(HloInstruction::CreateCall( scalar_shape_, {constant1, constant2}, outer_computation)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -423,14 +428,10 @@ TEST_P(HloDataflowAnalysisTest, NestedCalls) { // Verify that the uses of the constants are properly swizzled by parameter // permutation in nested_call. - EXPECT_THAT( - analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{call, 0, {}}, HloUse{nested_call, 1, {}}, - HloUse{add, 1, {}})); - EXPECT_THAT( - analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{call, 1, {}}, HloUse{nested_call, 0, {}}, - HloUse{add, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{add, 1, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{add, 0, {}})); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); } @@ -465,33 +466,37 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); - auto body_tuple = body_builder.AddInstruction( + body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_0, add})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); // Condition computation trivially returns a constant "false". auto cond_builder = HloComputation::Builder("condition"); auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); - cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + auto cond_constant = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + EXPECT_TRUE( + analysis.GetValueDefinedAt(cond_constant).live_out_of_computation()); + EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module()); + if (ssa_form) { // Element 0 of the tuple passed through the body so no phi value is // defined. @@ -507,15 +512,17 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1})); EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi()); - EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{tuple, 0, {}}, - HloUse{xla_while, 0, {0}}, - HloUse{body_tuple, 0, {}})); + EXPECT_THAT( + analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{xla_while, 0, {0}})); // Constant1 passes through the body and out of the module. EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}) .live_out_of_module()); + + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); + EXPECT_FALSE(analysis.GetValueDefinedAt(add).live_out_of_module()); } else { // While instruction and subcomputation parameters should not define values // in non-ssa form. @@ -528,6 +535,7 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) { EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); } } @@ -565,21 +573,21 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) { scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_0, add})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while0 = builder.AddInstruction( @@ -588,7 +596,7 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) { HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while0)); auto xla_while2 = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, xla_while1)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -630,9 +638,9 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); // Element 0 passes transparently through the body. auto inner_builder = HloComputation::Builder("inner_body"); @@ -647,7 +655,7 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) { inner_builder.AddInstruction( HloInstruction::CreateTuple({inner_element_0, add})); HloComputation* inner_body = - module_.AddEmbeddedComputation(inner_builder.Build()); + module_->AddEmbeddedComputation(inner_builder.Build()); // Element 1 passes transparently through the body. auto outer_builder = HloComputation::Builder("outer_body"); @@ -664,18 +672,18 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) { auto nested_while = outer_builder.AddInstruction(HloInstruction::CreateWhile( tuple_shape, condition, inner_body, outer_tuple)); HloComputation* outer_body = - module_.AddEmbeddedComputation(outer_builder.Build()); + module_->AddEmbeddedComputation(outer_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto entry_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, outer_body, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -751,26 +759,26 @@ TEST_P(HloDataflowAnalysisTest, SwizzlingWhile) { HloInstruction::CreateGetTupleElement(scalar_shape_, body_param, 1)); body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_1, body_element_0})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); auto cond_param = cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple_shape, condition, body, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -817,15 +825,15 @@ TEST_P(HloDataflowAnalysisTest, ArraySelect) { // Test a kSelect of an array value. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( scalar_shape_, HloOpcode::kSelect, pred, constant1, constant2)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -841,15 +849,15 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { // instruction. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); + HloInstruction::CreateConstant(Literal::CreateR0(4.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = @@ -868,7 +876,7 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { auto select1234 = builder.AddInstruction(HloInstruction::CreateTernary( tuple_shape, HloOpcode::kSelect, pred, select12, select34)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -899,31 +907,33 @@ TEST_P(HloDataflowAnalysisTest, TupleSelect) { analysis.GetValueDefinedAt(constant4))); EXPECT_THAT( - analysis.GetValueDefinedAt(constant1).uses(), - UnorderedElementsAre(HloUse{tuple1, 0, {}}, HloUse{select11, 1, {0}}, - HloUse{select11, 2, {0}}, HloUse{select12, 1, {0}}, - HloUse{select1234, 1, {0}})); - EXPECT_THAT( - analysis.GetValueDefinedAt(constant2).uses(), - UnorderedElementsAre(HloUse{tuple2, 0, {}}, HloUse{select12, 2, {0}}, - HloUse{select1234, 1, {0}})); + analysis.GetValueDefinedAt(tuple1, /*index=*/{}).uses(), + UnorderedElementsAre(HloUse{select11, 1, {}}, HloUse{select11, 2, {}}, + HloUse{select12, 1, {}})); + + // The two constant values just pass through the Selects and are not + // used. They are live out however. + EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).uses().empty()); + EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).uses().empty()); + EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module()); + EXPECT_TRUE(analysis.GetValueDefinedAt(constant2).live_out_of_module()); } TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) { // Test kSelect of a nested tuple. auto builder = HloComputation::Builder(TestName()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto constant4 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(4.0))); + HloInstruction::CreateConstant(Literal::CreateR0(4.0))); auto constant5 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5.0))); + HloInstruction::CreateConstant(Literal::CreateR0(5.0))); auto inner_tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant2, constant3})); auto tuple1 = builder.AddInstruction( @@ -935,7 +945,7 @@ TEST_P(HloDataflowAnalysisTest, NestedTupleSelect) { auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -993,24 +1003,24 @@ TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) { scalar_shape_, HloOpcode::kAdd, body_element_0, body_element_1)); body_builder.AddInstruction( HloInstruction::CreateTuple({body_element_0, add})); - HloComputation* body = module_.AddEmbeddedComputation(body_builder.Build()); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); auto cond_builder = HloComputation::Builder("condition"); cond_builder.AddInstruction( HloInstruction::CreateParameter(0, tuple_shape, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); HloComputation* condition = - module_.AddEmbeddedComputation(cond_builder.Build()); + module_->AddEmbeddedComputation(cond_builder.Build()); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto tuple1 = builder.AddInstruction(HloInstruction::CreateTuple({constant1})); auto tuple2 = @@ -1024,7 +1034,7 @@ TEST_P(HloDataflowAnalysisTest, TupleSelectToWhile) { auto xla_while = builder.AddInstruction( HloInstruction::CreateWhile(tuple->shape(), condition, body, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -1062,11 +1072,11 @@ TEST_P(HloDataflowAnalysisTest, BitcastDefinesValue) { // Test the bitcast_defines_value flag to the dataflow analysis. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( scalar_shape_, HloOpcode::kBitcast, constant)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); { @@ -1102,7 +1112,7 @@ TEST_P(HloDataflowAnalysisTest, TupleCopy) { builder.AddInstruction(HloInstruction::CreateTuple({param0, param1})); auto copy = builder.AddInstruction( HloInstruction::CreateUnary(tuple->shape(), HloOpcode::kCopy, tuple)); - module_.AddEntryComputation(builder.Build()); + module_->AddEntryComputation(builder.Build()); bool ssa_form = GetParam(); const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); @@ -1126,6 +1136,352 @@ TEST_P(HloDataflowAnalysisTest, TupleCopy) { analysis.GetValueDefinedAt(copy, /*index=*/{}).live_out_of_module()); } +TEST_P(HloDataflowAnalysisTest, ElementwiseChainInterference) { + // A simple chain of elementwise operations. No values should interfere. + // + // param --> negate -> exp -> log + // + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "param")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, negate)); + auto log = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kLog, exp)); + + module_->AddEntryComputation(builder.Build()); + RunAnalysis(GetParam()); + + DependencyHloOrdering ordering(module_.get()); + + // No values should interfere. + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, log)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, log)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, log)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, log, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, log, exp)); + + // Values should interfere with itself. + EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, exp)); +} + +TEST_P(HloDataflowAnalysisTest, MultipleEntryParameters_Sequential) { + // Two entry params, which interfere with each other. + // + // param0 --> negate ---------------\ + // param1 --> exp --> add + auto builder = HloComputation::Builder(TestName()); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, vector_shape_, "param1")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param0)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param1)); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + vector_shape_, HloOpcode::kAdd, negate, exp)); + + auto entry = module_->AddEntryComputation(builder.Build()); + RunAnalysis(GetParam()); + + SequentialHloOrdering::HloModuleSequence sequence; + sequence.insert({entry, {param0, negate, param1, exp, add}}); + SequentialHloOrdering ordering(module_.get(), sequence); + + // Entry parameters interfere as if they are defined simultaneously at + // the very beginning. + EXPECT_TRUE(InstructionsMayInterfere(ordering, param0, param1)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param0, add)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, param0)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, param1, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param1, add)); + + // Negate and exp still interfere. + EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate)); + + // But {negate, add} and {exp, add} don't interfere. + EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp)); +} + +TEST_P(HloDataflowAnalysisTest, WhileParameters_Sequential) { + // Similar to MultipleEntryParameters_Sequential, but the parameter is of + // while body computation. Body computation in the sequential order: + // + // %constant = Constant(...) + // %exp = Exp(%constant) + // %param = Param(0) + // %add = Add(%param, %exp) ;; Root of body + // %dead_constant = Constant(...) + // %dead_negate = Negate(%dead_constant) + // + // %constant and its only use %exp are ordered before 'param'. However, the + // %constant and %param values still interfere because the parameter is + // considered live into the while body. + // + // Similarly, %dead_constant and %dead_negate are ordered after the root of + // the body computation %add. However, %add is liveout of the computation so + // %dead_constant and %add interfere. + auto body_builder = HloComputation::Builder(TestName()); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "body_param")); + auto constant = body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto exp = body_builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant)); + auto add = body_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, exp, body_param)); + auto dead_constant = body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto dead_negate = body_builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kNegate, dead_constant)); + HloComputation* body = module_->AddEmbeddedComputation( + body_builder.Build(/*root_instruction=*/add)); + + auto cond_builder = HloComputation::Builder("condition"); + auto cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "cond_param")); + auto cond_constant = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloComputation* condition = + module_->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(scalar_shape_, condition, body, param)); + + auto entry = module_->AddEntryComputation(builder.Build()); + bool ssa_form = GetParam(); + const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form); + + SequentialHloOrdering::HloModuleSequence sequence; + sequence.insert({entry, {param, xla_while}}); + sequence.insert({condition, {cond_param, cond_constant}}); + // Construct the order such that 'constant' and its use 'exp' are before + // body_param. + sequence.insert({body, {constant, exp, body_param, add}}); + + SequentialHloOrdering ordering(module_.get(), sequence); + + // 'add' is the body root even though later instructions follow in the order + // like 'dead_negate'. Only 'add' should be live out of the computation. + EXPECT_TRUE(analysis.GetValueDefinedAt(add).live_out_of_computation()); + EXPECT_FALSE( + analysis.GetValueDefinedAt(dead_negate).live_out_of_computation()); + + // 'add' is live out of the body and will interfere with an later instructions + // such as 'dead_constant' and 'dead_negate'. + EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_constant)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, add, dead_negate)); + + // The remaining checks test phi values defined by body and condition + // parameters which only occur in the SSA form of the analysis. + if (ssa_form) { + // Though the ordering suggests 'constant' and 'param' should not interfere, + // 'param' is live in and thus interferes with any earlier instruction of + // the computation in the order (eg 'constant')' + EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, constant)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, body_param, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add)); + + // The following values end up in the same buffer: + // (1) the init value: 'param' + // (2) the body parameter: 'body_param' + // (3) the condition parameter: 'cond_param' + // (4) the root value of the while body: 'add' + // (5) the while value: 'xla_while' + // None should interfere. + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, body_param)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, cond_param)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, xla_while)); + + EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, cond_param)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, body_param, xla_while)); + + EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, cond_param, xla_while)); + + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, xla_while)); + } +} + +TEST_P(HloDataflowAnalysisTest, NonElementwiseOperand) { + // A chain of operations with two elementwise and one non-elementwise. The + // elementwise op should not interfere with its operand, while the + // non-elementwise op should interfere. Entry params always interfere. + // + // param --> exp -> negate -> reverse + // + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "param")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, exp)); + auto reverse = builder.AddInstruction( + HloInstruction::CreateReverse(vector_shape_, negate, {0})); + + module_->AddEntryComputation(builder.Build()); + RunAnalysis(GetParam()); + + DependencyHloOrdering ordering(module_.get()); + + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, reverse)); + + // Negate is elementwise, so doesn't interfere with its operand. + // Reverse is non-elementwise, so does interfere with its operand. + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, negate)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, reverse)); +} + +TEST_P(HloDataflowAnalysisTest, OverlappedValues) { + // Verify simultaneously live values interfere (exp and negate). + // + // param --> negate -> add + // \---> exp -----/ + // + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "param")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param)); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + vector_shape_, HloOpcode::kAdd, negate, exp)); + + module_->AddEntryComputation(builder.Build()); + RunAnalysis(GetParam()); + + DependencyHloOrdering ordering(module_.get()); + + EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, param, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add)); + + // Negate and exp interfere with each other, but not with add. + EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp)); +} + +TEST_P(HloDataflowAnalysisTest, OverlappedValuesSequentialOrder) { + // Identical to the test OverlappedValue but using a sequential ordering of + // HLO instructions. + // + // param --> negate -> add + // \---> exp -----/ + // + // Sequential order: + // param, negate, exp, add + // + // Liveness is identical to the DependencyHloOrdering. + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "param")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param)); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + vector_shape_, HloOpcode::kAdd, negate, exp)); + + auto entry = module_->AddEntryComputation(builder.Build()); + RunAnalysis(GetParam()); + + SequentialHloOrdering::HloModuleSequence sequence; + std::vector order = {param, negate, exp, add}; + sequence.emplace(entry, order); + + SequentialHloOrdering ordering(module_.get(), sequence); + + EXPECT_TRUE(InstructionsMayInterfere(ordering, param, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, exp)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, param, add)); + + // Negate and exp interfere with each other, but not with add. + EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, exp)); + EXPECT_TRUE(InstructionsMayInterfere(ordering, exp, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, negate, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, negate)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, add)); + EXPECT_FALSE(InstructionsMayInterfere(ordering, add, exp)); +} + +TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) { + // Test MayInterfere() for embedded computation, specifically the interference + // of values in different computations. + // + // embedded_computation: + // %embedded_param = Param(0) + // %embedded_log = Log(%embedded_param) + // + // entry computation: + // %param = Param(0) + // %negate = Negate(%param) + // %exp = Negate(%exp) + // %call = Call(embedded_computation, {%exp}) + // %add = Add(%negate, %call) + // + // Note %negate is live across the call and should interfere with all values + // in the embedded computation. + auto embedded_builder = HloComputation::Builder(TestName() + "_embedded"); + auto embedded_param = embedded_builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "embedded_param")); + auto embedded_log = + embedded_builder.AddInstruction(HloInstruction::CreateUnary( + vector_shape_, HloOpcode::kLog, embedded_param)); + auto embedded_computation = + module_->AddEmbeddedComputation(embedded_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, vector_shape_, "param")); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kNegate, param)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vector_shape_, HloOpcode::kExp, param)); + auto call = builder.AddInstruction( + HloInstruction::CreateCall(vector_shape_, {exp}, embedded_computation)); + builder.AddInstruction(HloInstruction::CreateBinary( + vector_shape_, HloOpcode::kAdd, negate, call)); + module_->AddEntryComputation(builder.Build()); + RunAnalysis(GetParam()); + + DependencyHloOrdering ordering(module_.get()); + + // Exp only use is the call so it should not interfere with values inside the + // embedded computation. + EXPECT_FALSE(InstructionsMayInterfere(ordering, exp, embedded_log)); + + // Negate is live across the call and should interfere with values in the + // embedded computation + EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log)); +} + INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, HloDataflowAnalysisTest, ::testing::Values(false, true)); diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 10cd7ca7c09..704b8dfca70 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -45,9 +45,9 @@ TEST_F(HloDceTest, NoDeadCode) { // Verify that no dead code is removed from a computation with no dead code. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(123.0f))); builder.AddInstruction(HloInstruction::CreateBinary( constant1->shape(), HloOpcode::kAdd, constant1, constant2)); @@ -98,9 +98,9 @@ TEST_F(HloDceTest, ControlDependencies) { // Verify that instructions with control dependencies are not removed. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(123.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(123.0f))); // Create two dead instructions: a negate and an add. auto dead_negate = builder.AddInstruction(HloInstruction::CreateUnary( diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 3e7f5b1f3d9..393e25380e0 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -89,11 +89,11 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, << HloOpcodeString(opcode); } - auto result = LiteralUtil::CreateFromShape(shape); - TF_RETURN_IF_ERROR(LiteralUtil::Populate( - result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { - return compare_op(LiteralUtil::Get(lhs_literal, multi_index), - LiteralUtil::Get(rhs_literal, multi_index)); + auto result = Literal::CreateFromShape(shape); + TF_RETURN_IF_ERROR(result.get()->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); })); return std::move(result); @@ -117,12 +117,11 @@ StatusOr> ElementWiseUnaryOpImpl( ShapeUtil::HumanString(operand->shape()).c_str()); } - auto result = LiteralUtil::CreateFromShape(shape); + auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(LiteralUtil::Populate( - result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { - return unary_op( - LiteralUtil::Get(operand_literal, multi_index)); + TF_RETURN_IF_ERROR(result.get()->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return unary_op(operand_literal.Get(multi_index)); })); return std::move(result); } @@ -168,6 +167,23 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return HandleAbs(abs, operand); }; + Status HandleBroadcast(HloInstruction* broadcast) override { + parent_->evaluated_[broadcast] = + Literal::CreateFromShape(broadcast->shape()); + auto output = parent_->evaluated_[broadcast].get(); + auto operand_to_broadcast = + parent_->GetEvaluatedLiteralFor(broadcast->operand(0)); + std::vector broadcast_indices( + ShapeUtil::Rank(broadcast->operand(0)->shape()), 0); + return output->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + for (int64 i = 0; i < broadcast->dimensions().size(); ++i) { + broadcast_indices[i] = multi_index[broadcast->dimensions(i)]; + } + return operand_to_broadcast.Get(broadcast_indices); + }); + } + Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], ElementWiseUnaryOp(ceil, [](ReturnT elem_operand) { @@ -176,7 +192,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override { + Status HandleCopy(HloInstruction* copy) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[copy], ElementWiseUnaryOp(copy, [](ReturnT elem_operand) { return elem_operand; @@ -184,42 +200,19 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - template - std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal) { - DCHECK_EQ(src_type, src_literal.shape().element_type()); - return LiteralUtil::Convert< - typename primitive_util::PrimitiveTypeToNative::type, - typename primitive_util::PrimitiveTypeToNative::type>( - src_literal); - } + Status HandleConvert(HloInstruction* convert) override { + const HloInstruction* operand = convert->operand(0); + TF_RET_CHECK(ShapeUtil::SameDimensions(operand->shape(), convert->shape())); + TF_ASSIGN_OR_RETURN(std::unique_ptr result, + parent_->GetEvaluatedLiteralFor(operand).Convert( + convert->shape().element_type())); - Status HandleConvert(HloInstruction* convert, - HloInstruction* operand) override { - auto operand_literal = parent_->GetEvaluatedLiteralFor(operand); - - switch (operand->shape().element_type()) { -#define CONVERT_IF_TYPES_MATCH(src_type) \ - case (src_type): \ - parent_->evaluated_[convert] = LiteralUtil::Convert< \ - typename primitive_util::PrimitiveTypeToNative::type, \ - ReturnT>(operand_literal); \ - break; - CONVERT_IF_TYPES_MATCH(PRED) - CONVERT_IF_TYPES_MATCH(S8) - CONVERT_IF_TYPES_MATCH(S32) - CONVERT_IF_TYPES_MATCH(S64) - CONVERT_IF_TYPES_MATCH(U8) - CONVERT_IF_TYPES_MATCH(U32) - CONVERT_IF_TYPES_MATCH(U64) - CONVERT_IF_TYPES_MATCH(F32) - CONVERT_IF_TYPES_MATCH(F64) -#undef CONVERT_IF_TYPES_MATCH - // Other types are not yet supported. - default: - LOG(FATAL) << "unimplemented operand type for HandleCovert: " - << PrimitiveType_Name(operand->shape().element_type()); + if (LayoutUtil::LayoutsInShapesEqual(result->shape(), convert->shape())) { + parent_->evaluated_[convert] = std::move(result); + } else { + parent_->evaluated_[convert] = + result->Relayout(convert->shape().layout()); } - return Status::OK(); } @@ -322,8 +315,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleMaximum(HloInstruction* maximum, HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleMaximum(HloInstruction* maximum) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[maximum], ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) { @@ -332,8 +324,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); }; - Status HandleMinimum(HloInstruction* minimum, HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleMinimum(HloInstruction* minimum) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[minimum], ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) { @@ -446,12 +437,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - auto result = LiteralUtil::CreateFromShape(shape); + auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(LiteralUtil::Populate( - result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { - return binary_op(LiteralUtil::Get(lhs_literal, multi_index), - LiteralUtil::Get(rhs_literal, multi_index)); + TF_RETURN_IF_ERROR(result.get()->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return binary_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); })); return std::move(result); } @@ -483,14 +474,13 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); const Literal& ehs_literal = parent_->GetEvaluatedLiteralFor(ehs); - auto result = LiteralUtil::CreateFromShape(shape); + auto result = Literal::CreateFromShape(shape); - TF_RETURN_IF_ERROR(LiteralUtil::Populate( - result.get(), [&](tensorflow::gtl::ArraySlice multi_index) { - return ternary_op( - LiteralUtil::Get(lhs_literal, multi_index), - LiteralUtil::Get(rhs_literal, multi_index), - LiteralUtil::Get(ehs_literal, multi_index)); + TF_RETURN_IF_ERROR(result.get()->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return ternary_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index), + ehs_literal.Get(multi_index)); })); return std::move(result); @@ -552,7 +542,7 @@ StatusOr> HloEvaluator::Evaluate( if (operand->opcode() == HloOpcode::kParameter) { const Literal* input_literal = arg_literals_[operand->parameter_number()]; VLOG(2) << "Parameter operand evaluated to: " - << LiteralUtil::ToString(*input_literal); + << input_literal->ToString(); TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); evaluated_[operand] = MakeUnique(*input_literal); @@ -589,8 +579,7 @@ std::unique_ptr HloEvaluator::TryEvaluate( Status HloEvaluator::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); const Literal* input_literal = arg_literals_[parameter->parameter_number()]; - VLOG(2) << "Parameter evaluated to: " - << LiteralUtil::ToString(*input_literal); + VLOG(2) << "Parameter evaluated to: " << input_literal->ToString(); DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape())); evaluated_[parameter] = MakeUnique(*input_literal); @@ -606,14 +595,14 @@ Status HloEvaluator::HandleConstant(HloInstruction* constant, Status HloEvaluator::HandleReshape(HloInstruction* reshape) { TF_ASSIGN_OR_RETURN( evaluated_[reshape], - LiteralUtil::Reshape(GetEvaluatedLiteralFor(reshape->operand(0)), - AsInt64Slice(reshape->shape().dimensions()))); + GetEvaluatedLiteralFor(reshape->operand(0)) + .Reshape(AsInt64Slice(reshape->shape().dimensions()))); return Status::OK(); } Status HloEvaluator::HandleTranspose(HloInstruction* transpose) { - evaluated_[transpose] = LiteralUtil::Transpose( - GetEvaluatedLiteralFor(transpose->operand(0)), transpose->dimensions()); + evaluated_[transpose] = GetEvaluatedLiteralFor(transpose->operand(0)) + .Transpose(transpose->dimensions()); return Status::OK(); } @@ -641,16 +630,16 @@ Status HloEvaluator::HandleConcatenate( ShapeUtil::GetDimension(operand_shape, concat_dim); } - auto result_literal = LiteralUtil::CreateFromDimensions( + auto result_literal = Literal::CreateFromDimensions( reference_shape.element_type(), concat_dimensions); DimensionVector source_indices(rank, 0); DimensionVector dest_indices(concat_dimensions.size(), 0); for (auto operand : operands) { const Shape& operand_shape = operand->shape(); - TF_RETURN_IF_ERROR(LiteralUtil::Copy( - GetEvaluatedLiteralFor(operand), source_indices, result_literal.get(), - dest_indices, AsInt64Slice(operand_shape.dimensions()))); + TF_RETURN_IF_ERROR(result_literal.get()->Copy( + GetEvaluatedLiteralFor(operand), source_indices, dest_indices, + AsInt64Slice(operand_shape.dimensions()))); dest_indices[concat_dim] += ShapeUtil::GetDimension(operand_shape, concat_dim); } @@ -775,14 +764,14 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode, Status HloEvaluator::HandleSlice(HloInstruction* slice, HloInstruction* operand) { const Shape& shape = slice->shape(); - auto literal = LiteralUtil::CreateFromDimensions( + auto literal = Literal::CreateFromDimensions( shape.element_type(), AsInt64Slice(shape.dimensions())); DimensionVector dest_indices(slice->slice_starts().size(), 0); - TF_RETURN_IF_ERROR(LiteralUtil::Copy( - GetEvaluatedLiteralFor(operand), slice->slice_starts(), literal.get(), - dest_indices, AsInt64Slice(shape.dimensions()))); + TF_RETURN_IF_ERROR(literal.get()->Copy(GetEvaluatedLiteralFor(operand), + slice->slice_starts(), dest_indices, + AsInt64Slice(shape.dimensions()))); evaluated_[slice] = std::move(literal); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index b26ece28b75..a11a5abc03d 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -34,7 +35,7 @@ limitations under the License. namespace xla { namespace { -class HloEvaluatorTest : public ::testing::Test { +class HloEvaluatorTest : public HloTestBase { protected: HloEvaluatorTest() { evaluator_ = MakeUnique(); } @@ -44,9 +45,9 @@ class HloEvaluatorTest : public ::testing::Test { // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp // with 3 operands. TEST_F(HloEvaluatorTest, DoesClamp) { - auto low = LiteralUtil::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); - auto high = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); - auto value = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); + auto low = Literal::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); + auto high = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); + auto value = Literal::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); Shape shape = low->shape(); auto c1 = HloInstruction::CreateConstant(std::move(low)); @@ -58,17 +59,17 @@ TEST_F(HloEvaluatorTest, DoesClamp) { std::unique_ptr result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); - auto expected = LiteralUtil::CreateR2({{0, 4}, {2, 4}}); + auto expected = Literal::CreateR2({{0, 4}, {2, 4}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + EXPECT_TRUE(result->Equal(*expected)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs select // with 3 operands. TEST_F(HloEvaluatorTest, DoesSelect) { - auto pred = LiteralUtil::CreateR2({{true, false}, {false, true}}); - auto on_true = LiteralUtil::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); - auto on_false = LiteralUtil::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); + auto pred = Literal::CreateR2({{true, false}, {false, true}}); + auto on_true = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); + auto on_false = Literal::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); Shape shape = on_true->shape(); auto c1 = HloInstruction::CreateConstant(std::move(pred)); @@ -80,16 +81,16 @@ TEST_F(HloEvaluatorTest, DoesSelect) { std::unique_ptr result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); - auto expected = LiteralUtil::CreateR2({{2, 5}, {0, 4}}); + auto expected = Literal::CreateR2({{2, 5}, {0, 4}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + EXPECT_TRUE(result->Equal(*expected)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise addition with 2 operands. TEST_F(HloEvaluatorTest, DoesAdd) { - auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); auto c1 = HloInstruction::CreateConstant(std::move(lhs)); @@ -100,16 +101,16 @@ TEST_F(HloEvaluatorTest, DoesAdd) { std::unique_ptr result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); - auto expected = LiteralUtil::CreateR2({{3, 4}, {-96, 8}}); + auto expected = Literal::CreateR2({{3, 4}, {-96, 8}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + EXPECT_TRUE(result->Equal(*expected)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise divide with 2 operands. TEST_F(HloEvaluatorTest, DoesDivide) { - auto lhs_s64 = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); - auto rhs_s64 = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); + auto lhs_s64 = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs_s64 = Literal::CreateR2({{2, 4}, {4, 4}}); Shape shape_s64 = ShapeUtil::MakeShape(S64, {2, 2}); auto c1_s64 = HloInstruction::CreateConstant(std::move(lhs_s64)); @@ -120,12 +121,12 @@ TEST_F(HloEvaluatorTest, DoesDivide) { std::unique_ptr result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); - auto expected = LiteralUtil::CreateR2({{0, 0}, {-25, 1}}); + auto expected = Literal::CreateR2({{0, 0}, {-25, 1}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + EXPECT_TRUE(result->Equal(*expected)); - auto lhs_f64 = LiteralUtil::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); - auto rhs_f64 = LiteralUtil::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); + auto lhs_f64 = Literal::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); + auto rhs_f64 = Literal::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); Shape shape_f64 = ShapeUtil::MakeShape(F64, {2, 2}); auto c1_f64 = HloInstruction::CreateConstant(std::move(lhs_f64)); @@ -135,16 +136,15 @@ TEST_F(HloEvaluatorTest, DoesDivide) { result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); - expected = - LiteralUtil::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); + expected = Literal::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + EXPECT_TRUE(result->Equal(*expected)); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise abs op with 1 operand. TEST_F(HloEvaluatorTest, DoesAbs) { - auto operand = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); + auto operand = Literal::CreateR2({{1, -20}, {-100, 4}}); const Shape& shape = ShapeUtil::MakeShape(S64, {2, 2}); auto c1 = HloInstruction::CreateConstant(std::move(operand)); auto instruction = @@ -153,42 +153,40 @@ TEST_F(HloEvaluatorTest, DoesAbs) { std::unique_ptr result = evaluator_->Evaluate(instruction.get(), {}).ConsumeValueOrDie(); - auto expected = LiteralUtil::CreateR2({{1, 20}, {100, 4}}); + auto expected = Literal::CreateR2({{1, 20}, {100, 4}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + EXPECT_TRUE(result->Equal(*expected)); // For R0 literal. const Shape& r0 = ShapeUtil::MakeShape(F32, {}); - operand = LiteralUtil::CreateR0(-1.0f); + operand = Literal::CreateR0(-1.0f); c1 = HloInstruction::CreateConstant(std::move(operand)); instruction = HloInstruction::CreateUnary(r0, HloOpcode::kAbs, c1.get()); result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie(); - expected = LiteralUtil::CreateR0(1.0f); + expected = Literal::CreateR0(1.0f); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + EXPECT_TRUE(result->Equal(*expected)); // For R1 literal with dimension of size 0. Shape empty_r1 = ShapeUtil::MakeShape(F32, {0}); - operand = LiteralUtil::CreateR1({}); + operand = Literal::CreateR1({}); c1 = HloInstruction::CreateConstant(std::move(operand)); instruction = HloInstruction::CreateUnary(empty_r1, HloOpcode::kAbs, c1.get()); result = evaluator_->Evaluate(instruction.get()).ConsumeValueOrDie(); - expected = LiteralUtil::CreateR1({}); + expected = Literal::CreateR1({}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + EXPECT_TRUE(result->Equal(*expected)); } // namespace // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. -TEST_F(HloEvaluatorTest, DoesTraveseInstructions) { - HloComputation::Builder builder( - ::testing::UnitTest::GetInstance()->current_test_info()->name()); - - auto lhs = LiteralUtil::CreateR2({{1, 0}, {-100, 4}}); - auto rhs = LiteralUtil::CreateR2({{2, 4}, {4, 4}}); - auto rhs2 = LiteralUtil::CreateR2({{1, -20}, {-100, 4}}); +TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { + HloComputation::Builder builder(TestName()); + auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); + auto rhs2 = Literal::CreateR2({{1, -20}, {-100, 4}}); std::vector args = {lhs.get(), rhs.get(), rhs2.get()}; Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); @@ -206,21 +204,19 @@ TEST_F(HloEvaluatorTest, DoesTraveseInstructions) { std::unique_ptr result = evaluator_->Evaluate(builder.Build().get(), args).ConsumeValueOrDie(); - auto expected = LiteralUtil::CreateR2({{4, -16}, {-196, 12}}); + auto expected = Literal::CreateR2({{4, -16}, {-196, 12}}); - EXPECT_TRUE(LiteralUtil::Equal(*result, *expected)); + EXPECT_TRUE(result->Equal(*expected)); } // Verifies Reshape operation is correctly evaluated. TEST_F(HloEvaluatorTest, DoesReshape) { - HloComputation::Builder builder( - ::testing::UnitTest::GetInstance()->current_test_info()->name()); - + HloComputation::Builder builder(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSIGN_OR_ASSERT_OK(auto literal, LiteralTestUtil::CreateRandomLiteral( ShapeUtil::MakeShape(F32, dimensions), 0.0, 1.0)); - auto literal_clone = LiteralUtil::CloneToUnique(*literal); + auto literal_clone = literal->CloneToUnique(); HloInstruction* literal_instruction = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -233,13 +229,73 @@ TEST_F(HloEvaluatorTest, DoesReshape) { evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie(); using NativeT = typename primitive_util::PrimitiveTypeToNative::type; - LiteralUtil::EachCell( - *result, [&](tensorflow::gtl::ArraySlice indices, NativeT value) { + result->EachCell( + [&](tensorflow::gtl::ArraySlice indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); - EXPECT_TRUE(value == - LiteralUtil::Get(*literal_clone, rindexes)); + EXPECT_TRUE(value == literal_clone->Get(rindexes)); }); } +// Verifies Broadcast operation is correctly evaluated. +TEST_F(HloEvaluatorTest, DoesBroadcast) { + HloComputation::Builder builder(TestName()); + auto input_literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); + auto output_literal = Literal::CreateR3( + {{{1, 2}, {3, 4}, {5, 6}}, {{1, 2}, {3, 4}, {5, 6}}}); + HloInstruction* literal_instruction = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal))); + + builder.AddInstruction(HloInstruction::CreateBroadcast( + output_literal->shape(), literal_instruction, {1, 2})); + + std::unique_ptr result = + evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie(); + + LiteralTestUtil::ExpectEqual(*result, *output_literal); +} + +TEST_F(HloEvaluatorTest, ConvertWithSameLayout) { + HloComputation::Builder builder(TestName()); + + auto input_literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); + auto expected = + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}); + ASSERT_TRUE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), + expected->shape())); + + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal))); + builder.AddInstruction( + HloInstruction::CreateConvert(expected->shape(), constant)); + + std::unique_ptr result = + evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie(); + + EXPECT_TRUE(ShapeUtil::Equal(result->shape(), expected->shape())); + LiteralTestUtil::ExpectEqual(*result, *expected); +} + +TEST_F(HloEvaluatorTest, ConvertWithDifferentLayout) { + HloComputation::Builder builder(TestName()); + + auto input_literal = Literal::CreateR2WithLayout( + {{1, 2}, {3, 4}, {5, 6}}, LayoutUtil::MakeLayout({0, 1})); + auto expected = Literal::CreateR2WithLayout( + {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, LayoutUtil::MakeLayout({1, 0})); + ASSERT_FALSE(LayoutUtil::LayoutsInShapesEqual(input_literal->shape(), + expected->shape())); + + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(std::move(input_literal))); + builder.AddInstruction( + HloInstruction::CreateConvert(expected->shape(), constant)); + + std::unique_ptr result = + evaluator_->Evaluate(builder.Build().get(), {}).ConsumeValueOrDie(); + + EXPECT_TRUE(ShapeUtil::Equal(result->shape(), expected->shape())); + LiteralTestUtil::ExpectEqual(*result, *expected); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index eb2e5dfb37f..dffb53320c4 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -214,6 +214,7 @@ string InstructionSequenceGraph( case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kConvert: + case HloOpcode::kCos: case HloOpcode::kDivide: case HloOpcode::kEq: case HloOpcode::kExp: @@ -282,6 +283,10 @@ string InstructionSequenceGraph( // port for each parameter instruction. No need to emit anything in this // case. continue; + case HloOpcode::kBatchNormTraining: + StrAppend(&name, " feature_index=", instruction->feature_index()); + color = kPurple; + break; case HloOpcode::kReduce: StrAppend(&name, " dims=", Join(instruction->dimensions(), ",")); color = kPurple; @@ -313,6 +318,11 @@ string InstructionSequenceGraph( StrAppend(&name, "
", "custom_call_target=", instruction->custom_call_target()); break; + case HloOpcode::kReducePrecision: + // Make ReducePrecision ops a bit more visible, since typically they + // will be inserted as modifications to an existing graph. + color = kDarkRed; + break; } // Create instruction node with appropriate label, shape, and color. @@ -325,8 +335,7 @@ string InstructionSequenceGraph( ShapeUtil::IsEffectiveScalar(instruction->shape())) { auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex( instruction->shape(), /*linear_index=*/0); - StrAppend(&label, " = {", - LiteralUtil::GetAsString(instruction->literal(), elem_idx), + StrAppend(&label, " = {", instruction->literal().GetAsString(elem_idx), "}"); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index ea813c98743..9117ab96536 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -122,6 +122,7 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kBitcast: case HloOpcode::kCeil: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kFloor: case HloOpcode::kIsFinite: @@ -226,6 +227,19 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, return instruction; } +/* static */ std::unique_ptr +HloInstruction::CreateReducePrecision(const Shape& shape, + HloInstruction* operand, + const int exponent_bits, + const int mantissa_bits) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kReducePrecision, shape)); + instruction->AppendOperand(operand); + instruction->exponent_bits_ = exponent_bits; + instruction->mantissa_bits_ = mantissa_bits; + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateCrossReplicaSum(const Shape& shape, HloInstruction* operand) { @@ -371,6 +385,22 @@ HloInstruction::CreateDynamicUpdateSlice(const Shape& shape, return instruction; } +/* static */ std::unique_ptr +HloInstruction::CreateBatchNormTraining(const Shape& shape, + HloInstruction* operand, + HloInstruction* scale, + HloInstruction* offset, float epsilon, + int64 feature_index) { + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kBatchNormTraining, shape)); + instruction->AppendOperand(operand); + instruction->AppendOperand(scale); + instruction->AppendOperand(offset); + instruction->epsilon_ = epsilon; + instruction->feature_index_ = feature_index; + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateSelectAndScatter( const Shape& shape, HloInstruction* operand, HloComputation* select, @@ -730,6 +760,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kBitcast: case HloOpcode::kCeil: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kIsFinite: case HloOpcode::kFloor: @@ -780,6 +811,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kConvert: CHECK_EQ(new_operands.size(), 1); return CreateConvert(shape, new_operands[0]); + case HloOpcode::kReducePrecision: + CHECK_EQ(new_operands.size(), 1); + return CreateReducePrecision(shape, new_operands[0], exponent_bits_, + mantissa_bits_); case HloOpcode::kConvolution: CHECK_EQ(new_operands.size(), 2); return CreateConvolve(shape, new_operands[0], new_operands[1], *window_, @@ -838,12 +873,13 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( return CreateWhile(shape, while_condition(), while_body(), new_operands[0]); case HloOpcode::kConstant: - return CreateConstant(LiteralUtil::CloneToUnique(*literal_)); + return CreateConstant(literal_->CloneToUnique()); case HloOpcode::kFusion: return CloneFusionWithNewOperands(shape, new_operands); case HloOpcode::kParameter: return CreateParameter(parameter_number_, shape, parameter_name_); // Unsupported ops for cloning. + case HloOpcode::kBatchNormTraining: case HloOpcode::kRecv: case HloOpcode::kSend: case HloOpcode::kUpdate: @@ -1099,6 +1135,7 @@ bool HloInstruction::Identical( case HloOpcode::kCeil: case HloOpcode::kClamp: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: case HloOpcode::kDot: @@ -1141,15 +1178,24 @@ bool HloInstruction::Identical( // different HloComputations. ShapeUtil::Compatible(shape(), other.shape()); + case HloOpcode::kBatchNormTraining: + return feature_index() == other.feature_index() && + epsilon() == other.epsilon(); + // A constant is defined by the value in the literal. case HloOpcode::kConstant: - return LiteralUtil::Equal(literal(), other.literal()); + return literal().Equal(other.literal()); // A convert result is determined by the primitive type that the operand is // converted into. case HloOpcode::kConvert: return shape().element_type() == other.shape().element_type(); + // A reduce-precision operation is determined by the bit sizes. + case HloOpcode::kReducePrecision: + return exponent_bits() == other.exponent_bits() && + mantissa_bits() == other.mantissa_bits(); + // Convolution has a window and dimensions. case HloOpcode::kConvolution: return protobuf_util::ProtobufEquals(window(), other.window()) && @@ -1439,9 +1485,9 @@ string HloInstruction::ToString(bool compact_operands, if (opcode() == HloOpcode::kConstant) { // For constants, show the actual value in place of an empty operand list. if (ShapeUtil::ElementsIn(shape()) <= 10) { - // LiteralUtil::ToString emits multidimensional arrays over multiple + // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. - string tmp = LiteralUtil::ToString(literal()); + string tmp = literal().ToString(); std::replace(tmp.begin(), tmp.end(), '\n', ' '); std::vector v = tensorflow::str_util::Split(tmp, ' '); bool first = true; @@ -1736,6 +1782,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { switch (opcode_) { case HloOpcode::kAbs: return visitor->HandleAbs(this, operands_[0]); + case HloOpcode::kBatchNormTraining: + return visitor->HandleBatchNormTraining(this); case HloOpcode::kSign: return visitor->HandleSign(this, operands_[0]); case HloOpcode::kConstant: @@ -1758,9 +1806,9 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kSubtract: return visitor->HandleSubtract(this, operands_[0], operands_[1]); case HloOpcode::kMaximum: - return visitor->HandleMaximum(this, operands_[0], operands_[1]); + return visitor->HandleMaximum(this); case HloOpcode::kMinimum: - return visitor->HandleMinimum(this, operands_[0], operands_[1]); + return visitor->HandleMinimum(this); case HloOpcode::kLogicalAnd: return visitor->HandleLogicalAnd(this, operands_[0], operands_[1]); case HloOpcode::kLogicalOr: @@ -1768,9 +1816,9 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kConcatenate: return visitor->HandleConcatenate(this, operands_); case HloOpcode::kConvert: - return visitor->HandleConvert(this, operands_[0]); + return visitor->HandleConvert(this); case HloOpcode::kCopy: - return visitor->HandleCopy(this, operands_[0]); + return visitor->HandleCopy(this); case HloOpcode::kMultiply: return visitor->HandleMultiply(this, operands_[0], operands_[1]); case HloOpcode::kDot: @@ -1814,6 +1862,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { return visitor->HandleLog(this, operands_[0]); case HloOpcode::kTanh: return visitor->HandleTanh(this, operands_[0]); + case HloOpcode::kCos: + return visitor->HandleCos(this, operands_[0]); case HloOpcode::kIsFinite: return visitor->HandleIsFinite(this, operands_[0]); case HloOpcode::kLogicalNot: @@ -1830,6 +1880,8 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { return visitor->HandleTranspose(this); case HloOpcode::kReverse: return visitor->HandleReverse(this, operands_[0]); + case HloOpcode::kReducePrecision: + return visitor->HandleReducePrecision(this, operands_[0]); case HloOpcode::kSlice: return visitor->HandleSlice(this, operands_[0]); case HloOpcode::kDynamicSlice: @@ -1868,72 +1920,86 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { HloOpcodeString(opcode_).c_str()); } -Status HloInstruction::AcceptInternal(DfsHloVisitor* visitor, - const CompareFunction* operand_order, - bool ignore_control_predecessors) { - // Do not visit this HLO node again if it is already visited. - if (visitor->DidVisit(*this)) { - VLOG(3) << "Not visiting HLO " << name() << " as it was already visited."; - return Status::OK(); - } - - // If the instruction is in the visiting state, it means a cycle. - if (visitor->IsVisiting(*this)) { +static Status PushDFSChild(DfsHloVisitor* visitor, + std::vector* dfs_stack, + HloInstruction* parent, HloInstruction* child) { + if (visitor->IsVisiting(*child)) { return FailedPrecondition( "A cycle is detected while visiting instruction %s", - ToString().c_str()); - } - visitor->SetVisiting(*this); - - // Sort operands, if an ordering was provided. 'temp_sorted_operands' must - // live at this scope, since 'operands' will point to it if the operands are - // sorted. The purpose of the 'operands' pointer is to avoid copying the - // operands in the common case where the operands are not sorted. - std::vector* operands = &operands_; - std::vector temp_sorted_operands; - if (operand_order != nullptr) { - temp_sorted_operands = operands_; - std::sort(temp_sorted_operands.begin(), temp_sorted_operands.end(), - *operand_order); - operands = &temp_sorted_operands; - } - for (HloInstruction* operand : *operands) { - VLOG(3) << "Going to visit HLO " << operand->name() << " as operand of HLO " - << name(); - TF_RETURN_IF_ERROR(operand->AcceptInternal(visitor, operand_order, - ignore_control_predecessors)); + parent->ToString().c_str()); } - if (!ignore_control_predecessors) { - // This uses the same pointer/vector sorting to avoid extra copies as above. - std::vector* predecessors = &control_predecessors_; - std::vector temp_sorted_predecessors; + if (!visitor->DidVisit(*child)) { + dfs_stack->push_back(child); + } else { + VLOG(3) << "Not visiting HLO " << child->name() + << " as it was already visited."; + } + return Status::OK(); +} + +static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor, + const HloInstruction::CompareFunction* operand_order, + bool ignore_control_predecessors) { + std::vector dfs_stack; + dfs_stack.push_back(root); + + do { + DCHECK(!dfs_stack.empty()); + + HloInstruction* current_node = dfs_stack.back(); + if (visitor->DidVisit(*current_node)) { + dfs_stack.pop_back(); + VLOG(3) << "Not visiting HLO " << current_node->name() + << " as it was already visited."; + continue; + } + + if (visitor->IsVisiting(*current_node)) { + dfs_stack.pop_back(); + + TF_RETURN_IF_ERROR(visitor->Preprocess(current_node)); + VLOG(2) << "Visiting HLO " << current_node->name(); + TF_RETURN_IF_ERROR(current_node->Visit(visitor)); + visitor->SetVisited(*current_node); + TF_RETURN_IF_ERROR(visitor->Postprocess(current_node)); + continue; + } + + visitor->SetVisiting(*current_node); + + const size_t old_dfs_stack_size = dfs_stack.size(); + + for (HloInstruction* child : current_node->operands()) { + TF_RETURN_IF_ERROR( + PushDFSChild(visitor, &dfs_stack, current_node, child)); + } + + if (!ignore_control_predecessors) { + for (HloInstruction* child : current_node->control_predecessors()) { + TF_RETURN_IF_ERROR( + PushDFSChild(visitor, &dfs_stack, current_node, child)); + } + } + if (operand_order != nullptr) { - temp_sorted_predecessors = control_predecessors_; - std::sort(temp_sorted_predecessors.begin(), - temp_sorted_predecessors.end(), *operand_order); - predecessors = &temp_sorted_predecessors; + std::sort(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end(), + *operand_order); } - for (HloInstruction* control_predecessor : *predecessors) { - VLOG(3) << "Going to visit HLO " << control_predecessor->name() - << " as a control predecessor of HLO " << name(); - TF_RETURN_IF_ERROR(control_predecessor->AcceptInternal( - visitor, operand_order, ignore_control_predecessors)); - } - } - TF_RETURN_IF_ERROR(visitor->Preprocess(this)); - VLOG(2) << "Visiting HLO " << name(); - TF_RETURN_IF_ERROR(Visit(visitor)); - visitor->SetVisited(*this); - return visitor->Postprocess(this); + // This makes the traversal order the same as what you'd expect + // out of a recursive algorithm. + std::reverse(dfs_stack.begin() + old_dfs_stack_size, dfs_stack.end()); + } while (!dfs_stack.empty()); + + return Status::OK(); } Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit, bool ignore_control_predecessors) { VLOG(2) << "HloInstruction::Accept(" << name() << ")"; TF_RETURN_IF_ERROR( - AcceptInternal(visitor, nullptr, ignore_control_predecessors)); + PostOrderDFS(this, visitor, nullptr, ignore_control_predecessors)); if (call_finish_visit) { TF_RETURN_IF_ERROR(visitor->FinishVisit(this)); } @@ -1944,8 +2010,8 @@ Status HloInstruction::AcceptWithOperandOrder( DfsHloVisitor* visitor, const CompareFunction& operand_order, bool call_finish_visit) { VLOG(2) << "HloInstruction::AcceptWithOperandOrder(" << name() << ")"; - TF_RETURN_IF_ERROR(AcceptInternal(visitor, &operand_order, - /*ignore_control_predecessors=*/false)); + TF_RETURN_IF_ERROR(PostOrderDFS(this, visitor, &operand_order, + /*ignore_control_predecessors=*/false)); if (call_finish_visit) { TF_RETURN_IF_ERROR(visitor->FinishVisit(this)); } @@ -2060,12 +2126,14 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kCeil: case HloOpcode::kConvert: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kFloor: case HloOpcode::kIsFinite: case HloOpcode::kLog: case HloOpcode::kLogicalNot: case HloOpcode::kNegate: + case HloOpcode::kReducePrecision: case HloOpcode::kSign: case HloOpcode::kTanh: return true; @@ -2348,4 +2416,9 @@ void HloInstruction::UniquifyName(NameUniquer* name_uniquer) { name_ = name_uniquer->GetUniqueName(name_); } +void HloInstruction::set_outer_dimension_partitions( + const std::vector& outer_dimension_partitions) { + outer_dimension_partitions_ = outer_dimension_partitions; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index c7cd729934b..2f26a7be2d0 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -131,6 +131,13 @@ class HloInstruction { const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); + // Creates a reduce-precision op, where operand is the data to reduce in + // precision, and exponent_bits and mantissa_bits describe the precision to + // reduce it to. + static std::unique_ptr CreateReducePrecision( + const Shape& shape, HloInstruction* operand, const int exponent_bits, + const int mantissa_bits); + // Creates a cross replica sum op. static std::unique_ptr CreateCrossReplicaSum( const Shape& shape, HloInstruction* operand); @@ -209,6 +216,11 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, HloInstruction* init_value, const Window& window, HloComputation* reduce_computation); + // Creates a batch-norm-training instruction. + static std::unique_ptr CreateBatchNormTraining( + const Shape& shape, HloInstruction* operand, HloInstruction* scale, + HloInstruction* offset, float epsilon, int64 feature_index); + // Creates a scatter computation that scatters the `source` array to the // selected indices of each window. static std::unique_ptr CreateSelectAndScatter( @@ -528,6 +540,18 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv int64 channel_id() const { return channel_id_; } + // Returns feature_index field associated with the instruction. The index + // represents the index of the feature dimension. + // + // Precondition: opcode() == HloOpcode::kBatchNormTraining + int64 feature_index() const { return feature_index_; } + + // Returns a epsilon value associated with the instruction. The is a small + // number added to the variance to avoid divide-by-zero error. + // + // Precondition: opcode() == HloOpcode::kBatchNormTraining + int64 epsilon() const { return epsilon_; } + // Returns the infeed configuration string. The infeed configuration includes // any metadata needed for the backend compiler (e.g., infeed buffer address) // and is target-dependent. @@ -661,6 +685,22 @@ class HloInstruction { return dynamic_slice_sizes_; } + // Returns the number of exponent bits for a reduce-precision node. + // + // Precondition: opcode() == HloOpcode::kReducePrecision + int32 exponent_bits() const { + CHECK_EQ(HloOpcode::kReducePrecision, opcode_); + return exponent_bits_; + } + + // Returns the number of mantissa bits for a reduce-precision node. + // + // Precondition: opcode() == HloOpcode::kReducePrecision + int32 mantissa_bits() const { + CHECK_EQ(HloOpcode::kReducePrecision, opcode_); + return mantissa_bits_; + } + // Returns data on the window in a windowed operation such as // convolution. const Window& window() const { @@ -782,6 +822,17 @@ class HloInstruction { parent_fusion_instruction_ = fusion_instruction; } + // Get/Set the number of partitions per outer dimension (in order, starting + // with outer-most dimension first). Currently used by the parallel cpu + // backend to partition HLOs into parallel tasks. + // TODO(b/62783254) Replace these methods with a more general way to + // annotate HLOs with backend-specific information. + const std::vector& outer_dimension_partitions() const { + return outer_dimension_partitions_; + } + void set_outer_dimension_partitions( + const std::vector& outer_dimension_partitions); + private: enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; @@ -818,12 +869,6 @@ class HloInstruction { const Shape& shape, tensorflow::gtl::ArraySlice operands); - // Inner DFS traversal function -- this function being called (rather than - // Accept above) allows us to distinguish the root of the traversal. - Status AcceptInternal(DfsHloVisitor* visitor, - const CompareFunction* operand_order, - bool ignore_control_predecessors); - // CHECKs various invariants of a fusion instruction. void CheckFusionInstruction() const; @@ -864,6 +909,10 @@ class HloInstruction { std::vector slice_limits_; std::vector slice_strides_; + // The bit sizes for a reduce-precision operation. + int32 exponent_bits_; + int32 mantissa_bits_; + // Describes the [start, start + size) range size for a dynamic slice // ('start' is specified dynamically in the second operand of the operation). std::vector dynamic_slice_sizes_; @@ -934,6 +983,14 @@ class HloInstruction { // Only present for kRng. RandomDistribution distribution_; + // A small float number added to the variance to avoid divide-by-zero error. + // Only present for kBatchNormTraining. + float epsilon_; + + // An integer value representing the index of the feature dimension. + // Only present for kBatchNormTraining. + int64 feature_index_; + // Represents a unique identifier for each Send/Recv instruction pair. // Only present for kSend or kRecv. int64 channel_id_ = -1; @@ -950,6 +1007,10 @@ class HloInstruction { // Metadata for debugging. OpMetadata metadata_; + // The number of partitions per outer dimension (listed in order from + // outer-most dimension first). + std::vector outer_dimension_partitions_; + TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction); }; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index bcf81cd8ddf..bb1b477e139 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -232,7 +232,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperands) { // ------- auto param0 = HloInstruction::CreateParameter(0, r0f32_, "param0"); auto param1 = HloInstruction::CreateParameter(1, r0f32_, "param1"); - auto c0 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + auto c0 = HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); auto addleft = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0.get(), c0.get()); auto addright = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, @@ -271,7 +271,7 @@ TEST_F(HloInstructionTest, MultipleUsersAndOperandsWithUnaryOps) { // ------- auto param0 = HloInstruction::CreateParameter(0, r0f32_, "param0"); auto param1 = HloInstruction::CreateParameter(1, r0f32_, "param1"); - auto c0 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + auto c0 = HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); auto neg1 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, c0.get()); auto addleft = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, param0.get(), neg1.get()); @@ -307,7 +307,7 @@ TEST_F(HloInstructionTest, TrivialMap) { auto param = builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "x")); auto value = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param, value)); auto add_f32 = builder.Build(); @@ -349,9 +349,8 @@ TEST_F(HloInstructionTest, TrivialReduce) { // Builds a parameter and an initial value and feeds them to the reduce. auto param0 = HloInstruction::CreateParameter(0, f32a100x10, ""); - auto const0 = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0.0f)); - auto c0 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + auto const0 = HloInstruction::CreateConstant(Literal::CreateR0(0.0f)); + auto c0 = HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); auto reduce = HloInstruction::CreateReduce(f32v100, param0.get(), const0.get(), /*dimensions_to_reduce=*/{1}, add_f32.get()); @@ -560,7 +559,7 @@ TEST_F(HloInstructionTest, PostProcessAllVisitedNodes) { TEST_F(HloInstructionTest, SingletonFusionOp) { // Create a fusion instruction containing a single unary operation. auto constant = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); auto exp = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get()); @@ -574,9 +573,9 @@ TEST_F(HloInstructionTest, SingletonFusionOp) { TEST_F(HloInstructionTest, BinaryFusionOp) { // Create a fusion instruction containing a single binary operation. auto constant1 = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); auto constant2 = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.1f)); + HloInstruction::CreateConstant(Literal::CreateR0(42.1f)); auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, constant1.get(), constant2.get()); @@ -594,7 +593,7 @@ TEST_F(HloInstructionTest, BinaryFusionOp) { TEST_F(HloInstructionTest, ChainFusionOp) { // Create a chain of fused unary ops. auto constant = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); auto exp1 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get()); auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get()); @@ -613,7 +612,7 @@ TEST_F(HloInstructionTest, ChainFusionOp) { TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { // Create a chain of fused unary ops. auto constant = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); auto exp1 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, constant.get()); auto exp2 = HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, exp1.get()); @@ -644,7 +643,7 @@ TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { std::unique_ptr computation_y = make_map_computation(); auto constant = - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); + HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); auto map_1_x = HloInstruction::CreateMap(scalar_shape, {constant.get()}, computation_x.get(), /*static_operands=*/{}); @@ -681,9 +680,9 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { // // Notable complexities are repeated operands in a same instruction, different // shapes, use of value in different expressions. - auto c1 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.1f)); - auto c2 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.1f)); - auto c3 = HloInstruction::CreateConstant(LiteralUtil::CreateR0(9.0f)); + auto c1 = HloInstruction::CreateConstant(Literal::CreateR0(1.1f)); + auto c2 = HloInstruction::CreateConstant(Literal::CreateR0(2.1f)); + auto c3 = HloInstruction::CreateConstant(Literal::CreateR0(9.0f)); auto add = HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, c1.get(), c2.get()); @@ -732,11 +731,11 @@ TEST_F(HloInstructionTest, IdenticalInstructions) { // Create a set of random constant operands to use below. Make them matrices // so dimensions are interesting. auto operand1 = HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); auto operand2 = HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{10.0, 20.0}, {30.0, 40.0}})); - auto vector_operand = HloInstruction::CreateConstant( - LiteralUtil::CreateR1({42.0, 123.0})); + Literal::CreateR2({{10.0, 20.0}, {30.0, 40.0}})); + auto vector_operand = + HloInstruction::CreateConstant(Literal::CreateR1({42.0, 123.0})); Shape shape = operand1->shape(); // Convenient short names for the operands. diff --git a/tensorflow/compiler/xla/service/hlo_module_config.cc b/tensorflow/compiler/xla/service/hlo_module_config.cc index a2235a26823..8974deb530c 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.cc +++ b/tensorflow/compiler/xla/service/hlo_module_config.cc @@ -58,6 +58,10 @@ string HloModuleConfig::compilation_cache_key() const { StrAppend(&key, "::replica_count=", replica_count()); } StrAppend(&key, debug_options_.DebugString()); + if (intra_op_parallelism_threads() > 0) { + StrAppend(&key, "::intra_op_parallelism_threads=", + intra_op_parallelism_threads()); + } return key; } diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index ee32ab9bc4b..2299200b5be 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -92,6 +92,15 @@ class HloModuleConfig { debug_options_ = debug_options; } + // Sets/returns the number of intra op threads for this module. + void set_intra_op_parallelism_threads( + const int intra_op_parallelism_threads) { + intra_op_parallelism_threads_ = intra_op_parallelism_threads; + } + int64 intra_op_parallelism_threads() const { + return intra_op_parallelism_threads_; + } + private: // If you add new members, be sure to update compilation_cache_key. @@ -116,6 +125,10 @@ class HloModuleConfig { // The number of replicas to compile this binary for. int64 replica_count_ = 1; + // The target maximum parallelism at which to partition HLOs for parallel + // execution on the CPU backend. + int64 intra_op_parallelism_threads_ = -1; + DebugOptions debug_options_; }; diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 870bc729aec..58173bca077 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -38,7 +38,7 @@ class HloModuleTest : public HloTestBase { std::unique_ptr CreateConstantComputation() { auto builder = HloComputation::Builder("Constant"); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); return builder.Build(); } diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index ceb0cdaa316..4d68d0d0882 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -24,6 +24,8 @@ string HloOpcodeString(HloOpcode opcode) { return "abs"; case HloOpcode::kAdd: return "add"; + case HloOpcode::kBatchNormTraining: + return "batch-norm-training"; case HloOpcode::kBitcast: return "bitcast"; case HloOpcode::kBroadcast: @@ -40,6 +42,8 @@ string HloOpcodeString(HloOpcode opcode) { return "convert"; case HloOpcode::kConvolution: return "convolution"; + case HloOpcode::kCos: + return "cosine"; case HloOpcode::kCrossReplicaSum: return "cross-replica-sum"; case HloOpcode::kCustomCall: @@ -112,6 +116,8 @@ string HloOpcodeString(HloOpcode opcode) { return "recv"; case HloOpcode::kReduce: return "reduce"; + case HloOpcode::kReducePrecision: + return "reduce-precision"; case HloOpcode::kReduceWindow: return "reduce-window"; case HloOpcode::kRemainder: diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index e2cdbfdfa7a..d1263219c01 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -30,6 +30,7 @@ namespace xla { enum class HloOpcode { kAbs, kAdd, + kBatchNormTraining, kBitcast, kBroadcast, kCall, @@ -40,6 +41,7 @@ enum class HloOpcode { kConvert, kConvolution, kCopy, + kCos, kCrossReplicaSum, kCustomCall, kDivide, @@ -74,6 +76,7 @@ enum class HloOpcode { kPower, kRecv, kReduce, + kReducePrecision, kReduceWindow, kRemainder, kReshape, diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 72911ae9f91..61e5efa5b63 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -113,6 +113,20 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a, // a_ancestor and b_ancestor must be either both null or both non-null. CHECK_NE(b_ancestor, nullptr); CHECK_EQ(a_ancestor->parent(), b_ancestor->parent()); + + // If the common ancestor is a while instruction there is an additional + // ordering criteria which may apply. The condition computation is considered + // to execute before the body computation so if 'a' is in the condition and + // 'b' is in the body, then 'a' executes before 'b'. + if (a_ancestor == b_ancestor && a_ancestor->opcode() == HloOpcode::kWhile) { + const HloComputation* body = a_ancestor->while_body(); + const HloComputation* condition = a_ancestor->while_condition(); + if (call_graph_->InstructionIsNestedIn(a, condition) && + call_graph_->InstructionIsNestedIn(b, body)) { + return true; + } + } + return ExecutesBeforeInSameComputation(a_ancestor, b_ancestor); } diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 21d852a51d6..56e36bd705a 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -101,7 +101,7 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { auto builder_c = HloComputation::Builder("C"); HloInstruction* c = builder_c.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); HloComputation* computation_c = module->AddEmbeddedComputation(builder_c.Build()); @@ -155,6 +155,68 @@ TEST_F(HloOrderingTest, InstructionsInDifferentComputations) { EXPECT_FALSE(ordering.ExecutesBefore(y, c)); } +TEST_F(HloOrderingTest, InstructionsInWhileComputations) { + // Tests the ordering of instructions in the body and condition of a while + // instruction. HLO code: + // + // body(F32[]) %param): + // %negate = Negate(%param) + // + // condition(F32[] %param): + // %convert = Convert(%param) + // + // entry: + // %constant = Constant(1.0) + // return While(%constant, body, condition) + // + auto module = CreateNewModule(); + const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {}); + + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "body_param")); + auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape, HloOpcode::kNegate, body_param)); + HloComputation* body = module->AddEmbeddedComputation(body_builder.Build()); + + auto cond_builder = HloComputation::Builder("condition"); + auto cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "cond_param")); + auto convert = cond_builder.AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(xla::PRED, {}), cond_param)); + HloComputation* condition = + module->AddEmbeddedComputation(cond_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(scalar_shape, condition, body, constant)); + module->AddEntryComputation(builder.Build()); + + DependencyHloOrdering ordering(module.get()); + EXPECT_TRUE(ordering.ExecutesBefore(constant, xla_while)); + EXPECT_TRUE(ordering.ExecutesBefore(constant, cond_param)); + EXPECT_TRUE(ordering.ExecutesBefore(constant, convert)); + EXPECT_TRUE(ordering.ExecutesBefore(constant, body_param)); + EXPECT_TRUE(ordering.ExecutesBefore(constant, negate)); + + // The while should be unordered relative to the body and condition + // instructions. + EXPECT_FALSE(ordering.ExecutesBefore(xla_while, body_param)); + EXPECT_FALSE(ordering.ExecutesBefore(xla_while, cond_param)); + EXPECT_FALSE(ordering.ExecutesBefore(body_param, xla_while)); + EXPECT_FALSE(ordering.ExecutesBefore(cond_param, xla_while)); + + // Condition instructions should be ordered before body instructions. + EXPECT_TRUE(ordering.ExecutesBefore(cond_param, body_param)); + EXPECT_TRUE(ordering.ExecutesBefore(convert, body_param)); + EXPECT_TRUE(ordering.ExecutesBefore(cond_param, negate)); + EXPECT_TRUE(ordering.ExecutesBefore(convert, negate)); + + EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param)); +} + class MinimumMemoryForSequenceTest : public HloTestBase {}; TEST_F(MinimumMemoryForSequenceTest, MultiComputation) { diff --git a/tensorflow/compiler/xla/service/hlo_query.cc b/tensorflow/compiler/xla/service/hlo_query.cc index a153d73dbd8..d45038f1f4a 100644 --- a/tensorflow/compiler/xla/service/hlo_query.cc +++ b/tensorflow/compiler/xla/service/hlo_query.cc @@ -25,7 +25,7 @@ namespace hlo_query { bool IsConstantR0F32(HloInstruction* instruction, float* out) { if (instruction->opcode() == HloOpcode::kConstant && ShapeUtil::IsScalarF32(instruction->shape())) { - *out = LiteralUtil::Get(instruction->literal(), {}); + *out = instruction->literal().Get({}); return true; } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 2c1b0fff4e6..fb6d8674b6c 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -58,9 +58,8 @@ bool IsRematerializable(const HloInstruction* instruction) { return false; } - // Don't rematerialize instructions with side effects, those with a cost that - // might not be captured by HloCostAnalysis, or instructions which cannot be - // cloned safely. + // Don't rematerialize instructions with side effects or instructions which + // cannot be cloned safely. switch (instruction->opcode()) { case HloOpcode::kCall: case HloOpcode::kConstant: @@ -802,23 +801,14 @@ bool MemoryUsageTracker::Check() const { // Computes and returns the cost of rematerializing the given instruction. // Cost per rematerialized instruction is defined as: // -// (flop_count + transcendental_count + element_count) / memory_reduced +// memory_limit_bytes / memory_reduced // -// flop_count: from HloCostAnalysis -// transcendental_count: from HloCostAnalysis -// element_count: number of elements accessed in operands and output of -// instruction -// memory_reduced: The memory usage reduced by rematerializing the -// instruction. -// -// This is a rough estimate of the extra execution time per byte saved by -// rematerializing this instruction for its remaining uses. In general, we -// want the most memory saving for the least latency penalty which is captured -// by this heuristic. +// The idea is to choose the operation that will save the most memory for +// rematerialization and do not worry about how much the compute costs since +// running out of memory is more harmful than taking longer to get the answer. int64 RematerializationCost(const HloInstruction* instruction, const MemoryUsageTracker& memory_tracker, - const HloCostAnalysis& cost_analysis, - int64 memory_reduced) { + int64 memory_reduced, int64 memory_limit_bytes) { // If none of the users of 'instruction' have been placed in the sequence (as // tracked by memory_tracker), then rematerialization of 'instruction' is a // zero-cost move of 'instruction' in the sequence. @@ -830,22 +820,8 @@ int64 RematerializationCost(const HloInstruction* instruction, } CHECK_GT(memory_reduced, 0); - const int64 bytes_accessed = cost_analysis.bytes_accessed(*instruction); - const int64 elements_accessed = - ShapeUtil::IsTuple(instruction->shape()) - ? bytes_accessed - : bytes_accessed / ShapeUtil::ByteSizeOfPrimitiveType( - instruction->shape().element_type()); - - // Multiply by 256 to improve precision of cost. Without this factor, - // many instructions such as many elementwise instructions would have - // zero cost because the bytes reduced can be several times greater than - // the element count. - return 256 * - (cost_analysis.flop_count(*instruction) + - cost_analysis.transcendental_count(*instruction) + - elements_accessed) / - memory_reduced; + // Return the inverse of the benefit of rematerialization. + return memory_limit_bytes / memory_reduced; } // Selects and returns the best candidate instruction for rematerialization. @@ -856,8 +832,8 @@ int64 RematerializationCost(const HloInstruction* instruction, HloInstruction* PickRematerializationCandidate( const MemoryUsageTracker& memory_tracker, const InstructionList& instruction_list, - const HloCostAnalysis& cost_analysis, - const tensorflow::gtl::FlatSet& blacklist) { + const tensorflow::gtl::FlatSet& blacklist, + int64 memory_limit_bytes) { HloInstruction* best = nullptr; int64 best_cost = 0; @@ -891,12 +867,12 @@ HloInstruction* PickRematerializationCandidate( if (memory_reduced <= 0) { VLOG(5) << "candidate " << candidate->name() - << " memory reduced = " << memory_reduced << " <= 0"; + << " memory reduced = " << memory_reduced << " <= 0"; continue; } const int cost = RematerializationCost(candidate, memory_tracker, - cost_analysis, memory_reduced); + memory_reduced, memory_limit_bytes); VLOG(5) << "candidate " << candidate->name() << ", memory reduced " << memory_reduced << ", cost per byte " << cost; @@ -1011,7 +987,7 @@ StatusOr HloRematerialization::RematerializeComputation( << ", limit is " << HumanReadableNumBytes(memory_limit_bytes); HloInstruction* best = PickRematerializationCandidate( - memory_tracker, instruction_list, cost_analysis_, blacklist); + memory_tracker, instruction_list, blacklist, memory_limit_bytes); if (best == nullptr) { VLOG(3) << "Unable to find rematerialization candidate at program " @@ -1211,11 +1187,6 @@ StatusOr HloRematerialization::Run( VLOG(1) << "Peak memory usage of module (before): " << HumanReadableNumBytes(before_peak_memory); - // Run cost analysis. Operation cost is used in the heuristic for selecting - // instructions for rematerialization. - TF_RETURN_IF_ERROR( - module->entry_computation()->root_instruction()->Accept(&cost_analysis_)); - // Subcomputations called by the entry computation will also be // rematerialized. TF_ASSIGN_OR_RETURN(bool changed, RematerializeComputation( diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 1693f93183b..42c279d440b 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -18,7 +18,6 @@ #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" @@ -61,7 +60,7 @@ class HloRematerialization { protected: HloRematerialization(const ShapeSizeFunction& size_function) - : size_function_(size_function), cost_analysis_(size_function_) {} + : size_function_(size_function) {} ~HloRematerialization() {} // Runs rematerialization on the given module. Returns whether the module was @@ -100,9 +99,6 @@ class HloRematerialization { // Call graph of the hlo_module. std::unique_ptr call_graph_; - // Analysis used for computing the rematerialization cost of instructions. - HloCostAnalysis cost_analysis_; - // The peak memory usage of each computation. The map contains only those // computations called from sequential context // (CallContext::kSequential). These values are updated as rematerialization diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index f306bcc309c..1a861cd16b9 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -126,7 +126,7 @@ class HloRematerializationTest : public HloTestBase { builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + HloInstruction::CreateConstant(Literal::CreateR0(true))); return builder.Build(); } @@ -215,7 +215,7 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + HloInstruction::CreateConstant(Literal::CreateR0(true))); HloComputation* while_cond = module->AddEmbeddedComputation(cond_builder.Build()); @@ -254,7 +254,7 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + HloInstruction::CreateConstant(Literal::CreateR0(true))); HloComputation* while_cond = module->AddEmbeddedComputation(cond_builder.Build()); @@ -289,7 +289,7 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { cond_builder.AddInstruction( HloInstruction::CreateParameter(0, vec1_shape_, "param")); cond_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + HloInstruction::CreateConstant(Literal::CreateR0(true))); HloComputation* while_cond = module->AddEmbeddedComputation(cond_builder.Build()); diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc index 867ebc7f61a..c98856b1921 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc @@ -75,7 +75,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { module->AddEmbeddedComputation(CreateR0S32IdentityComputation()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction::CreateConstant(Literal::CreateR0(5))); auto x = builder.AddInstruction( HloInstruction::CreateCall(r0s32_, {constant}, callee1)); auto y = builder.AddInstruction( @@ -110,9 +110,9 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { module->AddEmbeddedComputation(CreateR0S32AdditionComputation()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction::CreateConstant(Literal::CreateR0(5))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3))); + HloInstruction::CreateConstant(Literal::CreateR0(3))); auto x = builder.AddInstruction( HloInstruction::CreateCall(r0s32_, {constant1, constant2}, callee1)); auto y = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 6707b02c5c5..76177462aa4 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -171,8 +171,7 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, break; case HloOpcode::kConstant: if (ShapeUtil::IsScalar(instruction->shape())) { - attrs["value"].set_s( - LiteralUtil::GetAsString(instruction->literal(), {})); + attrs["value"].set_s(instruction->literal().GetAsString({})); } break; case HloOpcode::kCustomCall: diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc index c2718ea8003..8e9d93e367e 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc @@ -91,7 +91,7 @@ TEST_F(HloTfGraphBuilderTest, CheckConcatenateDimsAndShapes) { TEST_F(HloTfGraphBuilderTest, CheckScalarValue) { auto builder = HloComputation::Builder("Const"); HloInstruction *instruction = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(123))); + HloInstruction::CreateConstant(Literal::CreateR0(123))); OpMetadata metadata; metadata.set_op_name("x"); metadata.set_op_type("y"); diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 2887a8a0a09..84bfbb30c30 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -51,10 +51,10 @@ TEST_F(InlinerTest, MapMax) { auto max_f32 = max_builder.Build(); auto builder = HloComputation::Builder("MapMaxFunction"); - auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1, 2, 3, 4}))); - auto rhs = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({4, 3, 2, 1}))); + auto lhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); + auto rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({4, 3, 2, 1}))); builder.AddInstruction( HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); @@ -70,7 +70,7 @@ TEST_F(InlinerTest, MapMax) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - auto expected = LiteralUtil::CreateR1({4, 3, 3, 4}); + auto expected = Literal::CreateR1({4, 3, 3, 4}); LiteralTestUtil::ExpectEqual(*result, *expected); } @@ -83,12 +83,12 @@ TEST_F(InlinerTest, MapConstant) { HloInstruction::CreateParameter(0, r0f32, "x")); (void)param1; const2_builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0f))); auto const2_f32 = const2_builder.Build(); auto builder = HloComputation::Builder("MapConstFunction"); auto lhs = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}}))); + Literal::CreateR2({{1, 2, 3, 4}, {5, 6, 7, 8}}))); builder.AddInstruction( HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get())); @@ -104,7 +104,7 @@ TEST_F(InlinerTest, MapConstant) { // Verify execution on CPU. auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - auto expected = LiteralUtil::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); + auto expected = Literal::CreateR2({{2, 2, 2, 2}, {2, 2, 2, 2}}); LiteralTestUtil::ExpectEqual(*result, *expected); } diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 721640cdbd8..52da222ab9d 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -43,6 +43,7 @@ namespace xla { case HloOpcode::kConstant: case HloOpcode::kConvert: case HloOpcode::kCopy: + case HloOpcode::kCos: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kEq: @@ -64,6 +65,7 @@ namespace xla { case HloOpcode::kNegate: case HloOpcode::kOutfeed: case HloOpcode::kPad: + case HloOpcode::kReducePrecision: case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kSelect: @@ -75,6 +77,7 @@ namespace xla { return false; // Expensive instructions. + case HloOpcode::kBatchNormTraining: case HloOpcode::kCall: case HloOpcode::kConvolution: case HloOpcode::kCrossReplicaSum: diff --git a/tensorflow/compiler/xla/service/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/instruction_fusion_test.cc index a2e6c2ae00b..b3e0007dcc2 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion_test.cc @@ -28,7 +28,7 @@ TEST_F(InstructionFusionTest, CostlyProducerAndOperandElementReusingConsumerNotFused) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction::CreateConstant(Literal::CreateR0(5))); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); HloInstruction* broadcast2 = @@ -49,7 +49,7 @@ TEST_F(InstructionFusionTest, NonCostlyProducerAndOperandElementReusingConsumerFused) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction::CreateConstant(Literal::CreateR0(5))); HloInstruction* negate1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kNegate, const0)); HloInstruction* broadcast2 = @@ -70,7 +70,7 @@ TEST_F(InstructionFusionTest, CostlyProducerAndNonOperandElementReusingConsumerFused_Reshape) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction::CreateConstant(Literal::CreateR0(5))); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); HloInstruction* reshape2 = builder.AddInstruction( @@ -90,7 +90,7 @@ TEST_F(InstructionFusionTest, CostlyProducerAndNonOperandElementReusingConsumerFused_Transpose) { HloComputation::Builder builder(TestName()); HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(5))); + HloInstruction::CreateConstant(Literal::CreateR0(5))); HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {}), HloOpcode::kExp, const0)); HloInstruction* transpose2 = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index e9e199226a6..ff1493230b3 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -382,7 +382,11 @@ Status LayoutAssignment::AddMandatoryConstraints( // instruction. // TODO(b/31425034): Change infeeds to be more like parameters, with // shapes in the ComputationLayout. - shape_with_layout = &instruction->shape(); + // TODO(b/62477016): When the infeed does not set padding anymore, the + // call to ShapeWithoutPadding can be removed. + Shape infeed_shape = ShapeUtil::ShapeWithoutPadding(instruction->shape()); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(infeed_shape, instruction.get())); } else if (instruction->opcode() == HloOpcode::kOutfeed) { // Constrain the input to the Outfeed instruction to be the expected // layout of the Outfeed. @@ -729,23 +733,18 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( if (instruction->opcode() == HloOpcode::kReshape) { // Prefer the operand layout that makes the reshape an bitcast. If any // dimension bound is 1 in the operand shape, there may be several such - // layouts. So if 'output_layout' is a MajorToMinor layout, try if the + // layouts. So if 'output_layout' is the default layout, try if the // reshape is a bitcast when using the same layout. This may avoid copy // operations. const Shape& output_shape = instruction->shape(); Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout( output_shape.element_type(), AsInt64Slice(output_shape.dimensions()), AsInt64Slice(output_layout.minor_to_major())); - const Shape& operand_shape = operand->shape(); - if (LayoutUtil::IsMonotonicWithDim0Major(output_layout)) { - Shape operand_shape_with_layout = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - operand_shape.element_type(), - AsInt64Slice(operand_shape.dimensions())); - if (ShapeUtil::ReshapeIsBitcast(operand_shape_with_layout, - output_shape_with_layout)) { - return MakeUnique(operand_shape_with_layout.layout()); - } + Shape operand_shape = operand->shape(); + *operand_shape.mutable_layout() = + LayoutUtil::GetDefaultLayoutForShape(operand_shape); + if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { + return MakeUnique(operand_shape.layout()); } auto aligned_operand_shape = ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape); @@ -759,10 +758,14 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( if (instruction->opcode() == HloOpcode::kTranspose) { // Pick the operand layout that makes the transpose a bitcast. - std::vector perm = - ComposePermutations(instruction->dimensions(), - AsInt64Slice(output_layout.minor_to_major())); - Layout operand_layout = LayoutUtil::MakeLayout(perm); + int64 rank = ShapeUtil::Rank(instruction->shape()); + std::vector new_minor_to_major(rank); + for (int64 i = 0; i < rank; ++i) { + int64 output_dim = output_layout.minor_to_major(i); + int64 operand_dim = instruction->dimensions(output_dim); + new_minor_to_major[i] = operand_dim; + } + Layout operand_layout = LayoutUtil::MakeLayout(new_minor_to_major); TF_CHECK_OK( LayoutUtil::ValidateLayoutForShape(operand_layout, operand->shape())); return MakeUnique(operand_layout); @@ -789,23 +792,18 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( if (user->opcode() == HloOpcode::kReshape) { // Prefer the user layout that makes the reshape an bitcast. If any // dimension bound is 1 in the user shape, there may be several such - // layouts. So if 'operand_layout' is a MajorToMinor layout, try if the + // layouts. So if 'operand_layout' is the default layout, try if the // reshape is a bitcast when using the same layout. This may avoid copy // operations. Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout( operand->shape().element_type(), AsInt64Slice(operand->shape().dimensions()), AsInt64Slice(operand_layout.minor_to_major())); - const Shape& output_shape = user->shape(); - if (LayoutUtil::IsMonotonicWithDim0Major(operand_layout)) { - Shape output_shape_with_layout = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - output_shape.element_type(), - AsInt64Slice(output_shape.dimensions())); - if (ShapeUtil::ReshapeIsBitcast(output_shape_with_layout, - operand_shape_with_layout)) { - return MakeUnique(output_shape_with_layout.layout()); - } + Shape output_shape = user->shape(); + *output_shape.mutable_layout() = + LayoutUtil::GetDefaultLayoutForShape(output_shape); + if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { + return MakeUnique(output_shape.layout()); } auto aligned_user_shape = ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape); @@ -818,14 +816,16 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( } if (user->opcode() == HloOpcode::kTranspose) { - // Pick the user layout that makes the reshape a bitcast. - // To become a bitcast, the layouts need to satisfy - // collapsing_order * output_layout = input_layout - // so output_layout = inverse(collapsing_order) * input_layout - std::vector perm = - Permute(InversePermutation(user->dimensions()), - AsInt64Slice(operand_layout.minor_to_major())); - Layout user_layout = LayoutUtil::MakeLayout(perm); + // Pick the user layout that makes the transpose a bitcast. + int64 rank = ShapeUtil::Rank(user->shape()); + std::vector new_minor_to_major(rank); + auto inverse_dimensions = InversePermutation(user->dimensions()); + for (int64 i = 0; i < rank; ++i) { + int64 operand_dim = operand_layout.minor_to_major(i); + int64 user_dim = inverse_dimensions[operand_dim]; + new_minor_to_major[i] = user_dim; + } + Layout user_layout = LayoutUtil::MakeLayout(new_minor_to_major); TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(user_layout, user->shape())); return MakeUnique(user_layout); } @@ -926,7 +926,7 @@ Status LayoutAssignment::PropagateUseConstraintToDefs( ShapeUtil::IsArray(buffer->shape())) { TF_RETURN_IF_ERROR(constraints->SetBufferLayout( ShapeUtil::GetSubshape(shape_layout.shape(), index).layout(), - *buffer)); + *buffer, /*mandatory=*/true)); } } } diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 6d818cdea0c..f69c043f32b 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -230,7 +230,7 @@ TEST_F(LayoutAssignmentTest, TupleSelect) { HloInstruction::CreateTuple({constant0, constant1})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(true))); + HloInstruction::CreateConstant(Literal::CreateR0(true))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); @@ -264,7 +264,7 @@ TEST_F(LayoutAssignmentTest, ConflictingLayoutTuple) { // tuple and assigning the layouts of the copied arrays as needed. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); auto inner_tuple = builder.AddInstruction(HloInstruction::CreateTuple({constant})); auto nested_tuple = builder.AddInstruction( @@ -552,6 +552,41 @@ TEST_F(LayoutAssignmentTest, MakeOperandsTheSame) { ElementsAre(1, 0)); } +// Test layout assignment of a transpose into a bitcast based on its operand. +TEST_F(LayoutAssignmentTest, TransposeToBitcastFromOperand) { + auto builder = HloComputation::Builder(TestName()); + Shape input_shape_with_layout = + ShapeUtil::MakeShapeWithLayout(F32, {3, 5, 6, 7}, {2, 0, 3, 1}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape_with_layout, "param")); + auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), param, {2, 3, 0, 1})); + auto module = CreateNewModule(); + HloComputation* computation = + module->AddEntryComputation(builder.Build(transpose)); + ComputationLayout computation_layout(computation->ComputeProgramShape()); + AssignLayouts(module.get(), &computation_layout); + EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), + transpose->shape(), {2, 3, 0, 1})); +} +// Test layout assignment of a transpose into a bitcast based on its user. +TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { + auto builder = HloComputation::Builder(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {3, 5, 6, 7}); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(input_shape, constant, {})); + auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {6, 7, 3, 5}), broadcast, {2, 3, 0, 1})); + auto module = CreateNewModule(); + HloComputation* computation = + module->AddEntryComputation(builder.Build(transpose)); + ComputationLayout computation_layout(computation->ComputeProgramShape()); + AssignLayouts(module.get(), &computation_layout); + EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), + transpose->shape(), {2, 3, 0, 1})); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util.cc b/tensorflow/compiler/xla/service/liveness_util.cc index 682bf19807b..9c80fb3adbc 100644 --- a/tensorflow/compiler/xla/service/liveness_util.cc +++ b/tensorflow/compiler/xla/service/liveness_util.cc @@ -28,17 +28,6 @@ limitations under the License. namespace xla { -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user) { - CHECK(user->IsUserOf(operand)) - << "user: " << user->ToString() << " operand: " << operand->ToString(); - - // GetTupleElement instructions only access the top-level buffer of their - // operand. - return (user->opcode() == HloOpcode::kGetTupleElement && !index.empty()); -} - bool DoesNotUseOperandBuffer(const HloInstruction* operand, const ShapeIndex& index, const HloInstruction* user, @@ -149,18 +138,22 @@ bool HasUniqueFusedUseOfOperandAt( // User and operand can share buffers iff both instructions emit the same shape // and layout, and 'user' meets one of the following qualifications: -// *) Is element-wise. Or... -// *) Is a loop fusion instruction where the only use of 'operand' at 'index' -// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root -// at operand 0. Or... -// *) Is a kDot -> kAdd (or fused kTransposeDot -> kAdd) output fusion -// instruction where the only use of 'operand' at 'index' in the set -// 'user.fused_instructions' is a kAdd fused root at operand 0 or 1. Or... -// *) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index 0. +// +// (1) Is element-wise. Or... +// (2) Is a loop fusion instruction where the only use of 'operand' at 'index' +// in the set 'user.fused_instructions' is a DynamicUpdateSlice fused root +// at operand 0. Or... +// (3) Is a kDot -> kAdd (or fused kTransposeDot -> kAdd) output fusion +// instruction where the only use of 'operand' at 'index' in the set +// 'user.fused_instructions' is a kAdd fused root at operand 0 or 1. Or... +// (4) The 'user' of 'operand' is DynamicUpdateSlice or While at operand index +// 0. +// +// (2) and (3) can only be determined if points-to analysis is available. bool CanShareOperandBufferWithUser( HloInstruction* operand, const ShapeIndex& operand_index, HloInstruction* user, const ShapeIndex& user_index, - const TuplePointsToAnalysis& points_to_analysis) { + const TuplePointsToAnalysis* points_to_analysis) { CHECK(user->IsUserOf(operand)) << "user: " << user->ToString() << " operand: " << operand->ToString(); Shape operand_subshape = @@ -170,7 +163,7 @@ bool CanShareOperandBufferWithUser( if (!ShapeUtil::Equal(operand_subshape, user_subshape)) { return false; } - if (user->opcode() == HloOpcode::kFusion) { + if (points_to_analysis != nullptr && user->opcode() == HloOpcode::kFusion) { if (user->fusion_kind() == HloInstruction::FusionKind::kLoop && user->fused_expression_root()->opcode() == HloOpcode::kDynamicUpdateSlice) { @@ -180,7 +173,7 @@ bool CanShareOperandBufferWithUser( // 'operand_index', and this singleton use is the fused root at operand // index 0. return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, 0, - points_to_analysis); + *points_to_analysis); } else if (user->fusion_kind() == HloInstruction::FusionKind::kOutput && user->fused_expression_root()->opcode() == HloOpcode::kAdd) { // Output fusion with kAdd fused root. @@ -208,7 +201,7 @@ bool CanShareOperandBufferWithUser( // index 'other_add_operand_index'). return HasUniqueFusedUseOfOperandAt(operand, operand_index, user, other_add_operand_index, - points_to_analysis); + *points_to_analysis); } } if (user->opcode() == HloOpcode::kDynamicUpdateSlice || diff --git a/tensorflow/compiler/xla/service/liveness_util.h b/tensorflow/compiler/xla/service/liveness_util.h index 0b01223db73..c7799e5ab5d 100644 --- a/tensorflow/compiler/xla/service/liveness_util.h +++ b/tensorflow/compiler/xla/service/liveness_util.h @@ -34,21 +34,16 @@ bool DoesNotUseOperandBuffer(const HloInstruction* operand, const HloInstruction* user, const TuplePointsToAnalysis& points_to_analysis); -// Overload which does not require points-to analysis. The result is more -// conservative (returns false more often). -bool DoesNotUseOperandBuffer(const HloInstruction* operand, - const ShapeIndex& index, - const HloInstruction* user); - // Returns true if 'user' (at 'user_index') can share a buffer with its operand -// 'operand' (at 'operand_index'). -// Returns false otherwise. +// 'operand' (at 'operand_index'). Returns false otherwise. Optionally takes a +// points-to analysis argument. Without the analysis, the result is more +// conservative (returns false more often). // // REQUIRES: 'operand' is an operand of 'user'. bool CanShareOperandBufferWithUser( HloInstruction* operand, const ShapeIndex& operand_index, HloInstruction* user, const ShapeIndex& user_index, - const TuplePointsToAnalysis& points_to_analysis); + const TuplePointsToAnalysis* points_to_analysis = nullptr); } // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc index bad4be149a6..6a4fde87614 100644 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ b/tensorflow/compiler/xla/service/liveness_util_test.cc @@ -85,9 +85,9 @@ TEST_F(DoesNotUseOperandBufferTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(Literal::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + Literal::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -122,10 +122,10 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseSameShape) { BuildModuleAndRunAnalysis(builder.Build()); - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, log, {}, *points_to_analysis_)); + EXPECT_TRUE(CanShareOperandBufferWithUser(param, {}, exp, {}, + points_to_analysis_.get())); + EXPECT_TRUE(CanShareOperandBufferWithUser(exp, {}, log, {}, + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { @@ -143,9 +143,9 @@ TEST_F(CanShareOperandBufferWithUserTest, ElementWiseDifferentShape) { BuildModuleAndRunAnalysis(builder.Build()); EXPECT_FALSE(CanShareOperandBufferWithUser(param0, {}, result, {}, - *points_to_analysis_)); + points_to_analysis_.get())); EXPECT_FALSE(CanShareOperandBufferWithUser(param1, {}, result, {}, - *points_to_analysis_)); + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { @@ -161,10 +161,10 @@ TEST_F(CanShareOperandBufferWithUserTest, CopyShares) { BuildModuleAndRunAnalysis(builder.Build()); - EXPECT_TRUE( - CanShareOperandBufferWithUser(param, {}, exp, {}, *points_to_analysis_)); - EXPECT_TRUE( - CanShareOperandBufferWithUser(exp, {}, copy, {}, *points_to_analysis_)); + EXPECT_TRUE(CanShareOperandBufferWithUser(param, {}, exp, {}, + points_to_analysis_.get())); + EXPECT_TRUE(CanShareOperandBufferWithUser(exp, {}, copy, {}, + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { @@ -180,9 +180,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { // Create a DynamicUpdateSlice instruction of tuple element 1. auto starts = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({2}))); + HloInstruction::CreateConstant(Literal::CreateR1({2}))); auto update = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({2.f, 2.f, 2.f}))); + Literal::CreateR1({2.f, 2.f, 2.f}))); auto dynamic_update_slice = builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( data_shape, gte1, update, starts)); @@ -197,9 +197,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDynamicUpdateSlice) { // The fusion instruction can share with tuple element 1. EXPECT_FALSE(CanShareOperandBufferWithUser(tuple, {0}, fusion, {}, - *points_to_analysis_)); + points_to_analysis_.get())); EXPECT_TRUE(CanShareOperandBufferWithUser(tuple, {1}, fusion, {}, - *points_to_analysis_)); + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { @@ -221,12 +221,12 @@ TEST_F(CanShareOperandBufferWithUserTest, DynamicUpdateSliceCanShare) { // The DynamicUpdateSlice instruction can share with the data operand, but not // with update or starts. - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, dus, {}, *points_to_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(update, {}, dus, {}, *points_to_analysis_)); - EXPECT_FALSE( - CanShareOperandBufferWithUser(starts, {}, dus, {}, *points_to_analysis_)); + EXPECT_TRUE(CanShareOperandBufferWithUser(data, {}, dus, {}, + points_to_analysis_.get())); + EXPECT_FALSE(CanShareOperandBufferWithUser(update, {}, dus, {}, + points_to_analysis_.get())); + EXPECT_FALSE(CanShareOperandBufferWithUser(starts, {}, dus, {}, + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { @@ -234,15 +234,15 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto a = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); auto b = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); auto dot = builder.AddInstruction( HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b)); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto add_operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -256,7 +256,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { // Output fused dot add should be able to share buffer with 'add_operand'. EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *points_to_analysis_)); + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { @@ -264,9 +264,9 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto a = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); auto b = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); auto b_t = builder.AddInstruction( HloInstruction::CreateTranspose(data_shape, b, {1, 0})); @@ -274,7 +274,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b_t)); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto add_operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -292,7 +292,7 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { // Output fused transpose-dot-add should be share buffer with 'add_operand'. EXPECT_TRUE(CanShareOperandBufferWithUser(add_operand, {}, fusion, {}, - *points_to_analysis_)); + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { @@ -300,7 +300,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); auto one = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto operand = builder.AddInstruction( HloInstruction::CreateBroadcast(data_shape, one, {1})); @@ -308,7 +308,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { HloInstruction::CreateReverse(data_shape, operand, {0, 1})); auto two = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); auto add = builder.AddInstruction( HloInstruction::CreateBinary(data_shape, HloOpcode::kAdd, reverse, two)); @@ -320,7 +320,7 @@ TEST_F(CanShareOperandBufferWithUserTest, OutputFusionCantAliasOperandBuffer) { // Output fused operand->reverse->add cannot alias operand buffer 'operand'. EXPECT_FALSE(CanShareOperandBufferWithUser(operand, {}, fusion, {}, - *points_to_analysis_)); + points_to_analysis_.get())); } TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { @@ -360,8 +360,8 @@ TEST_F(CanShareOperandBufferWithUserTest, WhileCanShare) { RunAnalysis(); // The While instruction can share with the data operand. - EXPECT_TRUE( - CanShareOperandBufferWithUser(data, {}, whil, {}, *points_to_analysis_)); + EXPECT_TRUE(CanShareOperandBufferWithUser(data, {}, whil, {}, + points_to_analysis_.get())); } } // namespace diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 12b2762f0ed..c5b3f317b25 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -29,7 +29,6 @@ cc_library( ":ir_array", ":llvm_util", "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/legacy_flags:alias_analysis_flags", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:logical_buffer", @@ -47,7 +46,6 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:llvm_util_flags", "//tensorflow/core:lib", "@llvm//:core", "@llvm//:support", diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index 02710ff57f6..d4512557745 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "external/llvm/include/llvm/IR/MDBuilder.h" -#include "tensorflow/compiler/xla/legacy_flags/alias_analysis_flags.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" #include "tensorflow/compiler/xla/types.h" @@ -87,12 +86,6 @@ llvm::MDNode* AliasAnalysis::GetAliasDomain() { llvm::MDNode* AliasAnalysis::GetAliasScopeMetadataForBuffer( const BufferAllocation::Slice& buffer_slice, llvm::MDNode* domain) { - legacy_flags::AliasAnalysisFlags* flags = - legacy_flags::GetAliasAnalysisFlags(); - if (!flags->xla_emit_alias_scope) { - return nullptr; - } - // While we could synthesize an alias.scope, doing so is not more profitable // than LLVM's default behavior. if (buffer_slice.allocation() == kParameterAllocation) { @@ -109,12 +102,6 @@ llvm::MDNode* AliasAnalysis::GetAliasScopeMetadataForBuffer( llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer( const BufferAllocation::Slice& buffer_slice, llvm::MDNode* domain, const BufferAssignment& assignment, const HloInstruction& hlo) { - legacy_flags::AliasAnalysisFlags* flags = - legacy_flags::GetAliasAnalysisFlags(); - if (!flags->xla_emit_alias_scope) { - return nullptr; - } - // We want to construct a list of buffers which: // // 1. Do not alias the given buffer. diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index e401305ae73..5268e89e0f2 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -85,7 +85,7 @@ IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape) ++depth; } - if (ShapeUtil::Rank(*shape_) == 0) { + if (!ShapeUtil::IsArray(*shape_) || ShapeUtil::IsScalar(*shape_)) { DCHECK(depth == 1 || depth == 0) << depth; } else { DCHECK_EQ(depth, ShapeUtil::Rank(*shape_)) << shape.ShortDebugString(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index ff2f4cd693c..bcc9418d591 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -22,7 +22,6 @@ limitations under the License. #include "external/llvm/include/llvm/IR/Operator.h" #include "external/llvm/include/llvm/Target/TargetOptions.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/llvm_util_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" @@ -163,36 +162,36 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, llvm::Constant* value; switch (shape.element_type()) { case PRED: - value = llvm::ConstantInt::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantInt::get(ir_element_type, + literal.Get(*multi_index)); break; case U8: - value = llvm::ConstantInt::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantInt::get(ir_element_type, + literal.Get(*multi_index)); break; case S32: - value = llvm::ConstantInt::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantInt::get(ir_element_type, + literal.Get(*multi_index)); break; case U32: - value = llvm::ConstantInt::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantInt::get(ir_element_type, + literal.Get(*multi_index)); break; case S64: - value = llvm::ConstantInt::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantInt::get(ir_element_type, + literal.Get(*multi_index)); break; case U64: - value = llvm::ConstantInt::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantInt::get(ir_element_type, + literal.Get(*multi_index)); break; case F32: - value = llvm::ConstantFP::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantFP::get(ir_element_type, + literal.Get(*multi_index)); break; case F64: - value = llvm::ConstantFP::get( - ir_element_type, LiteralUtil::Get(literal, *multi_index)); + value = llvm::ConstantFP::get(ir_element_type, + literal.Get(*multi_index)); break; default: LOG(FATAL) << "unsupported type " << shape.element_type(); @@ -357,11 +356,6 @@ void EmitLogging(const char* tag, llvm::Value* value, void SetTbaaForInstruction(llvm::Instruction* instruction, Shape shape, bool is_pointer_to) { - legacy_flags::LlvmUtilFlags* flags = legacy_flags::GetLlvmUtilFlags(); - if (!flags->xla_emit_tbaa) { - return; - } - llvm::MDBuilder metadata_builder(instruction->getContext()); llvm::MDNode* root = metadata_builder.createTBAARoot("XLA TBAA"); string type_name; @@ -371,7 +365,7 @@ void SetTbaaForInstruction(llvm::Instruction* instruction, Shape shape, // Scalars do not have layout which makes it permissible to omit an explicit // layout. To make sure that equivalent scalar shapes have the same TBAA, // remove the (meaningless) explicit layout if one is present. - if (ShapeUtil::Rank(shape) == 0) { + if (!ShapeUtil::IsArray(shape) || ShapeUtil::IsScalar(shape)) { LayoutUtil::ClearLayout(&shape); } else { CHECK(shape.has_layout()); diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 131c2ee87b0..90dfe01bb9b 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/legacy_flags/service_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" @@ -62,7 +63,6 @@ namespace xla { BackendOptions backend_options; backend_options.set_platform(platform) - .set_number_of_replicas(options.number_of_replicas()) .set_intra_op_parallelism_threads(options.intra_op_parallelism_threads()); TF_ASSIGN_OR_RETURN(std::unique_ptr backend, Backend::CreateBackend(backend_options)); @@ -70,13 +70,15 @@ namespace xla { TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, CreateComputeConstantBackend()); std::unique_ptr service(new LocalService( - std::move(backend), std::move(compute_constant_backend))); + options, std::move(backend), std::move(compute_constant_backend))); return std::move(service); } -LocalService::LocalService(std::unique_ptr execute_backend, +LocalService::LocalService(const ServiceOptions& options, + std::unique_ptr execute_backend, std::unique_ptr compute_constant_backend) - : Service(std::move(execute_backend), std::move(compute_constant_backend)) { + : Service(options, std::move(execute_backend), + std::move(compute_constant_backend)) { runs_in_client_process_ = true; } @@ -152,7 +154,12 @@ StatusOr> LocalService::CompileExecutable( // Construct computation layout from the argument layouts. auto module_config = MakeUnique(*program_shape); module_config->set_has_hybrid_result(has_hybrid_result); - module_config->set_replica_count(execute_backend_->Replicas().size()); + module_config->set_replica_count(options_.number_of_replicas()); + module_config->set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + if (execute_backend_->eigen_intra_op_thread_pool() != nullptr) { + module_config->set_intra_op_parallelism_threads( + execute_backend_->eigen_intra_op_thread_pool()->NumThreads()); + } legacy_flags::ServiceFlags* flags = legacy_flags::GetServiceFlags(); if (flags->xla_hlo_profile) { module_config->enable_hlo_profiling(true); diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 767a3ab697f..915c8c3072f 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -60,7 +60,8 @@ class LocalService : public Service { const Shape* result_layout, int device_ordinal, bool has_hybrid_result); private: - explicit LocalService(std::unique_ptr backend, + explicit LocalService(const ServiceOptions& options, + std::unique_ptr backend, std::unique_ptr compute_constant_backend); LocalService(const LocalService&) = delete; void operator=(const LocalService&) = delete; diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index 9becdb2bed4..49c17555202 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -84,7 +84,7 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { builder.AddInstruction(HloInstruction::CreateReshape(root_shape, rng0)); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateFromShape(root_shape))); + HloInstruction::CreateConstant(Literal::CreateFromShape(root_shape))); builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, const1)); @@ -179,9 +179,8 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); - auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2( - {{true, true, false}, {false, false, true}}))); + auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{true, true, false}, {false, false, true}}))); auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param1")); @@ -263,12 +262,12 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const0)); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1)); @@ -318,7 +317,7 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( 0, ShapeUtil::MakeShape(F32, {1, 3, 1, 2}), "param0")); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); builder.AddInstruction(HloInstruction::CreateBinary( @@ -464,7 +463,7 @@ TEST_F(ReshapeMoverTest, ImplicitlyBroadcastReshapeIsNotMovedBug37787999) { auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {128, 1}), param0)); Array2D a(128, 1024); - auto literal = LiteralUtil::CreateR2FromArray2D(a); + auto literal = Literal::CreateR2FromArray2D(a); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); auto multiply = builder.AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 85ca7e4e59c..31740757ab6 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -21,6 +21,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/legacy_flags/backend_flags.h" #include "tensorflow/compiler/xla/legacy_flags/service_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -141,12 +142,13 @@ int ServiceOptions::intra_op_parallelism_threads() const { } BackendOptions backend_options; backend_options.set_platform(platform); - backend_options.set_number_of_replicas(options.number_of_replicas()); TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options)); + TF_ASSIGN_OR_RETURN(std::unique_ptr compute_constant_backend, CreateComputeConstantBackend()); - std::unique_ptr service(new Service( - std::move(execute_backend), std::move(compute_constant_backend))); + std::unique_ptr service( + new Service(options, std::move(execute_backend), + std::move(compute_constant_backend))); return std::move(service); } @@ -158,7 +160,6 @@ Service::CreateComputeConstantBackend() { if (platform->id() == se::host::kHostPlatformId) { BackendOptions backend_options; backend_options.set_platform(platform); - backend_options.set_number_of_replicas(1); return Backend::CreateBackend(backend_options); } } @@ -171,11 +172,24 @@ Service::CreateComputeConstantBackend() { }; } -Service::Service(std::unique_ptr execute_backend, +Service::Service(const ServiceOptions& options, + std::unique_ptr execute_backend, std::unique_ptr compute_constant_backend) - : execute_backend_(std::move(execute_backend)), + : options_(options), + execute_backend_(std::move(execute_backend)), compute_constant_backend_(std::move(compute_constant_backend)) { + // TODO(b/32648682): this flag / options update dance will go away once we + // pass the replica count explicitly to the service. + if (options_.number_of_replicas() < 0) { + legacy_flags::BackendFlags* flags = legacy_flags::GetBackendFlags(); + options_.set_number_of_replicas(flags->xla_replicas); + } + if (execute_backend_) { + if (execute_backend_->device_count() > 0) { + CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas()) + << "Requested more replicas than there are devices."; + } LOG(INFO) << Printf( "XLA service %p executing computations on platform %s. Devices:", this, execute_backend_->platform()->Name().c_str()); @@ -325,7 +339,7 @@ StatusOr> Service::CreateModuleConfig( module_config->enable_hlo_profiling(true); } - module_config->set_replica_count(backend->Replicas().size()); + module_config->set_replica_count(options_.number_of_replicas()); module_config->set_seed(execution_options.seed()); module_config->set_debug_options(execution_options.debug_options()); @@ -495,47 +509,55 @@ Service::ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice< std::vector> arguments, - Backend* backend, - tensorflow::gtl::ArraySlice executors, + Backend* backend, tensorflow::gtl::ArraySlice device_handles, tensorflow::gtl::ArraySlice result_tags) { - // TODO(b/33943292): Support for replication when using multiple computations. - TF_RET_CHECK(backend->Replicas().size() == 1); - - // Set up streams. + // Streams where the computation are launched, so we can wait on the streams + // to complete. std::vector::SmartPtr> streams; - for (se::StreamExecutor* executor : executors) { - TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, - backend->BorrowStream(executor)); - streams.push_back(std::move(stream)); - } - - // Set up run options. - std::vector run_options; - for (const Pool::SmartPtr& stream : streams) { - ExecutableRunOptions options; - options.set_stream(stream.get()); - options.set_allocator(backend->memory_allocator()); - options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); - options.set_intra_op_thread_pool( - backend->eigen_intra_op_thread_pool_device()); - run_options.emplace_back(options, backend->StreamBorrower()); - } - - // Asynchronously launch all executables. + // Global data handles for the computation results, one for each computation. std::vector result_handles; - for (tensorflow::gtl::ArraySlice::size_type i = 0; - i < executables.size(); i++) { - TF_ASSIGN_OR_RETURN( - perftools::gputools::DeviceMemoryBase result, - executables[i]->ExecuteAsyncOnStream(&run_options[i], arguments[i])); - result_handles.push_back(allocation_tracker_.Register( - backend, executors[i]->device_ordinal(), result, - executables[i]->result_shape(), result_tags[i])); + + TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, + backend->computation_placer()->AssignDevices( + options_.number_of_replicas(), executables.size())); + + for (int64 i = 0; i < executables.size(); i++) { + // Stream executors for the replicas of the current computation. + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i])); + for (int64 replica = 0; replica < replicas.size(); ++replica) { + TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, + backend->BorrowStream(replicas[replica])); + streams.push_back(std::move(stream)); + + // Set up run options. + ExecutableRunOptions options; + options.set_stream(streams.back().get()); + options.set_allocator(backend->memory_allocator()); + options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); + options.set_intra_op_thread_pool( + backend->eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); + ServiceExecutableRunOptions run_options(options, + backend->StreamBorrower()); + + // Asynchronously launch the computation. + TF_ASSIGN_OR_RETURN( + perftools::gputools::DeviceMemoryBase result, + executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i])); + + // All replicas share the same device address for the result allocation, + // so only one of the replicas need to register the result handle. + if (replica == 0) { + result_handles.push_back(allocation_tracker_.Register( + backend, replicas[0]->device_ordinal(), result, + executables[i]->result_shape(), result_tags[i])); + } + } } // Wait for all executions to complete. - for (int64 i = 0; i < result_handles.size(); ++i) { + for (int64 i = 0; i < streams.size(); ++i) { if (!streams[i]->BlockHostUntilDone()) { return InternalError("failed to complete execution for stream %lld", i); } @@ -550,17 +572,23 @@ StatusOr Service::ExecuteAndRegisterResult( arguments, Backend* backend, perftools::gputools::StreamExecutor* executor, const string& result_tag, ExecutionProfile* profile) { - TF_RET_CHECK(!backend->Replicas().empty()); - // Set up streams. std::vector::SmartPtr> streams; - for (se::StreamExecutor* executor : backend->Replicas()) { + TF_ASSIGN_OR_RETURN(auto replicas, + Replicas(*backend, SingleComputationDeviceHandle())); + TF_RET_CHECK(!replicas.empty()); + for (se::StreamExecutor* executor : replicas) { TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, backend->BorrowStream(executor)); streams.push_back(std::move(stream)); } + TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, + backend->computation_placer()->AssignDevices( + options_.number_of_replicas(), + /*computation_count=*/1)); + // Set up run options. std::vector run_options; for (const Pool::SmartPtr& stream : streams) { @@ -570,19 +598,20 @@ StatusOr Service::ExecuteAndRegisterResult( options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); options.set_intra_op_thread_pool( backend->eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); run_options.emplace_back(options, backend->StreamBorrower(), backend->inter_op_thread_pool()); } perftools::gputools::DeviceMemoryBase result; - if (backend->Replicas().size() == 1) { + if (options_.number_of_replicas() == 1) { TF_ASSIGN_OR_RETURN( result, executable->ExecuteOnStreamWrapper( &run_options[0], profile, arguments)); } else { std::vector< tensorflow::gtl::ArraySlice> - repeated_arguments(backend->Replicas().size(), arguments); + repeated_arguments(options_.number_of_replicas(), arguments); TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( run_options, repeated_arguments)); @@ -610,25 +639,26 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, std::vector versioned_handles; std::vector> module_configs; std::vector computation_names; + std::vector device_handles; - if (arg->requests_size() > execute_backend_->stream_executors().size()) { + if (arg->requests_size() * options_.number_of_replicas() > + execute_backend_->device_count()) { return FailedPrecondition( "there are not enough stream executors to execute %d computations", arg->requests_size()); } for (int64 i = 0; i < arg->requests_size(); ++i) { - // Get the stream executor on which the computation will run. Select the - // specific device if requested, otherwise select the i'th device from the - // list of available stream executors. - se::StreamExecutor* executor; - if (arg->requests(i).has_device_handle()) { - executor = - execute_backend_ - ->stream_executors()[arg->requests(i).device_handle().handle()]; - } else { - executor = execute_backend_->stream_executors()[i]; + // Get the stream executor for the i'th computation. This stream executor + // is one of the executors to run the replicated computation. + if (!arg->requests(i).has_device_handle()) { + return FailedPrecondition( + "device handles must be given to execute parallel computations"); } + TF_ASSIGN_OR_RETURN( + auto replicas, + Replicas(*execute_backend_, arg->requests(i).device_handle())); + se::StreamExecutor* executor = replicas[0]; CHECK(executor != nullptr); // Resolve the UserComputation object associated with the requested @@ -673,6 +703,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, module_configs.push_back(std::move(module_config)); computation_names.push_back(user_computation->name()); executors.push_back(executor); + device_handles.push_back(arg->requests(i).device_handle()); } // Build the user computations into HloModules and compile to generate the @@ -692,7 +723,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, TF_ASSIGN_OR_RETURN( std::vector outputs, ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, - execute_backend_.get(), executors, + execute_backend_.get(), device_handles, computation_names)); for (const GlobalDataHandle& output : outputs) { ExecuteResponse response; @@ -706,10 +737,12 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, GetDeviceHandlesResponse* result) { - const int64 available_device_count = - execute_backend_->stream_executors().size(); - const int64 replicas = execute_backend_->Replicas().size(); - if (available_device_count < arg->device_count() * replicas) { + const int64 available_device_count = execute_backend_->device_count(); + const int64 replica_count = options_.number_of_replicas(); + if (replica_count <= 0) { + return FailedPrecondition("Replica count must be a positive integer"); + } + if (available_device_count < arg->device_count() * replica_count) { return ResourceExhausted( "Requested device count (%lld) exceeds the number of available devices " "on the target (%lld)", @@ -718,8 +751,8 @@ tensorflow::Status Service::GetDeviceHandles(const GetDeviceHandlesRequest* arg, for (int64 i = 0; i < arg->device_count(); ++i) { DeviceHandle device_handle; - device_handle.set_handle( - execute_backend_->stream_executors()[i * replicas]->device_ordinal()); + device_handle.set_handle(i); + device_handle.set_device_count(arg->device_count()); *result->add_device_handles() = device_handle; } @@ -841,11 +874,14 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, execute_backend_->default_stream_executor(), &profile)); - TF_RET_CHECK(!execute_backend_->Replicas().empty()); + TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, + SingleComputationDeviceHandle())); + TF_RET_CHECK(!replicas.empty()); + // Set up streams. std::vector::SmartPtr> streams; - for (se::StreamExecutor* executor : execute_backend_->Replicas()) { + for (se::StreamExecutor* executor : replicas) { TF_ASSIGN_OR_RETURN(Pool::SmartPtr stream, execute_backend_->BorrowStream(executor)); streams.push_back(std::move(stream)); @@ -927,19 +963,20 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, Literal literal = Literal(arg->literal()); const Shape& shape = literal.shape(); - if (ShapeUtil::IsTuple(shape) && execute_backend_->Replicas().size() > 1) { + if (ShapeUtil::IsTuple(shape) && options_.number_of_replicas() > 1) { // TODO(b/32990684): Tuple transfers to host end up allocating further // buffers - implement that correctly. return Unimplemented( "Tuple transfers to the device not supported with replication."); } - se::StreamExecutor* stream_executor; + std::vector replicas; if (arg->has_device_handle()) { - TF_ASSIGN_OR_RETURN(stream_executor, execute_backend_->stream_executor( - arg->device_handle().handle())); + TF_ASSIGN_OR_RETURN(replicas, + Replicas(*execute_backend_, arg->device_handle())); } else { - stream_executor = execute_backend_->default_stream_executor(); + TF_ASSIGN_OR_RETURN( + replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); } // Allocate memory on the device, using the stream executor. The size of the @@ -950,14 +987,12 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation, execute_backend_->memory_allocator()->Allocate( - stream_executor->device_ordinal(), allocation_size)); + replicas[0]->device_ordinal(), allocation_size)); *result->mutable_data() = allocation_tracker_.Register( - execute_backend_.get(), stream_executor->device_ordinal(), allocation, - shape, StrCat("TransferToServer literal of size ", allocation_size)); + execute_backend_.get(), replicas[0]->device_ordinal(), allocation, shape, + StrCat("TransferToServer literal of size ", allocation_size)); - TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas( - stream_executor->device_ordinal())); for (se::StreamExecutor* executor : replicas) { TF_RETURN_IF_ERROR( execute_backend_->transfer_manager()->TransferLiteralToDevice( @@ -968,7 +1003,7 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, TransferToInfeedResponse* result) { - const int64 replica_count = execute_backend_->Replicas().size(); + const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( "%s", @@ -980,11 +1015,14 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, se::StreamExecutor* executor; if (arg->has_device_handle()) { - TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas( - arg->device_handle().handle())); + TF_ASSIGN_OR_RETURN(auto replicas, + Replicas(*execute_backend_, arg->device_handle())); executor = replicas[arg->replica_id()]; } else { - executor = execute_backend_->Replicas()[arg->replica_id()]; + TF_ASSIGN_OR_RETURN( + auto replicas, + Replicas(*execute_backend_, SingleComputationDeviceHandle())); + executor = replicas[arg->replica_id()]; } return execute_backend_->transfer_manager()->TransferLiteralToInfeed( @@ -994,7 +1032,7 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, tensorflow::Status Service::TransferFromOutfeed( const TransferFromOutfeedRequest* arg, TransferFromOutfeedResponse* result) { - const int64 replica_count = execute_backend_->Replicas().size(); + const int64 replica_count = options_.number_of_replicas(); if (arg->replica_id() < 0 || arg->replica_id() >= replica_count) { return FailedPrecondition( "The replica_id=%lld on TransferFromOutfeedRequest not in range [0, " @@ -1004,11 +1042,14 @@ tensorflow::Status Service::TransferFromOutfeed( se::StreamExecutor* executor; if (arg->has_device_handle()) { - TF_ASSIGN_OR_RETURN(auto replicas, execute_backend_->Replicas( - arg->device_handle().handle())); + TF_ASSIGN_OR_RETURN(auto replicas, + Replicas(*execute_backend_, arg->device_handle())); executor = replicas[arg->replica_id()]; } else { - executor = execute_backend_->Replicas()[arg->replica_id()]; + TF_ASSIGN_OR_RETURN( + auto replicas, + Replicas(*execute_backend_, SingleComputationDeviceHandle())); + executor = replicas[arg->replica_id()]; } Literal literal; @@ -1195,6 +1236,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { StatusOr handle_status; switch (arg->op_case()) { + case OpRequest::kBatchNormTrainingRequest: + handle_status = computation->AddBatchNormTrainingInstruction( + arg->batch_norm_training_request()); + break; case OpRequest::kBinaryOpRequest: handle_status = computation->AddBinaryInstruction(arg->binary_op_request()); @@ -1277,6 +1322,11 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { computation->AddReduceInstruction(arg->reduce_request(), *to_apply); break; } + case OpRequest::kReducePrecisionRequest: { + handle_status = computation->AddReducePrecisionInstruction( + arg->reduce_precision_request()); + break; + } case OpRequest::kReduceWindowRequest: { TF_ASSIGN_OR_RETURN(UserComputation * to_apply, computation_tracker_.Resolve( @@ -1383,4 +1433,28 @@ tensorflow::Status Service::LoadComputationSnapshot( return tensorflow::Status::OK(); } +DeviceHandle Service::SingleComputationDeviceHandle() const { + DeviceHandle device_handle; + device_handle.set_handle(0); + device_handle.set_device_count(1); + return device_handle; +} + +StatusOr> Service::Replicas( + const Backend& backend, const DeviceHandle& device_handle) const { + std::vector replicas; + for (int replica = 0; replica < options_.number_of_replicas(); ++replica) { + // From the computation placer, find out the device ids of the replicas for + // the given device handle. + TF_ASSIGN_OR_RETURN( + int device_ordinal, + backend.computation_placer()->DeviceId(replica, device_handle.handle(), + options_.number_of_replicas(), + device_handle.device_count())); + TF_ASSIGN_OR_RETURN(auto executor, backend.stream_executor(device_ordinal)); + replicas.push_back(executor); + } + return replicas; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index abd1281bdd0..968fb53b347 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -248,7 +248,7 @@ class Service : public ServiceInterface { // The constructor is private. Use the NewService factory to create new // service objects. - Service(std::unique_ptr backend, + Service(const ServiceOptions& options, std::unique_ptr backend, std::unique_ptr compute_constant_backend); static StatusOr> CreateComputeConstantBackend(); @@ -319,8 +319,7 @@ class Service : public ServiceInterface { std::vector> arguments, Backend* backend, - tensorflow::gtl::ArraySlice - executors, + tensorflow::gtl::ArraySlice device_handles, tensorflow::gtl::ArraySlice result_tags); // Returns an HLO dumper for use in the compiler (it refers to flags @@ -346,6 +345,18 @@ class Service : public ServiceInterface { tensorflow::Status ValidateResultShapeWithLayout( const Shape& shape_with_layout, const Shape& result_shape) const; + // Returns the stream executors assigned to the replicas represented by the + // given device handle. Each device_handle is a virtual replicated device that + // represents a set of physical devices for the replicas. + StatusOr> Replicas( + const Backend& backend, const DeviceHandle& device_handle) const; + + // Returns the device handle that represents the replicated device for a + // single computation that is not model-parallelized. + DeviceHandle SingleComputationDeviceHandle() const; + + ServiceOptions options_; + // Tracks computations built via the API. ComputationTracker computation_tracker_; diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index d6436cf988d..8d0dc1edac7 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -184,6 +184,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, switch (operation) { case UNOP_FLOOR: case UNOP_CEIL: + case UNOP_COS: case UNOP_EXP: case UNOP_LOG: case UNOP_TANH: @@ -297,6 +298,30 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, return ShapeUtil::ChangeElementType(operand_shape, new_element_type); } +/* static */ StatusOr ShapeInference::InferReducePrecisionShape( + const Shape& operand_shape, const int exponent_bits, + const int mantissa_bits) { + if (!ShapeUtil::ElementIsFloating(operand_shape)) { + return InvalidArgument( + "expected element type in shape to be floating point for " + "ReducePrecision operation; got %s", + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + if (exponent_bits < 1) { + // One exponent bit is necessary to distinguish 0 from infinity. Having + // no exponent bits doesn't produce a sensible number, so we require at + // least one. + return InvalidArgument("expected exponent_bits >= 1; got %d", + exponent_bits); + } + if (mantissa_bits < 0) { + // A number with no mantissa bits is still meaningful, however. + return InvalidArgument("expected non-negative mantissa_bits; got %d", + mantissa_bits); + } + return operand_shape; +} + /* static */ StatusOr ShapeInference::InferPadShape( const Shape& operand_shape, const Shape& padding_value_shape, const PaddingConfig& padding_config) { @@ -754,6 +779,109 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( AsInt64Slice(arg_shape->dimensions())); } +/* static */ StatusOr ShapeInference::InferBatchNormTrainingShape( + const Shape& operand_shape, const Shape& offset_shape, + const Shape& scale_shape, int64 feature_index) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(operand_shape, "operand of batch norm training")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + offset_shape, "offset input of batch norm training")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + scale_shape, "scale input of batch norm training")); + + TF_RET_CHECK(ShapeUtil::ValidateShape(operand_shape) == + tensorflow::Status::OK()); + TF_RET_CHECK(ShapeUtil::ValidateShape(offset_shape) == + tensorflow::Status::OK()); + TF_RET_CHECK(ShapeUtil::ValidateShape(scale_shape) == + tensorflow::Status::OK()); + + if (feature_index >= ShapeUtil::Rank(operand_shape)) { + return InvalidArgument( + "Expected feature_index of batch-norm-training to be " + "smaller than the rank of operand_shape; " + "got feature_index %lld, and rank %lld", + feature_index, ShapeUtil::Rank(offset_shape)); + } + + if (feature_index < 0) { + return InvalidArgument( + "Expected feature_index of batch-norm-training to " + "be a non-negative number, got %lld", + feature_index); + } + + if (ShapeUtil::Rank(operand_shape) < 1) { + return InvalidArgument( + "Expected the rank of operand to " + "batch-norm-training to be at least 1; got %lld", + ShapeUtil::Rank(offset_shape)); + } + + if (ShapeUtil::Rank(offset_shape) != 1) { + return InvalidArgument( + "Offset input of batch-norm-training must have" + " rank 1, but has rank %lld.", + ShapeUtil::Rank(offset_shape)); + } + + if (ShapeUtil::Rank(scale_shape) != 1) { + return InvalidArgument( + "Scale input of batch-norm-training must have" + " rank 1, but has rank %lld.", + ShapeUtil::Rank(scale_shape)); + } + + if (!ShapeUtil::ElementIsFloating(operand_shape)) { + return InvalidArgument( + "The operand to batch-norm-training must have a floating point " + "element type, but the shape is %s", + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) { + return InvalidArgument( + "The inputs should have the same element type for batch-norm-training, " + "but the shape of offset factor is %s " + "and the shape of operand is %s", + PrimitiveType_Name(offset_shape.element_type()).c_str(), + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) { + return InvalidArgument( + "The inputs should have the same element type for batch-norm-training, " + "but the shape of scale factor is %s " + "and the shape of operand is %s", + PrimitiveType_Name(scale_shape.element_type()).c_str(), + PrimitiveType_Name(operand_shape.element_type()).c_str()); + } + + const int64 feature_count = operand_shape.dimensions(feature_index); + Shape output_shape_for_mean_and_var = + ShapeUtil::MakeShape(operand_shape.element_type(), {feature_count}); + + if (ShapeUtil::GetDimension(offset_shape, 0) != feature_count) { + return InvalidArgument( + "The size of offset factor should be the same as feature count," + "but the size of offset factor is %lld " + "and the feature count is %lld", + ShapeUtil::GetDimension(offset_shape, 0), feature_count); + } + + if (ShapeUtil::GetDimension(scale_shape, 0) != feature_count) { + return InvalidArgument( + "The size of scale factor should be the same as feature count," + "but the size of scale factor is %lld " + "and the feature count is %lld", + ShapeUtil::GetDimension(scale_shape, 0), feature_count); + } + + return ShapeUtil::MakeTupleShape({operand_shape, + output_shape_for_mean_and_var, + output_shape_for_mean_and_var}); +} + /* static */ StatusOr ShapeInference::InferConvolveShape( const Shape& lhs, const Shape& rhs, const Window& window, const ConvolutionDimensionNumbers& dnums) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 0d270f99794..42e4c7d39d2 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -64,6 +64,13 @@ class ShapeInference { tensorflow::gtl::ArraySlice arg_shapes, const ProgramShape& to_apply); + // Infers the shape produced by InferBatchNormTraining with the given + // operands. + static StatusOr InferBatchNormTrainingShape(const Shape& operand_shape, + const Shape& offset_shape, + const Shape& scale_shape, + int64 feature_index); + // Infers the shape produced by applying the given convolutional // filter (rhs) to lhs in the way specified by the fields on window. static StatusOr InferConvolveShape( @@ -165,6 +172,12 @@ class ShapeInference { static StatusOr InferConvertShape(const Shape& operand_shape, PrimitiveType new_element_type); + // Helper that validates the input data type for a reduce-precision operation, + // and returns the result shape. + static StatusOr InferReducePrecisionShape(const Shape& operand_shape, + const int exponent_bits, + const int mantissa_bits); + // Helper that infers the shape produced by a pad operation based on the // padding configuration. static StatusOr InferPadShape(const Shape& operand_shape, diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index 15f6b7bfb4a..c79ffa9cd73 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -65,6 +65,17 @@ class TransferManager { perftools::gputools::StreamExecutor* executor, const Literal& literal) = 0; + // Transfer a memory block of the given size from 'source' buffer to the + // Infeed interface of the device using the given executor. + // + // size is the size to transfer from source in bytes. + // + // source is the source data that must be in the target-dependent layout that + // the Infeed HLO used in the computation expects. + virtual Status TransferBufferToInfeed( + perftools::gputools::StreamExecutor* executor, int64 size, + const void* source) = 0; + // Transfers the given literal from the Outfeed interface of the device, // using the given executor. virtual Status TransferLiteralFromOutfeed( diff --git a/tensorflow/compiler/xla/service/transfer_manager_test.cc b/tensorflow/compiler/xla/service/transfer_manager_test.cc index ca38601d919..29ecef9510c 100644 --- a/tensorflow/compiler/xla/service/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/service/transfer_manager_test.cc @@ -55,7 +55,7 @@ class CpuTransferManagerTest : public ::testing::Test { TEST_F(CpuTransferManagerTest, TransferR0U32ToDevice) { std::vector storage(sizeof(uint32), '\x00'); se::DeviceMemoryBase memptr(storage.data(), storage.size()); - std::unique_ptr literal = LiteralUtil::CreateR0(42); + std::unique_ptr literal = Literal::CreateR0(42); TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, &memptr)); @@ -66,7 +66,7 @@ TEST_F(CpuTransferManagerTest, TransferR1F32ToDevice) { std::vector storage(4 * sizeof(float), '\x00'); se::DeviceMemoryBase memptr(storage.data(), storage.size()); std::unique_ptr literal = - LiteralUtil::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); + Literal::CreateR1({1.25f, 2.5f, -17.0f, -20.125f}); TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, &memptr)); @@ -80,7 +80,7 @@ TEST_F(CpuTransferManagerTest, TransferR1U8ToDevice) { std::vector storage(16, '\x00'); se::DeviceMemoryBase memptr(storage.data(), storage.size()); const char* str = "0123456789abcdef"; - std::unique_ptr literal = LiteralUtil::CreateR1U8(str); + std::unique_ptr literal = Literal::CreateR1U8(str); TF_CHECK_OK(transfer_manager_.TransferLiteralToDevice(stream_exec_, *literal, &memptr)); diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index c72d127ea86..9520c42d280 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -92,11 +92,11 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { auto builder = HloComputation::Builder("entry_computation"); // 2x1 HloInstruction* const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2({{1}, {2}}))); + HloInstruction::CreateConstant(Literal::CreateR2({{1}, {2}}))); // 3x2 HloInstruction* const1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); + Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); HloInstruction* transpose0 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {1, 2}), const0, {1, 0})); @@ -130,11 +130,11 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { auto builder = HloComputation::Builder("entry"); // (1.0 + 2.0) * (2.0 - 3.0) HloInstruction* const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); HloInstruction* const2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); HloInstruction* const3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( const1->shape(), HloOpcode::kAdd, const1, const2)); HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index ad6f015c70e..8d68398450c 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -243,12 +243,11 @@ Status TuplePointsToAnalysis::HandleGetTupleElement( return Status::OK(); } -Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy, - HloInstruction* operand) { +Status TuplePointsToAnalysis::HandleCopy(HloInstruction* copy) { // A kCopy instruction performs a shallow copy of the operand. The top-level // buffer (index={}) is newly created, but all other buffers (in the case of a // tuple shape) come from the operand - PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, operand); + PointsToSet& points_to_set = CreateCopiedPointsToSet(copy, copy->operand(0)); points_to_set.mutable_element(/*index=*/{})->clear(); points_to_set.AddPointedToBuffer(NewLogicalBuffer(copy, /*index=*/{}), /*index=*/{}); diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index 4d7fc7cbc9e..bab4235a287 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -208,7 +208,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleGetTupleElement(HloInstruction* get_tuple_element, HloInstruction* operand) override; Status HandleBitcast(HloInstruction* bitcast) override; - Status HandleCopy(HloInstruction* copy, HloInstruction* operand) override; + Status HandleCopy(HloInstruction* copy) override; Status HandleSelect(HloInstruction* select, HloInstruction* pred, HloInstruction* on_true, HloInstruction* on_false) override; diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index 9909c11929d..cd79e63cafc 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -124,9 +124,9 @@ class TuplePointsToAnalysisTest : public HloTestBase { TEST_F(TuplePointsToAnalysisTest, SimpleTuple) { auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); @@ -177,14 +177,14 @@ TEST_F(TuplePointsToAnalysisTest, NestedTuple) { // tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto inner_tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({inner_tuple, constant3})); @@ -238,14 +238,14 @@ TEST_F(TuplePointsToAnalysisTest, GetTupleElement) { // tuple. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto inner_tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto constant3 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(3.0))); + HloInstruction::CreateConstant(Literal::CreateR0(3.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({inner_tuple, constant3})); @@ -270,7 +270,7 @@ TEST_F(TuplePointsToAnalysisTest, DuplicatedElement) { // Create a tuple which contains duplicate elements. auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant, constant, constant})); @@ -291,9 +291,9 @@ TEST_F(TuplePointsToAnalysisTest, TupleCopy) { // the same. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto copy = builder.AddInstruction( @@ -318,16 +318,16 @@ TEST_F(TuplePointsToAnalysisTest, TupleSelect) { // set containing the union of both sides. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto tuple2 = builder.AddInstruction( HloInstruction::CreateTuple({constant2, constant2})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); @@ -356,7 +356,7 @@ TEST_F(TuplePointsToAnalysisTest, SelectTupleParameters) { auto param1 = builder.AddInstruction( HloInstruction::CreateParameter(1, tuple_shape, "param1")); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple_shape, HloOpcode::kSelect, pred, param0, param1)); auto copy = builder.AddInstruction( @@ -396,16 +396,16 @@ TEST_F(TuplePointsToAnalysisTest, UnambiguousTupleSelect) { // Select from two identical tuples. The result should not be ambiguous. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto tuple2 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); @@ -427,9 +427,9 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) { // the right values. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto inner_tuple1 = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto inner_tuple2 = builder.AddInstruction( @@ -441,7 +441,7 @@ TEST_F(TuplePointsToAnalysisTest, NestedTupleSelect) { builder.AddInstruction(HloInstruction::CreateTuple({inner_tuple2})); auto pred = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(false))); + HloInstruction::CreateConstant(Literal::CreateR0(false))); auto select = builder.AddInstruction(HloInstruction::CreateTernary( tuple1->shape(), HloOpcode::kSelect, pred, tuple1, tuple2)); @@ -474,9 +474,9 @@ TEST_F(TuplePointsToAnalysisTest, TupleWithBitcast) { // have the operand of the bitcast in its points-to set. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( constant2->shape(), HloOpcode::kBitcast, constant2)); auto tuple = @@ -510,10 +510,9 @@ TEST_F(TuplePointsToAnalysisTest, PointsToTupleConstantElements) { // Construct a tuple constant and kCopy it. Verify the points-to set of the // copy correctly correctly points into the nested elements of the constant. auto builder = HloComputation::Builder(TestName()); - auto tuple_constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), - LiteralUtil::CreateR1({2.0, 42}).get()}))); + auto tuple_constant = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::MakeTuple({Literal::CreateR2({{1.0}, {2.0}}).get(), + Literal::CreateR1({2.0, 42}).get()}))); auto copy = builder.AddInstruction(HloInstruction::CreateUnary( tuple_constant->shape(), HloOpcode::kCopy, tuple_constant)); @@ -533,9 +532,9 @@ TEST_F(TuplePointsToAnalysisTest, BufferAliases) { // times. Verify buffer alias sets. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); auto constant2 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(2.0))); + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); auto inner_tuple = builder.AddInstruction( HloInstruction::CreateTuple({constant1, constant2})); auto tuple = builder.AddInstruction( @@ -574,7 +573,7 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { auto tuple_element1 = builder.AddInstruction( HloInstruction::CreateGetTupleElement(update_shape, tuple_param0, 1)); auto ones = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.f, 1.f, 1.f, 1.f}))); + Literal::CreateR1({1.f, 1.f, 1.f, 1.f}))); // Create 'update' = Add(GetTupleElement(tuple_param0, 1), ones) auto update = builder.AddInstruction(HloInstruction::CreateBinary( update_shape, HloOpcode::kAdd, tuple_element1, ones)); diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 4aba8875161..92b8c7bb210 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -49,6 +49,8 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { return HloOpcode::kAbs; case UNOP_CEIL: return HloOpcode::kCeil; + case UNOP_COS: + return HloOpcode::kCos; case UNOP_EXP: return HloOpcode::kExp; case UNOP_FLOOR: @@ -465,6 +467,45 @@ StatusOr UserComputation::AddReduceInstruction( return handle; } +StatusOr +UserComputation::AddBatchNormTrainingInstruction( + const BatchNormTrainingRequest& batch_norm_training_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookUpRequest(batch_norm_training_request.operand())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* scale, + LookUpRequest(batch_norm_training_request.scale())); + + TF_ASSIGN_OR_RETURN(const OperationRequest* offset, + LookUpRequest(batch_norm_training_request.offset())); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + + TF_ASSIGN_OR_RETURN( + Shape inferred_shape, + ShapeInference::InferBatchNormTrainingShape( + operand->output_shape(), scale->output_shape(), + offset->output_shape(), batch_norm_training_request.feature_index())); + + *request.mutable_output_shape() = inferred_shape; + + *request.mutable_output_handle() = handle; + + *request.mutable_request()->mutable_batch_norm_training_request() = + batch_norm_training_request; + + VLOG(1) << "AddBatchNormTrainingInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << batch_norm_training_request.ShortDebugString(); + + return handle; +} + StatusOr UserComputation::AddReduceWindowInstruction( const ReduceWindowRequest& reduce_window_request, const UserComputation& to_apply_computation) { @@ -841,6 +882,34 @@ StatusOr UserComputation::AddConvertInstruction( return handle; } +StatusOr UserComputation::AddReducePrecisionInstruction( + const ReducePrecisionRequest& reduce_precision_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookUpRequest(reduce_precision_request.operand())); + + TF_ASSIGN_OR_RETURN( + Shape new_shape, + ShapeInference::InferReducePrecisionShape( + operand->output_shape(), reduce_precision_request.exponent_bits(), + reduce_precision_request.mantissa_bits())); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = new_shape; + *request.mutable_request()->mutable_reduce_precision_request() = + reduce_precision_request; + + VLOG(1) << "AddReducePrecisionInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << reduce_precision_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddConvolveInstruction( const ConvolveRequest& convolve_request) { tensorflow::mutex_lock lock(mutex_); @@ -1556,6 +1625,19 @@ void ConstantVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kBatchNormTrainingRequest: { + const BatchNormTrainingRequest& batch_norm_training_request = + request.request().batch_norm_training_request(); + ConstantVisitor(session_computation, + batch_norm_training_request.operand(), visited, + is_constant); + ConstantVisitor(session_computation, batch_norm_training_request.scale(), + visited, is_constant); + ConstantVisitor(session_computation, batch_norm_training_request.offset(), + visited, is_constant); + break; + } + case OpRequest::kBinaryOpRequest: { const BinaryOpRequest& binary_op_request = request.request().binary_op_request(); @@ -1824,7 +1906,6 @@ Status UserComputation::CheckParametersAreContiguous( } } - auto program_shape = MakeUnique(); for (int64 i = 0; i < parameter_requests.size(); ++i) { auto it = parameter_requests.find(i); if (it == parameter_requests.end()) { @@ -1964,6 +2045,16 @@ static void ForEachOperand( break; } + case OpRequest::kBatchNormTrainingRequest: { + const BatchNormTrainingRequest& batch_norm_training_request = + request.request().batch_norm_training_request(); + + apply(batch_norm_training_request.operand()); + apply(batch_norm_training_request.scale()); + apply(batch_norm_training_request.offset()); + break; + } + case OpRequest::kCrossReplicaSumRequest: { const CrossReplicaSumRequest& cross_replica_sum_request = request.request().cross_replica_sum_request(); @@ -2117,6 +2208,13 @@ static void ForEachOperand( break; } + case OpRequest::kReducePrecisionRequest: { + const ReducePrecisionRequest& reduce_precision_request = + request.request().reduce_precision_request(); + apply(reduce_precision_request.operand()); + break; + } + case OpRequest::kTraceRequest: { const TraceRequest& trace_request = request.request().trace_request(); apply(trace_request.operand()); @@ -2276,7 +2374,7 @@ void ComputationLowerer::Visit( const ConstantRequest& constant_request = request.request().constant_request(); hlo_instruction = add_instruction(HloInstruction::CreateConstant( - LiteralUtil::CloneToUnique(Literal(constant_request.literal())))); + Literal(constant_request.literal()).CloneToUnique())); break; } @@ -2457,6 +2555,23 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kBatchNormTrainingRequest: { + const BatchNormTrainingRequest& batch_norm_training_request = + request.request().batch_norm_training_request(); + HloInstruction* operand = + lookup_instruction(batch_norm_training_request.operand()); + HloInstruction* scale = + lookup_instruction(batch_norm_training_request.scale()); + HloInstruction* offset = + lookup_instruction(batch_norm_training_request.offset()); + + hlo_instruction = add_instruction(HloInstruction::CreateBatchNormTraining( + request.output_shape(), operand, scale, offset, + batch_norm_training_request.epsilon(), + batch_norm_training_request.feature_index())); + break; + } + case OpRequest::kBroadcastRequest: { const BroadcastRequest& broadcast_request = request.request().broadcast_request(); @@ -2688,6 +2803,18 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kReducePrecisionRequest: { + const ReducePrecisionRequest& reduce_precision_request = + request.request().reduce_precision_request(); + HloInstruction* operand = + lookup_instruction(reduce_precision_request.operand()); + auto exponent_bits = reduce_precision_request.exponent_bits(); + auto mantissa_bits = reduce_precision_request.mantissa_bits(); + hlo_instruction = add_instruction(HloInstruction::CreateReducePrecision( + request.output_shape(), operand, exponent_bits, mantissa_bits)); + break; + } + case OpRequest::kTraceRequest: { const TraceRequest& trace_request = request.request().trace_request(); HloInstruction* operand = lookup_instruction(trace_request.operand()); diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index fb5425ae61a..9bb7bf491a9 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -84,6 +84,10 @@ class UserComputation { StatusOr AddUnaryInstruction( const UnaryOpRequest& unary_request); + // Enqueues a batch norm training instruction onto this user computation. + StatusOr AddBatchNormTrainingInstruction( + const BatchNormTrainingRequest& batch_norm_training_request); + // Enqueues a binary instruction onto this user computation. // Returns an error status if the operand indices are out of bounds. StatusOr AddBinaryInstruction( @@ -112,6 +116,10 @@ class UserComputation { const MapRequest& map_request, const UserComputation& to_apply_computation); + // Enqueues a reduce-precision instruction onto this user computation. + StatusOr AddReducePrecisionInstruction( + const ReducePrecisionRequest& reduce_precision_request); + // Enqueues a convolution instruction onto this user computation. StatusOr AddConvolveInstruction( const ConvolveRequest& convolve_request); diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index ea691201263..41bb641f430 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -50,7 +50,7 @@ TEST_F(UserComputationTest, SimpleComputation) { ConstantRequest constant_request; *constant_request.mutable_literal() = - LiteralUtil::CreateR1({123.0f, 42.0f})->ToProto(); + Literal::CreateR1({123.0f, 42.0f})->ToProto(); TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle constant_handle, computation.AddConstantInstruction(constant_request)); @@ -161,12 +161,12 @@ TEST_F(UserComputationTest, EliminateScalarBroadcast) { ConstantRequest a_request; *a_request.mutable_literal() = - LiteralUtil::CreateR1({123.0f, 42.0f})->ToProto(); + Literal::CreateR1({123.0f, 42.0f})->ToProto(); TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle, computation.AddConstantInstruction(a_request)); ConstantRequest b_request; - *b_request.mutable_literal() = LiteralUtil::CreateR0(1.0f)->ToProto(); + *b_request.mutable_literal() = Literal::CreateR0(1.0f)->ToProto(); TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle, computation.AddConstantInstruction(b_request)); diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index ee49a9ae5f5..057905a4311 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -105,6 +105,11 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { return equal; } +/* static */ int64 ShapeUtil::Rank(const Shape& shape) { + CHECK(!ShapeUtil::IsTuple(shape)) << "Tuples do not have a rank"; + return shape.dimensions_size(); +} + /* static */ int64 ShapeUtil::TrueRank(const Shape& shape) { int64 accum = 0; for (int64 dimension : shape.dimensions()) { @@ -165,6 +170,17 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } return MakeShapeWithMonotonicDim0MajorLayout(shape.element_type(), dims); } + +/* static */ Shape ShapeUtil::ShapeWithoutPadding(const Shape& shape) { + Shape result = shape; + ForEachMutableSubshape(&result, [](Shape* subshape, const ShapeIndex& index) { + auto layout = subshape->mutable_layout(); + layout->clear_padding_value(); + layout->clear_padded_dimensions(); + }); + return result; +} + /* static */ void ShapeUtil::PopulateShape( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, Shape* shape) { @@ -270,7 +286,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } /* static */ bool ShapeUtil::IsNil(const Shape& shape) { - return IsEmptyTuple(shape) || HasZeroElements(shape); + return IsTuple(shape) ? IsEmptyTuple(shape) : HasZeroElements(shape); } /* static */ int64 ShapeUtil::TupleElementCount(const Shape& shape) { @@ -323,6 +339,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { + CHECK(!IsTuple(shape)); CHECK_EQ(shape.dimensions_size(), Rank(shape)); return std::accumulate( shape.dimensions().begin(), shape.dimensions().end(), 1LL, @@ -534,11 +551,6 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { if (shape.element_type() == TUPLE) { - // Tuple shape. - if (Rank(shape) != 0) { - return InvalidArgument("tuples must be rank-0; got rank %lld", - Rank(shape)); - } if (shape.dimensions_size() != 0) { return InvalidArgument("tuples must not have dimensions specified"); } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 853be6b4cb8..fa34bfc951d 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -93,6 +93,7 @@ class ShapeUtil { public: // Returns the number of elements are contained within the provided shape; // e.g. for rank 0 (scalars) the result is always 1. + // Precondition: !IsTuple(shape) static int64 ElementsIn(const Shape& shape); // Returns true if 'shape' has zero elements. @@ -144,7 +145,8 @@ class ShapeUtil { static bool Equal(const Shape& lhs, const Shape& rhs); // Returns the rank (number of dimensions) of the given shape. - static int64 Rank(const Shape& shape) { return shape.dimensions_size(); } + // Precondition: !IsTuple(shape) + static int64 Rank(const Shape& shape); // Returns the number of dimensions for which the dimension is not (trivially) // 1. e.g., f32[2x1x1] has a true rank of 1D, the other dimensions are just @@ -220,6 +222,9 @@ class ShapeUtil { // elements with a different shape. static Shape NormalizeShapeToMonotonicDim0MajorLayout(const Shape& shape); + // Returns a new shape that has all padding values cleared. + static Shape ShapeWithoutPadding(const Shape& shape); + // As MakeShape, but the object to write to is passed in. static void PopulateShape(PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 13dd1a30b60..a11ac0bec62 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -99,6 +99,7 @@ cc_library( "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_layout", + "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_execution_profile", @@ -116,8 +117,12 @@ cc_binary( name = "local_client_aot_test_helper", srcs = ["local_client_aot_test_helper.cc"], deps = [ + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", ], ) @@ -139,6 +144,7 @@ cc_library( "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", + "//tensorflow/compiler/xla:execution_options_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -151,7 +157,6 @@ cc_library( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -171,6 +176,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:executable", @@ -196,12 +202,14 @@ cc_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/service:shaped_buffer", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//third_party/eigen3", @@ -213,12 +221,13 @@ xla_test( srcs = ["bad_rng_shape_validation_test.cc"], deps = [ "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:lib", @@ -233,12 +242,12 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:test", @@ -255,7 +264,6 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:test", @@ -268,6 +276,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", @@ -275,7 +284,6 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -291,7 +299,6 @@ xla_test( deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -307,6 +314,7 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_proto", @@ -315,7 +323,6 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -339,7 +346,6 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -356,7 +362,6 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:test", @@ -371,7 +376,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -388,7 +393,7 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -409,7 +414,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -422,12 +427,13 @@ xla_test( srcs = ["deallocation_test.cc"], deps = [ "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -441,13 +447,14 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -471,7 +478,6 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", @@ -490,8 +496,8 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -518,8 +524,8 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -551,8 +557,8 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -583,8 +589,8 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:cpu_runtime_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/legacy_flags:layout_util_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -605,7 +611,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -624,7 +630,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -650,7 +656,7 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -677,7 +683,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -694,12 +700,13 @@ xla_test( "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -721,7 +728,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -738,7 +745,7 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -754,7 +761,7 @@ xla_test( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -768,11 +775,13 @@ xla_test( deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", + "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:platform_util", @@ -799,7 +808,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -816,7 +825,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -842,7 +851,7 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -865,7 +874,7 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -889,7 +898,7 @@ xla_test( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:padding", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -906,7 +915,7 @@ xla_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -921,10 +930,11 @@ xla_test( deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -941,11 +951,12 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", + "//tensorflow/core:test", ], ) @@ -958,7 +969,7 @@ xla_test( "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -976,7 +987,7 @@ xla_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -995,7 +1006,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1009,7 +1020,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -1022,7 +1033,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -1044,7 +1055,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1058,11 +1069,12 @@ xla_test( deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/core:lib", "//tensorflow/core:test", @@ -1081,13 +1093,14 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1103,7 +1116,7 @@ xla_test( "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1125,7 +1138,7 @@ xla_test( "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1142,11 +1155,12 @@ xla_test( "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:test", @@ -1161,7 +1175,6 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1184,7 +1197,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1199,7 +1212,7 @@ xla_test( deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1215,13 +1228,14 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", "//tensorflow/core:lib", @@ -1240,7 +1254,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:test_utils", @@ -1262,7 +1276,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1279,7 +1293,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1298,7 +1312,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1315,8 +1329,14 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1333,6 +1353,7 @@ cc_test( linkstatic = 1, deps = [ "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", ], @@ -1347,9 +1368,9 @@ cc_test( ":local_client_test_base", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:computation_tracker", "//tensorflow/compiler/xla/service:local_service", - "//tensorflow/core:lib", "//tensorflow/core:test_main", ], ) @@ -1365,7 +1386,7 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1381,7 +1402,7 @@ xla_test( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", @@ -1407,7 +1428,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:global_data", "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags", + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/core:lib", @@ -1415,6 +1436,16 @@ xla_test( ], ) +xla_test( + name = "deep_graph_test", + srcs = ["deep_graph_test.cc"], + deps = [ + "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", + "//tensorflow/compiler/xla/legacy_flags:user_computation_flags", + "//tensorflow/compiler/xla/tests:client_library_test_base", + ], +) + cc_test( name = "literal_test_util_test", srcs = ["literal_test_util_test.cc"], diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index c07f2745fe9..554042b35f7 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -26,7 +26,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -158,13 +157,13 @@ TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { b_values.push_back(2 * i / static_cast(count + 2)); } - std::unique_ptr a_literal = LiteralUtil::CreateR1({a_values}); + std::unique_ptr a_literal = Literal::CreateR1({a_values}); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); auto a_constant = builder.ConstantR1(a_values); auto a_param = builder.Parameter(0, a_literal->shape(), "a_param"); - std::unique_ptr b_literal = LiteralUtil::CreateR1({b_values}); + std::unique_ptr b_literal = Literal::CreateR1({b_values}); std::unique_ptr b_data = client_->TransferToServer(*b_literal).ConsumeValueOrDie(); auto b_constant = builder.Parameter(1, a_literal->shape(), "b_param"); @@ -804,7 +803,7 @@ TEST_F(ArrayElementwiseOpTest, PowSpecialF32) { std::vector values = {1.0f, 2.0f, 3.2f, -4.0f}; std::vector exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f}; - std::unique_ptr param_literal = LiteralUtil::CreateR1(values); + std::unique_ptr param_literal = Literal::CreateR1(values); std::unique_ptr param_data = client_->TransferToServer(*param_literal).ConsumeValueOrDie(); @@ -1241,12 +1240,12 @@ TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - LiteralUtil::CreateR1({7.2f, 2.3f, 3.4f, 5.6f}); + Literal::CreateR1({7.2f, 2.3f, 3.4f, 5.6f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -1263,12 +1262,12 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersZeroElementF32s) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); + Literal::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - LiteralUtil::CreateR3FromArray3D(Array3D(0, 7, 0)); + Literal::CreateR3FromArray3D(Array3D(0, 7, 0)); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -1285,7 +1284,7 @@ TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); + Literal::CreateR1({1.1f, 2.2f, 3.3f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -1297,6 +1296,15 @@ TEST_F(ArrayElementwiseOpTest, AddParameterToConstantF32s) { {param0_data.get()}, error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, CosF32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({3.14159f, 0.0f, 1.570796f, -0.78539f}); + auto result = builder.Cos(a); + + ComputeAndCompareR1(&builder, {-1.0f, 1.0f, 0.0f, 0.707107f}, {}, + error_spec_); +} + TEST_F(ArrayElementwiseOpTest, TanhF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f}); @@ -1447,9 +1455,9 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) { auto cmp_dim_1 = builder.Eq(v, m, /*broadcast_dimensions=*/{0}); auto result = builder.Tuple({cmp_dim_0, cmp_dim_1}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{true, true}, {true, false}}).get(), - LiteralUtil::CreateR2({{true, false}, {false, false}}).get()}); + auto expected = Literal::MakeTuple( + {Literal::CreateR2({{true, true}, {true, false}}).get(), + Literal::CreateR2({{true, false}, {false, false}}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -1802,7 +1810,7 @@ TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { std::iota(r1.begin(), r1.end(), 1.0); ComputationBuilder builder(client_, TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR4FromArray4D(r4); + std::unique_ptr a_literal = Literal::CreateR4FromArray4D(r4); *a_literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3}); auto a = builder.ConstantLiteral(*a_literal); @@ -1838,8 +1846,8 @@ TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { // broadcast. TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { ComputationBuilder builder(client_, TestName()); - auto x_literal = LiteralUtil::CreateR1({1, 2, 3}); - auto y_literal = LiteralUtil::CreateR1({4, 5}); + auto x_literal = Literal::CreateR1({1, 2, 3}); + auto y_literal = Literal::CreateR1({4, 5}); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); @@ -1856,13 +1864,35 @@ INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount, ArrayElementwiseOpTestParamCount, ::testing::Values(127, 128, 129, 17 * 4096)); +XLA_TEST_F(ArrayElementwiseOpTest, ReducePrecisionNoOpF32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({-2.5f, 25.5f}); + auto reduce_precision = builder.ReducePrecision(a, 8, 23); + + ComputeAndCompareR1(&builder, {-2.5f, 25.5f}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ReducePrecisionNoOpParamF32) { + ComputationBuilder builder(client_, TestName()); + + std::vector a_values = {-2.5f, 25.5f}; + + std::unique_ptr a_literal = Literal::CreateR1({a_values}); + std::unique_ptr a_data = + client_->TransferToServer(*a_literal).ConsumeValueOrDie(); + auto a_param = builder.Parameter(0, a_literal->shape(), "a_param"); + + auto reduce_precision = builder.ReducePrecision(a_param, 8, 23); + + ComputeAndCompareR1(&builder, {-2.5f, 25.5f}, {a_data.get()}); +} + } // namespace } // namespace xla int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendUserComputationFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc index a1ca1de584f..67dbc913b42 100644 --- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -76,7 +75,6 @@ TEST_F(AxpySimpleTest, AxpyTenValues) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc index ea58491038c..02be0b5ab83 100644 --- a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc +++ b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -70,7 +69,6 @@ TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 6a47f1b718a..9f9f5412301 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -48,7 +47,7 @@ class BatchNormalizationTest : public ClientLibraryTestBase { {5.0f, 4.4f}, // p2 }); input_array_.FillWithPZ(pz); - input_literal_ = *LiteralUtil::CreateR4FromArray4D(input_array_); + input_literal_ = *Literal::CreateR4FromArray4D(input_array_); CHECK_EQ(kSamples, input_array_.planes()); CHECK_EQ(kZ, input_array_.depth()); CHECK_EQ(kY, input_array_.height()); @@ -196,7 +195,6 @@ TEST_F(BatchNormalizationTest, SpecComparisonForward) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/binop_scaling_test.cc b/tensorflow/compiler/xla/tests/binop_scaling_test.cc index 5e3b70702dd..e6b853c2e4e 100644 --- a/tensorflow/compiler/xla/tests/binop_scaling_test.cc +++ b/tensorflow/compiler/xla/tests/binop_scaling_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -143,7 +142,6 @@ TEST_F(BinopScalingTest, R4PlusR0S32) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 25fe04a930e..aab2c746344 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -63,9 +62,8 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { Array3D* r3_array, float start, float end, int seed) { *r3_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r3_array->FillRandom(start, end, seed); - auto r3_data = - LiteralUtil::Relayout(*LiteralUtil::CreateR3FromArray3D(*r3_array), - LayoutUtil::MakeLayout(minor_to_major)); + auto r3_data = Literal::CreateR3FromArray3D(*r3_array)->Relayout( + LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r3_global_data = client_->TransferToServer(*r3_data).ConsumeValueOrDie(); return r3_global_data; @@ -77,9 +75,8 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { Array2D* r2_array, float start, float end, int seed) { *r2_shape = ShapeUtil::MakeShapeWithLayout(F32, bounds, minor_to_major); r2_array->FillRandom(start, end, seed); - auto r2_data = - LiteralUtil::Relayout(*LiteralUtil::CreateR2FromArray2D(*r2_array), - LayoutUtil::MakeLayout(minor_to_major)); + auto r2_data = Literal::CreateR2FromArray2D(*r2_array)->Relayout( + LayoutUtil::MakeLayout(minor_to_major)); std::unique_ptr r2_global_data = client_->TransferToServer(*r2_data).ConsumeValueOrDie(); return r2_global_data; @@ -217,13 +214,13 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) { ComputationBuilder b(client_, TestName()); b.Add(b.ConstantR2({{1.0, 5.0}}), - b.ConstantLiteral(*LiteralUtil::CreateR3( + b.ConstantLiteral(*Literal::CreateR3( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); auto expected = - LiteralUtil::CreateR3({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, - {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); + Literal::CreateR3({{{3.0, 7.0}, {4.0, 8.0}, {5.0, 9.0}}, + {{6.0, 10.0}, {7.0, 11.0}, {8.0, 12.0}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -292,7 +289,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) { } } } - auto expected = LiteralUtil::CreateR3FromArray3D(expected_array); + auto expected = Literal::CreateR3FromArray3D(expected_array); ComputeAndCompareLiteral( &builder, *expected, {r3_implicit_global_data.get(), r3_global_data.get()}, @@ -317,7 +314,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { b.Add(r3h, r1h); auto expected = - LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); + Literal::CreateR3({{{2, 3}, {4, 5}}, {{7, 8}, {9, 10}}}); ComputeAndCompareLiteral(&b, *expected, {r3.get(), r1.get()}, ErrorSpec(0.0001)); @@ -325,81 +322,79 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) { XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) { ComputationBuilder b(client_, TestName()); - auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR3({{{1, 2}}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}}})); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1); auto expected = - LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); + Literal::CreateR3({{{2, 4}, {4, 6}}, {{6, 8}, {8, 10}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) { ComputationBuilder b(client_, TestName()); - auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR3({{{1}, {2}}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1}, {2}}})); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1); auto expected = - LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); + Literal::CreateR3({{{2, 3}, {5, 6}}, {{6, 7}, {9, 10}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) { ComputationBuilder b(client_, TestName()); - auto r1 = - b.ConstantLiteral(*LiteralUtil::CreateR3({{{1, 2}, {3, 4}}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}, {3, 4}}})); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1); auto expected = - LiteralUtil::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); + Literal::CreateR3({{{2, 4}, {6, 8}}, {{6, 8}, {10, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) { ComputationBuilder b(client_, TestName()); - auto r1 = - b.ConstantLiteral(*LiteralUtil::CreateR3({{{1, 2}}, {{3, 4}}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1, 2}}, {{3, 4}}})); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1); auto expected = - LiteralUtil::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); + Literal::CreateR3({{{2, 4}, {4, 6}}, {{8, 10}, {10, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) { ComputationBuilder b(client_, TestName()); - auto r1 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1}, {2}}, {{3}, {4}}})); + auto r1 = + b.ConstantLiteral(*Literal::CreateR3({{{1}, {2}}, {{3}, {4}}})); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1); auto expected = - LiteralUtil::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); + Literal::CreateR3({{{2, 3}, {5, 6}}, {{8, 9}, {11, 12}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) { ComputationBuilder b(client_, TestName()); - auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR3({{{1}}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR3({{{1}}})); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1); auto expected = - LiteralUtil::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); + Literal::CreateR3({{{2, 3}, {4, 5}}, {{6, 7}, {8, 9}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -541,7 +536,7 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) { *v = ApplyOpToFloats(spec.op2, tmp, v3); }); - auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); + auto expected = Literal::CreateR2FromArray2D(expected_array); ComputeAndCompareLiteral( &builder, *expected, {r2_implicit_global_data1.get(), r2_global_data.get(), @@ -555,22 +550,22 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances, XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) { ComputationBuilder b(client_, TestName()); - auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1, 2}})); - auto r2 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}})); + auto r2 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}, {3, 4}})); b.Add(r2, r1); - auto expected = LiteralUtil::CreateR2({{2, 4}, {4, 6}}); + auto expected = Literal::CreateR2({{2, 4}, {4, 6}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) { ComputationBuilder b(client_, TestName()); - auto r1 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1}, {2}})); - auto r2 = b.ConstantLiteral(*LiteralUtil::CreateR2({{1, 2}, {3, 4}})); + auto r1 = b.ConstantLiteral(*Literal::CreateR2({{1}, {2}})); + auto r2 = b.ConstantLiteral(*Literal::CreateR2({{1, 2}, {3, 4}})); b.Add(r2, r1); - auto expected = LiteralUtil::CreateR2({{2, 3}, {5, 6}}); + auto expected = Literal::CreateR2({{2, 3}, {5, 6}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -579,11 +574,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) { ComputationBuilder b(client_, TestName()); auto r1 = b.ConstantR1({10, 20}); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r3, r1, {0}); - auto expected = LiteralUtil::CreateR3( - {{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); + auto expected = + Literal::CreateR3({{{11, 12}, {13, 14}}, {{25, 26}, {27, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -592,11 +587,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) { ComputationBuilder b(client_, TestName()); auto r1 = b.ConstantR1({10, 20}); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r1, r3, {1}); - auto expected = LiteralUtil::CreateR3( - {{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); + auto expected = + Literal::CreateR3({{{11, 12}, {23, 24}}, {{15, 16}, {27, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -605,11 +600,11 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) { ComputationBuilder b(client_, TestName()); auto r1 = b.ConstantR1({10, 20}); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); b.Add(r1, r3, {2}); - auto expected = LiteralUtil::CreateR3( - {{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); + auto expected = + Literal::CreateR3({{{11, 22}, {13, 24}}, {{15, 26}, {17, 28}}}); ComputeAndCompareLiteral(&b, *expected, {}, ErrorSpec(0.0001)); } @@ -620,7 +615,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { auto r1_1 = b.ConstantR1({100, 200}); auto r1_2 = b.ConstantR1({10, 20}); auto r3 = b.ConstantLiteral( - *LiteralUtil::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); + *Literal::CreateR3({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}})); for (int i = 0; i < 3; ++i) { r3 = b.Add(r1_0, r3, {0}); r3 = b.Add(r3, r1_1, {1}); @@ -628,7 +623,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) { } r3 = b.Mul(r3, b.ConstantR0(-2)); - auto expected = LiteralUtil::CreateR3( + auto expected = Literal::CreateR3( {{{-6 * 1110 - 2, -6 * 1120 - 4}, {-6 * 1210 - 6, -6 * 1220 - 8}}, {{-6 * 2110 - 10, -6 * 2120 - 12}, {-6 * 2210 - 14, -6 * 2220 - 16}}}); @@ -649,7 +644,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) { } r3 = b.Mul(r3, b.ConstantR0(-1)); - auto expected = LiteralUtil::CreateR3( + auto expected = Literal::CreateR3( {{{-3 * 1110 - 3, -3 * 1120 - 3}, {-3 * 1210 - 3, -3 * 1220 - 3}}, {{-3 * 2110 - 3, -3 * 2120 - 3}, {-3 * 2210 - 3, -3 * 2220 - 3}}}); @@ -662,7 +657,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) { ComputationBuilder b(client_, TestName()); b.Add(b.ConstantR2({{1.0, 5.0}, {1.0, 5.0}}), - b.ConstantLiteral(*LiteralUtil::CreateR3( + b.ConstantLiteral(*Literal::CreateR3( {{{2.0}, {3.0}, {4.0}}, {{5.0}, {6.0}, {7.0}}})), /*broadcast_dimensions=*/{1, 2}); @@ -704,7 +699,6 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendUserComputationFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 96a329a9bd8..dc1443f5363 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -39,7 +38,7 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { // Test degenerate case of broadcasting a scalar into a scalar. auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {}), input, {})); @@ -48,14 +47,14 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR0(42.0), *result, + LiteralTestUtil::ExpectNear(*Literal::CreateR0(42.0), *result, error_spec_); } XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 2}), input, {})); @@ -65,14 +64,14 @@ XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, + *Literal::CreateR2({{42.0, 42.0}, {42.0, 42.0}}), *result, error_spec_); } XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); + Literal::CreateR1({1.0, 2.0, 3.0}))); // Broadcast vector in both dimension 0 and dimension 1. Join them in a tuple // to enable testing of the results. @@ -88,18 +87,18 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), + *Literal::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), result->tuple_literals(0), error_spec_); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), + *Literal::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), result->tuple_literals(1), error_spec_); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1})); @@ -109,7 +108,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), *result, + *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}), *result, error_spec_); } @@ -118,7 +117,7 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { // the dimensions, ie transpose. auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0})); @@ -128,14 +127,14 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), *result, + *Literal::CreateR2({{1.0, 3.0}, {2.0, 4.0}}), *result, error_spec_); } XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2})); @@ -145,15 +144,15 @@ XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) { auto result = ExecuteAndTransfer(std::move(hlo_module), {}); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), + *Literal::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}), *result, error_spec_); } TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({1.0, 2.0}))); + HloInstruction::CreateConstant(Literal::CreateR1({1.0, 2.0}))); // Broadcast vector in dimension 1. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -168,8 +167,8 @@ TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) { Array2D pz({{1, 2}, {1, 2}}); expected.FillWithPZ(pz); - LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); + LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), + *result, error_spec_); } TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { @@ -178,7 +177,7 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { int64 r1_size = input_data.size(); std::iota(input_data.begin(), input_data.end(), 0.0f); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1(input_data))); + HloInstruction::CreateConstant(Literal::CreateR1(input_data))); // Broadcast vector in dimension 3. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -198,8 +197,8 @@ TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) { } expected.FillWithYX(yx); - LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); + LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), + *result, error_spec_); } XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { @@ -209,7 +208,7 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { std::vector r1_array(64, 42.0); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1(r1_array))); + HloInstruction::CreateConstant(Literal::CreateR1(r1_array))); // Broadcast vector in dimension 1. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -220,14 +219,14 @@ XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR4FromArray4D(r4_array), - *result, error_spec_); + LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(r4_array), *result, + error_spec_); } TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { auto builder = HloComputation::Builder(TestName()); auto input = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(1.0f))); builder.AddInstruction(HloInstruction::CreateBroadcast( ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {})); @@ -240,15 +239,15 @@ TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) { Array4D expected(64, 64, 3, 3); expected.Fill(1.0f); - LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); + LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), + *result, error_spec_); } TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { auto builder = HloComputation::Builder(TestName()); Array2D to_broadcast({{1.0f, 2.0f}, {3.0f, 4.0f}}); auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2FromArray2D(to_broadcast))); + Literal::CreateR2FromArray2D(to_broadcast))); // Broadcast vector in dimensions 2 and 3. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -262,8 +261,8 @@ TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) { Array4D expected(3, 3, 2, 2); expected.FillWithYX(to_broadcast); - LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); + LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), + *result, error_spec_); } TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { @@ -282,7 +281,7 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { } } auto input = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR3FromArray3D(input_vals))); + Literal::CreateR3FromArray3D(input_vals))); // Broadcast vector in dimensions 2 and 3. builder.AddInstruction(HloInstruction::CreateBroadcast( @@ -293,8 +292,8 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { hlo_module->AddEntryComputation(builder.Build()); auto result = ExecuteAndTransfer(std::move(hlo_module), {}); - LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR4FromArray4D(expected), *result, error_spec_); + LiteralTestUtil::ExpectNear(*Literal::CreateR4FromArray4D(expected), + *result, error_spec_); } } // namespace @@ -302,7 +301,6 @@ TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl index 1f61743451a..50edd8ea5b9 100644 --- a/tensorflow/compiler/xla/tests/build_defs.bzl +++ b/tensorflow/compiler/xla/tests/build_defs.bzl @@ -2,11 +2,25 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured") -def all_backends(): +all_backends = ["cpu", "cpu_parallel", "gpu"] + +def filter_backends(backends): + """Removes "gpu" from a backend list if CUDA is not enabled. + + This allows us to simply hardcode lists including "gpu" here and in the + BUILD file, without causing failures when CUDA isn't enabled.' + + Args: + backends: A list of backends to filter. + + Returns: + The filtered list of backends. + """ if cuda_is_configured(): - return ["cpu", "cpu_parallel", "gpu"] + return backends else: - return ["cpu", "cpu_parallel"] + return [backend for backend in backends if backend != "gpu"] + def xla_test(name, srcs, @@ -81,7 +95,7 @@ def xla_test(name, """ test_names = [] if not backends: - backends = all_backends() + backends = all_backends native.cc_library( name="%s_lib" % name, @@ -91,7 +105,7 @@ def xla_test(name, deps=deps + ["//tensorflow/compiler/xla/tests:test_macros_header"], ) - for backend in backends: + for backend in filter_backends(backends): test_name = "%s_%s" % (name, backend) this_backend_tags = ["xla_%s" % backend] this_backend_copts = [] @@ -127,16 +141,16 @@ def xla_test(name, def generate_backend_suites(backends=[]): if not backends: - backends = all_backends() - for backend in backends: + backends = all_backends + for backend in filter_backends(backends): native.test_suite(name="%s_tests" % backend, tags = ["xla_%s" % backend]) def generate_backend_test_macros(backends=[]): if not backends: - backends = all_backends() - for backend in backends: + backends = all_backends + for backend in filter_backends(backends): native.cc_library( name="test_macros_%s" % backend, testonly = True, diff --git a/tensorflow/compiler/xla/tests/call_test.cc b/tensorflow/compiler/xla/tests/call_test.cc index 55701c62db2..086199fda14 100644 --- a/tensorflow/compiler/xla/tests/call_test.cc +++ b/tensorflow/compiler/xla/tests/call_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -78,7 +77,7 @@ class CallOpTest : public ClientLibraryTestBase { XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32IdentityScalar)) { ComputationBuilder builder(client_, TestName()); Computation callee = CreateR0F32IdentityComputation(); - auto constant = builder.ConstantLiteral(*LiteralUtil::CreateR0(42.0)); + auto constant = builder.ConstantLiteral(*Literal::CreateR0(42.0)); builder.Call(callee, {constant}); ComputeAndCompareR0(&builder, 42.0, {}, ErrorSpec(0.01f)); @@ -87,8 +86,8 @@ XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32IdentityScalar)) { XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S0F32AddArray)) { ComputationBuilder builder(client_, TestName()); Computation callee = CreateR1S0F32AdditionComputation(); - auto x = builder.ConstantLiteral(*LiteralUtil::CreateR1({})); - auto y = builder.ConstantLiteral(*LiteralUtil::CreateR1({})); + auto x = builder.ConstantLiteral(*Literal::CreateR1({})); + auto y = builder.ConstantLiteral(*Literal::CreateR1({})); builder.Call(callee, {x, y}); ComputeAndCompareR1(&builder, {}, {}, ErrorSpec(0.01f)); @@ -97,8 +96,8 @@ XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S0F32AddArray)) { XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S2F32AddArray)) { ComputationBuilder builder(client_, TestName()); Computation callee = CreateR1S2F32AdditionComputation(); - auto x = builder.ConstantLiteral(*LiteralUtil::CreateR1({1.0f, 2.0f})); - auto y = builder.ConstantLiteral(*LiteralUtil::CreateR1({2.0f, 3.0f})); + auto x = builder.ConstantLiteral(*Literal::CreateR1({1.0f, 2.0f})); + auto y = builder.ConstantLiteral(*Literal::CreateR1({2.0f, 3.0f})); builder.Call(callee, {x, y}); ComputeAndCompareR1(&builder, {3.0f, 5.0f}, {}, ErrorSpec(0.01f)); @@ -107,8 +106,8 @@ XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR1S2F32AddArray)) { XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32Tuple)) { ComputationBuilder builder(client_, TestName()); Computation callee = CreateR0F32TupleComputation(); - auto elem = LiteralUtil::CreateR0(42.0); - auto tuple = LiteralUtil::MakeTuple({elem.get()}); + auto elem = Literal::CreateR0(42.0); + auto tuple = Literal::MakeTuple({elem.get()}); builder.Call(callee, {builder.ConstantLiteral(*elem)}); ComputeAndCompareTuple(&builder, *tuple, {}, ErrorSpec(0.01f)); @@ -120,7 +119,6 @@ XLA_TEST_F(CallOpTest, DISABLED_ON_GPU(CallR0F32Tuple)) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index 4825eaf19dc..2f4ad22f5bf 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -38,7 +37,7 @@ class CheckExecutionArityTest : public ClientLibraryTestBase {}; TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { ComputationBuilder builder(client_, "add_two_params"); - auto param_literal = LiteralUtil::CreateR1({1.1f, 2.2f}); + auto param_literal = Literal::CreateR1({1.1f, 2.2f}); auto p0 = builder.Parameter(0, param_literal->shape(), "param0"); auto p1 = builder.Parameter(1, param_literal->shape(), "param1"); @@ -55,18 +54,20 @@ TEST_F(CheckExecutionArityTest, TwoParamComputationNumArguments) { // The arity of the UserComputation is 2 arguments. Execution will succeed // with 2 arguments, but fail with a different number. - auto result_two_args = - client_->Execute(computation, {param0_data.get(), param1_data.get()}); + auto result_two_args = client_->Execute( + computation, {param0_data.get(), param1_data.get()}, &execution_options_); ASSERT_IS_OK(result_two_args.status()); - auto result_one_arg = client_->Execute(computation, {param0_data.get()}); + auto result_one_arg = + client_->Execute(computation, {param0_data.get()}, &execution_options_); ASSERT_FALSE(result_one_arg.ok()); ASSERT_EQ(result_one_arg.status().code(), tensorflow::error::INVALID_ARGUMENT); ASSERT_THAT(result_one_arg.status().error_message(), ContainsRegex("takes 2")); - auto result_zero_args = client_->Execute(computation, {}); + auto result_zero_args = + client_->Execute(computation, {}, &execution_options_); ASSERT_FALSE(result_zero_args.ok()); ASSERT_EQ(result_zero_args.status().code(), tensorflow::error::INVALID_ARGUMENT); @@ -85,35 +86,38 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { ASSERT_IS_OK(computation_status.status()); auto computation = computation_status.ConsumeValueOrDie(); - auto f32_literal = LiteralUtil::CreateR0(1.1f); + auto f32_literal = Literal::CreateR0(1.1f); auto f32_data = client_->TransferToServer(*f32_literal).ConsumeValueOrDie(); - auto f32_4_literal = LiteralUtil::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); + auto f32_4_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f, 4.0f}); auto f32_4_data = client_->TransferToServer(*f32_4_literal).ConsumeValueOrDie(); - auto u8_4_literal = LiteralUtil::CreateR1U8("hola"); + auto u8_4_literal = Literal::CreateR1U8("hola"); auto u8_4_data = client_->TransferToServer(*u8_4_literal).ConsumeValueOrDie(); // Match - auto status = - client_->Execute(computation, {f32_data.get(), f32_4_data.get()}); + auto status = client_->Execute( + computation, {f32_data.get(), f32_4_data.get()}, &execution_options_); ASSERT_IS_OK(status.status()); // Shape mismatch in parameter 0 - status = client_->Execute(computation, {f32_4_data.get(), f32_4_data.get()}); + status = client_->Execute(computation, {f32_4_data.get(), f32_4_data.get()}, + &execution_options_); ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); ASSERT_THAT(status.status().error_message(), ContainsRegex("expects parameter 0")); // Shape mismatch in parameter 1 (rank) - status = client_->Execute(computation, {f32_data.get(), f32_data.get()}); + status = client_->Execute(computation, {f32_data.get(), f32_data.get()}, + &execution_options_); ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); ASSERT_THAT(status.status().error_message(), ContainsRegex("expects parameter 1")); // Shape mismatch in parameter 1 (element type) - status = client_->Execute(computation, {f32_data.get(), u8_4_data.get()}); + status = client_->Execute(computation, {f32_data.get(), u8_4_data.get()}, + &execution_options_); ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); ASSERT_THAT(status.status().error_message(), @@ -126,7 +130,6 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index b96bb8f8469..4e2e0c7776d 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -45,10 +45,8 @@ Client* GetOrCreateLocalClientOrDie(se::Platform* platform) { } // namespace ClientLibraryTestBase::ClientLibraryTestBase(se::Platform* platform) - : client_(GetOrCreateLocalClientOrDie(platform)) { - *(execution_options_.mutable_debug_options()) = - legacy_flags::GetDebugOptionsFromFlags(); - + : client_(GetOrCreateLocalClientOrDie(platform)), + execution_options_(CreateDefaultExecutionOptions()) { // Disabling constant_folding so that tests (usually written using Constants) // will exercise the intended code paths, instead of being constant folded. // @@ -72,12 +70,9 @@ StatusOr> ClientLibraryTestBase::Execute( } StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( - ComputationBuilder* builder, + const Computation& computation, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout) { - // Build the computation, as a convenience. - TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); - ExecutionOptions execution_options = execution_options_; if (shape_with_output_layout != nullptr) { *execution_options.mutable_shape_with_output_layout() = @@ -87,6 +82,15 @@ StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( &execution_options); } +StatusOr> ClientLibraryTestBase::ExecuteAndTransfer( + ComputationBuilder* builder, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout) { + // Build the computation, as a convenience. + TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); + return ExecuteAndTransfer(computation, arguments, shape_with_output_layout); +} + std::unique_ptr ClientLibraryTestBase::ExecuteOrDie( ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments) { @@ -113,14 +117,14 @@ string ClientLibraryTestBase::ExecuteToString( if (!result.ok()) { return result.status().ToString(); } else { - return LiteralUtil::ToString(*result.ValueOrDie()); + return result.ValueOrDie()->ToString(); } } void ClientLibraryTestBase::ComputeAndCompareR1( ComputationBuilder* builder, const tensorflow::core::Bitmap& expected, tensorflow::gtl::ArraySlice arguments) { - std::unique_ptr expected_literal = LiteralUtil::CreateR1(expected); + std::unique_ptr expected_literal = Literal::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -179,10 +183,10 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( auto actual = actual_status.ConsumeValueOrDie(); // Turn the expected value into a literal. - std::unique_ptr expected_literal = LiteralUtil::CreateR1U8(expected); + std::unique_ptr expected_literal = Literal::CreateR1U8(expected); - VLOG(1) << "expected: " << LiteralUtil::ToString(*expected_literal); - VLOG(1) << "actual: " << LiteralUtil::ToString(*actual); + VLOG(1) << "expected: " << expected_literal->ToString(); + VLOG(1) << "actual: " << actual->ToString(); EXPECT_EQ(expected, actual->u8s_string()); } diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index f9e1082ebb4..763ff099654 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -66,7 +66,9 @@ class ClientLibraryTestBase : public ::testing::Test { // TODO(b/25566808): Add helper that populates a literal from a testdata file. - // Convenience methods for building and running a computation from a builder. + // Convenience methods for building and running a computation with the member + // execution options. Modify execution_options_ in your test if you want to + // customize the options. StatusOr> Execute( ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments); @@ -74,6 +76,10 @@ class ClientLibraryTestBase : public ::testing::Test { ComputationBuilder* builder, tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_output_layout = nullptr); + StatusOr> ExecuteAndTransfer( + const Computation& computation, + tensorflow::gtl::ArraySlice arguments, + const Shape* shape_with_output_layout = nullptr); // Convenience OrDie variants of above methods. std::unique_ptr ExecuteOrDie( @@ -278,7 +284,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( ComputationBuilder* builder, NativeT expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - LiteralUtil::CreateR0(expected); + Literal::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -291,7 +297,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( std::is_same::value, "Floating point type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - LiteralUtil::CreateR0(expected); + Literal::CreateR0(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -301,7 +307,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( ComputationBuilder* builder, tensorflow::gtl::ArraySlice expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); + Literal::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -314,7 +320,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( std::is_same::value, "Floating point type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - LiteralUtil::CreateR1(expected); + Literal::CreateR1(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -324,7 +330,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( ComputationBuilder* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - LiteralUtil::CreateR2FromArray2D(expected); + Literal::CreateR2FromArray2D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -337,7 +343,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( std::is_same::value, "Floating point type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - LiteralUtil::CreateR2FromArray2D(expected); + Literal::CreateR2FromArray2D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -347,7 +353,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( ComputationBuilder* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - LiteralUtil::CreateR3FromArray3D(expected); + Literal::CreateR3FromArray3D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -360,7 +366,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( std::is_same::value, "Floating point type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - LiteralUtil::CreateR3FromArray3D(expected); + Literal::CreateR3FromArray3D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -370,7 +376,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( ComputationBuilder* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments) { std::unique_ptr expected_literal = - LiteralUtil::CreateR4FromArray4D(expected); + Literal::CreateR4FromArray4D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments); } @@ -383,7 +389,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( std::is_same::value, "Floating point type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = - LiteralUtil::CreateR4FromArray4D(expected); + Literal::CreateR4FromArray4D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, arguments, error); } @@ -392,7 +398,7 @@ template std::unique_ptr ClientLibraryTestBase::CreateR0Parameter( NativeT value, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR0(value); + std::unique_ptr literal = Literal::CreateR0(value); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -404,7 +410,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR1Parameter( tensorflow::gtl::ArraySlice values, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR1(values); + std::unique_ptr literal = Literal::CreateR1(values); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -416,7 +422,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR2Parameter( const Array2D& array_2d, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR2FromArray2D(array_2d); + std::unique_ptr literal = Literal::CreateR2FromArray2D(array_2d); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); @@ -428,7 +434,7 @@ std::unique_ptr ClientLibraryTestBase::CreateR3Parameter( const Array3D& array_3d, int64 parameter_number, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { - std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(array_3d); + std::unique_ptr literal = Literal::CreateR3FromArray3D(array_3d); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, literal->shape(), name); diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 1247804dae0..e84a6ce7102 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -47,7 +46,7 @@ TEST_F(ClientTest, ExecuteWithLayout) { auto computation = b.Build(); ASSERT_TRUE(computation.ok()) << computation.status(); - ExecutionOptions execution_options; + ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(S32, /*dimensions=*/{2, 2}, execute_layout); @@ -77,7 +76,7 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) { auto computation = b.Build(); ASSERT_TRUE(computation.ok()) << computation.status(); - ExecutionOptions execution_options; + ExecutionOptions execution_options = execution_options_; // Create a result shape with one element column major and the other row // major. *execution_options.mutable_shape_with_output_layout() = @@ -115,7 +114,6 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.cc b/tensorflow/compiler/xla/tests/codegen_test_base.cc index cc3eb0e8d46..f48dc50708a 100644 --- a/tensorflow/compiler/xla/tests/codegen_test_base.cc +++ b/tensorflow/compiler/xla/tests/codegen_test_base.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" @@ -32,6 +33,17 @@ limitations under the License. namespace xla { +std::unique_ptr CodegenTestBase::CreateNewModuleWithEmbeddedIr( + bool ftz) { + HloModuleConfig config; + auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); + debug_options.set_xla_embed_ir_in_executable(true); + debug_options.set_xla_gpu_ftz(ftz); + config.set_debug_options(debug_options); + return MakeUnique(TestName(), VersionedComputationHandle(), + config); +} + void CodegenTestBase::CompileAndVerifyIr(std::unique_ptr hlo_module, const string& pattern) { std::unique_ptr executable = diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.h b/tensorflow/compiler/xla/tests/codegen_test_base.h index 50c04531070..fa073cd91ee 100644 --- a/tensorflow/compiler/xla/tests/codegen_test_base.h +++ b/tensorflow/compiler/xla/tests/codegen_test_base.h @@ -28,7 +28,11 @@ namespace xla { // Tests that verify IR emitted by the CPU/GPU backend is as expected. class CodegenTestBase : public HloTestBase { protected: - CodegenTestBase() {} + // Like HloTestBase::CreateNewModule, but also sets the "embed ir in + // executable" flag to true, since this is needed for codegen tests. + // The optional ftz flags configures whether these modules have their ftz + // option turned on. + std::unique_ptr CreateNewModuleWithEmbeddedIr(bool ftz = false); // Returns the embedded LLVM IR from the given executable. Codegen tests must // override this method, but execution tests do not have to because they do diff --git a/tensorflow/compiler/xla/tests/compilation_cache_test.cc b/tensorflow/compiler/xla/tests/compilation_cache_test.cc index 18ea9714d1a..7038afc5b1f 100644 --- a/tensorflow/compiler/xla/tests/compilation_cache_test.cc +++ b/tensorflow/compiler/xla/tests/compilation_cache_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -48,10 +47,10 @@ class CompilationCacheTest : public ClientLibraryTestBase { std::unique_ptr result = client_ ->ExecuteAndTransfer(computation, arguments, - /*execution_options=*/nullptr, + /*execution_options=*/&execution_options_, &execution_profile) .ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR0(expected_result), + LiteralTestUtil::ExpectNear(*Literal::CreateR0(expected_result), *result, error_spec_); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -62,14 +61,13 @@ class CompilationCacheTest : public ClientLibraryTestBase { std::initializer_list> expected_result, bool expect_cache_hit) { ExecutionProfile execution_profile; - auto data_handle = - client_ - ->Execute(computation, arguments, /*execution_options=*/nullptr, - &execution_profile) - .ConsumeValueOrDie(); + auto data_handle = client_ + ->Execute(computation, arguments, + &execution_options_, &execution_profile) + .ConsumeValueOrDie(); std::unique_ptr result = client_->Transfer(*data_handle).ConsumeValueOrDie(); - LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR2(expected_result), + LiteralTestUtil::ExpectNear(*Literal::CreateR2(expected_result), *result, error_spec_); EXPECT_EQ(expect_cache_hit, execution_profile.compilation_cache_hit()); } @@ -89,13 +87,13 @@ XLA_TEST_F(CompilationCacheTest, ComputationCalledMultipleTimes) { XLA_TEST_F(CompilationCacheTest, ComputationCalledWithDifferentParameters) { std::unique_ptr data_42 = - client_->TransferToServer(*LiteralUtil::CreateR0(42.0f)) + client_->TransferToServer(*Literal::CreateR0(42.0f)) .ConsumeValueOrDie(); std::unique_ptr data_123 = - client_->TransferToServer(*LiteralUtil::CreateR0(123.0f)) + client_->TransferToServer(*Literal::CreateR0(123.0f)) .ConsumeValueOrDie(); std::unique_ptr data_456 = - client_->TransferToServer(*LiteralUtil::CreateR0(456.0f)) + client_->TransferToServer(*Literal::CreateR0(456.0f)) .ConsumeValueOrDie(); ComputationBuilder builder(client_, TestName()); @@ -205,7 +203,6 @@ XLA_TEST_F(CompilationCacheTest, MutatedComputation) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 13c78fb1633..4384c9b3149 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -86,7 +85,7 @@ class ComputeConstantTest : public ::testing::Test { ComputationBuilder* builder) { TF_ASSIGN_OR_RETURN(auto literal, ComputeConstantLiteral(client, operand, builder)); - return LiteralUtil::Get(*literal, {}); + return literal->Get({}); } bool IsConstant(const ComputationDataHandle& operand, @@ -211,7 +210,7 @@ TEST_F(ComputeConstantTest, NonScalarAdd) { auto computed = ComputeConstantLiteral(client, computation, &b); ASSERT_TRUE(computed.ok()) << computed.status(); std::unique_ptr expected_literal = - LiteralUtil::CreateR1({4, 6}); + Literal::CreateR1({4, 6}); LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); } } @@ -225,7 +224,7 @@ TEST_F(ComputeConstantTest, IntegerDivide) { auto computed = ComputeConstantLiteral(client, computation, &b); ASSERT_TRUE(computed.ok()) << computed.status(); - std::unique_ptr expected_literal = LiteralUtil::CreateR0(5); + std::unique_ptr expected_literal = Literal::CreateR0(5); LiteralTestUtil::ExpectEqual(*expected_literal, *computed.ValueOrDie()); } } @@ -291,7 +290,6 @@ TEST_F(ComputeConstantTest, DISABLED_ON_CPU(ReuseComputedConstant)) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/concat_test.cc b/tensorflow/compiler/xla/tests/concat_test.cc index a7034930bc9..c5d88ad6a08 100644 --- a/tensorflow/compiler/xla/tests/concat_test.cc +++ b/tensorflow/compiler/xla/tests/concat_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -518,8 +517,8 @@ TEST_P(ConcatR2BinaryTest, DoIt) { // concat XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); - auto x_literal = LiteralUtil::CreateR0(2.f); - auto y_literal = LiteralUtil::CreateR0(3.f); + auto x_literal = Literal::CreateR0(2.f); + auto y_literal = Literal::CreateR0(3.f); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); @@ -540,9 +539,9 @@ XLA_TEST_F(ConcatTest, ConcatOperandsOfSameOperand) { // produces the correct result in rank 1. XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); - auto x_literal = LiteralUtil::CreateR1({2.0f, 3.0f, 5.0f, 6.0f}); - auto y_literal = LiteralUtil::CreateR0(1.5f); - auto z_literal = LiteralUtil::CreateR0(5.5f); + auto x_literal = Literal::CreateR1({2.0f, 3.0f, 5.0f, 6.0f}); + auto y_literal = Literal::CreateR0(1.5f); + auto z_literal = Literal::CreateR0(5.5f); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); @@ -568,9 +567,9 @@ XLA_TEST_F(ConcatTest, ConcatBroadcastArgument) { XLA_TEST_F(ConcatTest, ConcatBroadcastArgumentR3) { auto f32_scalar = ShapeUtil::MakeShape(xla::F32, {}); Array3D x3d(3, 5, 7, 3.14f); - auto x_literal = LiteralUtil::CreateR3FromArray3D(x3d); - auto y_literal = LiteralUtil::CreateR0(1.5f); - auto z_literal = LiteralUtil::CreateR0(5.5f); + auto x_literal = Literal::CreateR3FromArray3D(x3d); + auto y_literal = Literal::CreateR0(1.5f); + auto z_literal = Literal::CreateR0(5.5f); auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie(); auto y_data = client_->TransferToServer(*y_literal).ConsumeValueOrDie(); auto z_data = client_->TransferToServer(*z_literal).ConsumeValueOrDie(); @@ -607,7 +606,6 @@ INSTANTIATE_TEST_CASE_P(ConcatR2BinaryTestInstantiation, ConcatR2BinaryTest, int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 1c065de8ba7..7c276c8c8d0 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -113,7 +112,7 @@ TEST_F(ConstantsTest, Small_2x2) { TEST_F(ConstantsTest, Empty_3x0x2) { ComputationBuilder builder(client_, TestName()); auto constant = builder.ConstantLiteral( - *LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2))); + *Literal::CreateR3FromArray3D(Array3D(3, 0, 2))); ComputeAndCompareR3(&builder, Array3D(3, 0, 2), {}); } @@ -128,8 +127,8 @@ TEST_F(ConstantsTest, Small_2x2x2) { {{5.f, 6.f}, // y0 {7.f, 8.f}}, // y1 }); - auto constant = builder.ConstantLiteral( - *LiteralUtil::CreateR3FromArray3D(array3d)); + auto constant = + builder.ConstantLiteral(*Literal::CreateR3FromArray3D(array3d)); ComputeAndCompareR3(&builder, array3d, {}); } @@ -143,7 +142,7 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { {5.0f, 4.4f}, // p2 }); input_array.FillWithPZ(pz); - Literal input_literal = *LiteralUtil::CreateR4FromArray4D(input_array); + Literal input_literal = *Literal::CreateR4FromArray4D(input_array); { ComputationBuilder builder(client_, TestName()); @@ -161,9 +160,9 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { // TODO(b/29263943): Support tuple constants. TEST_F(ConstantsTest, DISABLED_TupleConstant) { ComputationBuilder builder(client_, TestName()); - builder.ConstantLiteral(*LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2({{1.0}, {2.0}}).get(), - LiteralUtil::CreateR1({2.0, 42}).get()})); + builder.ConstantLiteral( + *Literal::MakeTuple({Literal::CreateR2({{1.0}, {2.0}}).get(), + Literal::CreateR1({2.0, 42}).get()})); std::unique_ptr result = ExecuteAndTransferOrDie(&builder, {}); @@ -179,7 +178,6 @@ TEST_F(ConstantsTest, DISABLED_TupleConstant) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 6d379797250..f9652ed8cfc 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -197,7 +196,6 @@ TEST_F(ConvertTest, ConvertReshape) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 0b09416a747..fb50d9b0ebf 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -63,8 +62,7 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, auto weight_array = MakeUnique>(4, 3, 1, 1); weight_array->FillWithMultiples(0.2); auto weight_data = - client_ - ->TransferToServer(*LiteralUtil::CreateR4FromArray4D(*weight_array)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(*weight_array)) .ConsumeValueOrDie(); ComputationBuilder builder(client_, TestName()); @@ -102,7 +100,6 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index ec19469fa66..a110082f9a5 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -115,10 +114,10 @@ TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) { ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR4(&builder, *aexpected, @@ -158,10 +157,10 @@ TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) { ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR4(&builder, *aexpected, @@ -201,10 +200,10 @@ TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) { ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR4(&builder, *aexpected, @@ -246,10 +245,10 @@ TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) { ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(input)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR4FromArray4D(filter)) + client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR4(&builder, *aexpected, @@ -273,10 +272,10 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { Array3D expected({{{510, 610, 710, 810}}}); auto input_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(input)) + client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) .ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*LiteralUtil::CreateR3FromArray3D(filter)) + client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) .ConsumeValueOrDie(); ComputeAndCompareR3(&builder, expected, @@ -313,21 +312,18 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); std::iota(input_elems.begin(), input_elems.end(), 1.0f); - auto input_r1 = LiteralUtil::CreateR1(input_elems); - auto input_r5 = - LiteralUtil::Reshape(*input_r1, input_dims).ConsumeValueOrDie(); + auto input_r1 = Literal::CreateR1(input_elems); + auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); std::iota(filter_elems.begin(), filter_elems.end(), 1.0f); - auto filter_r1 = LiteralUtil::CreateR1(filter_elems); - auto filter_r5 = - LiteralUtil::Reshape(*filter_r1, filter_dims).ConsumeValueOrDie(); + auto filter_r1 = Literal::CreateR1(filter_elems); + auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); - auto expected_r1 = LiteralUtil::CreateR1( + auto expected_r1 = Literal::CreateR1( {19554, 19962, 20370, 22110, 22590, 23070, 34890, 35730, 36570, 37446, 38358, 39270, 50226, 51498, 52770, 52782, 54126, 55470}); - auto expected_r5 = - LiteralUtil::Reshape(*expected_r1, {1, 3, 1, 2, 3}).ConsumeValueOrDie(); + auto expected_r5 = expected_r1->Reshape({1, 3, 1, 2, 3}).ConsumeValueOrDie(); auto input_literal = client_->TransferToServer(*input_r5).ConsumeValueOrDie(); auto filter_literal = @@ -344,7 +340,6 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index b5afc2498da..c8e74aa01a5 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -28,7 +28,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -1312,20 +1311,19 @@ TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) { TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { ComputationBuilder builder(client_, TestName()); - auto gradients_flat = LiteralUtil::CreateR1({1}); + auto gradients_flat = Literal::CreateR1({1}); auto gradients_literal = - LiteralUtil::Reshape(*gradients_flat, {1, 1, 1, 1, 1}) - .ConsumeValueOrDie(); + gradients_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); auto gradients = builder.ConstantLiteral(*gradients_literal); - auto weights_flat = LiteralUtil::CreateR1({1, 10, 100}); + auto weights_flat = Literal::CreateR1({1, 10, 100}); auto weights_literal = - LiteralUtil::Reshape(*weights_flat, {1, 1, 1, 1, 3}).ConsumeValueOrDie(); + weights_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); auto weights = builder.ConstantLiteral(*weights_literal); - auto expected_flat = LiteralUtil::CreateR1({10}); + auto expected_flat = Literal::CreateR1({10}); auto expected_literal = - LiteralUtil::Reshape(*expected_flat, {1, 1, 1, 1, 1}).ConsumeValueOrDie(); + expected_flat->Reshape({1, 1, 1, 1, 1}).ConsumeValueOrDie(); auto mirrored_weights = builder.Rev(weights, {2, 3, 4}); builder.ConvWithGeneralPadding(gradients, mirrored_weights, @@ -1337,21 +1335,19 @@ TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { ComputationBuilder builder(client_, TestName()); - auto activations_flat = LiteralUtil::CreateR1({1, 2, 3, 4}); + auto activations_flat = Literal::CreateR1({1, 2, 3, 4}); auto activations_literal = - LiteralUtil::Reshape(*activations_flat, {1, 1, 1, 1, 4}) - .ConsumeValueOrDie(); + activations_flat->Reshape({1, 1, 1, 1, 4}).ConsumeValueOrDie(); auto activations = builder.ConstantLiteral(*activations_literal); - auto gradients_flat = LiteralUtil::CreateR1({100, 10, 1}); + auto gradients_flat = Literal::CreateR1({100, 10, 1}); auto gradients_literal = - LiteralUtil::Reshape(*gradients_flat, {1, 1, 1, 1, 3}) - .ConsumeValueOrDie(); + gradients_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); auto gradients = builder.ConstantLiteral(*gradients_literal); - auto expected_flat = LiteralUtil::CreateR1({13, 24, 130}); + auto expected_flat = Literal::CreateR1({13, 24, 130}); auto expected_literal = - LiteralUtil::Reshape(*expected_flat, {1, 1, 1, 1, 3}).ConsumeValueOrDie(); + expected_flat->Reshape({1, 1, 1, 1, 3}).ConsumeValueOrDie(); auto forward_conv = builder.ConvGeneralDilated( activations, gradients, @@ -1370,7 +1366,6 @@ TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index 4c2413d0fe4..76ae280f1a0 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -58,39 +57,34 @@ class CopyOpTest : public HloTestBase { tensorflow::gtl::ArraySlice permutation); }; -TEST_F(CopyOpTest, CopyR0Bool) { - TestCopyOp(*LiteralUtil::CreateR0(true)); -} +TEST_F(CopyOpTest, CopyR0Bool) { TestCopyOp(*Literal::CreateR0(true)); } -TEST_F(CopyOpTest, CopyR1S0U32) { - TestCopyOp(*LiteralUtil::CreateR1({})); -} +TEST_F(CopyOpTest, CopyR1S0U32) { TestCopyOp(*Literal::CreateR1({})); } TEST_F(CopyOpTest, CopyR1S3U32) { - TestCopyOp(*LiteralUtil::CreateR1({1, 2, 3})); + TestCopyOp(*Literal::CreateR1({1, 2, 3})); } TEST_F(CopyOpTest, CopyR3F32_2x2x3) { - TestCopyOp( - *LiteralUtil::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, - {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); + TestCopyOp(*Literal::CreateR3({{{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}, + {{1.1f, 2.1f, 3.1f}, {6.1f, 3.5f, 2.8f}}})); } TEST_F(CopyOpTest, CopyR4S32_2x2x3x2) { - TestCopyOp(*LiteralUtil::CreateR4( + TestCopyOp(*Literal::CreateR4( {{{{1, -2}, {-4, 5}, {6, 7}}, {{8, 9}, {10, 11}, {12, 13}}}, {{{10, 3}, {7, -2}, {3, 6}}, {{2, 5}, {-11, 5}, {-2, -5}}}})); } TEST_F(CopyOpTest, CopyR4S32_0x2x3x2) { - TestCopyOp(*LiteralUtil::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); + TestCopyOp(*Literal::CreateR4FromArray4D(Array4D(0, 2, 3, 2))); } TEST_F(CopyOpTest, CopyParameterScalar) { auto builder = HloComputation::Builder(TestName()); // Copy literal to device to use as parameter. - auto literal = LiteralUtil::CreateR0(42.0); + auto literal = Literal::CreateR0(42.0); Shape shape = literal->shape(); auto constant_device_base = TransferToDevice(*literal); @@ -112,7 +106,7 @@ TEST_F(CopyOpTest, CopyParameterScalar) { TEST_F(CopyOpTest, CopyConstantR2Twice) { auto builder = HloComputation::Builder(TestName()); - auto literal = LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto literal = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -134,7 +128,7 @@ TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { HloComputation::Builder builder(TestName()); std::unique_ptr literal = - LiteralUtil::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); // Reverse the minor-to-major order of the literal. Layout* literal_layout = literal->mutable_shape()->mutable_layout(); ASSERT_EQ(2, literal_layout->minor_to_major_size()); @@ -170,7 +164,7 @@ void CopyOpTest::TestCopyConstantLayout021(size_t n1, size_t n2, size_t n3) { HloComputation::Builder builder(TestName()); - std::unique_ptr literal = LiteralUtil::CreateR3FromArray3D(a); + std::unique_ptr literal = Literal::CreateR3FromArray3D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -204,7 +198,7 @@ void CopyOpTest::TestCopyConstantLayoutR4( HloComputation::Builder builder(TestName()); - std::unique_ptr literal = LiteralUtil::CreateR4FromArray4D(a); + std::unique_ptr literal = Literal::CreateR4FromArray4D(a); HloInstruction* constant = builder.AddInstruction( HloInstruction::CreateConstant(std::move(literal))); @@ -247,7 +241,7 @@ using CopyOpClientTest = ClientLibraryTestBase; XLA_TEST_F(CopyOpClientTest, Copy0x0) { Shape in_shape = ShapeUtil::MakeShapeWithLayout(F32, {0, 0}, {0, 1}); Shape out_shape = ShapeUtil::MakeShapeWithLayout(F32, {0, 0}, {1, 0}); - auto empty = LiteralUtil::CreateFromShape(in_shape); + auto empty = Literal::CreateFromShape(in_shape); ComputationBuilder builder(client_, TestName()); auto param0 = builder.Parameter(0, in_shape, "input"); @@ -263,7 +257,6 @@ XLA_TEST_F(CopyOpClientTest, Copy0x0) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 32232acf6e3..73772fdec02 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -68,7 +67,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR0F32Add2)) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(42.0f))); + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); builder.AddInstruction( HloInstruction::CreateCustomCall(r0f32_, {constant}, "R0F32Add2")); @@ -89,7 +88,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { array(1, 1) = 4.0f; auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D(array))); + HloInstruction::CreateConstant(Literal::CreateR2FromArray2D(array))); builder.AddInstruction( HloInstruction::CreateCustomCall(r0f32_, {constant}, "R2F32ReduceSum")); @@ -105,7 +104,7 @@ XLA_TEST_F(CustomCallTest, auto b = HloComputation::Builder(TestName()); auto input = b.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2FromArray2D( + HloInstruction::CreateConstant(Literal::CreateR2FromArray2D( Array2D{{1.0f, 2.0f}, {3.0f, 4.0f}}))); auto incremented = b.AddInstruction(HloInstruction::CreateCustomCall( ShapeUtil::MakeShape(F32, {1, 2, 2}), {input}, "Add1ToValues")); @@ -129,7 +128,6 @@ XLA_TEST_F(CustomCallTest, int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/deallocation_test.cc b/tensorflow/compiler/xla/tests/deallocation_test.cc index 074753bf6f8..0c7c3a8ff66 100644 --- a/tensorflow/compiler/xla/tests/deallocation_test.cc +++ b/tensorflow/compiler/xla/tests/deallocation_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -42,7 +41,8 @@ class DeallocationTest : public ClientLibraryTestBase { tensorflow::gtl::ArraySlice arguments) { Computation computation = builder->Build().ConsumeValueOrDie(); auto global_data = - client_->Execute(computation, arguments).ConsumeValueOrDie(); + client_->Execute(computation, arguments, &execution_options_) + .ConsumeValueOrDie(); TF_CHECK_OK(client_->Transfer(*global_data).status()); return global_data; } @@ -143,7 +143,6 @@ XLA_TEST_F(DeallocationTest, DeallocateNestedTuple) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc index fcddffc1e13..3d6a995a245 100644 --- a/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc +++ b/tensorflow/compiler/xla/tests/deconstruct_tuple_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -48,7 +47,8 @@ class DeconstructTupleTest : public ClientLibraryTestBase { tensorflow::gtl::ArraySlice arguments) { Computation computation = builder->Build().ConsumeValueOrDie(); auto global_data = - client_->Execute(computation, arguments).ConsumeValueOrDie(); + client_->Execute(computation, arguments, &execution_options_) + .ConsumeValueOrDie(); TF_CHECK_OK(client_->Transfer(*global_data).status()); return global_data; } @@ -173,7 +173,7 @@ TEST_F(DeconstructTupleTest, DeconstructNonTuple) { XLA_TEST_F(DeconstructTupleTest, DeconstructTupleFromParam) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({3.14f, -100.25f}); + Literal::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto p = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2}), "param0"); @@ -205,7 +205,6 @@ XLA_TEST_F(DeconstructTupleTest, DeconstructNestedTuple) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/deep_graph_test.cc b/tensorflow/compiler/xla/tests/deep_graph_test.cc new file mode 100644 index 00000000000..7a5601ada30 --- /dev/null +++ b/tensorflow/compiler/xla/tests/deep_graph_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" +#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" + +namespace xla { +namespace { +TEST_F(ClientLibraryTestBase, DeepGraph) { + // TODO(b/62624812): To trigger the stack overflow this test is + // intended to track, we need to set kDepth to 20000. + // Unfortunately, setting it that high causes the test to time out. + const int kDepth = 200; + ComputationBuilder b(client_, TestName()); + ComputationDataHandle x; + ComputationDataHandle y; + auto x_data = CreateR0Parameter(3, 0, "x", &b, &x); + auto y_data = CreateR0Parameter(1, 1, "y", &b, &y); + ComputationDataHandle z = x; + for (int i = 0; i < kDepth; ++i) { + z = b.Add(z, y); + } + ComputeAndCompareR0(&b, /*expected=*/kDepth + 3, + {x_data.get(), y_data.get()}); +} +} // namespace +} // namespace xla + +int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + xla::legacy_flags::AppendUserComputationFlags(&flag_list); + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parse_result) { + LOG(ERROR) << "\n" << usage; + return 2; + } + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 754eec1b1ed..ac64e2ee9e7 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/cpu_runtime_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/legacy_flags/layout_util_flags.h" @@ -186,14 +185,14 @@ void DotOperationTest::TestMatrixDot(int M, int K, int N, bool lhs_row_major, bool rhs_row_major) { std::unique_ptr> lhs_data = MakeLinspaceArray2D(0.0, 1.0, M, K); - std::unique_ptr lhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout( + std::unique_ptr lhs_lit = Literal::CreateR2FromArray2DWithLayout( *lhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))); auto lhs_handle = client_->TransferToServer(*lhs_lit).ConsumeValueOrDie(); std::unique_ptr> rhs_data = MakeLinspaceArray2D(0.0, 1.0, K, N); - std::unique_ptr rhs_lit = LiteralUtil::CreateR2FromArray2DWithLayout( + std::unique_ptr rhs_lit = Literal::CreateR2FromArray2DWithLayout( *rhs_data, LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))); auto rhs_handle = client_->TransferToServer(*rhs_lit).ConsumeValueOrDie(); @@ -380,12 +379,12 @@ XLA_TEST_F(DotOperationTest, BatchMatMul) { builder.Reshape(out_flat, {0, 1, 2}, {2, 2, 2, 2}); auto x_data = client_ - ->TransferToServer(*LiteralUtil::CreateR4( + ->TransferToServer(*Literal::CreateR4( {{{{1000, 100}, {10, 1}}, {{2000, 200}, {20, 2}}}, {{{3000, 300}, {30, 3}}, {{4000, 400}, {40, 4}}}})) .ConsumeValueOrDie(); auto y_data = client_ - ->TransferToServer(*LiteralUtil::CreateR4( + ->TransferToServer(*Literal::CreateR4( {{{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}, {{{11, 22}, {33, 44}}, {{55, 66}, {77, 88}}}})) .ConsumeValueOrDie(); @@ -416,14 +415,14 @@ TEST_F(DotOperationTest, TransposeFolding) { auto lhs_handle = client_ ->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( + *Literal::CreateR2FromArray2DWithLayout( *lhs, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); auto rhs_handle = client_ ->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( + *Literal::CreateR2FromArray2DWithLayout( *rhs, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(row_major)))) .ConsumeValueOrDie(); @@ -463,7 +462,6 @@ int main(int argc, char** argv) { xla::legacy_flags::AppendLayoutUtilFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::legacy_flags::AppendCpuRuntimeFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index b7bb1792f3b..f653766f39d 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -389,8 +388,8 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { template void DumpArray(const string& name, const Array3D values) { std::unique_ptr literal = - LiteralUtil::CreateR3FromArray3D(values); - LOG(INFO) << name << ":" << LiteralUtil::ToString(*literal); + Literal::CreateR3FromArray3D(values); + LOG(INFO) << name << ":" << literal->ToString(); } }; @@ -470,7 +469,7 @@ void BM_DynamicSlice(int num_iters) { ComputationBuilder builder(client, "DynamicSlice"); // Create input as a constant: shape [1, 2, 3, 4] - auto input_literal = LiteralUtil::CreateR4( + auto input_literal = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); auto input = builder.ConstantLiteral(*input_literal); @@ -488,7 +487,7 @@ void BM_DynamicSlice(int num_iters) { &allocator, 0) .ConsumeValueOrDie(); - auto start_indices_literal = LiteralUtil::CreateR1({0, 1, 2, 3}); + auto start_indices_literal = Literal::CreateR1({0, 1, 2, 3}); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( executors[device_ordinal], *start_indices_literal, buffer->mutable_buffer({}))); @@ -521,7 +520,6 @@ BENCHMARK(BM_DynamicSlice); int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/floor_ceil_test.cc b/tensorflow/compiler/xla/tests/floor_ceil_test.cc index 80267e5459d..90c5aa65592 100644 --- a/tensorflow/compiler/xla/tests/floor_ceil_test.cc +++ b/tensorflow/compiler/xla/tests/floor_ceil_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -114,7 +113,6 @@ TEST_F(FloorCeilTest, R0Ceil) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/fmax_test.cc b/tensorflow/compiler/xla/tests/fmax_test.cc index ee4e92505d9..9c86c65e5bb 100644 --- a/tensorflow/compiler/xla/tests/fmax_test.cc +++ b/tensorflow/compiler/xla/tests/fmax_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -47,7 +46,6 @@ TEST_F(FmaxSimpleTest, FmaxTenValues) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index fa36381267e..7803d234fdf 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -20,7 +20,9 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -29,7 +31,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -37,10 +41,13 @@ limitations under the License. #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test_benchmark.h" #include "tensorflow/core/platform/types.h" using tensorflow::gtl::ArraySlice; +namespace se = ::perftools::gputools; + namespace xla { namespace { @@ -81,7 +88,7 @@ class FusionTest : public HloTestBase { HloInstruction* hlos[4]; for (int i = 0; i < Arity; ++i) { hlos[i + 1] = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2FromArray2D(operand_data[i]))); + Literal::CreateR2FromArray2D(operand_data[i]))); } auto answer_shape = ShapeUtil::MakeShape(prim_type, {test_width, test_height}); @@ -107,7 +114,7 @@ class FusionTest : public HloTestBase { ArraySlice(hlos, 0, Arity + 1), HloInstruction::FusionKind::kLoop); - auto expected = LiteralUtil::CreateR2FromArray2D(answer_data); + auto expected = Literal::CreateR2FromArray2D(answer_data); auto actual = ExecuteAndTransfer(std::move(hlo_module), {}); if (primitive_util::IsFloatingPointType(prim_type)) { LiteralTestUtil::ExpectNear(*expected, *actual, ErrorSpec(1e-4)); @@ -178,28 +185,27 @@ XLA_TEST_F(FusionTest, Test) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0}, {2.0}, {3.0}}))); + Literal::CreateR2({{1.0}, {2.0}, {3.0}}))); auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{-1.0}, {-1.0}, {-1.0}}))); + Literal::CreateR2({{-1.0}, {-1.0}, {-1.0}}))); auto add2 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {3, 1}), HloOpcode::kAdd, const0, const1)); auto reshape3 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {1, 3}), add2, {1, 0})); auto const4 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.62, 2.72, 3.14}}))); + Literal::CreateR2({{1.62, 2.72, 3.14}}))); auto concat5 = builder.AddInstruction(HloInstruction::CreateConcatenate( ShapeUtil::MakeShape(F32, {2, 3}), {reshape3, const4}, 0)); auto const6 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}))); + Literal::CreateR2({{1.0, 1.0, 1.0}, {0.0, 0.0, 0.0}}))); auto negate7 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kNegate, const6)); auto add8 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kAdd, concat5, negate7)); auto const9 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}}))); - auto const10 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2( - {{true, false, true}, {false, true, false}}))); + Literal::CreateR2({{0.5, 0.5, 0.5}, {0.5, 0.5, 0.5}}))); + auto const10 = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{true, false, true}, {false, true, false}}))); auto select11 = builder.AddInstruction( HloInstruction::CreateTernary(ShapeUtil::MakeShape(F32, {2, 3}), HloOpcode::kSelect, const10, add8, const9)); @@ -214,7 +220,7 @@ XLA_TEST_F(FusionTest, Test) { const4, reshape3, add2, const1, const0}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR2({{0.5}, {2.72}}), + LiteralTestUtil::ExpectNear(*Literal::CreateR2({{0.5}, {2.72}}), *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)); } @@ -226,11 +232,11 @@ XLA_TEST_F(FusionTest, Parameter) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1.0, 2.0, 3.0}}))); + Literal::CreateR2({{1.0, 2.0, 3.0}}))); auto copy1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kCopy, const0)); auto const2 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{-2.0, -2.0, -2.0}}))); + Literal::CreateR2({{-2.0, -2.0, -2.0}}))); // add3 = copy1 + const2 = const0 + const2 = {1,2,3} + {-2,-2,-2} = {-1,0,+1} auto add3 = builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {1, 3}), HloOpcode::kAdd, copy1, const2)); @@ -240,7 +246,7 @@ XLA_TEST_F(FusionTest, Parameter) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{add3, const2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectNear(*LiteralUtil::CreateR2({{-1.0, 0.0, 1.0}}), + LiteralTestUtil::ExpectNear(*Literal::CreateR2({{-1.0, 0.0, 1.0}}), *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)); } @@ -249,9 +255,9 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const_vector = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1.0, 2.0, 3.0}))); + Literal::CreateR1({1.0, 2.0, 3.0}))); auto const_array = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}))); + Literal::CreateR2({{-1.0, -2.0, -4.0}, {10.0, 20.0, 30.0}}))); auto broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(const_array->shape(), const_vector, {1})); // add2 = broadcast(const_vector) + const_array @@ -265,7 +271,7 @@ XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectNear( - *LiteralUtil::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), + *Literal::CreateR2({{0.0, 0.0, -1.0}, {11.0, 22.0, 33.0}}), *ExecuteAndTransfer(std::move(hlo_module), {}), ErrorSpec(1e-4)); } @@ -273,13 +279,13 @@ XLA_TEST_F(FusionTest, ReshapeToScalar) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto single_element_array = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR2({{5}}))); + HloInstruction::CreateConstant(Literal::CreateR2({{5}}))); auto reshape = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {}), single_element_array)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0(5), + LiteralTestUtil::ExpectEqual(*Literal::CreateR0(5), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -287,14 +293,14 @@ XLA_TEST_F(FusionTest, Reshape_3by2_1by2by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); + Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 2, 3}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectEqual( - *LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), + *Literal::CreateR3({{{1, 2, 3}, {4, 5, 6}}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -302,14 +308,14 @@ XLA_TEST_F(FusionTest, Reshape_1by2by3_3by2) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR3({{{1, 2, 3}, {4, 5, 6}}}))); + Literal::CreateR3({{{1, 2, 3}, {4, 5, 6}}}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 2}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectEqual( - *LiteralUtil::CreateR2({{1, 2}, {3, 4}, {5, 6}}), + *Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -317,13 +323,13 @@ XLA_TEST_F(FusionTest, Reshape_1by1by1_) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR3({{{7}}}))); + HloInstruction::CreateConstant(Literal::CreateR3({{{7}}}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0(7), + LiteralTestUtil::ExpectEqual(*Literal::CreateR0(7), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -331,13 +337,13 @@ XLA_TEST_F(FusionTest, Reshape__1by1by1) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); + HloInstruction::CreateConstant(Literal::CreateR0(7))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1, 1, 1}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR3({{{7}}}), + LiteralTestUtil::ExpectEqual(*Literal::CreateR3({{{7}}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -345,13 +351,13 @@ XLA_TEST_F(FusionTest, Reshape__) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(7))); + HloInstruction::CreateConstant(Literal::CreateR0(7))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0(7), + LiteralTestUtil::ExpectEqual(*Literal::CreateR0(7), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -359,14 +365,14 @@ XLA_TEST_F(FusionTest, Reshape_3by3_3by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(S32, {3, 3}), const0)); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectEqual( - *LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), + *Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -374,14 +380,14 @@ XLA_TEST_F(FusionTest, Transpose_2by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}}))); + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {3, 2}), const0, {1, 0})); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectEqual( - *LiteralUtil::CreateR2({{1, 4}, {2, 5}, {3, 6}}), + *Literal::CreateR2({{1, 4}, {2, 5}, {3, 6}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -389,14 +395,14 @@ XLA_TEST_F(FusionTest, Transpose_3by3) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); + Literal::CreateR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}))); auto reshape1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {3, 3}), const0, {1, 0})); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reshape1}, HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectEqual( - *LiteralUtil::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), + *Literal::CreateR2({{1, 4, 7}, {2, 5, 8}, {3, 6, 9}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -404,14 +410,14 @@ XLA_TEST_F(FusionTest, Reverse) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR1({1, 2, 3}))); + HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3}))); auto reverse1 = builder.AddInstruction(HloInstruction::CreateReverse( ShapeUtil::MakeShape(S32, {3}), const0, {0})); hlo_module->AddEntryComputation(builder.Build()) ->CreateFusionInstruction(/*instructions_to_fuse=*/{reverse1}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR1({3, 2, 1}), + LiteralTestUtil::ExpectEqual(*Literal::CreateR1({3, 2, 1}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -430,10 +436,10 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { auto hlo_module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1, 2, 4, 8}))); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 4, 8}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); @@ -441,7 +447,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(Reduce)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{reduce2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR0(15), + LiteralTestUtil::ExpectEqual(*Literal::CreateR0(15), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -449,10 +455,10 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { auto hlo_module = CreateNewModule(); auto builder = HloComputation::Builder(TestName()); - auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR1({1, 2, 4, 8}))); + auto const0 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 4, 8}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(0))); + HloInstruction::CreateConstant(Literal::CreateR0(0))); auto reduce2 = builder.AddInstruction(HloInstruction::CreateReduce( ShapeUtil::MakeShape(S32, {}), const0, const1, {0}, hlo_module->AddEmbeddedComputation(MakeReduceTestComputation()))); @@ -462,7 +468,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceImplicitBroadcast)) { ->CreateFusionInstruction(/*instructions_to_fuse=*/{negate3, reduce2}, HloInstruction::FusionKind::kLoop); - LiteralTestUtil::ExpectEqual(*LiteralUtil::CreateR1({-15}), + LiteralTestUtil::ExpectEqual(*Literal::CreateR1({-15}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -470,9 +476,9 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - LiteralUtil::CreateR2({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); + Literal::CreateR2({{2, 3, 5}, {7, 11, 13}, {17, 19, 23}}))); auto const1 = builder.AddInstruction( - HloInstruction::CreateConstant(LiteralUtil::CreateR0(1))); + HloInstruction::CreateConstant(Literal::CreateR0(1))); Window window; ASSERT_TRUE( tensorflow::protobuf::TextFormat::ParseFromString("dimensions:{\n" @@ -512,7 +518,7 @@ XLA_TEST_F(FusionTest, DISABLED_ON_CPU(ReduceWindow)) { HloInstruction::FusionKind::kLoop); LiteralTestUtil::ExpectEqual( - *LiteralUtil::CreateR2({{462, 2145}, {24871, 62491}}), + *Literal::CreateR2({{462, 2145}, {24871, 62491}}), *ExecuteAndTransfer(std::move(hlo_module), {})); } @@ -568,12 +574,66 @@ XLA_TEST_F(FusionTest, Clamp2D) { TestElementwise2D(HloOpcode::kClamp); } +void BM_ParallelFusion(int num_iters) { + // Simple element-wise computation to benchmark parallel task partitioning. + tensorflow::testing::StopTiming(); + + se::Platform* platform = PlatformUtil::GetDefaultPlatform().ValueOrDie(); + auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); + StreamExecutorMemoryAllocator allocator(platform, executors); + + const int64 intra_op_parallelism_threads = 16; + xla::LocalClientOptions client_options; + client_options.set_platform(platform); + client_options.set_intra_op_parallelism_threads(intra_op_parallelism_threads); + auto client = + ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie(); + + const int64 dim_size = 1024; + // Create a simple fusable elementwise computation. + ComputationBuilder builder(client, "ParallelFusion"); + Shape input_shape = ShapeUtil::MakeShape(F32, {dim_size, dim_size}); + auto input0 = builder.Broadcast(builder.ConstantR0(1.5f), + AsInt64Slice(input_shape.dimensions())); + auto input1 = builder.Broadcast(builder.ConstantR0(2.0f), + AsInt64Slice(input_shape.dimensions())); + auto input2 = builder.Broadcast(builder.ConstantR0(3.0f), + AsInt64Slice(input_shape.dimensions())); + auto x = builder.Mul(input0, input1); + auto y = builder.Add(x, input2); + auto computation = builder.Build().ConsumeValueOrDie(); + + std::unique_ptr executable = + client->Compile(computation, {}, ExecutableBuildOptions()) + .ConsumeValueOrDie(); + + // Run some warm-up executions. + ExecutableRunOptions options; + options.set_allocator(&allocator); + const int kWarmups = 2; + for (int i = 0; i < kWarmups; ++i) { + auto result = executable->Run({}, options); + ASSERT_TRUE(result.ok()); + } + + // Run benchmark. + tensorflow::testing::BytesProcessed(static_cast(num_iters) * dim_size * + dim_size * sizeof(float)); + tensorflow::testing::UseRealTime(); + tensorflow::testing::StartTiming(); + for (int i = 0; i < num_iters; ++i) { + auto result = executable->Run({}, options); + ASSERT_TRUE(result.ok()); + } +} + +BENCHMARK(BM_ParallelFusion); + } // namespace } // namespace xla int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); @@ -586,5 +646,6 @@ int main(int argc, char** argv) { LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; return 2; } + tensorflow::testing::RunBenchmarks(); return RUN_ALL_TESTS(); } diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 5f7b7aa434e..354e4a84c5c 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -77,6 +77,7 @@ HloTestBase::~HloTestBase() { } } +/* static */ std::unique_ptr HloTestBase::CreateNewModule() { HloModuleConfig config; config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 98bc35ae528..2b6a2e9672c 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -48,7 +48,7 @@ class HloTestBase : public ::testing::Test { // TestName() for its name; it will also automatically populate its debug // options from command-line flags. It's recommended to use this method to // create all HloModules for tests. - std::unique_ptr CreateNewModule(); + static std::unique_ptr CreateNewModule(); // Executes the given module and returns a global data handle. StatusOr Execute( diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index eb979ad189d..69c12cc437b 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -41,20 +41,25 @@ namespace xla { /* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected, const Shape& actual) { - ASSERT_EQ(ShapeUtil::Rank(expected), ShapeUtil::Rank(actual)); - ASSERT_EQ(expected.element_type(), actual.element_type()) - << PrimitiveType_Name(expected.element_type()) << " vs " - << PrimitiveType_Name(actual.element_type()); - ASSERT_EQ(expected.dimensions_size(), actual.dimensions_size()); - for (int i = 0; i < expected.dimensions_size(); ++i) { - ASSERT_EQ(expected.dimensions(i), actual.dimensions(i)) - << "mismatch in dimension #" << i - << " expected: " << ShapeUtil::HumanString(expected) - << " actual: " << ShapeUtil::HumanString(actual); - } - ASSERT_EQ(expected.tuple_shapes_size(), actual.tuple_shapes_size()); - for (int i = 0; i < expected.tuple_shapes_size(); ++i) { - AssertEqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + ASSERT_EQ(ShapeUtil::IsTuple(expected), ShapeUtil::IsTuple(actual)); + if (ShapeUtil::IsTuple(expected)) { + ASSERT_EQ(ShapeUtil::TupleElementCount(expected), + ShapeUtil::TupleElementCount(actual)); + for (int i = 0; i < expected.tuple_shapes_size(); ++i) { + AssertEqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + } + } else { + ASSERT_EQ(ShapeUtil::Rank(expected), ShapeUtil::Rank(actual)); + ASSERT_EQ(expected.element_type(), actual.element_type()) + << PrimitiveType_Name(expected.element_type()) << " vs " + << PrimitiveType_Name(actual.element_type()); + ASSERT_EQ(expected.dimensions_size(), actual.dimensions_size()); + for (int i = 0; i < expected.dimensions_size(); ++i) { + ASSERT_EQ(expected.dimensions(i), actual.dimensions(i)) + << "mismatch in dimension #" << i + << " expected: " << ShapeUtil::HumanString(expected) + << " actual: " << ShapeUtil::HumanString(actual); + } } } @@ -128,8 +133,8 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, tensorflow::gtl::MutableArraySlice multi_index, int64 dimension) { if (dimension == expected.shape().dimensions_size()) { - NativeT expected_value = LiteralUtil::Get(expected, multi_index); - NativeT actual_value = LiteralUtil::Get(actual, multi_index); + NativeT expected_value = expected.Get(multi_index); + NativeT actual_value = actual.Get(multi_index); ::testing::AssertionResult result = CompareEqual(expected_value, actual_value); return result; // Defines implicit coersion to bool. @@ -148,10 +153,10 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, /* static */ void LiteralTestUtil::ExpectEqual(const Literal& expected, const Literal& actual) { - EXPECT_TRUE(Equal(expected, actual)) << "expected:\n" - << LiteralUtil::ToString(expected) - << "\n\tvs actual:\n" - << LiteralUtil::ToString(actual); + EXPECT_TRUE(Equal(expected, actual)) + << "expected:\n" + << expected.ToString() << "\n\tvs actual:\n" + << actual.ToString(); } /* static */ void LiteralTestUtil::ExpectNotEqual(const Literal& expected, @@ -161,8 +166,8 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, /* static */ ::testing::AssertionResult LiteralTestUtil::Equal( const Literal& expected, const Literal& actual) { - VLOG(1) << "expected: " << LiteralUtil::ToString(expected); - VLOG(1) << "actual: " << LiteralUtil::ToString(actual); + VLOG(1) << "expected: " << expected.ToString(); + VLOG(1) << "actual: " << actual.ToString(); AssertEqualShapes(expected.shape(), actual.shape()); std::vector multi_index(expected.shape().dimensions_size(), 0); @@ -210,8 +215,8 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, ::testing::AssertionResult result = ::testing::AssertionSuccess(); if (!match) { result = ::testing::AssertionFailure() - << "expected: " << LiteralUtil::ToString(expected) - << "\nactual: " << LiteralUtil::ToString(actual); + << "expected: " << expected.ToString() + << "\nactual: " << actual.ToString(); VLOG(1) << result.message(); } return result; @@ -219,8 +224,8 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, /* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected, const Literal& actual) { - VLOG(1) << "expected: " << LiteralUtil::ToString(expected); - VLOG(1) << "actual: " << LiteralUtil::ToString(actual); + VLOG(1) << "expected: " << expected.ToString(); + VLOG(1) << "actual: " << actual.ToString(); ASSERT_TRUE(ShapeUtil::IsTuple(expected.shape())); ASSERT_TRUE(ShapeUtil::IsTuple(actual.shape())); @@ -247,8 +252,8 @@ class NearComparator { // within the error bound. Emits useful log messages and dumps literals to // temporary files on failure. Returns true if literals match. bool ExpectNear(const Literal& expected, const Literal& actual) { - VLOG(1) << "expected: " << LiteralUtil::ToString(expected); - VLOG(1) << "actual: " << LiteralUtil::ToString(actual); + VLOG(1) << "expected: " << expected.ToString(); + VLOG(1) << "actual: " << actual.ToString(); LiteralTestUtil::AssertEqualShapes(expected.shape(), actual.shape()); @@ -282,9 +287,9 @@ class NearComparator { if (num_miscompares_ > 0) { if (!VLOG_IS_ON(1)) { LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) - << " " << LiteralUtil::ToString(expected); + << " " << expected.ToString(); LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) - << " " << LiteralUtil::ToString(actual); + << " " << actual.ToString(); } EXPECT_TRUE(num_miscompares_ == 0) << "\nmax relative mismatch at index " @@ -369,10 +374,9 @@ class NearComparator { void ExpectLiteralsNear(const Literal& expected, const Literal& actual, int64 dimension) { if (dimension == expected.shape().dimensions_size()) { - bool near = - ExpectValuesNear(LiteralUtil::Get(expected, multi_index_), - LiteralUtil::Get(actual, multi_index_)); - LiteralUtil::Set(&miscompares_, multi_index_, !near); + bool near = ExpectValuesNear(expected.Get(multi_index_), + actual.Get(multi_index_)); + miscompares_.Set(multi_index_, !near); } else { for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { multi_index_[dimension] = i; @@ -437,8 +441,8 @@ class NearComparator { /* static */ ::testing::AssertionResult LiteralTestUtil::NearTuple( const Literal& expected, const Literal& actual, const ErrorSpec& error) { - VLOG(1) << "expected: " << LiteralUtil::ToString(expected); - VLOG(1) << "actual: " << LiteralUtil::ToString(actual); + VLOG(1) << "expected: " << expected.ToString(); + VLOG(1) << "actual: " << actual.ToString(); if (!ShapeUtil::IsTuple(expected.shape()) || !ShapeUtil::IsTuple(actual.shape())) { @@ -504,8 +508,7 @@ class NearComparator { *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); // Allocate space in the new literal. - LiteralUtil::Reserve(ShapeUtil::ElementsIn(literal.shape()), - new_literal.get()); + new_literal.get()->Reserve(ShapeUtil::ElementsIn(literal.shape())); // Copy data into new literal, element-by-element. for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { @@ -515,44 +518,36 @@ class NearComparator { IndexUtil::LinearIndexToMultidimensionalIndex(shape_with_layout, i); switch (literal.shape().element_type()) { case PRED: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal.get()->Set(to_multi_index, + literal.Get(from_multi_index)); break; case U8: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal.get()->Set(to_multi_index, + literal.Get(from_multi_index)); break; case U32: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal.get()->Set(to_multi_index, + literal.Get(from_multi_index)); break; case S32: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal.get()->Set(to_multi_index, + literal.Get(from_multi_index)); break; case U64: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal.get()->Set(to_multi_index, + literal.Get(from_multi_index)); break; case S64: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal.get()->Set(to_multi_index, + literal.Get(from_multi_index)); break; case F32: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal.get()->Set(to_multi_index, + literal.Get(from_multi_index)); break; case F64: - LiteralUtil::Set( - new_literal.get(), to_multi_index, - LiteralUtil::Get(literal, from_multi_index)); + new_literal.get()->Set(to_multi_index, + literal.Get(from_multi_index)); break; default: LOG(FATAL) << "Unhandled primitive element type: " diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index a8b07a2c5d1..0def25f34e5 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -210,20 +210,20 @@ class LiteralTestUtil { template /* static */ void LiteralTestUtil::ExpectR0Equal(NativeT expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR0(expected), actual); + ExpectEqual(*Literal::CreateR0(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR1Equal( tensorflow::gtl::ArraySlice expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR1(expected), actual); + ExpectEqual(*Literal::CreateR1(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR2Equal( std::initializer_list> expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR2(expected), actual); + ExpectEqual(*Literal::CreateR2(expected), actual); } template @@ -231,46 +231,46 @@ template std::initializer_list>> expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR3(expected), actual); + ExpectEqual(*Literal::CreateR3(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR2EqualArray2D( const Array2D& expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR2FromArray2D(expected), actual); + ExpectEqual(*Literal::CreateR2FromArray2D(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR3EqualArray3D( const Array3D& expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR3FromArray3D(expected), actual); + ExpectEqual(*Literal::CreateR3FromArray3D(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR4EqualArray4D( const Array4D& expected, const Literal& actual) { - ExpectEqual(*LiteralUtil::CreateR4FromArray4D(expected), actual); + ExpectEqual(*Literal::CreateR4FromArray4D(expected), actual); } template /* static */ void LiteralTestUtil::ExpectR0Near(NativeT expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR0(expected), actual, error); + ExpectNear(*Literal::CreateR0(expected), actual, error); } template /* static */ void LiteralTestUtil::ExpectR1Near( tensorflow::gtl::ArraySlice expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR1(expected), actual, error); + ExpectNear(*Literal::CreateR1(expected), actual, error); } template /* static */ void LiteralTestUtil::ExpectR2Near( std::initializer_list> expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR2(expected), actual, error); + ExpectNear(*Literal::CreateR2(expected), actual, error); } template @@ -278,28 +278,28 @@ template std::initializer_list>> expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR3(expected), actual, error); + ExpectNear(*Literal::CreateR3(expected), actual, error); } template /* static */ void LiteralTestUtil::ExpectR2NearArray2D( const Array2D& expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR2FromArray2D(expected), actual, error); + ExpectNear(*Literal::CreateR2FromArray2D(expected), actual, error); } template /* static */ void LiteralTestUtil::ExpectR3NearArray3D( const Array3D& expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR3FromArray3D(expected), actual, error); + ExpectNear(*Literal::CreateR3FromArray3D(expected), actual, error); } template /* static */ void LiteralTestUtil::ExpectR4NearArray4D( const Array4D& expected, const Literal& actual, const ErrorSpec& error) { - ExpectNear(*LiteralUtil::CreateR4FromArray4D(expected), actual, error); + ExpectNear(*Literal::CreateR4FromArray4D(expected), actual, error); } template @@ -309,9 +309,9 @@ LiteralTestUtil::CreateRandomLiteral( const std::function)>& generator) { using NativeT = typename primitive_util::PrimitiveTypeToNative::type; TF_RET_CHECK(shape.element_type() == type); - std::unique_ptr literal = LiteralUtil::CreateFromShape(shape); - TF_RETURN_IF_ERROR(LiteralUtil::Populate( - literal.get(), [&](tensorflow::gtl::ArraySlice indexes) { + std::unique_ptr literal = Literal::CreateFromShape(shape); + TF_RETURN_IF_ERROR(literal.get()->Populate( + [&](tensorflow::gtl::ArraySlice indexes) { return generator(indexes); })); return std::move(literal); diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index a94f45f73b7..2acf27ed390 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -31,9 +31,8 @@ namespace xla { namespace { TEST(LiteralTestUtilTest, ComparesEqualTuplesEqual) { - std::unique_ptr literal = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR0(64).get(), + std::unique_ptr literal = Literal::MakeTuple({ + Literal::CreateR0(42).get(), Literal::CreateR0(64).get(), }); LiteralTestUtil::ExpectEqual(*literal, *literal); } @@ -43,13 +42,11 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { // un-fail an assertion failure. The CHECK-failure is death, so we can make a // death assertion. auto unequal_things_are_equal = [] { - std::unique_ptr lhs = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(42).get(), - LiteralUtil::CreateR0(64).get(), + std::unique_ptr lhs = Literal::MakeTuple({ + Literal::CreateR0(42).get(), Literal::CreateR0(64).get(), }); - std::unique_ptr rhs = LiteralUtil::MakeTuple({ - LiteralUtil::CreateR0(64).get(), - LiteralUtil::CreateR0(42).get(), + std::unique_ptr rhs = Literal::MakeTuple({ + Literal::CreateR0(64).get(), Literal::CreateR0(42).get(), }); CHECK(LiteralTestUtil::Equal(*lhs, *rhs)) << "LHS and RHS are unequal"; }; @@ -58,8 +55,8 @@ TEST(LiteralTestUtilTest, ComparesUnequalTuplesUnequal) { TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { auto dummy_lambda = [] { - auto two = LiteralUtil::CreateR0(2); - auto four = LiteralUtil::CreateR0(4); + auto two = Literal::CreateR0(2); + auto four = Literal::CreateR0(4); ErrorSpec error(0.001); CHECK(LiteralTestUtil::Near(*two, *four, error)) << "two is not near four"; }; @@ -88,11 +85,11 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { &literal_proto)); Literal literal(literal_proto); if (result.find("expected") != string::npos) { - EXPECT_EQ("2", LiteralUtil::ToString(literal)); + EXPECT_EQ("2", literal.ToString()); } else if (result.find("actual") != string::npos) { - EXPECT_EQ("4", LiteralUtil::ToString(literal)); + EXPECT_EQ("4", literal.ToString()); } else if (result.find("miscompares") != string::npos) { - EXPECT_EQ("true", LiteralUtil::ToString(literal)); + EXPECT_EQ("true", literal.ToString()); } else { FAIL() << "unknown file in temporary directory: " << result; } diff --git a/tensorflow/compiler/xla/tests/log_test.cc b/tensorflow/compiler/xla/tests/log_test.cc index 796f43ea4ed..4cb383a78df 100644 --- a/tensorflow/compiler/xla/tests/log_test.cc +++ b/tensorflow/compiler/xla/tests/log_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -62,7 +61,6 @@ TEST_F(LogTest, LogTenValues) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index e4dbd6864a3..ffa87348a00 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -170,7 +169,7 @@ class MapTest : public ClientLibraryTestBase { TEST_F(MapTest, MapEachElemPlusOneR0) { // Applies lambda (x) (+ x 1)) to an input scalar. ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(42.0); + std::unique_ptr param0_literal = Literal::CreateR0(42.0); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -184,7 +183,7 @@ TEST_F(MapTest, MapEachElemPlusOneR0) { XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0. ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + std::unique_ptr param0_literal = Literal::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -199,7 +198,7 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 4. ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -213,7 +212,7 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { TEST_F(MapTest, MapEachF32ElementToS32Constant) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -226,7 +225,7 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) { TEST_F(MapTest, MapEachF32ElementToU32Constant) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -240,7 +239,7 @@ TEST_F(MapTest, MapEachElemLongerChainR1) { // Maps (lambda (x) (* (+ x 1) x)) onto an input R1F32 vector. ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); + Literal::CreateR1({2.6f, -5.1f, 0.1f, 0.2f, 999.0f, 255.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -256,7 +255,7 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { // Maps (lambda (x) (+ x 1)) onto an input R1F32 vector of length 0, and then // maps (lambda (x) (* x 2)) on the result. ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + std::unique_ptr param0_literal = Literal::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -273,7 +272,7 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { // maps (lambda (x) (* x 2)) on the result. ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -288,7 +287,7 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { TEST_F(MapTest, MapEachElemPlusOneR2) { // Maps (lambda (x) (+ x 1)) onto an input R2F32 vector. ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2( + std::unique_ptr param0_literal = Literal::CreateR2( {{13.25f, 14.0f}, {-7.1f, -7.2f}, {-8.8f, 8.8f}}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -385,11 +384,11 @@ TEST_F(MapTest, MapBinaryAdder) { // Maps (lambda (x y) (+ x y)) onto two R1F32 vectors. ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + Literal::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -434,12 +433,12 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) { XLA_TEST_F(MapTest, AddR3_3x0x2) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); + Literal::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - LiteralUtil::CreateR3FromArray3D(Array3D(3, 0, 2)); + Literal::CreateR3FromArray3D(Array3D(3, 0, 2)); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -456,15 +455,15 @@ TEST_F(MapTest, MapTernaryAdder) { // Maps (lambda (x y z) (+ x y z)) onto three R1F32 vectors. ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + Literal::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); std::unique_ptr param2_literal = - LiteralUtil::CreateR1({-10.0f, -100.0f, -900.0f, -400.0f}); + Literal::CreateR1({-10.0f, -100.0f, -900.0f, -400.0f}); std::unique_ptr param2_data = client_->TransferToServer(*param2_literal).ConsumeValueOrDie(); @@ -517,11 +516,11 @@ TEST_F(MapTest, MapOperantionWithBuildError) { auto error_add = sub_builder->BuildAndNoteError(); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 3.3f, 4.4f, 5.5f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_literal = - LiteralUtil::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); + Literal::CreateR1({5.1f, 4.4f, -0.1f, -5.5f}); std::unique_ptr param1_data = client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); @@ -554,8 +553,8 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { sub_builder->Pow(x, y); auto power = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); - std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); + std::unique_ptr param0_literal = Literal::CreateR0(2.0f); + std::unique_ptr param1_literal = Literal::CreateR0(5.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -581,8 +580,8 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { sub_builder->Sub(y, x); // note that this is y - x, not x - y auto sub_opposite = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(2.0f); - std::unique_ptr param1_literal = LiteralUtil::CreateR0(5.0f); + std::unique_ptr param0_literal = Literal::CreateR0(2.0f); + std::unique_ptr param1_literal = Literal::CreateR0(5.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); std::unique_ptr param1_data = @@ -606,7 +605,7 @@ TEST_F(MapTestWithFullOpt, MapSquare) { sub_builder->Mul(x, x); auto square = sub_builder->BuildAndNoteError(); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(10.0f); + std::unique_ptr param0_literal = Literal::CreateR0(10.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -622,7 +621,6 @@ TEST_F(MapTestWithFullOpt, MapSquare) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 51261f0ac1c..717e9cd4948 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" @@ -88,8 +87,8 @@ TEST_F(MatOpsSimpleTest, ExpTwoByTwoValues) { builder.Exp(data); std::unique_ptr expected = - LiteralUtil::CreateR2({{2.71828, 1.00000}, // row 0 - {0.36788, 1.64872}}); // row 1 + Literal::CreateR2({{2.71828, 1.00000}, // row 0 + {0.36788, 1.64872}}); // row 1 ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); } @@ -116,8 +115,8 @@ TEST_F(MatOpsSimpleTest, MapTwoByTwo) { auto map = builder.Map({data}, add_half); std::unique_ptr expected = - LiteralUtil::CreateR2({{1.5, 0.5}, // row 0 - {-0.5, 1.0}}); // row 1 + Literal::CreateR2({{1.5, 0.5}, // row 0 + {-0.5, 1.0}}); // row 1 ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-5)); } @@ -134,8 +133,8 @@ TEST_F(MatOpsSimpleTest, MaxTwoByTwoValues) { auto max = builder.Max(lhs, rhs); std::unique_ptr expected = - LiteralUtil::CreateR2({{7.0, 6.0}, // row 0 - {3.0, -4.0}}); // row 1 + Literal::CreateR2({{7.0, 6.0}, // row 0 + {3.0, -4.0}}); // row 1 ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6)); } @@ -181,14 +180,12 @@ TEST_P(MatOpsDotAddTest, Dot_Add_2x2_2x2) { TF_ASSIGN_OR_ASSERT_OK( auto lhs_handle, - client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( - lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( + lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); TF_ASSIGN_OR_ASSERT_OK( auto rhs_handle, - client_->TransferToServer( - *LiteralUtil::CreateR2FromArray2DWithLayout( - rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout( + rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); ComputationBuilder builder(client_, TestName()); auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); @@ -218,7 +215,6 @@ INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest, int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc index 4929e25c580..56c15e5ff72 100644 --- a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc +++ b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -60,7 +59,6 @@ XLA_TEST_F(SliceTest, Slice3D) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index 4922bbf21c4..e270a0477fe 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -183,8 +182,8 @@ TEST_F(PadTest, Pad4DFloatArrayMinorFirstSmall) { const float pad_value = -5.123f; Array4D input_array(1, 1, 2, 3, {1, 2, 3, 4, 5, 6}); - auto input = LiteralUtil::CreateR4FromArray4D(input_array); - input = LiteralUtil::Relayout(*input, layout); + auto input = Literal::CreateR4FromArray4D(input_array); + input = input->Relayout(layout); b.Pad(b.ConstantLiteral(*input), b.ConstantR0(pad_value), padding_config); @@ -228,8 +227,8 @@ XLA_TEST_F(PadTest, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { input_array(0, 0, 0, 0) = 1.0f; input_array(0, 24, 6, 6) = 2.0f; input_array(0, 17, 2, 5) = 3.0f; - auto input = LiteralUtil::CreateR4FromArray4D(input_array); - input = LiteralUtil::Relayout(*input, layout); + auto input = Literal::CreateR4FromArray4D(input_array); + input = input->Relayout(layout); b.Pad(b.ConstantLiteral(*input), b.ConstantR0(pad_value), padding_config); @@ -308,7 +307,7 @@ XLA_TEST_F(PadTest, Large2DPad) { auto ones = MakeUnique>(4, 4); ones->Fill(1.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(*ones); + auto input_literal = Literal::CreateR2FromArray2D(*ones); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -334,7 +333,7 @@ XLA_TEST_F(PadTest, AllTypes2DPad) { auto operand = MakeUnique>(in_rows, in_cols); operand->FillUnique(0.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(*operand); + auto input_literal = Literal::CreateR2FromArray2D(*operand); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -365,7 +364,7 @@ XLA_TEST_F(PadTest, High2DPad) { auto operand = MakeUnique>(in_rows, in_cols); operand->FillUnique(1.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(*operand); + auto input_literal = Literal::CreateR2FromArray2D(*operand); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -397,7 +396,7 @@ XLA_TEST_F(PadTest, NegativePadding2D) { auto operand = MakeUnique>(in_rows, in_cols); operand->FillUnique(1.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(*operand); + auto input_literal = Literal::CreateR2FromArray2D(*operand); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -429,7 +428,7 @@ XLA_TEST_F(PadTest, NegativeAndInteriorPadding2D) { auto operand = MakeUnique>(in_rows, in_cols); operand->FillUnique(1.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(*operand); + auto input_literal = Literal::CreateR2FromArray2D(*operand); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -453,7 +452,7 @@ XLA_TEST_F(PadTest, ReducePad) { auto ones = MakeUnique>(2, 2, 2, 2); ones->Fill(1.0); - auto input_literal = LiteralUtil::CreateR4FromArray4D(*ones); + auto input_literal = Literal::CreateR4FromArray4D(*ones); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -470,7 +469,6 @@ XLA_TEST_F(PadTest, ReducePad) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index 3e1bfcd3090..a7692fceb47 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -44,8 +43,7 @@ class ParamsTest : public ClientLibraryTestBase {}; XLA_TEST_F(ParamsTest, ConstantR0F32Param) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = - LiteralUtil::CreateR0(3.14159f); + std::unique_ptr param0_literal = Literal::CreateR0(3.14159f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -57,7 +55,7 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR1({}); + std::unique_ptr param0_literal = Literal::CreateR1({}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -70,7 +68,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) { XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({3.14f, -100.25f}); + Literal::CreateR1({3.14f, -100.25f}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -83,7 +81,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) { XLA_TEST_F(ParamsTest, ConstantR1U8Param) { ComputationBuilder builder(client_, TestName()); string str("hello world"); - std::unique_ptr param0_literal = LiteralUtil::CreateR1U8(str); + std::unique_ptr param0_literal = Literal::CreateR1U8(str); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -96,7 +94,7 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) { XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR2FromArray2D(Array2D(3, 0)); + Literal::CreateR2FromArray2D(Array2D(3, 0)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -108,7 +106,7 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) { XLA_TEST_F(ParamsTest, ConstantR2F32Param) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR2( + std::unique_ptr param0_literal = Literal::CreateR2( {{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}}); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -124,12 +122,12 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) { XLA_TEST_F(ParamsTest, TwoParameters) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + std::unique_ptr literal0 = Literal::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); auto param0 = builder.Parameter(0, literal0->shape(), "param0"); - std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); + std::unique_ptr literal1 = Literal::CreateR1({10, 20}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); auto param1 = builder.Parameter(1, literal1->shape(), "param1"); @@ -155,7 +153,7 @@ XLA_TEST_F(ParamsTest, TwoParameters) { XLA_TEST_F(ParamsTest, MissingParameter) { // Test that an error is returned when a computation with an incomplete set of // parameters (parameter numbers not contiguous from 0) is executed. - std::unique_ptr literal = LiteralUtil::CreateR0(3.14159f); + std::unique_ptr literal = Literal::CreateR0(3.14159f); std::unique_ptr data = client_->TransferToServer(*literal).ConsumeValueOrDie(); @@ -173,12 +171,12 @@ XLA_TEST_F(ParamsTest, MissingParameter) { XLA_TEST_F(ParamsTest, UnusedParameter) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + std::unique_ptr literal0 = Literal::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); auto param0 = builder.Parameter(0, literal0->shape(), "param0"); - std::unique_ptr literal1 = LiteralUtil::CreateR1({10, 20}); + std::unique_ptr literal1 = Literal::CreateR1({10, 20}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); auto param1 = builder.Parameter(1, literal1->shape(), "param1"); @@ -193,12 +191,11 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) { // unused expression. ComputationBuilder builder(client_, TestName()); - std::unique_ptr literal0 = LiteralUtil::CreateR1({1, 2}); + std::unique_ptr literal0 = Literal::CreateR1({1, 2}); std::unique_ptr param0_data = client_->TransferToServer(*literal0).ConsumeValueOrDie(); - std::unique_ptr literal1 = - LiteralUtil::CreateR1({10, 20, 30}); + std::unique_ptr literal1 = Literal::CreateR1({10, 20, 30}); std::unique_ptr param1_data = client_->TransferToServer(*literal1).ConsumeValueOrDie(); @@ -238,7 +235,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) { std::vector sum_value = {{entry0, entry1}}; sum_value.resize(size); - std::unique_ptr literal = LiteralUtil::CreateR1(sum_value); + std::unique_ptr literal = Literal::CreateR1(sum_value); param_data_owner.push_back( client_->TransferToServer(*literal).ConsumeValueOrDie()); ComputationDataHandle param = @@ -268,9 +265,9 @@ XLA_TEST_F(ParamsTest, std::unique_ptr data = client_ - ->TransferToServer(*LiteralUtil::MakeTuple({ - LiteralUtil::CreateR1({1, 2, 3}).get(), - LiteralUtil::CreateR1({4, 5, 6}).get(), + ->TransferToServer(*Literal::MakeTuple({ + Literal::CreateR1({1, 2, 3}).get(), + Literal::CreateR1({4, 5, 6}).get(), })) .ConsumeValueOrDie(); @@ -282,7 +279,7 @@ XLA_TEST_F(ParamsTest, // Verifies that passing a 2x2 with {0, 1} layout returns the same value back // when (transferred to the server and) passed through a parameter. XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { - std::unique_ptr literal = LiteralUtil::CreateR2({ + std::unique_ptr literal = Literal::CreateR2({ {1, 2}, {3, 4}, }); *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); @@ -296,7 +293,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { // As above, but for {1, 0} layout. XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { - std::unique_ptr literal = LiteralUtil::CreateR2({ + std::unique_ptr literal = Literal::CreateR2({ {1, 3}, {2, 4}, }); *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); @@ -309,7 +306,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { } XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { - std::unique_ptr literal = LiteralUtil::CreateR2({ + std::unique_ptr literal = Literal::CreateR2({ {1, 3}, {2, 4}, }); const Shape original = literal->shape(); @@ -322,7 +319,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { std::reverse(original_layout.begin(), original_layout.end()); *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout(original_layout); - ASSERT_EQ(2, LiteralUtil::Get(*literal, {0, 1})); + ASSERT_EQ(2, literal->Get({0, 1})); } // Use the original shape in building the computation. ComputationBuilder builder(client_, TestName()); @@ -344,7 +341,6 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index b031725d8ab..d865297ae61 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -142,7 +141,6 @@ TEST_F(PredTest, AnyR2False) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 5117478bfd5..ed994fda450 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/primitive_util.h" @@ -58,11 +57,10 @@ void PrngTest::UniformTest(T a, T b, tensorflow::gtl::ArraySlice dims) { SetSeed(42); auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); - LiteralUtil::EachCell(*actual, - [=](tensorflow::gtl::ArraySlice, T value) { - EXPECT_LE(a, value); - EXPECT_LT(value, b); - }); + actual->EachCell([=](tensorflow::gtl::ArraySlice, T value) { + EXPECT_LE(a, value); + EXPECT_LT(value, b); + }); } void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice dims) { @@ -71,7 +69,7 @@ void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice dims) { builder.RngBernoulli(builder.ConstantR0(p), shape); TF_ASSIGN_OR_ASSERT_OK(auto computation, builder.Build()); - ExecutionOptions execution_options; + ExecutionOptions execution_options = execution_options_; execution_options.set_seed(42); TF_ASSIGN_OR_ASSERT_OK( auto actual, @@ -79,8 +77,8 @@ void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice dims) { &execution_options)); EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); int32 sum = 0; - LiteralUtil::EachCell( - *actual, [&sum](tensorflow::gtl::ArraySlice, uint32 value) { + actual->EachCell( + [&sum](tensorflow::gtl::ArraySlice, uint32 value) { EXPECT_TRUE(value == 0 || value == 1); sum += value; }); @@ -124,10 +122,8 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count) { SetSeed(42); auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); std::vector counts(range_size, 0); - LiteralUtil::EachCell( - *actual, [&counts](tensorflow::gtl::ArraySlice, int32 value) { - ++counts[value]; - }); + actual->EachCell([&counts](tensorflow::gtl::ArraySlice, + int32 value) { ++counts[value]; }); int64 sum = 0; for (int32 i = 0; i < range_size; ++i) { sum += Square(static_cast(counts[i] - expected_count)); @@ -170,7 +166,7 @@ XLA_TEST_F(PrngTest, MapUsingRng) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR1({2.2f, 5.3f, 4.4f, 5.5f}); + Literal::CreateR1({2.2f, 5.3f, 4.4f, 5.5f}); TF_ASSIGN_OR_ASSERT_OK(std::unique_ptr param0_data, client_->TransferToServer(*param0_literal)); @@ -180,7 +176,7 @@ XLA_TEST_F(PrngTest, MapUsingRng) { TF_ASSIGN_OR_ASSERT_OK(auto computation, builder.Build()); - ExecutionOptions execution_options; + ExecutionOptions execution_options = execution_options_; execution_options.set_seed(125); TF_ASSIGN_OR_ASSERT_OK( auto actual, @@ -209,10 +205,10 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { return builder.Build(); }; - ExecutionOptions execution_options1; + ExecutionOptions execution_options1 = execution_options_; execution_options1.set_seed(42); - ExecutionOptions execution_options2; + ExecutionOptions execution_options2 = execution_options_; execution_options2.set_seed(65); std::unique_ptr result1; @@ -247,9 +243,11 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { client_->ExecuteAndTransfer(computation, /*arguments=*/{}, &execution_options2)); TF_ASSIGN_OR_ASSERT_OK( - result5, client_->ExecuteAndTransfer(computation, /*arguments=*/{})); + result5, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + &execution_options_)); TF_ASSIGN_OR_ASSERT_OK( - result6, client_->ExecuteAndTransfer(computation, /*arguments=*/{})); + result6, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, + &execution_options_)); } LiteralTestUtil::ExpectEqual(*result1, *result2); @@ -273,13 +271,23 @@ XLA_TEST_F(PrngTest, TenValuesN01) { // TODO(b/25995601): Test that resultant values are reasonable } +XLA_TEST_F(PrngTest, RngUniformCrash) { + ComputationBuilder builder(client_, TestName()); + + // This used to crash XLA during LLVM IR generation for CPUs. + auto rng_uniform = builder.RngUniform(builder.ConstantR0(0), + builder.ConstantR0(1000 * 1000), + ShapeUtil::MakeShape(S32, {})); + SetSeed(0); + ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); +} + } // namespace } // namespace xla int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc index 4a02567a1a2..0078733e197 100644 --- a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc +++ b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -46,7 +45,6 @@ TEST_F(QueryInferredShapeTest, OnePlusOneShape) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index ff24177520e..ac65a47afa5 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -40,7 +40,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -64,12 +63,12 @@ class ReduceTest : public ClientLibraryTestBase { ReduceTest() { // Implementation note: laid out z >> y >> x by default. // clang-format off - literal_2d_ = LiteralUtil::CreateR2({ + literal_2d_ = Literal::CreateR2({ // x0 x1 x2 { 1.f, 2.f, 3.f}, // y0 { 4.f, 5.f, 6.f}, // y1 }); - literal_3d_ = LiteralUtil::CreateR3Projected({ + literal_3d_ = Literal::CreateR3Projected({ // x0 x1 x2 { 1.f, 2.f, 3.f}, // y0 { 4.f, 5.f, 6.f}, // y1 @@ -98,7 +97,7 @@ class ReduceTest : public ClientLibraryTestBase { } } std::unique_ptr input_literal = - LiteralUtil::CreateR1(AsSlice(input_data)); + Literal::CreateR1(AsSlice(input_data)); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -130,7 +129,7 @@ class ReduceTest : public ClientLibraryTestBase { builder.Reduce(pred_values, init_value, reduce, /*dimensions_to_reduce=*/{0}); - std::unique_ptr input_literal = LiteralUtil::CreateR1(input_data); + std::unique_ptr input_literal = Literal::CreateR1(input_data); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -157,9 +156,9 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = LiteralUtil::Relayout( - *input_literal, LayoutUtil::MakeLayout({minor, major})); + Literal::CreateR2FromArray2D(input_data); + input_literal = + input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -185,9 +184,9 @@ class ReduceTest : public ClientLibraryTestBase { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = LiteralUtil::Relayout( - *input_literal, LayoutUtil::MakeLayout({minor, major})); + Literal::CreateR2FromArray2D(input_data); + input_literal = + input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -306,9 +305,8 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = - LiteralUtil::Relayout(*input_literal, LayoutUtil::MakeLayout({0, 1})); + Literal::CreateR2FromArray2D(input_data); + input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -339,9 +337,8 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { Array2D input_data(rows, cols); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - LiteralUtil::CreateR2FromArray2D(input_data); - input_literal = - LiteralUtil::Relayout(*input_literal, LayoutUtil::MakeLayout({0, 1})); + Literal::CreateR2FromArray2D(input_data); + input_literal = input_literal->Relayout(LayoutUtil::MakeLayout({0, 1})); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -372,7 +369,7 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { Array3D input_data(rows, 2, cols / 2); input_data.FillRandom(3.14f, 0.04); std::unique_ptr input_literal = - LiteralUtil::CreateR3FromArray3D(input_data); + Literal::CreateR3FromArray3D(input_data); std::unique_ptr input_global_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -435,7 +432,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) { auto max = CreateScalarMaxComputation(F32, &builder); Array2D input(300, 250); input.FillRandom(214.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(input); + auto input_literal = Literal::CreateR2FromArray2D(input); builder.Reduce(builder.ConstantLiteral(*input_literal), builder.ConstantR0(FLT_MIN), max, {0, 1}); auto input_max = FLT_MIN; @@ -450,7 +447,7 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) { auto min = CreateScalarMinComputation(F32, &builder); Array2D input(150, 130); input.FillRandom(214.0f); - auto input_literal = LiteralUtil::CreateR2FromArray2D(input); + auto input_literal = Literal::CreateR2FromArray2D(input); builder.Reduce(builder.ConstantLiteral(*input_literal), builder.ConstantR0(FLT_MAX), min, {0, 1}); @@ -580,9 +577,9 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) { Array3D input_array(bounds[0], bounds[1], bounds[2]); input_array.FillRandom(3.14f, 0.05); - auto input_literal = LiteralUtil::CreateR3FromArray3D(input_array); - input_literal = LiteralUtil::Relayout( - *input_literal, LayoutUtil::MakeLayout(GetParam().layout)); + auto input_literal = Literal::CreateR3FromArray3D(input_array); + input_literal = + input_literal->Relayout(LayoutUtil::MakeLayout(GetParam().layout)); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -630,7 +627,6 @@ INSTANTIATE_TEST_CASE_P( int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index ec7b47bc283..6b4bceb4377 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -58,7 +57,7 @@ class ReduceWindowTest : public ClientLibraryTestBase { tensorflow::gtl::ArraySlice window_strides, Padding padding) { builder_.ReduceWindow( - input, builder_.ConstantLiteral(LiteralUtil::MinValue(F32)), + input, builder_.ConstantLiteral(Literal::MinValue(F32)), CreateScalarMax(), window_dimensions, window_strides, padding); } @@ -67,7 +66,7 @@ class ReduceWindowTest : public ClientLibraryTestBase { tensorflow::gtl::ArraySlice window_strides, Padding padding) { builder_.ReduceWindow(input, - builder_.ConstantLiteral(LiteralUtil::MaxValue(F32)), + builder_.ConstantLiteral(Literal::MaxValue(F32)), CreateScalarMinComputation(F32, &builder_), window_dimensions, window_strides, padding); } @@ -476,7 +475,6 @@ XLA_TEST_F(ReduceWindowTest, NonstandardReduceFunction) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index 7c6700feef8..cb7f54ea01c 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" @@ -61,7 +60,8 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) { // Run it. std::unique_ptr literal = - client_->ExecuteAndTransfer(replayed, /*arguments=*/{}) + client_ + ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_) .ConsumeValueOrDie(); // Expect 4. @@ -92,15 +92,16 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { // Run it. std::unique_ptr x_data = - client_->TransferToServer(*LiteralUtil::CreateR0(2)) + client_->TransferToServer(*Literal::CreateR0(2)) .ConsumeValueOrDie(); std::unique_ptr y_data = - client_->TransferToServer(*LiteralUtil::CreateR0(3)) + client_->TransferToServer(*Literal::CreateR0(3)) .ConsumeValueOrDie(); std::unique_ptr literal = client_ ->ExecuteAndTransfer(replayed, - /*arguments=*/{x_data.get(), y_data.get()}) + /*arguments=*/{x_data.get(), y_data.get()}, + &execution_options_) .ConsumeValueOrDie(); // Expect 5. @@ -141,7 +142,8 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { // Run it. std::unique_ptr literal = - client_->ExecuteAndTransfer(replayed, /*arguments=*/{}) + client_ + ->ExecuteAndTransfer(replayed, /*arguments=*/{}, &execution_options_) .ConsumeValueOrDie(); // Expect result. @@ -154,7 +156,6 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc index c9817bc23d8..3051562455f 100644 --- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -63,7 +62,6 @@ TEST_F(ReshapeMotionTest, ElementwiseOfReshapesWithNonSameInputShapes) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index ae7d07727b1..6748d196c1a 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -71,7 +70,7 @@ XLA_TEST_F(ReshapeTest, SingleElementArrayToScalar) { XLA_TEST_F(ReshapeTest, ScalarToSingleElementArray) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr param0_literal = LiteralUtil::CreateR0(1.0f); + std::unique_ptr param0_literal = Literal::CreateR0(1.0f); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -99,7 +98,7 @@ XLA_TEST_F(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = - LiteralUtil::CreateR2FromArray2D(Array2D(0, 3)); + Literal::CreateR2FromArray2D(Array2D(0, 3)); std::unique_ptr param0_data = client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); @@ -403,7 +402,7 @@ XLA_TEST_F(ReshapeTest, FullyConnectedCollapseDesugared) { XLA_TEST_F(ReshapeTest, ToScalar) { for (int rank = 0; rank < 8; ++rank) { ComputationBuilder b(client_, TestName()); - auto input = LiteralUtil::CreateR1({83.0f}); + auto input = Literal::CreateR1({83.0f}); std::vector ones(rank, 1); // this is {1, ..., 1}. std::vector dimensions(rank); std::iota(dimensions.begin(), dimensions.end(), 0); @@ -435,7 +434,7 @@ XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); // clang-format off - auto literal = LiteralUtil::CreateR4FromArray4DWithLayout(Array4D{ + auto literal = Literal::CreateR4FromArray4DWithLayout(Array4D{ { { {0, 1}, @@ -467,7 +466,7 @@ XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { }); Computation computation = builder.Build().ConsumeValueOrDie(); - ExecutionOptions execution_options; + ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(F32, {2, 8}, {1, 0}); std::unique_ptr actual = @@ -475,12 +474,12 @@ XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) .ConsumeValueOrDie(); std::unique_ptr expected = - LiteralUtil::CreateR2FromArray2D(expected_array); + Literal::CreateR2FromArray2D(expected_array); LiteralTestUtil::ExpectEqual(*expected, *actual); } XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { - std::unique_ptr input = LiteralUtil::CreateR2({ + std::unique_ptr input = Literal::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, @@ -508,7 +507,7 @@ XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { // Tests R2->R4 reshape with the reshape dimensions {1, 0}. XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { - std::unique_ptr input = LiteralUtil::CreateR2({ + std::unique_ptr input = Literal::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, @@ -542,7 +541,7 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -565,7 +564,7 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -589,7 +588,7 @@ XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -603,7 +602,7 @@ XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { expected_array(indices[0], indices[2] * 30 + indices[1] * 3 + indices[3]) = *cell; }); - auto expected = LiteralUtil::CreateR2FromArray2D(expected_array); + auto expected = Literal::CreateR2FromArray2D(expected_array); ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}); } @@ -615,7 +614,7 @@ XLA_TEST_F(ReshapeTest, NoopReshape) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -626,7 +625,7 @@ XLA_TEST_F(ReshapeTest, NoopReshape) { /*new_sizes=*/{7, 2, 3, 5}); Computation computation = builder.Build().ConsumeValueOrDie(); - ExecutionOptions execution_options; + ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = ShapeUtil::MakeShapeWithLayout(F32, {7, 2, 3, 5}, {2, 3, 0, 1}); std::unique_ptr output_literal = @@ -642,7 +641,7 @@ XLA_TEST_F(ReshapeTest, NoopReshape) { } XLA_TEST_F(ReshapeTest, R4ToR4Reshape_Trivial) { - auto literal_1x2x3x4 = LiteralUtil::CreateR4( + auto literal_1x2x3x4 = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); @@ -655,7 +654,7 @@ XLA_TEST_F(ReshapeTest, R4ToR4Reshape_Trivial) { } XLA_TEST_F(ReshapeTest, R4ToR4Reshape) { - auto literal_1x2x3x4 = LiteralUtil::CreateR4( + auto literal_1x2x3x4 = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); @@ -665,7 +664,7 @@ XLA_TEST_F(ReshapeTest, R4ToR4Reshape) { /*new_sizes=*/{2, 4, 3, 1}); // clang-format off - auto expected_2x4x3x1 = LiteralUtil::CreateR4( + auto expected_2x4x3x1 = Literal::CreateR4( {{{{1}, {5}, {9}}, {{2}, {6}, {10}}, {{3}, {7}, {11}}, @@ -689,7 +688,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -698,9 +697,9 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { auto a = builder.Parameter(0, input_literal->shape(), "a"); builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = LiteralUtil::Relayout( - *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal), - LayoutUtil::MakeLayout({3, 2, 1, 0})); + std::unique_ptr expected = + LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. @@ -718,7 +717,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -727,9 +726,9 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { auto a = builder.Parameter(0, input_literal->shape(), "a"); builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = LiteralUtil::Relayout( - *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal), - LayoutUtil::MakeLayout({3, 2, 1, 0})); + std::unique_ptr expected = + LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. @@ -747,7 +746,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -756,9 +755,9 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { auto a = builder.Parameter(0, input_literal->shape(), "a"); builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = LiteralUtil::Relayout( - *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal), - LayoutUtil::MakeLayout({3, 2, 1, 0})); + std::unique_ptr expected = + LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. @@ -777,7 +776,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -786,9 +785,9 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { auto a = builder.Parameter(0, input_literal->shape(), "a"); builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = LiteralUtil::Relayout( - *LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal), - LayoutUtil::MakeLayout({3, 2, 1, 0})); + std::unique_ptr expected = + LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) + ->Relayout(LayoutUtil::MakeLayout({3, 2, 1, 0})); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. @@ -806,7 +805,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { [&rng, &distribution](tensorflow::gtl::ArraySlice /* indices */, float* cell) { *cell = distribution(rng); }); std::unique_ptr input_literal = - LiteralUtil::CreateR4FromArray4DWithLayout( + Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({0, 1, 2, 3})); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -815,9 +814,9 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { auto a = builder.Parameter(0, input_literal->shape(), "a"); builder.Reshape(a, /*dimensions=*/{1, 0, 2, 3}, /*new_sizes=*/new_bounds); - std::unique_ptr expected = LiteralUtil::Relayout( - *LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal), - input_literal->shape().layout()); + std::unique_ptr expected = + LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal) + ->Relayout(input_literal->shape().layout()); // Specify the requested output shape explicitly to ensure that this reshape // actually corresponds to a two minor transpose. @@ -831,7 +830,6 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 5ca9702380f..2f72fc0729a 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -159,7 +158,6 @@ TEST_F(ReverseTest, Reverse4DFloatArrayOnDim01) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc index 05ce22fc359..5b4c05c6733 100644 --- a/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_packed_literal_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/packed_literal_reader.h" @@ -66,8 +65,8 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR1F32Length2) { reader.Read(ShapeUtil::MakeShape(F32, {2})).ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0, LiteralUtil::Get(*actual, {0})); - EXPECT_EQ(24.0, LiteralUtil::Get(*actual, {1})); + EXPECT_EQ(42.0, actual->Get({0})); + EXPECT_EQ(24.0, actual->Get({1})); } TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { @@ -96,10 +95,10 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim0Minor) { .ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0f, LiteralUtil::Get(*actual, {0, 0})); - EXPECT_EQ(24.0f, LiteralUtil::Get(*actual, {0, 1})); - EXPECT_EQ(64.0f, LiteralUtil::Get(*actual, {1, 0})); - EXPECT_EQ(46.0f, LiteralUtil::Get(*actual, {1, 1})); + EXPECT_EQ(42.0f, actual->Get({0, 0})); + EXPECT_EQ(24.0f, actual->Get({0, 1})); + EXPECT_EQ(64.0f, actual->Get({1, 0})); + EXPECT_EQ(46.0f, actual->Get({1, 1})); std::unique_ptr round_tripped = RoundTripToServer(*actual); LiteralTestUtil::ExpectEqual(*round_tripped, *actual); @@ -131,10 +130,10 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { .ConsumeValueOrDie(); EXPECT_TRUE(reader.IsExhausted()); - EXPECT_EQ(42.0f, LiteralUtil::Get(*actual, {0, 0})); - EXPECT_EQ(24.0f, LiteralUtil::Get(*actual, {1, 0})); - EXPECT_EQ(64.0f, LiteralUtil::Get(*actual, {0, 1})); - EXPECT_EQ(46.0f, LiteralUtil::Get(*actual, {1, 1})); + EXPECT_EQ(42.0f, actual->Get({0, 0})); + EXPECT_EQ(24.0f, actual->Get({1, 0})); + EXPECT_EQ(64.0f, actual->Get({0, 1})); + EXPECT_EQ(46.0f, actual->Get({1, 1})); std::unique_ptr round_tripped = RoundTripToServer(*actual); LiteralTestUtil::ExpectEqual(*round_tripped, *actual); @@ -146,7 +145,6 @@ TEST_F(RoundTripPackedLiteralTest, RoundTripsR2F32Size2x2Dim1Minor) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc index f0760241cdb..e6a6b7b37a4 100644 --- a/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc +++ b/tensorflow/compiler/xla/tests/round_trip_transfer_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array4d.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -48,62 +47,61 @@ class RoundTripTransferTest : public ClientLibraryTestBase { }; TEST_F(RoundTripTransferTest, R0S32) { - RoundTripTest(*LiteralUtil::CreateR0(42)); + RoundTripTest(*Literal::CreateR0(42)); } TEST_F(RoundTripTransferTest, R0F32) { - RoundTripTest(*LiteralUtil::CreateR0(42.0)); + RoundTripTest(*Literal::CreateR0(42.0)); } TEST_F(RoundTripTransferTest, R1F32_Len0) { - RoundTripTest(*LiteralUtil::CreateR1({})); + RoundTripTest(*Literal::CreateR1({})); } TEST_F(RoundTripTransferTest, R1F32_Len2) { - RoundTripTest(*LiteralUtil::CreateR1({42.0, 64.0})); + RoundTripTest(*Literal::CreateR1({42.0, 64.0})); } TEST_F(RoundTripTransferTest, R1F32_Len256) { std::vector values(256); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(*Literal::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1024) { std::vector values(1024); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(*Literal::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len1025) { std::vector values(1025); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(*Literal::CreateR1(values)); } TEST_F(RoundTripTransferTest, R1F32_Len4096) { std::vector values(4096); std::iota(values.begin(), values.end(), 1.0); - RoundTripTest(*LiteralUtil::CreateR1(values)); + RoundTripTest(*Literal::CreateR1(values)); } TEST_F(RoundTripTransferTest, R2F32_Len10x0) { - RoundTripTest( - *LiteralUtil::CreateR2FromArray2D(Array2D(10, 0))); + RoundTripTest(*Literal::CreateR2FromArray2D(Array2D(10, 0))); } TEST_F(RoundTripTransferTest, R2F32_Len2x2) { - RoundTripTest(*LiteralUtil::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); + RoundTripTest(*Literal::CreateR2({{42.0, 64.0}, {77.0, 88.0}})); } TEST_F(RoundTripTransferTest, R3F32) { RoundTripTest( - *LiteralUtil::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, - {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); + *Literal::CreateR3({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}}, + {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}})); } TEST_F(RoundTripTransferTest, R4F32) { - RoundTripTest(*LiteralUtil::CreateR4({{ + RoundTripTest(*Literal::CreateR4({{ {{10, 11, 12, 13}, {14, 15, 16, 17}}, {{18, 19, 20, 21}, {22, 23, 24, 25}}, {{26, 27, 28, 29}, {30, 31, 32, 33}}, @@ -111,36 +109,33 @@ TEST_F(RoundTripTransferTest, R4F32) { } TEST_F(RoundTripTransferTest, EmptyTuple) { - RoundTripTest(*LiteralUtil::MakeTuple({})); + RoundTripTest(*Literal::MakeTuple({})); } TEST_F(RoundTripTransferTest, TupleOfR1F32) { - RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({1, 2}).get(), - LiteralUtil::CreateR1({3, 4}).get()})); + RoundTripTest(*Literal::MakeTuple({Literal::CreateR1({1, 2}).get(), + Literal::CreateR1({3, 4}).get()})); } TEST_F(RoundTripTransferTest, TupleOfR1F32_Len0_Len2) { - RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR1({}).get(), - LiteralUtil::CreateR1({3, 4}).get()})); + RoundTripTest(*Literal::MakeTuple({Literal::CreateR1({}).get(), + Literal::CreateR1({3, 4}).get()})); } TEST_F(RoundTripTransferTest, TupleOfR0F32AndR1S32) { - RoundTripTest( - *LiteralUtil::MakeTuple({LiteralUtil::CreateR0(1.0).get(), - LiteralUtil::CreateR1({2, 3}).get()})); + RoundTripTest(*Literal::MakeTuple({Literal::CreateR0(1.0).get(), + Literal::CreateR1({2, 3}).get()})); } // Below two tests are added to identify the cost of large data transfers. TEST_F(RoundTripTransferTest, R2F32_Large) { - RoundTripTest(*LiteralUtil::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); + RoundTripTest(*Literal::CreateR2F32Linspace(-1.0f, 1.0f, 512, 512)); } TEST_F(RoundTripTransferTest, R4F32_Large) { Array4D array4d(2, 2, 256, 256); array4d.FillWithMultiples(1.0f); - RoundTripTest(*LiteralUtil::CreateR4FromArray4D(array4d)); + RoundTripTest(*Literal::CreateR4FromArray4D(array4d)); } } // namespace @@ -149,7 +144,6 @@ TEST_F(RoundTripTransferTest, R4F32_Large) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 47a39ffbbc4..07bd00f0154 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -212,9 +211,9 @@ TEST_F(ScalarComputationsTest, MulThreeScalarsS32) { TEST_F(ScalarComputationsTest, MulThreeScalarsF32Params) { ComputationBuilder builder(client_, TestName()); - std::unique_ptr a_literal = LiteralUtil::CreateR0(2.1f); - std::unique_ptr b_literal = LiteralUtil::CreateR0(5.5f); - std::unique_ptr c_literal = LiteralUtil::CreateR0(0.5f); + std::unique_ptr a_literal = Literal::CreateR0(2.1f); + std::unique_ptr b_literal = Literal::CreateR0(5.5f); + std::unique_ptr c_literal = Literal::CreateR0(0.5f); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); @@ -361,8 +360,8 @@ TEST_F(ScalarComputationsTest, DivU32s) { for (uint32 divisor : vals) { if (divisor != 0) { for (uint32 dividend : vals) { - auto dividend_literal = LiteralUtil::CreateR0(dividend); - auto divisor_literal = LiteralUtil::CreateR0(divisor); + auto dividend_literal = Literal::CreateR0(dividend); + auto divisor_literal = Literal::CreateR0(divisor); TF_ASSIGN_OR_ASSERT_OK(auto dividend_data, client_->TransferToServer(*dividend_literal)); TF_ASSIGN_OR_ASSERT_OK(auto divisor_data, @@ -373,8 +372,7 @@ TEST_F(ScalarComputationsTest, DivU32s) { {dividend_data.get(), divisor_data.get()}, &execution_options_) .ConsumeValueOrDie(); - auto expected_literal = - LiteralUtil::CreateR0(dividend / divisor); + auto expected_literal = Literal::CreateR0(dividend / divisor); LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); } } @@ -403,8 +401,8 @@ TEST_F(ScalarComputationsTest, RemU32s) { for (uint32 divisor : vals) { if (divisor != 0) { for (uint32 dividend : vals) { - auto dividend_literal = LiteralUtil::CreateR0(dividend); - auto divisor_literal = LiteralUtil::CreateR0(divisor); + auto dividend_literal = Literal::CreateR0(dividend); + auto divisor_literal = Literal::CreateR0(divisor); TF_ASSIGN_OR_ASSERT_OK(auto dividend_data, client_->TransferToServer(*dividend_literal)); TF_ASSIGN_OR_ASSERT_OK(auto divisor_data, @@ -415,8 +413,7 @@ TEST_F(ScalarComputationsTest, RemU32s) { {dividend_data.get(), divisor_data.get()}, &execution_options_) .ConsumeValueOrDie(); - auto expected_literal = - LiteralUtil::CreateR0(dividend % divisor); + auto expected_literal = Literal::CreateR0(dividend % divisor); LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); } } @@ -428,7 +425,7 @@ TEST_F(ScalarComputationsTest, RemainderTwoScalarsNonConstDividendS32) { auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x"); builder.Rem(x, builder.ConstantR0(80000)); - std::unique_ptr literal = LiteralUtil::CreateR0(87919); + std::unique_ptr literal = Literal::CreateR0(87919); TF_ASSIGN_OR_ASSERT_OK(auto input_data, client_->TransferToServer(*literal)); ComputeAndCompareR0(&builder, 7919, {input_data.get()}); } @@ -764,7 +761,7 @@ TEST_F(ScalarComputationsTest, ComplicatedArithmeticExpressionS32) { TEST_F(ScalarComputationsTest, SqrtF320) { ComputationBuilder builder(client_, TestName()); - Literal zero_literal = LiteralUtil::Zero(PrimitiveType::F32); + Literal zero_literal = Literal::Zero(PrimitiveType::F32); std::unique_ptr zero_data = client_->TransferToServer(zero_literal).ConsumeValueOrDie(); @@ -782,7 +779,6 @@ TEST_F(ScalarComputationsTest, SqrtF320) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index 36110da2478..de89588042e 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" @@ -381,7 +380,6 @@ XLA_TEST_F(SelectAndScatterTest, R1F32OverlappingWindowMinScatter) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/select_test.cc b/tensorflow/compiler/xla/tests/select_test.cc index 5eb4fee8ed2..6b48116b6e1 100644 --- a/tensorflow/compiler/xla/tests/select_test.cc +++ b/tensorflow/compiler/xla/tests/select_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -262,7 +261,6 @@ TEST_F(SelectTest, SelectR1F32WithScalarPredicateFalse) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/set_return_value_test.cc b/tensorflow/compiler/xla/tests/set_return_value_test.cc index 25bb915be56..38fc27f200c 100644 --- a/tensorflow/compiler/xla/tests/set_return_value_test.cc +++ b/tensorflow/compiler/xla/tests/set_return_value_test.cc @@ -17,7 +17,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -102,7 +101,6 @@ TEST_F(SetReturnValueTest, SetValueMultipleTimesAndModify) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index 70345c300cc..5e7d4756624 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -269,7 +268,6 @@ INSTANTIATE_TEST_CASE_P( int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index 6a23df4d3c3..f3a522b05eb 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -61,7 +61,7 @@ std::unique_ptr CreateR2LiteralWithLayout( auto literal = MakeUnique(); const int64 d0 = values.size(); const int64 d1 = values.begin()->size(); - LiteralUtil::PopulateWithValue(0, {d0, d1}, literal.get()); + literal.get()->PopulateWithValue(0, {d0, d1}); *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape())); @@ -70,7 +70,7 @@ std::unique_ptr CreateR2LiteralWithLayout( for (auto inner_list : values) { int64 dim1 = 0; for (auto value : inner_list) { - LiteralUtil::Set(literal.get(), {dim0, dim1}, value); + literal.get()->Set({dim0, dim1}, value); ++dim1; } ++dim0; @@ -88,7 +88,7 @@ std::unique_ptr CreateR3LiteralWithLayout( const int64 d0 = values.size(); const int64 d1 = values.begin()->size(); const int64 d2 = values.begin()->begin()->size(); - LiteralUtil::PopulateWithValue(0, {d0, d1, d2}, literal.get()); + literal.get()->PopulateWithValue(0, {d0, d1, d2}); *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); TF_CHECK_OK(ShapeUtil::ValidateShape(literal->shape())); @@ -99,7 +99,7 @@ std::unique_ptr CreateR3LiteralWithLayout( for (auto inner_inner_list : inner_list) { int64 dim2 = 0; for (auto value : inner_inner_list) { - LiteralUtil::Set(literal.get(), {dim0, dim1, dim2}, value); + literal.get()->Set({dim0, dim1, dim2}, value); ++dim2; } ++dim1; diff --git a/tensorflow/compiler/xla/tests/transpose_test.cc b/tensorflow/compiler/xla/tests/transpose_test.cc index e4951c42010..07c0f073e86 100644 --- a/tensorflow/compiler/xla/tests/transpose_test.cc +++ b/tensorflow/compiler/xla/tests/transpose_test.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -189,7 +188,6 @@ TEST_F(TransposeTest, TransposeConstant021_MultipleTilesPerLayer) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 6309e712973..be3cdbca090 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -20,7 +20,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -54,10 +53,10 @@ XLA_TEST_F(TupleTest, TupleCreate) { builder.ConstantR1(constant_vector), builder.ConstantR2(constant_matrix)}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR0(constant_scalar).get(), - LiteralUtil::CreateR1(constant_vector).get(), - LiteralUtil::CreateR2(constant_matrix).get()}); + auto expected = + Literal::MakeTuple({Literal::CreateR0(constant_scalar).get(), + Literal::CreateR1(constant_vector).get(), + Literal::CreateR2(constant_matrix).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -68,9 +67,8 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { auto result = builder.Tuple( {builder.ConstantR0(7.0), builder.ConstantR1({})}); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR0(7.0).get(), - LiteralUtil::CreateR1({}).get()}); + auto expected = Literal::MakeTuple({Literal::CreateR0(7.0).get(), + Literal::CreateR1({}).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -78,7 +76,7 @@ XLA_TEST_F(TupleTest, TupleCreateWithZeroElementEntry) { XLA_TEST_F(TupleTest, EmptyTupleCreate) { ComputationBuilder builder(client_, TestName()); auto result = builder.Tuple({}); - auto expected = LiteralUtil::MakeTuple({}); + auto expected = Literal::MakeTuple({}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -147,12 +145,37 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { builder.ConstantR2(constant_matrix)}); auto new_tuple = builder.Tuple({builder.GetTupleElement(tuple_data, 1), builder.GetTupleElement(tuple_data, 0)}); - auto expected = LiteralUtil::MakeTuple( - {LiteralUtil::CreateR2(constant_matrix).get(), - LiteralUtil::CreateR1(constant_vector).get()}); + auto expected = + Literal::MakeTuple({Literal::CreateR2(constant_matrix).get(), + Literal::CreateR1(constant_vector).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } +XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { + ComputationBuilder b(client_, TestName()); + ComputationDataHandle v1, v2; + + for (bool direction : {false, true}) { + std::unique_ptr v1_data = + CreateR0Parameter(0.0f, /*parameter_number=*/0, /*name=*/"v1", + /*builder=*/&b, /*data_handle=*/&v1); + std::unique_ptr v2_data = + CreateR0Parameter(1.0f, /*parameter_number=*/1, /*name=*/"v2", + /*builder=*/&b, /*data_handle=*/&v2); + auto v1_gt = b.Gt(v1, v2); // false + auto v2_gt = b.Gt(v2, v1); // true + auto v1_v2 = b.Tuple({v1_gt, v2_gt}); // {false, true} + auto v2_v1 = b.Tuple({v2_gt, v1_gt}); // {true, false} + auto select = b.Select(direction ? v1_gt : v2_gt, v1_v2, v2_v1); + auto expected = + Literal::MakeTuple({Literal::CreateR0(direction).get(), + Literal::CreateR0(!direction).get()}); + + ComputeAndCompareTuple(&b, *expected, {v1_data.get(), v2_data.get()}, + error_spec_); + } +} + // Builds two new tuples from an existing tuple (by means of GetTupleElement), // then adds up the components of the new tuples. XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) { @@ -213,9 +236,8 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) { auto select = builder.Select(builder.ConstantR0(false), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), - LiteralUtil::CreateR1(vec1).get()}); + auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), + Literal::CreateR1(vec1).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -259,9 +281,8 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) { auto select = builder.Select(builder.ConstantR0(true), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec1).get(), - LiteralUtil::CreateR1(vec2).get()}); + auto expected = Literal::MakeTuple({Literal::CreateR1(vec1).get(), + Literal::CreateR1(vec2).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -340,9 +361,8 @@ XLA_TEST_F(TupleTest, auto select = builder.Select(builder.ConstantR0(false), tuple12, tuple21); - auto expected = - LiteralUtil::MakeTuple({LiteralUtil::CreateR1(vec2).get(), - LiteralUtil::CreateR1(vec1).get()}); + auto expected = Literal::MakeTuple({Literal::CreateR1(vec2).get(), + Literal::CreateR1(vec1).get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -353,13 +373,13 @@ XLA_TEST_F(TupleTest, NestedTuples) { auto outer_tuple = builder.Tuple({inner_tuple, builder.ConstantR1({22.0, 44.0})}); - auto expected_v1 = LiteralUtil::CreateR1({1.0, 2.0}); - auto expected_s = LiteralUtil::CreateR0(42.0); + auto expected_v1 = Literal::CreateR1({1.0, 2.0}); + auto expected_s = Literal::CreateR0(42.0); auto expected_inner_tuple = - LiteralUtil::MakeTuple({expected_v1.get(), expected_s.get()}); - auto expected_v2 = LiteralUtil::CreateR1({22.0, 44.0}); + Literal::MakeTuple({expected_v1.get(), expected_s.get()}); + auto expected_v2 = Literal::CreateR1({22.0, 44.0}); auto expected = - LiteralUtil::MakeTuple({expected_inner_tuple.get(), expected_v2.get()}); + Literal::MakeTuple({expected_inner_tuple.get(), expected_v2.get()}); ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } @@ -379,14 +399,14 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { std::unique_ptr data = client_ - ->TransferToServer(*LiteralUtil::MakeTuple({ - LiteralUtil::MakeTuple( + ->TransferToServer(*Literal::MakeTuple({ + Literal::MakeTuple( { - LiteralUtil::CreateR1({1.0, 2.0, 3.0}).get(), - LiteralUtil::CreateR1({4.0, 5.0, 6.0}).get(), + Literal::CreateR1({1.0, 2.0, 3.0}).get(), + Literal::CreateR1({4.0, 5.0, 6.0}).get(), }) .get(), - LiteralUtil::CreateR1({7.0, 8.0, 9.0}).get(), + Literal::CreateR1({7.0, 8.0, 9.0}).get(), })) .ConsumeValueOrDie(); @@ -401,7 +421,6 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index 61110d5b4cd..d35d9ecdeb6 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -165,7 +164,6 @@ TEST_F(UnaryOpTest, SignAbsTestR2) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc index 26a08953b15..079dbb06117 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" @@ -221,7 +220,6 @@ TEST_F(VecOpsReduceTest, AddReduceR3F32AllDims) { int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index efde45375fd..b2e0c796bde 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/global_data.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -441,7 +440,6 @@ XLA_TEST_F(VecOpsSimpleTest, VectorPredicateNotEqual) { int main(int argc, char** argv) { std::vector flag_list; - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 5f917797744..afa7d871c0e 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -22,7 +22,6 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/local_client.h" -#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" @@ -82,6 +81,70 @@ TEST_F(WhileTest, WhileWithScalarResult) { ComputeAndCompareR0(&builder, 5, {}); } +TEST_F(WhileTest, WhileWithScalarResultNonConstInit) { + auto result_shape = ShapeUtil::MakeShape(S32, {}); + auto orig_shape = ShapeUtil::MakeShape(S32, {2}); + + // Create a computation for the condition: repeat for 5 iterations. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + builder.Gt(builder.ConstantR0(5), prev); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body: add 1 to the result variable. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto input = builder.ConstantR0(1); + auto result = builder.Add(input, prev); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, TestName()); + auto init = builder.Reduce(builder.ConstantR1(2, 1), + builder.ConstantR0(0), + CreateScalarAddComputation(S32, &builder), {0}); + auto result = builder.While(condition, body, init); + auto shape = builder.GetShape(result).ConsumeValueOrDie(); + + ComputeAndCompareR0(&builder, 5, {}); +} + +TEST_F(WhileTest, WhileWithPredicateResult) { + auto result_shape = ShapeUtil::MakeShape(PRED, {}); + + // Create a computation for the condition: run until condition is true. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + builder.Ne(builder.ConstantR0(true), prev); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body: or condition with true. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto result = builder.LogicalOr(prev, builder.ConstantR0(true)); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, TestName()); + auto init = builder.Ne(builder.ConstantR0(false), + builder.ConstantR0(true)); + auto result = builder.While(condition, body, init); + + ComputeAndCompareR0(&builder, true, {}); +} + // Tests a while node when the result type T is a vector. // // All constants are chosen to produce exact results. @@ -240,15 +303,62 @@ TEST_F(WhileTest, WhileWithTupleResult) { VLOG(2) << "while = " << ShapeUtil::HumanString( *builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = LiteralUtil::CreateR0(5); - auto expected_data = LiteralUtil::CreateR1( + auto expected_counter = Literal::CreateR0(5); + auto expected_data = Literal::CreateR1( {5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f, 5.0f}); auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); + Literal::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } +TEST_F(WhileTest, WhileWithPredicateTupleResult) { + std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeShape(PRED, {})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for 5 iterations. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Gt(builder.ConstantR0(5), iteration); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add 1 to the iteration variable and or the predicate with true + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto pred = builder.GetTupleElement(prev, 1); + auto new_pred = builder.LogicalOr(pred, builder.ConstantR0(true)); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), new_pred}); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple({builder.ConstantR0(0), + builder.Ne(builder.ConstantR0(false), + builder.ConstantR0(true))}); + auto result = builder.While(condition, body, init); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + + auto expected_counter = Literal::CreateR0(5); + auto expected_predicate = Literal::CreateR0(true); + auto expected = + Literal::MakeTuple({expected_counter.get(), expected_predicate.get()}); + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0)); +} + // Tests two while nodes when the result type T is a Tuple and the second // while node uses the result of the first while node which is used in two // nodes. @@ -525,11 +635,11 @@ XLA_TEST_F(WhileTest, WhileWithDynamicUpdateSlice) { << ShapeUtil::HumanString( *builder.GetShape(result).ConsumeValueOrDie()); - auto expected_counter = LiteralUtil::CreateR0(5); - auto expected_data = LiteralUtil::CreateR1( + auto expected_counter = Literal::CreateR0(5); + auto expected_data = Literal::CreateR1( {1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f, 4.0f, 4.0f, 5.0f, 5.0f}); auto expected = - LiteralUtil::MakeTuple({expected_counter.get(), expected_data.get()}); + Literal::MakeTuple({expected_counter.get(), expected_data.get()}); VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } @@ -589,7 +699,7 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) { for (int i = 1; i < 4; ++i) { TF_ASSIGN_OR_ASSERT_OK(auto computation, while_loop(i)); - ExecutionOptions execution_options; + ExecutionOptions execution_options = execution_options_; execution_options.set_seed(65); TF_ASSIGN_OR_ASSERT_OK( auto result, @@ -743,7 +853,6 @@ BENCHMARK(BM_WhileLoop); int main(int argc, char** argv) { std::vector flag_list; xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); - xla::legacy_flags::AppendCpuCompilerFlags(&flag_list); xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list); if (!parse_result) { diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 78762724678..afdc6726f17 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -104,8 +104,8 @@ StatusOr> TextLiteralReader::ReadAllLines() { auto result = MakeUnique(); const float fill = std::numeric_limits::quiet_NaN(); - LiteralUtil::PopulateWithValue(fill, AsInt64Slice(shape.dimensions()), - result.get()); + result.get()->PopulateWithValue(fill, + AsInt64Slice(shape.dimensions())); std::vector pieces; std::vector coordinates; std::vector coordinate_values; @@ -147,7 +147,7 @@ StatusOr> TextLiteralReader::ReadAllLines() { "\"%s\"", shape.dimensions_size(), coordinate_values.size(), line.c_str()); } - LiteralUtil::Set(result.get(), coordinate_values, value); + result.get()->Set(coordinate_values, value); } return std::move(result); } diff --git a/tensorflow/compiler/xla/text_literal_reader_test.cc b/tensorflow/compiler/xla/text_literal_reader_test.cc index a167d80f73b..23070b66387 100644 --- a/tensorflow/compiler/xla/text_literal_reader_test.cc +++ b/tensorflow/compiler/xla/text_literal_reader_test.cc @@ -46,12 +46,12 @@ TEST(TextLiteralReaderTest, ReadsR3File) { TextLiteralReader::ReadPath(fname).ConsumeValueOrDie(); EXPECT_TRUE( ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {1, 2, 3}), literal->shape())); - EXPECT_EQ(42.5, LiteralUtil::Get(*literal, {0, 0, 0})); - EXPECT_EQ(43.5, LiteralUtil::Get(*literal, {0, 0, 1})); - EXPECT_EQ(44.5, LiteralUtil::Get(*literal, {0, 0, 2})); - EXPECT_EQ(45.5, LiteralUtil::Get(*literal, {0, 1, 0})); - EXPECT_EQ(46.5, LiteralUtil::Get(*literal, {0, 1, 1})); - EXPECT_EQ(47.5, LiteralUtil::Get(*literal, {0, 1, 2})); + EXPECT_EQ(42.5, literal->Get({0, 0, 0})); + EXPECT_EQ(43.5, literal->Get({0, 0, 1})); + EXPECT_EQ(44.5, literal->Get({0, 0, 2})); + EXPECT_EQ(45.5, literal->Get({0, 1, 0})); + EXPECT_EQ(46.5, literal->Get({0, 1, 1})); + EXPECT_EQ(47.5, literal->Get({0, 1, 2})); } } // namespace diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc index a5097e41cb3..3fee467594d 100644 --- a/tensorflow/compiler/xla/text_literal_writer.cc +++ b/tensorflow/compiler/xla/text_literal_writer.cc @@ -45,9 +45,9 @@ namespace xla { tensorflow::Status status; tensorflow::WritableFile* f_ptr = f.get(); - LiteralUtil::EachCellAsString( - literal, [f_ptr, &status](tensorflow::gtl::ArraySlice indices, - const string& value) { + literal.EachCellAsString( + [f_ptr, &status](tensorflow::gtl::ArraySlice indices, + const string& value) { if (!status.ok()) { return; } diff --git a/tensorflow/compiler/xla/text_literal_writer_test.cc b/tensorflow/compiler/xla/text_literal_writer_test.cc index 177ae4ea036..70cf2fb1b8a 100644 --- a/tensorflow/compiler/xla/text_literal_writer_test.cc +++ b/tensorflow/compiler/xla/text_literal_writer_test.cc @@ -30,7 +30,7 @@ namespace xla { namespace { TEST(TextLiteralWriterTest, WritesFloatLiteral) { - auto literal = LiteralUtil::CreateR2({ + auto literal = Literal::CreateR2({ {3.14, 2.17}, {1.23, 4.56}, }); string path = diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index 3a75bf64954..6228ca34c08 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -98,11 +98,11 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool use_fake_data) { std::unique_ptr result = result_status.ConsumeValueOrDie(); fprintf(stdout, "%s: %s :: %s:%s\n", arg, module.entry().name().c_str(), ShapeUtil::HumanString(result->shape()).c_str(), - LiteralUtil::ToString(*result).c_str()); + result->ToString().c_str()); if (module.has_result()) { fprintf(stdout, "was %s:%s\n", ShapeUtil::HumanString(module.result().shape()).c_str(), - LiteralUtil::ToString(Literal(module.result())).c_str()); + Literal(module.result()).ToString().c_str()); } } } diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc index b6538f5de07..b50cb5e28ea 100644 --- a/tensorflow/compiler/xla/tools/show_literal.cc +++ b/tensorflow/compiler/xla/tools/show_literal.cc @@ -42,5 +42,5 @@ int main(int argc, char **argv) { &literal_proto)); xla::Literal literal(literal_proto); LOG(INFO) << "literal: " << literal_proto.ShortDebugString(); - fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(literal).c_str()); + fprintf(stderr, "%s\n", literal.ToString().c_str()); } diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc index 2d983b407c6..bbe9902aa17 100644 --- a/tensorflow/compiler/xla/tools/show_text_literal.cc +++ b/tensorflow/compiler/xla/tools/show_text_literal.cc @@ -40,7 +40,7 @@ int main(int argc, char **argv) { xla::TextLiteralReader::ReadPath(argv[1]).ConsumeValueOrDie(); LOG(INFO) << "literal: " << literal->ShortDebugString(); - fprintf(stderr, "%s\n", xla::LiteralUtil::ToString(*literal).c_str()); + fprintf(stderr, "%s\n", literal->ToString().c_str()); if (literal->shape().element_type() == xla::F32) { float min = *std::min_element(literal->f32s().begin(), literal->f32s().end()); diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 4c3cd321f68..b6289f8e1cb 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -45,9 +45,21 @@ message DebugOptions { // - Assuming that +0 and -0 are indistinguishable. bool xla_enable_fast_math = 4; + // Embed the compiler IR as a string in the executable. + bool xla_embed_ir_in_executable = 5; + + // Dump compilation artifacts as JSON into this directory. + string xla_dump_debug_json_to = 6; + + // Path to directory with cuda/ptx tools and libraries. + string xla_gpu_cuda_data_dir = 7; + + // Enable flush-to-zero semantics in the GPU backend. + bool xla_gpu_ftz = 8; + // Extra options to pass to the compilation backend; specific interpretation // of these values is left to the backend. - map xla_backend_extra_options = 5; + map xla_backend_extra_options = 500; } // These settings control how XLA compiles and/or runs code. Not all settings diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 9470e6c3b26..b53bf98e1c0 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -53,7 +53,8 @@ enum PrimitiveType { // computation; e.g. a computation that returns weights and biases may have a // signature that results in a tuple like (f32[784x2000], f32[2000]) // - // Tuples are currently special in that they may only be rank 0. + // If a shape proto has the tuple element type, it may not have any entries + // in the dimensions field. TUPLE = 13; // An opaque type used for passing context specific data to a custom @@ -255,11 +256,15 @@ message ComputationDataHandle { int64 handle = 1; } -// Handle given to a user that represents a device to execute a computation. -// When replication is enabled, the device handle represents the device for the -// replica id 0. +// Handle given to a user that represents a replicated virtual device. Each +// replicated device represents N physical devices for execution where N is the +// number of replicas. message DeviceHandle { int64 handle = 1; + + // The number of model-parallel virtual devices that communicate via XLA + // Send/Recv instructions. + int64 device_count = 2; } // Handle given to a user to represent a channel between two computations @@ -269,6 +274,21 @@ message ChannelHandle { int64 handle = 1; } +// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which +// represents the device ids assigned to a set of replicated computations. +// See xla::DeviceAssignment class comment for more details. +message DeviceAssignmentProto { + int32 replica_count = 1; + int32 computation_count = 2; + + // Each logical computation runs on replica_count physical devices. + // ComputationDevice represents the device ids assinged to the replicas. + message ComputationDevice { + repeated int32 replica_device_ids = 1; + } + repeated ComputationDevice computation_devices = 3; +} + // Literals are used when the server and client need to exchange materialized // data / results. Literals are also used to describe constants used in // computations. @@ -463,6 +483,14 @@ message ReduceWindowRequest { ComputationHandle to_apply = 5; } +message BatchNormTrainingRequest { + ComputationDataHandle operand = 1; + ComputationDataHandle scale = 2; + ComputationDataHandle offset = 3; + float epsilon = 4; + int64 feature_index = 5; +} + message CrossReplicaSumRequest { ComputationDataHandle operand = 2; } @@ -596,6 +624,9 @@ enum UnaryOperation { // Elementwise, tests if values are finite (not NaN or inf) UNOP_IS_FINITE = 11; + + // Elementwise, computes the cosine of x. + UNOP_COS = 12; } message UnaryOpRequest { @@ -713,6 +744,12 @@ message VariadicOpRequest { repeated ComputationDataHandle operands = 3; } +message ReducePrecisionRequest { + ComputationDataHandle operand = 1; + int32 exponent_bits = 2; + int32 mantissa_bits = 3; +} + message SendRequest { ComputationDataHandle operand = 1; ChannelHandle channel_handle = 2; @@ -744,6 +781,7 @@ message OpRequest { MapRequest map_request = 15; PadRequest pad_request = 16; ParameterRequest parameter_request = 17; + ReducePrecisionRequest reduce_precision_request = 36; ReduceRequest reduce_request = 18; ReduceWindowRequest reduce_window_request = 19; ReshapeRequest reshape_request = 20; @@ -760,7 +798,8 @@ message OpRequest { SendRequest send_request = 30; RecvRequest recv_request = 31; OutfeedRequest outfeed_request = 32; - // Next: 35 + BatchNormTrainingRequest batch_norm_training_request = 35; + // Next: 37 } } diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index b99933ff9b5..b1eff737992 100755 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -53,6 +53,7 @@ py_library( "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/opt:opt_py", "//tensorflow/contrib/quantization:quantization_py", + "//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/contrib/saved_model:saved_model_py", "//tensorflow/contrib/seq2seq:seq2seq_py", @@ -70,6 +71,9 @@ py_library( "//tensorflow/contrib/testing:testing_py", "//tensorflow/contrib/text:text_py", "//tensorflow/contrib/tfprof", + "//tensorflow/contrib/tpu:tpu_estimator", + "//tensorflow/contrib/tpu:tpu_helper_library", + "//tensorflow/contrib/tpu:tpu_py", "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", ], @@ -105,6 +109,7 @@ cc_library( "//tensorflow/contrib/seq2seq:beam_search_ops_op_lib", "//tensorflow/contrib/tensor_forest:tensor_forest_ops_op_lib", "//tensorflow/contrib/text:all_ops", + "//tensorflow/contrib/tpu:all_ops", ], ) diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index a94e809c139..16bc533436a 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -65,14 +65,16 @@ from tensorflow.contrib import tensor_forest from tensorflow.contrib import tensorboard from tensorflow.contrib import testing from tensorflow.contrib import tfprof +from tensorflow.contrib import tpu from tensorflow.contrib import training from tensorflow.contrib import util from tensorflow.contrib.ndlstm import python as ndlstm +from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph from tensorflow.contrib.specs import python as specs from tensorflow.python.util.lazy_loader import LazyLoader -ffmpeg = LazyLoader("ffmpeg", globals(), - "tensorflow.contrib.ffmpeg") +ffmpeg = LazyLoader("ffmpeg", + globals(), "tensorflow.contrib.ffmpeg") del LazyLoader del absolute_import diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index 04288a1934d..3edcb6bc20e 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -16,7 +16,6 @@ py_library( srcs = ["__init__.py"] + glob(["python/ops/*.py"]), srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:framework_for_generated_wrappers", @@ -26,11 +25,30 @@ py_library( "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", + "//tensorflow/python/ops/distributions", "//third_party/py/numpy", "@six_archive//:six", ], ) +cuda_py_test( + name = "csiszar_divergence_test", + size = "small", + srcs = ["python/kernel_tests/csiszar_divergence_test.py"], + additional_deps = [ + ":bayesflow_py", + "//third_party/py/numpy", + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "entropy_test", size = "medium", diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py index dcda7377002..65a17d742c1 100644 --- a/tensorflow/contrib/bayesflow/__init__.py +++ b/tensorflow/contrib/bayesflow/__init__.py @@ -21,6 +21,7 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,line-too-long +from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence from tensorflow.contrib.bayesflow.python.ops import entropy from tensorflow.contrib.bayesflow.python.ops import monte_carlo from tensorflow.contrib.bayesflow.python.ops import stochastic_gradient_estimators @@ -33,7 +34,7 @@ from tensorflow.contrib.bayesflow.python.ops import variational_inference from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = ['entropy', 'monte_carlo', +_allowed_symbols = ['csiszar_divergence', 'entropy', 'monte_carlo', 'special_math', 'stochastic_gradient_estimators', 'stochastic_graph', 'stochastic_tensor', 'stochastic_variables', 'variational_inference'] diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py new file mode 100644 index 00000000000..fabf7a9b779 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/csiszar_divergence_test.py @@ -0,0 +1,661 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Csiszar Divergence Ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence_impl +from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib +from tensorflow.contrib.distributions.python.ops import mvn_full_covariance as mvn_full_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import kullback_leibler +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.platform import test + + +cd = csiszar_divergence_impl + + +class AmariAlphaTest(test.TestCase): + + def setUp(self): + self._logu = np.linspace(-10., 10, 100) + self._u = np.exp(self._logu) + + def test_at_zero(self): + for alpha in [-1., 0., 1., 2.]: + for normalized in [True, False]: + with self.test_session(graph=ops.Graph()): + self.assertAllClose( + cd.amari_alpha(0., alpha=alpha, + self_normalized=normalized).eval(), + 0.) + + def test_correct_when_alpha0(self): + with self.test_session(): + self.assertAllClose( + cd.amari_alpha(self._logu, alpha=0.).eval(), + -self._logu) + + self.assertAllClose( + cd.amari_alpha(self._logu, alpha=0., self_normalized=True).eval(), + -self._logu + (self._u - 1.)) + + def test_correct_when_alpha1(self): + with self.test_session(): + self.assertAllClose( + cd.amari_alpha(self._logu, alpha=1.).eval(), + self._u * self._logu) + + self.assertAllClose( + cd.amari_alpha(self._logu, alpha=1., self_normalized=True).eval(), + self._u * self._logu - (self._u - 1.)) + + def test_correct_when_alpha_not_01(self): + for alpha in [-2, -1., -0.5, 0.5, 2.]: + with self.test_session(graph=ops.Graph()): + self.assertAllClose( + cd.amari_alpha(self._logu, + alpha=alpha, + self_normalized=False).eval(), + ((self._u**alpha - 1)) / (alpha * (alpha - 1.))) + + self.assertAllClose( + cd.amari_alpha(self._logu, + alpha=alpha, + self_normalized=True).eval(), + ((self._u**alpha - 1.) + - alpha * (self._u - 1)) / (alpha * (alpha - 1.))) + + +class KLReverseTest(test.TestCase): + + def setUp(self): + self._logu = np.linspace(-10., 10, 100) + self._u = np.exp(self._logu) + + def test_at_zero(self): + for normalized in [True, False]: + with self.test_session(graph=ops.Graph()): + self.assertAllClose( + cd.kl_reverse(0., self_normalized=normalized).eval(), + 0.) + + def test_correct(self): + with self.test_session(): + self.assertAllClose( + cd.kl_reverse(self._logu).eval(), + -self._logu) + + self.assertAllClose( + cd.kl_reverse(self._logu, self_normalized=True).eval(), + -self._logu + (self._u - 1.)) + + +class KLForwardTest(test.TestCase): + + def setUp(self): + self._logu = np.linspace(-10., 10, 100) + self._u = np.exp(self._logu) + + def test_at_zero(self): + for normalized in [True, False]: + with self.test_session(graph=ops.Graph()): + self.assertAllClose( + cd.kl_forward(0., self_normalized=normalized).eval(), + 0.) + + def test_correct(self): + with self.test_session(): + self.assertAllClose( + cd.kl_forward(self._logu).eval(), + self._u * self._logu) + + self.assertAllClose( + cd.kl_forward(self._logu, self_normalized=True).eval(), + self._u * self._logu - (self._u - 1.)) + + +class JensenShannonTest(test.TestCase): + + def setUp(self): + self._logu = np.linspace(-10., 10, 100) + self._u = np.exp(self._logu) + + def test_at_zero(self): + with self.test_session(): + self.assertAllClose(cd.jensen_shannon(0.).eval(), np.log(0.25)) + + def test_symmetric(self): + with self.test_session(): + self.assertAllClose( + cd.jensen_shannon(self._logu).eval(), + cd.symmetrized_csiszar_function( + self._logu, cd.jensen_shannon).eval()) + + self.assertAllClose( + cd.jensen_shannon(self._logu, self_normalized=True).eval(), + cd.symmetrized_csiszar_function( + self._logu, + lambda x: cd.jensen_shannon(x, self_normalized=True)).eval()) + + def test_correct(self): + with self.test_session(): + self.assertAllClose( + cd.jensen_shannon(self._logu).eval(), + (self._u * self._logu + - (1 + self._u) * np.log1p(self._u))) + + self.assertAllClose( + cd.jensen_shannon(self._logu, self_normalized=True).eval(), + (self._u * self._logu + - (1 + self._u) * np.log((1 + self._u) / 2))) + + +class ArithmeticGeometricMeanTest(test.TestCase): + + def setUp(self): + self._logu = np.linspace(-10., 10, 100) + self._u = np.exp(self._logu) + + def test_at_zero(self): + with self.test_session(): + self.assertAllClose(cd.arithmetic_geometric(0.).eval(), np.log(4)) + self.assertAllClose( + cd.arithmetic_geometric(0., self_normalized=True).eval(), 0.) + + def test_symmetric(self): + with self.test_session(): + self.assertAllClose( + cd.arithmetic_geometric(self._logu).eval(), + cd.symmetrized_csiszar_function( + self._logu, cd.arithmetic_geometric).eval()) + + def test_correct(self): + with self.test_session(): + self.assertAllClose( + cd.arithmetic_geometric(self._logu).eval(), + (1. + self._u) * np.log((1. + self._u) / np.sqrt(self._u))) + + self.assertAllClose( + cd.arithmetic_geometric(self._logu, self_normalized=True).eval(), + (1. + self._u) * np.log(0.5 * (1. + self._u) / np.sqrt(self._u))) + + +class TotalVariationTest(test.TestCase): + + def setUp(self): + self._logu = np.linspace(-10., 10, 100) + self._u = np.exp(self._logu) + + def test_at_zero(self): + with self.test_session(): + self.assertAllClose(cd.total_variation(0.).eval(), 0.) + + def test_correct(self): + with self.test_session(): + self.assertAllClose( + cd.total_variation(self._logu).eval(), + 0.5 * np.abs(self._u - 1)) + + +class PearsonTest(test.TestCase): + + def setUp(self): + self._logu = np.linspace(-10., 10, 100) + self._u = np.exp(self._logu) + + def test_at_zero(self): + with self.test_session(): + self.assertAllClose(cd.pearson(0.).eval(), 0.) + + def test_correct(self): + with self.test_session(): + self.assertAllClose( + cd.pearson(self._logu).eval(), + np.square(self._u - 1)) + + +class SquaredHellingerTest(test.TestCase): + + def setUp(self): + self._logu = np.linspace(-10., 10, 100) + self._u = np.exp(self._logu) + + def test_at_zero(self): + with self.test_session(): + self.assertAllClose(cd.squared_hellinger(0.).eval(), 0.) + + def test_symmetric(self): + with self.test_session(): + self.assertAllClose( + cd.squared_hellinger(self._logu).eval(), + cd.symmetrized_csiszar_function( + self._logu, cd.squared_hellinger).eval()) + + def test_correct(self): + with self.test_session(): + self.assertAllClose( + cd.squared_hellinger(self._logu).eval(), + np.square(np.sqrt(self._u) - 1)) + + +class TriangularTest(test.TestCase): + + def setUp(self): + self._logu = np.linspace(-10., 10, 100) + self._u = np.exp(self._logu) + + def test_at_zero(self): + with self.test_session(): + self.assertAllClose(cd.triangular(0.).eval(), 0.) + + def test_symmetric(self): + with self.test_session(): + self.assertAllClose( + cd.triangular(self._logu).eval(), + cd.symmetrized_csiszar_function( + self._logu, cd.triangular).eval()) + + def test_correct(self): + with self.test_session(): + self.assertAllClose( + cd.triangular(self._logu).eval(), + np.square(self._u - 1) / (1 + self._u)) + + +class Log1pAbsTest(test.TestCase): + + def setUp(self): + self._logu = np.linspace(-10., 10, 100) + self._u = np.exp(self._logu) + + def test_at_zero(self): + with self.test_session(): + self.assertAllClose(cd.log1p_abs(0.).eval(), 0.) + + def test_correct(self): + with self.test_session(): + self.assertAllClose( + cd.log1p_abs(self._logu).eval(), + self._u**(np.sign(self._u - 1)) - 1) + + +class JeffreysTest(test.TestCase): + + def setUp(self): + self._logu = np.linspace(-10., 10, 100) + self._u = np.exp(self._logu) + + def test_at_zero(self): + with self.test_session(): + self.assertAllClose(cd.jeffreys(0.).eval(), 0.) + + def test_symmetric(self): + with self.test_session(): + self.assertAllClose( + cd.jeffreys(self._logu).eval(), + cd.symmetrized_csiszar_function( + self._logu, cd.jeffreys).eval()) + + def test_correct(self): + with self.test_session(): + self.assertAllClose( + cd.jeffreys(self._logu).eval(), + 0.5 * (self._u * self._logu - self._logu)) + + +class ChiSquareTest(test.TestCase): + + def setUp(self): + self._logu = np.linspace(-10., 10, 100) + self._u = np.exp(self._logu) + + def test_at_zero(self): + with self.test_session(): + self.assertAllClose(cd.chi_square(0.).eval(), 0.) + + def test_correct(self): + with self.test_session(): + self.assertAllClose( + cd.chi_square(self._logu).eval(), + self._u**2 - 1) + + +class ModifiedGanTest(test.TestCase): + + def setUp(self): + self._logu = np.linspace(-10., 10, 100) + self._u = np.exp(self._logu) + + def test_at_zero(self): + with self.test_session(): + self.assertAllClose( + cd.modified_gan(0.).eval(), np.log(2)) + self.assertAllClose( + cd.modified_gan(0., self_normalized=True).eval(), np.log(2)) + + def test_correct(self): + with self.test_session(): + self.assertAllClose( + cd.modified_gan(self._logu).eval(), + np.log1p(self._u) - self._logu) + + self.assertAllClose( + cd.modified_gan(self._logu, self_normalized=True).eval(), + np.log1p(self._u) - self._logu + 0.5 * (self._u - 1)) + + +class SymmetrizedCsiszarFunctionTest(test.TestCase): + + def setUp(self): + self._logu = np.linspace(-10., 10., 100) + self._u = np.exp(self._logu) + + def test_jensen_shannon(self): + with self.test_session(): + + # The following functions come from the claim made in the + # symmetrized_csiszar_function docstring. + def js1(logu): + return (-logu + - (1. + math_ops.exp(logu)) * ( + nn_ops.softplus(logu))) + + def js2(logu): + return 2. * (math_ops.exp(logu) * ( + logu - nn_ops.softplus(logu))) + + self.assertAllClose( + cd.symmetrized_csiszar_function(self._logu, js1).eval(), + cd.jensen_shannon(self._logu).eval()) + + self.assertAllClose( + cd.symmetrized_csiszar_function(self._logu, js2).eval(), + cd.jensen_shannon(self._logu).eval()) + + def test_jeffreys(self): + with self.test_session(): + self.assertAllClose( + cd.symmetrized_csiszar_function(self._logu, cd.kl_reverse).eval(), + cd.jeffreys(self._logu).eval()) + + self.assertAllClose( + cd.symmetrized_csiszar_function(self._logu, cd.kl_forward).eval(), + cd.jeffreys(self._logu).eval()) + + +class DualCsiszarFunctionTest(test.TestCase): + + def setUp(self): + self._logu = np.linspace(-10., 10., 100) + self._u = np.exp(self._logu) + + def test_kl_forward(self): + with self.test_session(): + self.assertAllClose( + cd.dual_csiszar_function(self._logu, cd.kl_forward).eval(), + cd.kl_reverse(self._logu).eval()) + + def test_kl_reverse(self): + with self.test_session(): + self.assertAllClose( + cd.dual_csiszar_function(self._logu, cd.kl_reverse).eval(), + cd.kl_forward(self._logu).eval()) + + +class MonteCarloCsiszarFDivergenceTest(test.TestCase): + + def test_kl_forward(self): + with self.test_session() as sess: + q = normal_lib.Normal( + loc=np.ones(6), + scale=np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])) + + p = normal_lib.Normal(loc=q.loc + 0.1, scale=q.scale - 0.2) + + approx_kl = cd.monte_carlo_csiszar_f_divergence( + f=cd.kl_forward, + p=p, + q=q, + num_draws=int(1e5), + seed=1) + + approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( + f=lambda logu: cd.kl_forward(logu, self_normalized=True), + p=p, + q=q, + num_draws=int(1e5), + seed=1) + + exact_kl = kullback_leibler.kl_divergence(p, q) + + [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([ + approx_kl, approx_kl_self_normalized, exact_kl]) + + self.assertAllClose(approx_kl_, exact_kl_, + rtol=0.08, atol=0.) + + self.assertAllClose(approx_kl_self_normalized_, exact_kl_, + rtol=0.02, atol=0.) + + def test_kl_reverse(self): + with self.test_session() as sess: + + q = normal_lib.Normal( + loc=np.ones(6), + scale=np.array([0.5, 1.0, 1.5, 2.0, 2.5, 3.0])) + + p = normal_lib.Normal(loc=q.loc + 0.1, scale=q.scale - 0.2) + + approx_kl = cd.monte_carlo_csiszar_f_divergence( + f=cd.kl_reverse, + p=p, + q=q, + num_draws=int(1e5), + seed=1) + + approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( + f=lambda logu: cd.kl_reverse(logu, self_normalized=True), + p=p, + q=q, + num_draws=int(1e5), + seed=1) + + exact_kl = kullback_leibler.kl_divergence(q, p) + + [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([ + approx_kl, approx_kl_self_normalized, exact_kl]) + + self.assertAllClose(approx_kl_, exact_kl_, + rtol=0.07, atol=0.) + + self.assertAllClose(approx_kl_self_normalized_, exact_kl_, + rtol=0.02, atol=0.) + + def _tridiag(self, d, diag_value, offdiag_value): + """d x d matrix with given value on diag, and one super/sub diag.""" + diag_mat = linalg_ops.eye(d) * (diag_value - offdiag_value) + three_bands = array_ops.matrix_band_part( + array_ops.fill([d, d], offdiag_value), 1, 1) + return diag_mat + three_bands + + def test_kl_reverse_multidim(self): + + with self.test_session() as sess: + d = 5 # Dimension + + p = mvn_full_lib.MultivariateNormalFullCovariance( + covariance_matrix=self._tridiag(d, diag_value=1, offdiag_value=0.5)) + + q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[0.5]*d) + + approx_kl = cd.monte_carlo_csiszar_f_divergence( + f=cd.kl_reverse, + p=p, + q=q, + num_draws=int(1e5), + seed=1) + + approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( + f=lambda logu: cd.kl_reverse(logu, self_normalized=True), + p=p, + q=q, + num_draws=int(1e5), + seed=1) + + exact_kl = kullback_leibler.kl_divergence(q, p) + + [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([ + approx_kl, approx_kl_self_normalized, exact_kl]) + + self.assertAllClose(approx_kl_, exact_kl_, + rtol=0.02, atol=0.) + + self.assertAllClose(approx_kl_self_normalized_, exact_kl_, + rtol=0.08, atol=0.) + + def test_kl_forward_multidim(self): + + with self.test_session() as sess: + d = 5 # Dimension + + p = mvn_full_lib.MultivariateNormalFullCovariance( + covariance_matrix=self._tridiag(d, diag_value=1, offdiag_value=0.5)) + + # Variance is very high when approximating Forward KL, so we make + # scale_diag larger than in test_kl_reverse_multidim. This ensures q + # "covers" p and thus Var_q[p/q] is smaller. + q = mvn_diag_lib.MultivariateNormalDiag(scale_diag=[1.]*d) + + approx_kl = cd.monte_carlo_csiszar_f_divergence( + f=cd.kl_forward, + p=p, + q=q, + num_draws=int(1e5), + seed=1) + + approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( + f=lambda logu: cd.kl_forward(logu, self_normalized=True), + p=p, + q=q, + num_draws=int(1e5), + seed=1) + + exact_kl = kullback_leibler.kl_divergence(p, q) + + [approx_kl_, approx_kl_self_normalized_, exact_kl_] = sess.run([ + approx_kl, approx_kl_self_normalized, exact_kl]) + + self.assertAllClose(approx_kl_, exact_kl_, + rtol=0.06, atol=0.) + + self.assertAllClose(approx_kl_self_normalized_, exact_kl_, + rtol=0.05, atol=0.) + + def test_score_trick(self): + + with self.test_session() as sess: + d = 5 # Dimension + num_draws = int(1e5) + seed = 1 + + p = mvn_full_lib.MultivariateNormalFullCovariance( + covariance_matrix=self._tridiag(d, diag_value=1, offdiag_value=0.5)) + + # Variance is very high when approximating Forward KL, so we make + # scale_diag larger than in test_kl_reverse_multidim. This ensures q + # "covers" p and thus Var_q[p/q] is smaller. + s = array_ops.constant(1.) + q = mvn_diag_lib.MultivariateNormalDiag( + scale_diag=array_ops.tile([s], [d])) + + approx_kl = cd.monte_carlo_csiszar_f_divergence( + f=cd.kl_reverse, + p=p, + q=q, + num_draws=num_draws, + seed=seed) + + approx_kl_self_normalized = cd.monte_carlo_csiszar_f_divergence( + f=lambda logu: cd.kl_reverse(logu, self_normalized=True), + p=p, + q=q, + num_draws=num_draws, + seed=seed) + + approx_kl_score_trick = cd.monte_carlo_csiszar_f_divergence( + f=cd.kl_reverse, + p=p, + q=q, + num_draws=num_draws, + use_reparametrization=False, + seed=seed) + + approx_kl_self_normalized_score_trick = ( + cd.monte_carlo_csiszar_f_divergence( + f=lambda logu: cd.kl_reverse(logu, self_normalized=True), + p=p, + q=q, + num_draws=num_draws, + use_reparametrization=False, + seed=seed)) + + exact_kl = kullback_leibler.kl_divergence(q, p) + + grad = lambda fs: gradients_impl.gradients(fs, s)[0] + + [ + approx_kl_, + approx_kl_self_normalized_, + approx_kl_score_trick_, + approx_kl_self_normalized_score_trick_, + exact_kl_, + ] = sess.run([ + grad(approx_kl), + grad(approx_kl_self_normalized), + grad(approx_kl_score_trick), + grad(approx_kl_self_normalized_score_trick), + grad(exact_kl), + ]) + + self.assertAllClose( + approx_kl_, exact_kl_, + rtol=0.06, atol=0.) + + self.assertAllClose( + approx_kl_self_normalized_, exact_kl_, + rtol=0.05, atol=0.) + + self.assertAllClose( + approx_kl_score_trick_, exact_kl_, + rtol=0.06, atol=0.) + + self.assertAllClose( + approx_kl_self_normalized_score_trick_, exact_kl_, + rtol=0.05, atol=0.) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py new file mode 100644 index 00000000000..5440df7dbfc --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence.py @@ -0,0 +1,49 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Csiszar f-Divergence and helpers. + +See ${python/contrib.bayesflow.csiszar_divergence}. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.bayesflow.python.ops.csiszar_divergence_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'amari_alpha', + 'arithmetic_geometric', + 'chi_square', + 'dual_csiszar_function', + 'jeffreys', + 'jensen_shannon', + 'kl_forward', + 'kl_reverse', + 'log1p_abs', + 'modified_gan', + 'monte_carlo_csiszar_f_divergence', + 'pearson', + 'squared_hellinger', + 'symmetrized_csiszar_function', + 'total_variation', + 'triangular', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py new file mode 100644 index 00000000000..7b51d8d9322 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/csiszar_divergence_impl.py @@ -0,0 +1,854 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Csiszar f-Divergence and helpers. + +@@amari_alpha +@@arithmetic_geometric +@@chi_square +@@dual_csiszar_function +@@jeffreys +@@jensen_shannon +@@kl_forward +@@kl_reverse +@@log1p_abs +@@modified_gan +@@monte_carlo_csiszar_f_divergence +@@pearson +@@squared_hellinger +@@symmetrized_csiszar_function +@@total_variation +@@triangular + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib import framework as contrib_framework +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import distribution + + +def amari_alpha(logu, alpha=1., self_normalized=False, name=None): + """The Amari-alpha Csiszar-function in log-space. + + A Csiszar-function is a member of, + + ```none + F = { f:R_+ to R : f convex }. + ``` + + When `self_normalized = True`, the Amari-alpha Csiszar-function is: + + ```none + f(u) = { -log(u) + (u - 1), alpha = 0 + { u log(u) - (u - 1), alpha = 1 + { [(u**alpha - 1) - alpha (u - 1)] / (alpha (alpha - 1)), otherwise + ``` + + When `self_normalized = False` the `(u - 1)` terms are omitted. + + Warning: when `alpha != 0` and/or `self_normalized = True` this function makes + non-log-space calculations and may therefore be numerically unstable for + `|logu| >> 0`. + + For more information, see: + A. Cichocki and S. Amari. "Families of Alpha-Beta-and GammaDivergences: + Flexible and Robust Measures of Similarities." Entropy, vol. 12, no. 6, pp. + 1532-1568, 2010. + + Args: + logu: Floating-type `Tensor` representing `log(u)` from above. + alpha: Floating-type Python scalar. (See Mathematical Details for meaning.) + self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When + `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even + when `p, q` are unnormalized measures. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + amari_alpha_of_u: Floating-type `Tensor` of the Csiszar-function evaluated + at `u = exp(logu)`. + + Raises: + TypeError: if `alpha` is `None` or a `Tensor`. + TypeError: if `self_normalized` is `None` or a `Tensor`. + """ + with ops.name_scope(name, "amari_alpha", [logu]): + if alpha is None or contrib_framework.is_tensor(alpha): + raise TypeError("`alpha` cannot be `None` or `Tensor` type.") + if self_normalized is None or contrib_framework.is_tensor(self_normalized): + raise TypeError("`self_normalized` cannot be `None` or `Tensor` type.") + + logu = ops.convert_to_tensor(logu, name="logu") + + if alpha == 0.: + f = -logu + elif alpha == 1.: + f = math_ops.exp(logu) * logu + else: + f = math_ops.expm1(alpha * logu) / (alpha * (alpha - 1.)) + + if not self_normalized: + return f + + if alpha == 0.: + return f + math_ops.expm1(logu) + elif alpha == 1.: + return f - math_ops.expm1(logu) + else: + return f - math_ops.expm1(logu) / (alpha - 1.) + + +def kl_reverse(logu, self_normalized=False, name=None): + """The reverse Kullback-Leibler Csiszar-function in log-space. + + A Csiszar-function is a member of, + + ```none + F = { f:R_+ to R : f convex }. + ``` + + When `self_normalized = True`, the KL-reverse Csiszar-function is: + + ```none + f(u) = -log(u) + (u - 1) + ``` + + When `self_normalized = False` the `(u - 1)` term is omitted. + + Observe that as an f-Divergence, this Csiszar-function implies: + + ```none + D_f[p, q] = KL[q, p] + ``` + + The KL is "reverse" because in maximum likelihood we think of minimizing `q` + as in `KL[p, q]`. + + Warning: when self_normalized = True` this function makes non-log-space + calculations and may therefore be numerically unstable for `|logu| >> 0`. + + Args: + logu: Floating-type `Tensor` representing `log(u)` from above. + self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When + `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even + when `p, q` are unnormalized measures. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + kl_reverse_of_u: Floating-type `Tensor` of the Csiszar-function evaluated at + `u = exp(logu)`. + + Raises: + TypeError: if `self_normalized` is `None` or a `Tensor`. + """ + + with ops.name_scope(name, "kl_reverse", [logu]): + return amari_alpha(logu, alpha=0., self_normalized=self_normalized) + + +def kl_forward(logu, self_normalized=False, name=None): + """The forward Kullback-Leibler Csiszar-function in log-space. + + A Csiszar-function is a member of, + + ```none + F = { f:R_+ to R : f convex }. + ``` + + When `self_normalized = True`, the KL-forward Csiszar-function is: + + ```none + f(u) = u log(u) - (u - 1) + ``` + + When `self_normalized = False` the `(u - 1)` term is omitted. + + Observe that as an f-Divergence, this Csiszar-function implies: + + ```none + D_f[p, q] = KL[p, q] + ``` + + The KL is "forward" because in maximum likelihood we think of minimizing `q` + as in `KL[p, q]`. + + Warning: this function makes non-log-space calculations and may therefore be + numerically unstable for `|logu| >> 0`. + + Args: + logu: Floating-type `Tensor` representing `log(u)` from above. + self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When + `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even + when `p, q` are unnormalized measures. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + kl_forward_of_u: Floating-type `Tensor` of the Csiszar-function evaluated at + `u = exp(logu)`. + + Raises: + TypeError: if `self_normalized` is `None` or a `Tensor`. + """ + + with ops.name_scope(name, "kl_forward", [logu]): + return amari_alpha(logu, alpha=1., self_normalized=self_normalized) + + +def jensen_shannon(logu, self_normalized=False, name=None): + """The Jensen-Shannon Csiszar-function in log-space. + + A Csiszar-function is a member of, + + ```none + F = { f:R_+ to R : f convex }. + ``` + + When `self_normalized = True`, the Jensen-Shannon Csiszar-function is: + + ```none + f(u) = u log(u) - (1 + u) log(1 + u) + (u + 1) log(2) + ``` + + When `self_normalized = False` the `(u + 1) log(2)` term is omitted. + + Observe that as an f-Divergence, this Csiszar-function implies: + + ```none + D_f[p, q] = KL[p, m] + KL[q, m] + m(x) = 0.5 p(x) + 0.5 q(x) + ``` + + In a sense, this divergence is the "reverse" of the Arithmetic-Geometric + f-Divergence. + + This Csiszar-function induces a symmetric f-Divergence, i.e., + `D_f[p, q] = D_f[q, p]`. + + Warning: this function makes non-log-space calculations and may therefore be + numerically unstable for `|logu| >> 0`. + + For more information, see: + Lin, J. "Divergence measures based on the Shannon entropy." IEEE Trans. + Inf. Th., 37, 145-151, 1991. + + Args: + logu: Floating-type `Tensor` representing `log(u)` from above. + self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When + `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even + when `p, q` are unnormalized measures. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + jensen_shannon_of_u: Floating-type `Tensor` of the Csiszar-function + evaluated at `u = exp(logu)`. + """ + + with ops.name_scope(name, "jensen_shannon", [logu]): + logu = ops.convert_to_tensor(logu, name="logu") + npdt = logu.dtype.as_numpy_dtype + y = nn_ops.softplus(logu) + if self_normalized: + y -= np.log(2).astype(npdt) + return math_ops.exp(logu) * logu - (1. + math_ops.exp(logu)) * y + + +def arithmetic_geometric(logu, self_normalized=False, name=None): + """The Arithmetic-Geometric Csiszar-function in log-space. + + A Csiszar-function is a member of, + + ```none + F = { f:R_+ to R : f convex }. + ``` + + When `self_normalized = True` the Arithmetic-Geometric Csiszar-function is: + + ```none + f(u) = (1 + u) log( (1 + u) / sqrt(u) ) - (1 + u) log(2) + ``` + + When `self_normalized = False` the `(1 + u) log(2)` term is omitted. + + Observe that as an f-Divergence, this Csiszar-function implies: + + ```none + D_f[p, q] = KL[m, p] + KL[m, q] + m(x) = 0.5 p(x) + 0.5 q(x) + ``` + + In a sense, this divergence is the "reverse" of the Jensen-Shannon + f-Divergence. + + This Csiszar-function induces a symmetric f-Divergence, i.e., + `D_f[p, q] = D_f[q, p]`. + + Warning: this function makes non-log-space calculations and may therefore be + numerically unstable for `|logu| >> 0`. + + Args: + logu: Floating-type `Tensor` representing `log(u)` from above. + self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When + `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even + when `p, q` are unnormalized measures. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + arithmetic_geometric_of_u: Floating-type `Tensor` of the + Csiszar-function evaluated at `u = exp(logu)`. + """ + + with ops.name_scope(name, "arithmetic_geometric", [logu]): + logu = ops.convert_to_tensor(logu, name="logu") + y = nn_ops.softplus(logu) - 0.5 * logu + if self_normalized: + y -= np.log(2.).astype(logu.dtype.as_numpy_dtype) + return (1. + math_ops.exp(logu)) * y + + +def total_variation(logu, name=None): + """The Total Variation Csiszar-function in log-space. + + A Csiszar-function is a member of, + + ```none + F = { f:R_+ to R : f convex }. + ``` + + The Total-Variation Csiszar-function is: + + ```none + f(u) = 0.5 |u - 1| + ``` + + Warning: this function makes non-log-space calculations and may therefore be + numerically unstable for `|logu| >> 0`. + + Args: + logu: Floating-type `Tensor` representing `log(u)` from above. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + total_variation_of_u: Floating-type `Tensor` of the Csiszar-function + evaluated at `u = exp(logu)`. + """ + + with ops.name_scope(name, "total_variation", [logu]): + logu = ops.convert_to_tensor(logu, name="logu") + return 0.5 * math_ops.abs(math_ops.expm1(logu)) + + +def pearson(logu, name=None): + """The Pearson Csiszar-function in log-space. + + A Csiszar-function is a member of, + + ```none + F = { f:R_+ to R : f convex }. + ``` + + The Pearson Csiszar-function is: + + ```none + f(u) = (u - 1)**2 + ``` + + Warning: this function makes non-log-space calculations and may therefore be + numerically unstable for `|logu| >> 0`. + + Args: + logu: Floating-type `Tensor` representing `log(u)` from above. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + pearson_of_u: Floating-type `Tensor` of the Csiszar-function evaluated at + `u = exp(logu)`. + """ + + with ops.name_scope(name, "pearson", [logu]): + logu = ops.convert_to_tensor(logu, name="logu") + return math_ops.square(math_ops.expm1(logu)) + + +def squared_hellinger(logu, name=None): + """The Amari-alpha Csiszar-function in log-space. + + A Csiszar-function is a member of, + + ```none + F = { f:R_+ to R : f convex }. + ``` + + The Squared-Hellinger Csiszar-function is: + + ```none + f(u) = (sqrt(u) - 1)**2 + ``` + + This Csiszar-function induces a symmetric f-Divergence, i.e., + `D_f[p, q] = D_f[q, p]`. + + Warning: this function makes non-log-space calculations and may therefore be + numerically unstable for `|logu| >> 0`. + + Args: + logu: Floating-type `Tensor` representing `log(u)` from above. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + squared_hellinger_of_u: Floating-type `Tensor` of the Csiszar-function + evaluated at `u = exp(logu)`. + """ + + with ops.name_scope(name, "squared_hellinger", [logu]): + logu = ops.convert_to_tensor(logu, name="logu") + return pearson(0.5 * logu) + + +def triangular(logu, name=None): + """The Amari-alpha Csiszar-function in log-space. + + A Csiszar-function is a member of, + + ```none + F = { f:R_+ to R : f convex }. + ``` + + The Triangular Csiszar-function is: + + ```none + f(u) = (u - 1)**2 / (1 + u) + ``` + + This Csiszar-function induces a symmetric f-Divergence, i.e., + `D_f[p, q] = D_f[q, p]`. + + Warning: this function makes non-log-space calculations and may therefore be + numerically unstable for `|logu| >> 0`. + + Args: + logu: Floating-type `Tensor` representing `log(u)` from above. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + triangular_of_u: Floating-type `Tensor` of the Csiszar-function evaluated + at `u = exp(logu)`. + """ + + with ops.name_scope(name, "triangular", [logu]): + logu = ops.convert_to_tensor(logu, name="logu") + return pearson(logu) / (1. + math_ops.exp(logu)) + + +def log1p_abs(logu, name=None): + """The log1p-abs Csiszar-function in log-space. + + A Csiszar-function is a member of, + + ```none + F = { f:R_+ to R : f convex }. + ``` + + The Log1p-Abs Csiszar-function is: + + ```none + f(u) = u**(sign(u-1)) - 1 + ``` + + This function is so-named because it was invented from the following recipe. + Choose a convex function g such that g(0)=0 and solve for f: + + ```none + log(1 + f(u)) = g(log(u)). + <=> + f(u) = exp(g(log(u))) - 1 + ``` + + That is, the graph is identically `g` when y-axis is `log1p`-domain and x-axis + is `log`-domain. + + Warning: this function makes non-log-space calculations and may therefore be + numerically unstable for `|logu| >> 0`. + + Args: + logu: Floating-type `Tensor` representing `log(u)` from above. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + log1p_abs_of_u: Floating-type `Tensor` of the Csiszar-function evaluated + at `u = exp(logu)`. + """ + + with ops.name_scope(name, "log1p_abs", [logu]): + logu = ops.convert_to_tensor(logu, name="logu") + return math_ops.expm1(math_ops.abs(logu)) + + +def jeffreys(logu, name=None): + """The Jeffreys Csiszar-function in log-space. + + A Csiszar-function is a member of, + + ```none + F = { f:R_+ to R : f convex }. + ``` + + The Jeffreys Csiszar-function is: + + ```none + f(u) = 0.5 ( u log(u) - log(u) ) + = 0.5 kl_forward + 0.5 kl_reverse + = symmetrized_csiszar_function(kl_reverse) + = symmetrized_csiszar_function(kl_forward) + ``` + + This Csiszar-function induces a symmetric f-Divergence, i.e., + `D_f[p, q] = D_f[q, p]`. + + Warning: this function makes non-log-space calculations and may therefore be + numerically unstable for `|logu| >> 0`. + + Args: + logu: Floating-type `Tensor` representing `log(u)` from above. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + jeffreys_of_u: Floating-type `Tensor` of the Csiszar-function evaluated + at `u = exp(logu)`. + """ + + with ops.name_scope(name, "jeffreys", [logu]): + logu = ops.convert_to_tensor(logu, name="logu") + return 0.5 * math_ops.expm1(logu) * logu + + +def chi_square(logu, name=None): + """The chi-Square Csiszar-function in log-space. + + A Csiszar-function is a member of, + + ```none + F = { f:R_+ to R : f convex }. + ``` + + The Chi-square Csiszar-function is: + + ```none + f(u) = u**2 - 1 + ``` + + Warning: this function makes non-log-space calculations and may therefore be + numerically unstable for `|logu| >> 0`. + + Args: + logu: Floating-type `Tensor` representing `log(u)` from above. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + chi_square_of_u: Floating-type `Tensor` of the Csiszar-function evaluated + at `u = exp(logu)`. + """ + + with ops.name_scope(name, "chi_square", [logu]): + logu = ops.convert_to_tensor(logu, name="logu") + return math_ops.expm1(2. * logu) + + +def modified_gan(logu, self_normalized=False, name=None): + """The Modified-GAN Csiszar-function in log-space. + + A Csiszar-function is a member of, + + ```none + F = { f:R_+ to R : f convex }. + ``` + + When `self_normalized = True` the modified-GAN (Generative/Adversarial + Network) Csiszar-function is: + + ```none + f(u) = log(1 + u) - log(u) + 0.5 (u - 1) + ``` + + When `self_normalized = False` the `0.5 (u - 1)` is omitted. + + The unmodified GAN Csiszar-function is identical to Jensen-Shannon (with + `self_normalized = False`). + + Warning: this function makes non-log-space calculations and may therefore be + numerically unstable for `|logu| >> 0`. + + Args: + logu: Floating-type `Tensor` representing `log(u)` from above. + self_normalized: Python `bool` indicating whether `f'(u=1)=0`. When + `f'(u=1)=0` the implied Csiszar f-Divergence remains non-negative even + when `p, q` are unnormalized measures. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + chi_square_of_u: Floating-type `Tensor` of the Csiszar-function evaluated + at `u = exp(logu)`. + """ + + with ops.name_scope(name, "chi_square", [logu]): + logu = ops.convert_to_tensor(logu, name="logu") + y = nn_ops.softplus(logu) - logu + if self_normalized: + y += 0.5 * math_ops.expm1(logu) + return y + + +def dual_csiszar_function(logu, csiszar_function, name=None): + """Calculates the dual Csiszar-function in log-space. + + A Csiszar-function is a member of, + + ```none + F = { f:R_+ to R : f convex }. + ``` + + The Csiszar-dual is defined as: + + ```none + f^*(u) = u f(1 / u) + ``` + + where `f` is some other Csiszar-function. + + For example, the dual of `kl_reverse` is `kl_forward`, i.e., + + ```none + f(u) = -log(u) + f^*(u) = u f(1 / u) = -u log(1 / u) = u log(u) + ``` + + The dual of the dual is the original function: + + ```none + f^**(u) = {u f(1/u)}^*(u) = u (1/u) f(1/(1/u)) = f(u) + ``` + + Warning: this function makes non-log-space calculations and may therefore be + numerically unstable for `|logu| >> 0`. + + Args: + logu: Floating-type `Tensor` representing `log(u)` from above. + csiszar_function: Python callable representing a Csiszar-function over + log-domain. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + dual_f_of_u: Floating-type `Tensor` of the result of calculating the dual of + `f` at `u = exp(logu)`. + """ + + with ops.name_scope(name, "dual_csiszar_function", [logu]): + return math_ops.exp(logu) * csiszar_function(-logu) + + +def symmetrized_csiszar_function(logu, csiszar_function, name=None): + """Symmetrizes a Csiszar-function in log-space. + + A Csiszar-function is a member of, + + ```none + F = { f:R_+ to R : f convex }. + ``` + + The symmetrized Csiszar-function is defined as: + + ```none + f_g(u) = 0.5 g(u) + 0.5 u g (1 / u) + ``` + + where `g` is some other Csiszar-function. + + We say the function is "symmetrized" because: + + ```none + D_{f_g}[p, q] = D_{f_g}[q, p] + ``` + + for all `p << >> q` (i.e., `support(p) = support(q)`). + + There exists alternatives for symmetrizing a Csiszar-function. For example, + + ```none + f_g(u) = max(f(u), f^*(u)), + ``` + + where `f^*` is the dual Csiszar-function, also implies a symmetric + f-Divergence. + + Example: + + When either of the following functions are symmetrized, we obtain the + Jensen-Shannon Csiszar-function, i.e., + + ```none + g(u) = -log(u) - (1 + u) log((1 + u) / 2) + u - 1 + h(u) = log(4) + 2 u log(u / (1 + u)) + ``` + + implies, + + ```none + f_g(u) = f_h(u) = u log(u) - (1 + u) log((1 + u) / 2) + = jensen_shannon(log(u)). + ``` + + Warning: this function makes non-log-space calculations and may therefore be + numerically unstable for `|logu| >> 0`. + + Args: + logu: Floating-type `Tensor` representing `log(u)` from above. + csiszar_function: Python callable representing a Csiszar-function over + log-domain. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + symmetrized_g_of_u: Floating-type `Tensor` of the result of applying the + symmetrization of `g` evaluated at `u = exp(logu)`. + """ + + with ops.name_scope(name, "symmetrized_csiszar_function", [logu]): + logu = ops.convert_to_tensor(logu, name="logu") + return 0.5 * (csiszar_function(logu) + + dual_csiszar_function(logu, csiszar_function)) + + +def monte_carlo_csiszar_f_divergence( + f, p, q, num_draws, use_reparametrization=True, seed=None, name=None): + """Monte-Carlo approximation of the Csiszar f-Divergence. + + A Csiszar-function is a member of, + + ```none + F = { f:R_+ to R : f convex }. + ``` + + The Csiszar f-Divergence for Csiszar-function f is given by: + + ```none + D_f[p(X), q(X)] := E_{q(X)}[ f( p(X) / q(X) ) ] + ~= m**-1 sum_j^m f( p(x_j) / q(x_j) ), + where x_j ~iid q(x) + ``` + + Tricks: Reparameterization and Score-Gradient + + When q is "reparameterized", i.e., a diffeomorphic transformation of a + parameterless distribution (e.g., + `Normal(Y; m, s) <=> Y = sX + m, X ~ Normal(0,1)`), we can swap gradient and + expectation, i.e., + `nabla Avg{ s_i : i=1...n } = Avg{ nabla s_i : i=1...n }` where `S_n=Avg{s_i}` + and `s_i = f(x_i), x_i ~ q`. + + However, if q is not reparameterized, TensorFlow's gradient will be incorrect + since the chain-rule stops at samples of unreparameterized distributions. In + this circumstance using the Score-Gradient trick results in an unbiased + gradient, i.e., + + ```none + nabla E_q[f(X)] + = nabla int dx q(x) f(x) + = int dx nabla [ q(x) f(x) ] + = int dx q'(x) f(x) + q(x) f'(x) + = int dx q(x) [q'(x) / q(x) f(x) + f'(x) ] + = int dx q(x) nabla [ log(q(x)) stopgrad[f(x)] + f(x) ] + = E_q[ nabla [ log(q(X)) stopgrad[f(X)] + f(X) ] ] + ~= Avg{ log(q(y_i)) stopgrad[f(y_i)] + f(y_i) : y_i = stopgrad[x_i], x_i ~ q} + ``` + + Unless `q.reparameterization_type != distribution.FULLY_REPARAMETERIZED` it is + usually preferable to `use_reparametrization = True`. + + Warning: using `use_reparametrization = False` will mean that the result is + *not* the Csiszar f-Divergence. However its expected gradient *is* the + gradient of the Csiszar f-Divergence. + + Example Application: + + The Csiszar f-Divergence is a useful framework for variational inference. + I.e., observe that, + + ```none + f(p(x)) = f( E_{q(Z | x)}[ p(x, Z) / q(Z | x) ] ) + <= E_{q(Z | x)}[ f( p(x, Z) / q(Z | x) ) ] + := D_f[p(x, Z), q(Z | x)] + ``` + + The inequality follows from the fact that the "perspective" of `f`, i.e., + `(s, t) |-> t f(s / t))`, is convex in `(s, t)` when `s/t in domain(f)` and + `t` is a real. Since the above framework includes the popular Evidence Lower + BOund (ELBO) as a special case, i.e., `f(u) = -log(u)`, we call this framework + "Evidence Divergence Bound Optimization" (EDBO). + + Args: + f: Python callable representing a Csiszar-function in log-space. + p: `tf.Distribution`-like instance; must implement `log_prob(x)`. + q: `tf.Distribution`-like instance; must implement: + `reparameterization_type`, `sample(n)`, and `log_prob(x)`. + num_draws: Integer scalar number of draws used to approximate the + f-Divergence expectation. + use_reparametrization: Python `bool`. When `True` uses the standard + Monte-Carlo average. When `False` uses the score-gradient trick. (See + above for details.) + seed: Python `int` seed for `q.sample`. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + monte_carlo_csiszar_f_divergence: Floating-type `Tensor` Monte Carlo + approximation of the Csiszar f-Divergence. Warning: using + `use_reparametrization = False` will mean that the result is *not* the + Csiszar f-Divergence. However its expected gradient *is* the actual + gradient of the Csiszar f-Divergence. + + Raises: + ValueError: if `q` is not a reparameterized distribution and + `use_reparametrization = True`. A distribution `q` is said to be + "reparameterized" when its samples are generated by transforming the + samples of another distribution which does not depend on the + parameterization of `q`. This property ensures the gradient (with respect + to parameters) is valid. + """ + with ops.name_scope(name, "monte_carlo_csiszar_f_divergence", [num_draws]): + x = q.sample(num_draws, seed=seed) + if use_reparametrization: + # TODO(jvdillon): Consider only raising an exception if the gradient is + # requested. + if q.reparameterization_type != distribution.FULLY_REPARAMETERIZED: + raise ValueError( + "Distribution `q` must be reparameterized, i.e., a diffeomorphic " + "transformation of a parameterless distribution. (Otherwise this " + "function has a biased gradient.)") + return math_ops.reduce_mean(f(p.log_prob(x) - q.log_prob(x)), axis=0) + else: + x = array_ops.stop_gradient(x) + logqx = q.log_prob(x) + fx = f(p.log_prob(x) - logqx) + # Alternatively we could have returned: + # reduce_mean(fx * exp(logqx) / stop_gradient(exp(logqx)), axis=0) + # This is nice because it means the result is exactly the Csiszar + # f-Divergence yet the gradient is unbiased. However its numerically + # unstable since the q is not in log-domain. + return math_ops.reduce_mean(logqx * array_ops.stop_gradient(fx) + fx, + axis=0) diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index 840997223fb..eec2beddc48 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -49,9 +49,9 @@ py_library( srcs_version = "PY2AND3", deps = [ ":gen_bigquery_reader_ops", - "//tensorflow/contrib/util:util_py", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:platform", + "//tensorflow/python:io_ops", + "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD index 6792ebd615c..0dfc5a81d57 100644 --- a/tensorflow/contrib/cluster_resolver/BUILD +++ b/tensorflow/contrib/cluster_resolver/BUILD @@ -29,7 +29,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:framework", + "//tensorflow/python:training", ], ) diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 500b917ac99..7f3be2aedf6 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -84,6 +84,12 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.cc" "${tensorflow_source_dir}/tensorflow/contrib/text/kernels/skip_gram_kernels.cc" "${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/cross_replica_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/infeed_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/outfeed_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/replication_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/tpu_sendrecv_ops.cc" ) list(APPEND tf_core_kernels_srcs ${tf_contrib_kernels_srcs}) endif(tensorflow_BUILD_CONTRIB_KERNELS) @@ -102,10 +108,11 @@ file(GLOB_RECURSE tf_core_kernels_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/kernels/*test*.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/*testutil.h" "${tensorflow_source_dir}/tensorflow/core/kernels/*testutil.cc" + "${tensorflow_source_dir}/tensorflow/core/kernels/*test_utils.h" + "${tensorflow_source_dir}/tensorflow/core/kernels/*test_utils.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/*main.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/*.cu.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/*" - "${tensorflow_source_dir}/tensorflow/core/kernels/remote_fused_graph_execute*.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/remote_fused_graph_rewriter_transform*.cc" ) list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_exclude_srcs}) diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 3c2f89c6c82..f1b01250a44 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -14,6 +14,7 @@ # ============================================================================== set(tf_op_lib_names "array_ops" + "bitwise_ops" "candidate_sampling_ops" "control_flow_ops" "ctc_ops" @@ -66,6 +67,10 @@ file(GLOB_RECURSE tensor_forest_hybrid_srcs "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/hybrid/core/ops/*.cc" ) +file(GLOB_RECURSE tpu_ops_srcs + "${tensorflow_source_dir}/tensorflow/contrib/tpu/ops/*.cc" +) + GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc") @@ -82,6 +87,7 @@ GENERATE_CONTRIB_OP_LIBRARY(seq2seq_beam_search "${tensorflow_source_dir}/tensor GENERATE_CONTRIB_OP_LIBRARY(tensor_forest "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(tensor_forest_hybrid "${tensor_forest_hybrid_srcs}") GENERATE_CONTRIB_OP_LIBRARY(text_skip_gram "${tensorflow_source_dir}/tensorflow/contrib/text/ops/skip_gram_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(tpu "${tpu_ops_srcs}") GENERATE_CONTRIB_OP_LIBRARY(bigquery_reader "${tensorflow_source_dir}/tensorflow/contrib/cloud/ops/bigquery_reader_ops.cc") ######################################################## diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index a969bb03eec..67b0c86f7f2 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -124,7 +124,6 @@ file(GLOB_RECURSE tf_protos_python_srcs RELATIVE ${tensorflow_source_dir} "${tensorflow_source_dir}/tensorflow/core/*.proto" "${tensorflow_source_dir}/tensorflow/python/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto" - "${tensorflow_source_dir}/tensorflow/tensorboard/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/training/*.proto" ) @@ -138,7 +137,6 @@ RELATIVE_PROTOBUF_GENERATE_PYTHON( file(GLOB_RECURSE tf_python_protos_cc_srcs RELATIVE ${tensorflow_source_dir} "${tensorflow_source_dir}/tensorflow/python/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto" - "${tensorflow_source_dir}/tensorflow/tensorboard/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/training/*.proto" ) @@ -232,19 +230,6 @@ add_python_module("tensorflow/python/training") add_python_module("tensorflow/python/user_ops") add_python_module("tensorflow/python/util") add_python_module("tensorflow/python/util/protobuf") -add_python_module("tensorflow/tensorboard") -add_python_module("tensorflow/tensorboard/backend") -add_python_module("tensorflow/tensorboard/backend/event_processing") -add_python_module("tensorflow/tensorboard/plugins") -add_python_module("tensorflow/tensorboard/plugins/audio") -add_python_module("tensorflow/tensorboard/plugins/distributions") -add_python_module("tensorflow/tensorboard/plugins/graphs") -add_python_module("tensorflow/tensorboard/plugins/histograms") -add_python_module("tensorflow/tensorboard/plugins/images") -add_python_module("tensorflow/tensorboard/plugins/projector") -add_python_module("tensorflow/tensorboard/plugins/scalars") -add_python_module("tensorflow/tensorboard/plugins/text") -add_python_module("tensorflow/tensorboard/scripts") add_python_module("tensorflow/contrib") add_python_module("tensorflow/contrib/android") add_python_module("tensorflow/contrib/android/java") @@ -458,6 +443,9 @@ add_python_module("tensorflow/contrib/pi_examples/label_image") add_python_module("tensorflow/contrib/pi_examples/label_image/data") add_python_module("tensorflow/contrib/quantization") add_python_module("tensorflow/contrib/quantization/python") +add_python_module("tensorflow/contrib/remote_fused_graph/pylib") +add_python_module("tensorflow/contrib/remote_fused_graph/pylib/python") +add_python_module("tensorflow/contrib/remote_fused_graph/pylib/python/ops") add_python_module("tensorflow/contrib/rnn") add_python_module("tensorflow/contrib/rnn/kernels") add_python_module("tensorflow/contrib/rnn/ops") @@ -527,6 +515,11 @@ add_python_module("tensorflow/contrib/tfprof" DONTCOPY) # SWIG wrapper not impl #add_python_module("tensorflow/contrib/tfprof/python") #add_python_module("tensorflow/contrib/tfprof/python/tools") #add_python_module("tensorflow/contrib/tfprof/python/tools/tfprof") +add_python_module("tensorflow/contrib/tpu") +add_python_module("tensorflow/contrib/tpu/ops") +add_python_module("tensorflow/contrib/tpu/python") +add_python_module("tensorflow/contrib/tpu/python/ops") +add_python_module("tensorflow/contrib/tpu/python/tpu") add_python_module("tensorflow/contrib/training") add_python_module("tensorflow/contrib/training/python") add_python_module("tensorflow/contrib/training/python/training") @@ -603,6 +596,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) endfunction() GENERATE_PYTHON_OP_LIB("array_ops") +GENERATE_PYTHON_OP_LIB("bitwise_ops") GENERATE_PYTHON_OP_LIB("math_ops") GENERATE_PYTHON_OP_LIB("functional_ops") GENERATE_PYTHON_OP_LIB("candidate_sampling_ops") @@ -619,6 +613,8 @@ GENERATE_PYTHON_OP_LIB("lookup_ops") GENERATE_PYTHON_OP_LIB("nn_ops") GENERATE_PYTHON_OP_LIB("parsing_ops") GENERATE_PYTHON_OP_LIB("random_ops") +GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/remote_fused_graph/pylib/python/ops/gen_remote_fused_graph_ops.py) GENERATE_PYTHON_OP_LIB("resource_variable_ops") GENERATE_PYTHON_OP_LIB("script_ops") GENERATE_PYTHON_OP_LIB("sdca_ops") @@ -911,19 +907,6 @@ add_custom_command(TARGET tf_python_build_pip_package POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/tools/pip_package/MANIFEST.in ${CMAKE_CURRENT_BINARY_DIR}/tf_python/) -# Copy resources for TensorBoard. -file(DOWNLOAD http://mirror.bazel.build/tensorboard/index.html ${DOWNLOAD_LOCATION}/tensorboard/index.html - EXPECTED_HASH SHA256=25554e708552ad8587152f7a444db3f4ca753f9ed72d9f8105203c1d1806d521) -add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD - COMMAND ${CMAKE_COMMAND} -E make_directory - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tensorboard/components/) -add_custom_command(TARGET tf_python_build_pip_package POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${DOWNLOAD_LOCATION}/tensorboard/index.html - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tensorboard/components/) -add_custom_command(TARGET tf_python_build_pip_package POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/tensorboard/TAG - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/tensorboard/) - # Copy datasets for tf.contrib.learn. add_custom_command(TARGET tf_python_build_pip_package POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/datasets/data/boston_house_prices.csv diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 55e9e311f92..d415749ac56 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -143,7 +143,6 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/kernel_tests/*.py" "${tensorflow_source_dir}/tensorflow/python/saved_model/*_test.py" "${tensorflow_source_dir}/tensorflow/python/training/*_test.py" - "${tensorflow_source_dir}/tensorflow/tensorboard/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/factorization/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/integration_test.py" @@ -191,7 +190,6 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/kernel_tests/variable_scope_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/reshape_op_test.py" "${tensorflow_source_dir}/tensorflow/python/training/evaluation_test.py" - "${tensorflow_source_dir}/tensorflow/tensorboard/backend/server_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py" # Depends on gemmlowp -> pthread. # int32/int64 mixup "${tensorflow_source_dir}/tensorflow/python/kernel_tests/functional_ops_test.py" @@ -206,13 +204,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/training/supervisor_test.py" # Flaky I/O error on rename. "${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" # Needs portpicker. "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops - # Broken TensorBoard tests due to different paths in windows - "${tensorflow_source_dir}/tensorflow/tensorboard/backend/application_test.py" - "${tensorflow_source_dir}/tensorflow/tensorboard/lib/python/http_util_test.py" - "${tensorflow_source_dir}/tensorflow/tensorboard/plugins/audio/audio_plugin_test.py" - "${tensorflow_source_dir}/tensorflow/tensorboard/plugins/images/images_plugin_test.py" # Broken tensorboard test due to cmake issues. - "${tensorflow_source_dir}/tensorflow/tensorboard/plugins/debugger/plugin_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # tensor_forest tests (also note that we exclude the hybrid tests for now) "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order. @@ -221,8 +213,6 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/topn_test.py" # Results inaccurate "${tensorflow_source_dir}/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py" # No libcurl support # Newly running on Windows since TensorBoard backend move. Fail on Windows and need debug. - "${tensorflow_source_dir}/tensorflow/tensorboard/backend/event_processing/directory_watcher_test.py" - "${tensorflow_source_dir}/tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows. ) endif() @@ -288,7 +278,6 @@ if (tensorflow_BUILD_CC_TESTS) "${tensorflow_source_dir}/tensorflow/cc/framework/gradients_test.cc" "${tensorflow_source_dir}/tensorflow/core/distributed_runtime/call_options_test.cc" "${tensorflow_source_dir}/tensorflow/core/distributed_runtime/tensor_coding_test.cc" - "${tensorflow_source_dir}/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/graph_transferer_test.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/quantized_matmul_op_for_hexagon_test.cc" diff --git a/tensorflow/contrib/crf/BUILD b/tensorflow/contrib/crf/BUILD index e82d2cf6f8a..7aad4abdb90 100644 --- a/tensorflow/contrib/crf/BUILD +++ b/tensorflow/contrib/crf/BUILD @@ -15,11 +15,12 @@ py_library( srcs = ["__init__.py"] + glob(["python/ops/*.py"]), srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python:array_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:rnn", + "//tensorflow/python:rnn_cell", + "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//third_party/py/numpy", ], diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index b1caac476a2..fc473d3380d 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -87,6 +87,8 @@ cuda_py_test( additional_deps = [ ":cudnn_rnn_py", "//tensorflow/core:protos_all_py", + "//tensorflow/contrib/rnn:rnn_py", + "//tensorflow/python/ops/losses:losses", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py index 08ec3076e49..0e51ab99353 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_test.py @@ -20,8 +20,13 @@ from __future__ import print_function import os import unittest +import numpy as np + from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops +from tensorflow.contrib.rnn.python.ops import lstm_ops from tensorflow.core.protobuf import saver_pb2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.framework.test_util import TensorFlowTestCase @@ -29,10 +34,14 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import googletest from tensorflow.python.platform import test +from tensorflow.python.training import gradient_descent from tensorflow.python.training import saver as saver_lib @@ -69,7 +78,8 @@ class CudnnRNNTest(TensorFlowTestCase): model: a CudnnRNN model. """ params_saveable = cudnn_rnn_ops.RNNParamsSaveable( - model.params_to_canonical, model.canonical_to_params, [params]) + model, model.params_to_canonical, model.canonical_to_params, [params], + "rnn") ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, params_saveable) def _testSaveRestoreVariable(self, rnn_mode): @@ -93,6 +103,218 @@ class CudnnRNNTest(TensorFlowTestCase): params_v_restored = sess.run(params) self.assertAllEqual(params_v, params_v_restored) + def _create_equivalent_canonical_rnn(self, + cudnn_model, + inputs, + use_block_cell, + scope="rnn"): + if cudnn_model.rnn_mode is not "lstm": + raise ValueError("%s is not supported!" % cudnn_model.rnn_mode) + + num_units = cudnn_model.num_units + num_layers = cudnn_model.num_layers + + # To reuse cuDNN-trained models, must set + # forget_bias, clip_cell = 0, False + # In LSTMCell and LSTMBlockCell, forget_bias is added in addition to learned + # bias, whereas cuDNN does not apply the additional bias. + if use_block_cell: + # pylint: disable=g-long-lambda + single_cell = lambda: lstm_ops.LSTMBlockCell(num_units, forget_bias=0, + clip_cell=False) + # pylint: enable=g-long-lambda + else: + single_cell = lambda: rnn_cell_impl.LSTMCell(num_units, forget_bias=0) + cell = rnn_cell_impl.MultiRNNCell( + [single_cell() for _ in range(num_layers)]) + return rnn.dynamic_rnn( + cell, inputs, dtype=dtypes.float32, time_major=True, scope=scope) + + def _build_forward_cudnn_model(self, + rnn_mode, + num_layers, + num_units, + input_data, + is_training=False): + input_data_shape = input_data.get_shape().with_rank(3) + batch_size = input_data_shape[1].value + input_size = input_data_shape[2].value + model = self._CreateModel(rnn_mode, num_layers, num_units, input_size) + + # Set zero init input states + input_h = constant_op.constant( + np.zeros([num_layers, batch_size, num_units]), dtype=dtypes.float32) + has_input_c = (rnn_mode == "lstm") + if has_input_c: + input_c = constant_op.constant( + np.zeros([num_layers, batch_size, num_units]), dtype=dtypes.float32) + + # Set rnn params + params_size_t = model.params_size() + params = variables.Variable( + random_ops.random_uniform([params_size_t]), validate_shape=False) + args = { + "input_data": input_data, + "input_h": input_h, + "params": params, + "is_training": is_training + } + if has_input_c: + args["input_c"] = input_c + # Build cell + output_tuple = model(**args) + + # Create savable objects for params + self._create_params_savable(params, model) + + return output_tuple, model, params + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testCheckpointReusableByCanonicalLSTMCells(self): + configs = [ + { + "num_layers": 1, + "seq_length": 3, + "num_units": 4, + "input_size": 5, + "batch_size": 6, + "rnn_mode": "lstm" + }, + { + "num_layers": 2, + "seq_length": 8, + "num_units": 4, + "input_size": 8, + "batch_size": 16, + "rnn_mode": "lstm" + }, + { + "num_layers": 2, + "seq_length": 3, + "num_units": 4, + "input_size": 5, + "batch_size": 6, + "rnn_mode": "lstm" + }, + { + "num_layers": 1, + "seq_length": 2, + "num_units": 2, + "input_size": 4, + "batch_size": 1, + "rnn_mode": "lstm" + }, + ] + for cfg in configs: + self._testCheckpointReusableByCanonicalLSTMCells( + cfg["num_layers"], + cfg["seq_length"], + cfg["num_units"], + cfg["input_size"], + cfg["batch_size"], + cfg["rnn_mode"], + use_block_cell=False) + self._testCheckpointReusableByCanonicalLSTMCells( + cfg["num_layers"], + cfg["seq_length"], + cfg["num_units"], + cfg["input_size"], + cfg["batch_size"], + cfg["rnn_mode"], + use_block_cell=True) + + def _testCheckpointReusableByCanonicalLSTMCells( + self, num_layers, seq_length, num_units, input_size, batch_size, rnn_mode, + use_block_cell): + np.random.seed(0) + # Train graph + with ops.Graph().as_default(): + random_seed.set_random_seed(299) + input_data = array_ops.placeholder( + dtypes.float32, shape=[seq_length, batch_size, input_size]) + output_tuple, cudnn_model, cudnn_params = self._build_forward_cudnn_model( + rnn_mode, num_layers, num_units, input_data, is_training=True) + target_output = array_ops.placeholder(dtype=dtypes.float32, shape=None) + total_sum = sum(map(math_ops.reduce_sum, output_tuple)) + + loss_op = losses.log_loss(labels=target_output, predictions=total_sum) + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1e-2) + train_op = optimizer.minimize(loss_op) + + saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) + + # Train Cudnn model + with self.test_session( + use_gpu=True, graph=ops.get_default_graph()) as sess: + sess.run(variables.global_variables_initializer()) + # Train 128 steps + num_steps = 128 + for _ in range(num_steps): + inputs = np.random.rand(seq_length, batch_size, + input_size).astype(np.float32) + targets = np.random.rand() + sess.run( + train_op, feed_dict={input_data: inputs, + target_output: targets}) + + save_path = os.path.join(self.get_temp_dir(), + ("cudnn-rnn-%s-test" % rnn_mode)) + save_v = saver.save(sess, save_path) + self.assertEqual(save_path, save_v) + cudnn_params_v = sess.run(cudnn_params) + + # cuDNN inference graph + with ops.Graph().as_default(): + random_seed.set_random_seed(299) + cudnn_inputs = array_ops.placeholder( + dtypes.float32, shape=[seq_length, batch_size, input_size]) + (cudnn_output_tuple, cudnn_model, + cudnn_params) = self._build_forward_cudnn_model( + rnn_mode, num_layers, num_units, cudnn_inputs, is_training=False) + saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) + + inference_input = np.random.rand(seq_length, batch_size, + input_size).astype(np.float32) + with self.test_session( + use_gpu=True, graph=ops.get_default_graph()) as sess: + sess.run(variables.global_variables_initializer()) + saver.restore(sess, save_path) + restored_cudnn_params_v = sess.run(cudnn_params) + self.assertAllEqual(cudnn_params_v, restored_cudnn_params_v) + + # Cudnn inference + (cudnn_output, cudnn_output_h, cudnn_output_c) = sess.run( + cudnn_output_tuple, feed_dict={cudnn_inputs: inference_input}) + + # LSTMBlockCell inference graph + with ops.Graph().as_default(): + random_seed.set_random_seed(299) + cell_inputs = array_ops.placeholder( + dtypes.float32, shape=[seq_length, batch_size, input_size]) + (output, states) = self._create_equivalent_canonical_rnn( + cudnn_model, cell_inputs, use_block_cell) + saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V2) + + with self.test_session( + use_gpu=True, graph=ops.get_default_graph()) as sess: + saver.restore(sess, save_path) + + # BlockCell inference + output_v, states_v = sess.run( + [output, states], feed_dict={cell_inputs: inference_input}) + + # output across timestamps are packed into one tensor. + self.assertAllClose(cudnn_output, output_v, atol=1e-6, rtol=1e-6) + + for i in range(num_layers): + # output_h + self.assertAllClose( + cudnn_output_h[i, :], states_v[i].h, atol=1e-6, rtol=1e-6) + # output_c + self.assertAllClose( + cudnn_output_c[i, :], states_v[i].c, atol=1e-6, rtol=1e-6) + def _testSaveRestoreOutput(self, rnn_mode): num_layers = 2 num_units = 7 @@ -187,9 +409,13 @@ class CudnnRNNTest(TensorFlowTestCase): batch_size, seq_length, dir_count, dropout, expected, tolerance): random_seed.set_random_seed(5678) - model = self._CreateModel(rnn_mode, num_layers, num_units, input_size, - input_mode="auto_select", - dropout=dropout) + model = self._CreateModel( + rnn_mode, + num_layers, + num_units, + input_size, + input_mode="auto_select", + dropout=dropout) has_input_c = (rnn_mode == "lstm") params_size_t = model.params_size() input_data = array_ops.ones([seq_length, batch_size, input_size]) @@ -216,7 +442,7 @@ class CudnnRNNTest(TensorFlowTestCase): if has_input_c: output_c_sum = math_ops.reduce_sum(output_c) total_sum += output_c_sum - with self.test_session(use_gpu=True) as sess: + with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: sess.run(variables.global_variables_initializer()) total_sum_v = sess.run([total_sum]) @@ -310,8 +536,8 @@ class CudnnRNNTest(TensorFlowTestCase): os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True) has_input_c = (rnn_mode == "lstm") random_seed.set_random_seed(1234) - model = self._CreateModel(rnn_mode, num_layers, num_units, input_size, - dropout=dropout) + model = self._CreateModel( + rnn_mode, num_layers, num_units, input_size, dropout=dropout) params_size_t = model.params_size() input_data = variables.Variable( random_ops.random_uniform([seq_length, batch_size, input_size])) @@ -417,6 +643,7 @@ class CudnnRNNTest(TensorFlowTestCase): }, }, ] + ops.reset_default_graph() with ops.Graph().as_default(): for config in test_configs: rnn_mode = config["rnn_mode"] diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index cc0c7b08296..0437467f3fb 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -16,7 +16,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import itertools from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops from tensorflow.contrib.util import loader @@ -46,9 +45,11 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject): """SaveableObject implementation that handles the RNN params variable.""" def __init__(self, + cudnn_rnn, params_to_canonical, canonical_to_params, param_variables, + base_variable_scope=None, name="params_canonical"): """Creates a RNNParamsSaveable object. @@ -75,6 +76,7 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject): tensor 1 and 4 the update gate; tensor 2 and 5 the new memory gate. Args: + cudnn_rnn: cudnn RNN class instance. params_to_canonical: a function to convert params from a specific format for cuDNN or other RNN ops to the canonical format. _CudnnRNN.params_to_canonical() should be provided here. @@ -87,25 +89,42 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject): For cuDNN RNN ops, this is a single merged variable for both weights and biases; for other RNN ops, this might be multiple unmerged or partially merged variables respectively for weights and biases. + base_variable_scope: a string, name of outer variable scope, used as + part of prefix of names of saved variables. name: the name of the RNNParamsSaveable object. """ # There is only a single merged parameter variable for cuDNN when saving. + self._cudnn_rnn = cudnn_rnn weights, biases = params_to_canonical(param_variables[0]) + weights, biases, = self._transform_canonical(weights, biases) + weight_names, biase_names = self._transformed_canonical_names( + weights, biases) self._canonical_to_params = canonical_to_params self._variables = param_variables # We currently don't use slice_spec. It might be useful in a distributed # setting where each parameter server node stores a slice of variable, # instead of having the master pull all slices and then save them. slice_spec = "" + params = weights + biases + param_names = weight_names + biase_names + if base_variable_scope: + param_names = ["%s/%s" % (base_variable_scope, pn) for pn in param_names] specs = [ - saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param.name) - for param in itertools.chain(weights, biases) + saver.BaseSaverBuilder.SaveSpec(param, slice_spec, param_name) + for param, param_name in zip(params, param_names) ] super(RNNParamsSaveable, self).__init__(None, specs, name) def restore(self, restored_tensors, restored_shapes): - weights = restored_tensors[:len(restored_tensors) // 2] - biases = restored_tensors[len(restored_tensors) // 2:] + if (self._cudnn_rnn.direction == "unidirectional" and + self._cudnn_rnn.rnn_mode == "lstm"): + assert len(restored_tensors) % 4 == 0 + weights = restored_tensors[:len(restored_tensors) // 4] + biases = restored_tensors[len(restored_tensors) // 4:] + else: + weights = restored_tensors[:len(restored_tensors) // 2] + biases = restored_tensors[len(restored_tensors) // 2:] + weights, biases = self._untransform_canonical(weights, biases) params = self._canonical_to_params(weights, biases) if not isinstance(params, tuple): params = (params,) @@ -115,6 +134,159 @@ class RNNParamsSaveable(saver.BaseSaverBuilder.SaveableObject): ] return control_flow_ops.group(*assign_ops) + def _switch_inner(self, array, base_idx): + array[base_idx + 1], array[base_idx + 2] = (array[base_idx + 2], + array[base_idx + 1]) + + def _transform_canonical(self, weights, biases): + if (self._cudnn_rnn.direction != "unidirectional" or + self._cudnn_rnn.rnn_mode != "lstm"): + return weights, biases + return self._transform_lstm_canonical(weights, biases) + + def _transformed_canonical_names(self, weights, biases): + """Return canonical names for fused weight and bias tensors.""" + if (self._cudnn_rnn.direction != "unidirectional" or + self._cudnn_rnn.rnn_mode != "lstm"): + assert len(weights) == len(biases) + return ([w.name for w in weights], [b.name for b in biases]) + else: + w_names, b_names = [], [] + assert len(weights) * 3 == len(biases) + num_layers = self._cudnn_rnn.num_layers + # TODO(jamesqin): get rid of multi_rnn_cell when num_layers is 1 + for i in range(num_layers): + # One fused weight tensor each layer. + w_names.append("multi_rnn_cell/cell_%d/lstm_cell/kernel" % i) + # Three fused bias tensors each layer: + # the 1st is for LSTMBlockCell restore; the latter two sum up to the + # 1st, and are used for cuDNN restore. + b_names.append("multi_rnn_cell/cell_%d/lstm_cell/bias" % i) + b_names.extend([ + "multi_rnn_cell/cell_%d/lstm_cell/bias_cudnn_%d" % (i, j) + for j in range(2) + ]) + return w_names, b_names + + def _transform_lstm_canonical(self, weights, biases): + """Create fused lstm canonical params. + + Produce properly-shaped monolithic weight and bias tensors to share between + cuDNN and non-platform specific LSTM cells (w/o peephole). + Args: + weights: a list of Tensors recovered from cuDNN params_to_canonical. + biases: a list of Tensors recovered from cuDNN params_to_canonical. + Returns: + Two lists of tensors, one for weight and bias each. + The weight list contains num_layers tensors and bias one contains 3 * + num_layers tensors. Both original and combined biases since cuDNN biases + are not restorable from the fused version. + """ + transformed_weights, transformed_biases = [], [] + for i in range(self._cudnn_rnn.num_layers): + base_idx = i * 8 + num_units = self._cudnn_rnn.num_units + input_size = self._cudnn_rnn.input_size if i == 0 else num_units + # cuDNN tensor shapes per time_step: + # input.shape: [batch_size, input_size], + # input_weights.shape: [num_units, input_size] (first layer) + # [num_units, num_units] (other layers) + # state_weights.shape: [num_units, num_units] + # biases.shape: [num_units] + # + # General LSTM cells compute gate functions using: + # [x, h_prev] * weights + biases + # Therefore for each layer, they expect + # weight.shape: [input_size + num_units, 4 * num_units] (first_layer) + # [num_units + num_units, 4 * num_units] (other layers) + # bias.shape: [4 * num_units] + + # Stitch weights together in this layer. + stitched_w = [] + for j in range(4): + stitched_w.append( + array_ops.concat( + [ + array_ops.reshape(weights[base_idx + j], + [num_units, input_size]), + array_ops.reshape(weights[base_idx + j + 4], + [num_units, num_units]) + ], + axis=1)) + # cuDNN weights are in ifco order, convert to icfo order. + self._switch_inner(stitched_w, 0) + transformed_weights.append( + array_ops.transpose(array_ops.concat(stitched_w, axis=0))) + + # Stitch biases together in this layer. + # Convert to icfo order. + self._switch_inner(biases, base_idx) + self._switch_inner(biases, base_idx + 4) + # The bias for layer input. + b_in = array_ops.concat(biases[base_idx:base_idx + 4], axis=0) + # The bias for recurrent input. + b_rec = array_ops.concat(biases[base_idx + 4:base_idx + 8], axis=0) + + transformed_biases.extend([b_in + b_rec, b_in, b_rec]) + return transformed_weights, transformed_biases + + def _untransform_canonical(self, transformed_weights, transformed_biases): + if (self._cudnn_rnn.direction != "unidirectional" or + self._cudnn_rnn.rnn_mode != "lstm"): + return transformed_weights, transformed_biases + return self._untransform_lstm_canonical(transformed_weights, + transformed_biases) + + def _untransform_lstm_canonical(self, transformed_weights, + transformed_biases): + """The reverse procedure of _transform_lstm_canonical(). + + Args: + transformed_weights: a list of tensors, one for each layer. + transformed_biases: a list of tensors , 3 for each layer: the 2nd for + layer input, the 3rd for recurrent input, the 1st is the sum of the + latter two. + Returns: + Two lists of tensors for weights and biases respectively. + There are 8 tensors per weight and per bias for each layer: + tensor 0-3 are applied to the input from the previous layer; + tensor 4-7 to the recurrent input. Tensor 0 and 4 are for the input gate; + tensor 1 and 5 the forget gate; tensor 2 and 6 the new memory gate; + tensor 3 and 7 the output gate. + """ + weights, biases = [], [] + assert 3 * len(transformed_weights) == len(transformed_biases) + for i in range(len(transformed_weights)): + num_units = self._cudnn_rnn.num_units + input_size = self._cudnn_rnn.input_size if i == 0 else num_units + # weights applied on layer inputs. + wi = array_ops.slice(transformed_weights[i], [0, 0], + [input_size, 4 * num_units]) + # weights applied on recurrent inputs. + wr = array_ops.slice(transformed_weights[i], [input_size, 0], + [num_units, 4 * num_units]) + wi_list = array_ops.split(wi, 4, axis=1) + wr_list = array_ops.split(wr, 4, axis=1) + + for j in range(len(wi_list)): + wi_list[j] = array_ops.reshape(array_ops.transpose(wi_list[j]), [-1]) + wr_list[j] = array_ops.reshape(array_ops.transpose(wr_list[j]), [-1]) + # canonical weights are in icfo order, convert to ifco order for cuDNN. + self._switch_inner(wi_list, 0) + self._switch_inner(wr_list, 0) + weights.extend(wi_list) + weights.extend(wr_list) + + base_idx = 3 * i + bi_list = array_ops.split(transformed_biases[base_idx + 1], 4, axis=0) + br_list = array_ops.split(transformed_biases[base_idx + 2], 4, axis=0) + # canonical weights are in icfo order, convert to ifco order for cuDNN. + self._switch_inner(bi_list, 0) + self._switch_inner(br_list, 0) + biases.extend(bi_list) + biases.extend(br_list) + return weights, biases + _cudnn_rnn_common_doc_string = """ Cudnn RNN has an opaque parameter buffer that can be used for inference and @@ -199,6 +371,26 @@ class _CudnnRNN(object): if self._seed is None and self._seed2 is None: self._seed, self._seed2 = 0, 0 + @property + def input_size(self): + return self._input_size + + @property + def num_units(self): + return self._num_units + + @property + def num_layers(self): + return self._num_layers + + @property + def rnn_mode(self): + return self._rnn_mode + + @property + def direction(self): + return self._direction + def params_size(self): """Calculates the size of the opaque parameter buffer needed for this model. @@ -222,9 +414,12 @@ class _CudnnRNN(object): """Runs the forward step for the RNN model. Args: - input_data: the input sequence to the RNN model. - input_h: the initial hidden state for h. + input_data: the input sequence to the RNN model. A Tensor of shape [?, + batch_size, input_size]. + input_h: the initial hidden state for h. A Tensor of shape [num_layers, + batch_size, num_units]. input_c: the initial hidden state for c. This is only relevant for LSTM. + A Tensor of the same shape as input_h. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference. @@ -308,7 +503,7 @@ class CudnnLSTM(_CudnnRNN): num_layers, num_units, input_size, - input_mode="auto_select", + input_mode="linear_input", direction="unidirectional", dropout=0., seed=0): @@ -344,9 +539,12 @@ class CudnnLSTM(_CudnnRNN): """Runs the forward step for the Cudnn LSTM model. Args: - input_data: the input sequence to the LSTM model. - input_h: the initial hidden state for h. - input_c: the initial hidden state for c. + input_data: the input sequence to the LSTM model. A Tensor of shape [?, + batch_size, input_size]. + input_h: the initial hidden state for h. A Tensor of shape [num_layers, + batch_size, num_units]. + input_c: the initial hidden state for c. A Tensor of the same shape as + input_h. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference. @@ -368,7 +566,7 @@ class _CudnnRNNNoInputC(_CudnnRNN): num_layers, num_units, input_size, - input_mode="auto_select", + input_mode="linear_input", direction="unidirectional", dropout=0., seed=0): @@ -390,6 +588,7 @@ class _CudnnRNNNoInputC(_CudnnRNN): dropout: whether to enable dropout. With it is 0, dropout is disabled. seed: the seed used for initializing dropout. """ + super(_CudnnRNNNoInputC, self).__init__( self._rnn_mode, num_layers, @@ -404,8 +603,10 @@ class _CudnnRNNNoInputC(_CudnnRNN): """Runs the forward step for the Cudnn LSTM model. Args: - input_data: the input sequence to the LSTM model. - input_h: the initial hidden state for h. + input_data: the input sequence to the RNN model. A Tensor of shape [?, + batch_size, input_size]. + input_h: the initial hidden state for h. A Tensor of shape [num_layers, + batch_size, num_units]. params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference. diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index ab4d80c3275..9909ea41c93 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -99,6 +99,21 @@ py_test( ], ) +py_test( + name = "list_files_dataset_op_test", + size = "small", + srcs = ["list_files_dataset_op_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework", + "//tensorflow/python:platform_test", + ], +) + py_test( name = "map_dataset_op_test", size = "small", diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py index 3ea783ad899..19be94e1742 100644 --- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py @@ -72,6 +72,17 @@ class FilterDatasetTest(test.TestCase): # Test an empty dataset. do_test(0, 1) + def testFilterRange(self): + dataset = dataset_ops.Dataset.range(100).filter( + lambda x: math_ops.not_equal(math_ops.mod(x, 3), 2)) + iterator = dataset.make_one_shot_iterator() + get_next = iterator.get_next() + + with self.test_session() as sess: + self.assertEqual(0, sess.run(get_next)) + self.assertEqual(1, sess.run(get_next)) + self.assertEqual(3, sess.run(get_next)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index d6dd134a5b9..b5b115dd705 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -20,11 +20,13 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -128,6 +130,71 @@ class IteratorTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testOneShotIteratorNonBlocking(self): + dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + # Create a session with a single thread to ensure that the + # one-shot iterator initializer does not deadlock. + config = config_pb2.ConfigProto(inter_op_parallelism_threads=1, + use_per_session_threads=True) + with session.Session(config=config) as sess: + self.assertAllEqual([1, 4, 9], sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + # Test with multiple threads invoking the one-shot iterator concurrently. + with session.Session(config=config) as sess: + results = [] + def consumer_thread(): + try: + results.append(sess.run(next_element)) + except errors.OutOfRangeError: + results.append(None) + + num_threads = 8 + threads = [ + self.checkedThread(consumer_thread) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(num_threads, len(results)) + self.assertEqual(num_threads - 1, + len([None for r in results if r is None])) + self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None]) + + def testOneShotIteratorInitializerFails(self): + # Define a dataset whose initialization will always fail. + dataset = dataset_ops.Dataset.from_tensors( + array_ops.check_numerics( + constant_op.constant(1.0) / constant_op.constant(0.0), "oops")) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): + sess.run(next_element) + + # Test that subsequent attempts to use the iterator also fail. + with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): + sess.run(next_element) + + with self.test_session() as sess: + def consumer_thread(): + with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): + sess.run(next_element) + + num_threads = 8 + threads = [ + self.checkedThread(consumer_thread) for _ in range(num_threads)] + for t in threads: + t.start() + for t in threads: + t.join() + def testSimpleSharedResource(self): components = ( np.array(1, dtype=np.int64), diff --git a/tensorflow/contrib/data/python/kernel_tests/list_files_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/list_files_dataset_op_test.py new file mode 100644 index 00000000000..27298de65f9 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/list_files_dataset_op_test.py @@ -0,0 +1,159 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from os import path +import shutil +import tempfile + +from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class ListFilesDatasetOpTest(test.TestCase): + + def setUp(self): + self.tmp_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.tmp_dir, ignore_errors=True) + + def _touchTempFiles(self, filenames): + for filename in filenames: + open(path.join(self.tmp_dir, filename), 'a').close() + + def testEmptyDirectory(self): + dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*')) + with self.test_session() as sess: + itr = dataset.make_one_shot_iterator() + with self.assertRaises(errors.OutOfRangeError): + sess.run(itr.get_next()) + + def testSimpleDirectory(self): + filenames = ['a', 'b', 'c'] + self._touchTempFiles(filenames) + + dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*')) + with self.test_session() as sess: + itr = dataset.make_one_shot_iterator() + + full_filenames = [] + produced_filenames = [] + for filename in filenames: + full_filenames.append( + compat.as_bytes(path.join(self.tmp_dir, filename))) + produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) + self.assertItemsEqual(full_filenames, produced_filenames) + with self.assertRaises(errors.OutOfRangeError): + sess.run(itr.get_next()) + + def testEmptyDirectoryInitializer(self): + filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) + dataset = dataset_ops.Dataset.list_files(filename_placeholder) + + with self.test_session() as sess: + itr = dataset.make_initializable_iterator() + sess.run( + itr.initializer, + feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')}) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(itr.get_next()) + + def testSimpleDirectoryInitializer(self): + filenames = ['a', 'b', 'c'] + self._touchTempFiles(filenames) + + filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) + dataset = dataset_ops.Dataset.list_files(filename_placeholder) + + with self.test_session() as sess: + itr = dataset.make_initializable_iterator() + sess.run( + itr.initializer, + feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')}) + + full_filenames = [] + produced_filenames = [] + for filename in filenames: + full_filenames.append( + compat.as_bytes(path.join(self.tmp_dir, filename))) + produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) + + self.assertItemsEqual(full_filenames, produced_filenames) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(itr.get_next()) + + def testFileSuffixes(self): + filenames = ['a.txt', 'b.py', 'c.py', 'd.pyc'] + self._touchTempFiles(filenames) + + filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) + dataset = dataset_ops.Dataset.list_files(filename_placeholder) + + with self.test_session() as sess: + itr = dataset.make_initializable_iterator() + sess.run( + itr.initializer, + feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py')}) + + full_filenames = [] + produced_filenames = [] + for filename in filenames[1:-1]: + full_filenames.append( + compat.as_bytes(path.join(self.tmp_dir, filename))) + produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) + self.assertItemsEqual(full_filenames, produced_filenames) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(itr.get_next()) + + def testFileMiddles(self): + filenames = ['a.txt', 'b.py', 'c.pyc'] + self._touchTempFiles(filenames) + + filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) + dataset = dataset_ops.Dataset.list_files(filename_placeholder) + + with self.test_session() as sess: + itr = dataset.make_initializable_iterator() + sess.run( + itr.initializer, + feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py*')}) + + full_filenames = [] + produced_filenames = [] + for filename in filenames[1:]: + full_filenames.append( + compat.as_bytes(path.join(self.tmp_dir, filename))) + produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) + + self.assertItemsEqual(full_filenames, produced_filenames) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(itr.get_next()) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index 89410bf8447..29f1209a58a 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import gen_io_ops from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops @@ -600,6 +601,29 @@ class Dataset(object): dataset = dataset.batch(batch_size) return dataset + @staticmethod + def list_files(file_pattern): + """A dataset of all files matching a pattern. + + Example: + If we had the following files on our filesystem: + - /path/to/dir/a.txt + - /path/to/dir/b.py + - /path/to/dir/c.py + If we pass "/path/to/dir/*.py" as the directory, the dataset would + produce: + - /path/to/dir/b.py + - /path/to/dir/c.py + + Args: + file_pattern: A string or scalar string `tf.Tensor`, representing + the filename pattern that will be matched. + + Returns: + A `Dataset` of strings corresponding to file names. + """ + return Dataset.from_tensor_slices(gen_io_ops.matching_files(file_pattern)) + def repeat(self, count=None): """Repeats this dataset `count` times. diff --git a/tensorflow/contrib/decision_trees/BUILD b/tensorflow/contrib/decision_trees/BUILD deleted file mode 100644 index 4045b92f10d..00000000000 --- a/tensorflow/contrib/decision_trees/BUILD +++ /dev/null @@ -1,19 +0,0 @@ -# Files common to decision-tree algorithms. -package(default_visibility = [ - "//visibility:public", -]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD index 86174c5865f..87c80740a8f 100644 --- a/tensorflow/contrib/decision_trees/proto/BUILD +++ b/tensorflow/contrib/decision_trees/proto/BUILD @@ -9,12 +9,21 @@ exports_files([ load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + ), + visibility = ["//tensorflow:__subpackages__"], +) + tf_proto_library( name = "generic_tree_model", srcs = ["generic_tree_model.proto"], cc_api_version = 2, go_api_version = 2, java_api_version = 2, + visibility = ["//visibility:public"], ) tf_proto_library( @@ -23,4 +32,5 @@ tf_proto_library( cc_api_version = 2, go_api_version = 2, protodeps = [":generic_tree_model"], + visibility = ["//visibility:public"], ) diff --git a/tensorflow/contrib/ffmpeg/BUILD b/tensorflow/contrib/ffmpeg/BUILD index 31a286939b6..e495ab48803 100644 --- a/tensorflow/contrib/ffmpeg/BUILD +++ b/tensorflow/contrib/ffmpeg/BUILD @@ -124,10 +124,9 @@ py_library( ":decode_audio_op_py", ":encode_audio_op_py", "//tensorflow/contrib/util:util_py", - "//tensorflow/python:errors", - "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform", + "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index 61fe729fd75..a953c04c1a9 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -178,7 +178,6 @@ py_test( deps = [ ":framework_py", "//tensorflow/python:client_testlib", - "//third_party/py/numpy", ], ) @@ -198,7 +197,6 @@ py_test( "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", - "//third_party/py/numpy", ], ) @@ -211,7 +209,6 @@ py_test( ":framework_py", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", - "//third_party/py/numpy", ], ) @@ -223,9 +220,9 @@ py_test( deps = [ ":framework_py", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:tensor_array_ops", "//tensorflow/python:variables", "//third_party/py/numpy", @@ -241,7 +238,6 @@ py_test( ":framework_py", "//tensorflow/python:client_testlib", "//tensorflow/python:platform", - "//third_party/py/numpy", ], ) @@ -255,8 +251,8 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", - "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:variables", "//third_party/py/numpy", ], @@ -308,6 +304,7 @@ py_test( tags = ["no_pip"], deps = [ ":framework_py", + ":gen_checkpoint_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -316,6 +313,7 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:partitioned_variables", + "//tensorflow/python:platform", "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index f02a7c63606..bf709f921da 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -525,7 +525,7 @@ def get_variable_full_name(var): # TODO(sguada): Update docs in slim/g3doc/index.md to describe # the new feature where the var_list dictionary can have values that # are each a list of Variables. -def assign_from_checkpoint(model_path, var_list): +def assign_from_checkpoint(model_path, var_list, ignore_missing_vars=False): """Creates an operation to assign specific variables from a checkpoint. Args: @@ -538,13 +538,15 @@ def assign_from_checkpoint(model_path, var_list): name in the checkpoint must be the full variable, not the name of the partitioned variable, eg. "my_var" rather than "my_var/part_4". If empty, returns no_op(), {}. + ignore_missing_vars: Boolean, if True ignore variables missing in the + checkpoint with a warning instead of failing. Returns: the restore_op and the feed_dict that need to be run to restore var_list. Raises: - ValueError: If the checkpoint specified at `model_path` is missing one of - the variables in `var_list`. + ValueError: If `ignore_missing_vars` is False and the checkpoint specified + at `model_path` is missing one of the variables in `var_list`. """ # Normalize var_list into a dictionary mapping names in the # checkpoint to the list of variables to initialize from that @@ -572,8 +574,12 @@ def assign_from_checkpoint(model_path, var_list): assign_ops = [] for ckpt_name in grouped_vars: if not reader.has_tensor(ckpt_name): - raise ValueError( - 'Checkpoint is missing variable [%s]' % ckpt_name) + log_str = 'Checkpoint is missing variable [%s]' % ckpt_name + if ignore_missing_vars: + logging.warning(log_str) + continue + else: + raise ValueError(log_str) ckpt_value = reader.get_tensor(ckpt_name) for var in grouped_vars[ckpt_name]: diff --git a/tensorflow/contrib/graph_editor/BUILD b/tensorflow/contrib/graph_editor/BUILD index d570a6d702d..bdb58fb5434 100644 --- a/tensorflow/contrib/graph_editor/BUILD +++ b/tensorflow/contrib/graph_editor/BUILD @@ -25,6 +25,7 @@ py_library( "//tensorflow/python:array_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform", + "//tensorflow/python:util", "@six_archive//:six", ], ) @@ -54,7 +55,11 @@ py_library( name = "match", srcs = ["tests/match.py"], srcs_version = "PY2AND3", - deps = [":graph_editor_py"], + deps = [ + ":graph_editor_py", + "//tensorflow/python:framework_ops", + "@six_archive//:six", + ], ) py_test( @@ -66,9 +71,7 @@ py_test( ":graph_editor_py", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", ], ) @@ -81,9 +84,7 @@ py_test( ":graph_editor_py", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", ], ) @@ -93,13 +94,10 @@ py_test( srcs = ["tests/match_test.py"], srcs_version = "PY2AND3", deps = [ - ":graph_editor_py", ":match", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", ], ) @@ -112,9 +110,7 @@ py_test( ":graph_editor_py", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", ], ) @@ -128,9 +124,7 @@ py_test( ":match", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", ], ) @@ -144,9 +138,7 @@ py_test( ":match", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", ], ) @@ -163,9 +155,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", "//tensorflow/python:variables", "//third_party/py/numpy", ], diff --git a/tensorflow/contrib/grid_rnn/BUILD b/tensorflow/contrib/grid_rnn/BUILD index 73473becf9b..7fbb9f024c5 100644 --- a/tensorflow/contrib/grid_rnn/BUILD +++ b/tensorflow/contrib/grid_rnn/BUILD @@ -20,6 +20,7 @@ py_library( "//tensorflow/python:array_ops", "//tensorflow/python:math_ops", "//tensorflow/python:nn", + "//tensorflow/python:platform", "//tensorflow/python:variable_scope", ], ) diff --git a/tensorflow/contrib/hvx/hexagon_controller/src_impl/graph_functions_wrapper.c b/tensorflow/contrib/hvx/hexagon_controller/src_impl/graph_functions_wrapper.c index 7c821585224..f6a38fe8a94 100644 --- a/tensorflow/contrib/hvx/hexagon_controller/src_impl/graph_functions_wrapper.c +++ b/tensorflow/contrib/hvx/hexagon_controller/src_impl/graph_functions_wrapper.c @@ -52,21 +52,21 @@ static enum InceptionVersion s_inception_version = INCEPTION_V3; ///////////////////////////////////////////////// // file local functions -static const char *ConvertGraphInfoIdToName(unsigned int id) { +static const char* ConvertGraphInfoIdToName(unsigned int id) { // TODO(satok): implement return "?"; } -static const char *ConvertGraphInfoIdToOpName(unsigned int id) { +static const char* ConvertGraphInfoIdToOpName(unsigned int id) { // TODO(satok): implement return "?"; } ///////////////////////////////////////////////// // file local utilities -static uint32_t FindMaxIdxWithExcludeList( - const float *data, uint32_t entries, const int exclude_size, - const int* exclude_idx) { +static uint32_t FindMaxIdxWithExcludeList(const float* data, uint32_t entries, + const int exclude_size, + const int* exclude_idx) { int i; float maxval = data[0]; int maxidx = 0; @@ -93,13 +93,16 @@ static uint32_t FindMaxIdx(const float* data, uint32_t entries) { return FindMaxIdxWithExcludeList(data, entries, 0, NULL); } -void hexagon_controller_PrintMaxNIdx(const float *data, const uint32_t entries, - const int n, int* out_ranking) { +void hexagon_controller_PrintMaxNIdx(const float* data, const uint32_t entries, + const int n, int* out_ranking) { if (DUMP_OUTPUT) { for (int i = 0; i < entries; ++i) { TFMLOGD("%d: val = %f", i, data[i]); } } + if (n >= entries) { + TFMLOGD("Too many N %d >= %d", n, entries); + } for (int i = 0; i < n; ++i) { out_ranking[i] = INT_MAX; } @@ -120,9 +123,9 @@ static inline unsigned long long int GetCounter(hexagon_nn_perfinfo s) { return ret; } -static int CompareCycle(const void *va, const void *vb) { - const hexagon_nn_perfinfo *a = va; - const hexagon_nn_perfinfo *b = vb; +static int CompareCycle(const void* va, const void* vb) { + const hexagon_nn_perfinfo* a = va; + const hexagon_nn_perfinfo* b = vb; unsigned long long int acount = GetCounter(*a); unsigned long long int bcount = GetCounter(*b); if (acount < bcount) { @@ -139,8 +142,6 @@ static int CompareCycle(const void *va, const void *vb) { uint32_t hexagon_controller_InstantiateGraph() { const uint32_t nn_id = hexagon_nn_init(); - // set debug level to 99 for now - //hexagon_nn_set_debug_level(nn_id, 99); // TODO(satok): make this as argument hexagon_nn_set_debug_level(nn_id, 0); return nn_id; @@ -167,7 +168,7 @@ bool hexagon_controller_ConstructGraph(uint32_t nn_id) { int err; if ((err = hexagon_nn_prepare(nn_id)) != 0) { TFMLOGE("Prepare failed! returned 0x%x\n", err); - hexagon_controller_PrintLog(nn_id); + DumpNNId(nn_id); return false; } else { TFMLOGD("Prepare success!\n"); @@ -175,65 +176,80 @@ bool hexagon_controller_ConstructGraph(uint32_t nn_id) { } } -uint32_t hexagon_controller_SetupGraph(int version) { +uint32_t hexagon_controller_SetupGraph(int version) { const uint32_t nn_id = hexagon_controller_InstantiateGraph(); hexagon_controller_InitGraph(version, nn_id); hexagon_controller_ConstructGraph(nn_id); return nn_id; } -bool hexagon_controller_ExecuteGraph( - const uint32_t nn_id, - const uint32_t batches, - const uint32_t height, - const uint32_t width, - const uint32_t depth, - uint8_t* int_data, - const uint32_t int_data_size, - uint32_t* out_batches, - uint32_t* out_height, - uint32_t* out_width, - uint32_t* out_depth, - uint8_t* out_vals, - const uint32_t output_val_byte_size, - uint32_t* out_data_byte_size) { - int err; +bool hexagon_controller_ExecuteGraphWithMultipleInOut( + const uint32_t nn_id, const int input_count, hexagon_nn_tensordef* inputs, + const int output_count, hexagon_nn_tensordef* outputs) { if (DBG_EXECUTION) { - TFMLOGD("Preparing to execute..."); - TFMLOGD("Input: %d, %d, %d, %d, %d, %d", - batches, height, width, depth, int_data[0], int_data_size); - TFMLOGD("Output: %d, %p", output_val_byte_size, out_vals); + TFMLOGD("Preparing to execute... in = %d, out = %d", input_count, + output_count); LogDHexagon("Execute graph!"); } - - if ((err = hexagon_nn_execute(nn_id, - batches, - height, - width, - depth, - int_data, - int_data_size, - out_batches, - out_height, - out_width, - out_depth, - out_vals, - output_val_byte_size, - out_data_byte_size)) != 0) { + + const int err = + hexagon_nn_execute_new(nn_id, inputs, input_count, outputs, output_count); + if (err != 0) { if (DBG_EXECUTION) { LogDHexagon("Execution failed!"); - TFMLOGE("execute got err: %d\n",err); + TFMLOGE("execute got err: %d\n", err); + DumpNNId(nn_id); } return false; } else { if (DBG_EXECUTION) { LogDHexagon("Execution succeeded!"); - TFMLOGD("%d x %d x %d x %d, byte size = %d\n", - *out_batches, - *out_height, - *out_width, - *out_depth, - *out_data_byte_size); + } + return true; + } +} + +bool hexagon_controller_ExecuteGraph( + const uint32_t nn_id, const uint32_t batches, const uint32_t height, + const uint32_t width, const uint32_t depth, uint8_t* int_data, + const uint32_t int_data_size, uint32_t* out_batches, uint32_t* out_height, + uint32_t* out_width, uint32_t* out_depth, uint8_t* out_vals, + const uint32_t output_val_byte_size, uint32_t* out_data_byte_size) { + if (DBG_EXECUTION) { + TFMLOGD("Preparing to execute..."); + TFMLOGD("Input: %d, %d, %d, %d, %d, %d", batches, height, width, depth, + int_data[0], int_data_size); + TFMLOGD("Output: %d, %p", output_val_byte_size, out_vals); + LogDHexagon("Execute graph!"); + } + + hexagon_nn_tensordef input; + hexagon_nn_tensordef output; + + input.batches = batches; + input.height = height; + input.width = width; + input.depth = depth; + input.data = int_data; + input.dataLen = int_data_size; + + output.data = out_vals; + output.dataLen = output_val_byte_size; + + if (!hexagon_controller_ExecuteGraphWithMultipleInOut(nn_id, 1, &input, 1, + &output)) { + return false; + } else { + *out_batches = output.batches; + *out_height = output.height; + *out_width = output.width; + *out_depth = output.depth; + *out_data_byte_size = output.dataLen; + + if (DBG_EXECUTION) { + LogDHexagon("Execution succeeded!"); + TFMLOGD("%d x %d x %d x %d, byte size = %d\n", *out_batches, *out_height, + *out_width, *out_depth, *out_data_byte_size); } return true; } @@ -246,27 +262,21 @@ bool hexagon_controller_ExecuteInceptionDummyData(uint32_t nn_id) { const bool success = hexagon_controller_ExecuteGraph( nn_id, INCEPTION_PARAM_BATCHES, INCEPTION_PARAM_HEIGHT_V3, INCEPTION_PARAM_WIDTH_V3, INCEPTION_PARAM_DEPTH, - (uint8_t *)inception_dummy_int_data_299x299, + (uint8_t*)inception_dummy_int_data_299x299, INCEPTION_PARAM_HEIGHT_V3 * INCEPTION_PARAM_WIDTH_V3 * - INCEPTION_PARAM_DEPTH, + INCEPTION_PARAM_DEPTH, &out_batches, &out_height, &out_width, &out_depth, - (uint8_t *)s_output_values, sizeof(s_output_values), - &out_data_size); + (uint8_t*)s_output_values, sizeof(s_output_values), &out_data_size); if (success) { int out_ranking[OUT_RANKING_SIZE]; hexagon_controller_PrintMaxNIdx( - s_output_values, - out_batches * out_height * out_width * out_depth, + s_output_values, out_batches * out_height * out_width * out_depth, OUT_RANKING_SIZE, out_ranking); - TFMLOGD("%d x %d x %d x %d, size = %d\n", - out_batches, - out_height, - out_width, - out_depth, - out_data_size); - TFMLOGD("max idx: %d\n", FindMaxIdx( - s_output_values, - out_batches * out_height * out_width * out_depth)); + TFMLOGD("%d x %d x %d x %d, size = %d\n", out_batches, out_height, + out_width, out_depth, out_data_size); + TFMLOGD("max idx: %d\n", + FindMaxIdx(s_output_values, + out_batches * out_height * out_width * out_depth)); if (out_ranking[0] == 169 && out_ranking[1] == 7) { return true; } else { @@ -290,25 +300,22 @@ void hexagon_controller_DumpPerf(uint32_t nn_id) { TFMLOGE("perf info failure"); return; } - TFMLOGD("Total %d nodes.",n_nodes); - qsort(info,n_nodes,sizeof(info[0]), CompareCycle); + TFMLOGD("Total %d nodes.", n_nodes); + qsort(info, n_nodes, sizeof(info[0]), CompareCycle); for (i = 0; i < n_nodes; i++) { total_cycles += GetCounter(info[i]); } - TFMLOGD("Total %lld cycles.",total_cycles); + TFMLOGD("Total %lld cycles.", total_cycles); for (i = 0; i < n_nodes; i++) { counter = GetCounter(info[i]); cum_cycles += counter; - TFMLOGD("node,0x%x,%s,%s,executions,%d,cycles,%lld,%f %%," - "cum_cycles,%lld,%f %%\n", - info[i].node_id, - ConvertGraphInfoIdToName(info[i].node_id), - ConvertGraphInfoIdToOpName(info[i].node_id), - info[i].executions, - counter, - 100*((double)counter)/total_cycles, - cum_cycles, - 100*((double)cum_cycles)/total_cycles); + TFMLOGD( + "node,0x%x,%s,%s,executions,%d,cycles,%lld,%f %%," + "cum_cycles,%lld,%f %%\n", + info[i].node_id, ConvertGraphInfoIdToName(info[i].node_id), + ConvertGraphInfoIdToOpName(info[i].node_id), info[i].executions, + counter, 100 * ((double)counter) / total_cycles, cum_cycles, + 100 * ((double)cum_cycles) / total_cycles); } #ifdef ENABLE_HVX_FULL_DEBUG DumpAllPerf(nn_id); @@ -329,7 +336,7 @@ void hexagon_controller_DumpNodeName(uint32_t nn_id) { TFMLOGD("perf info failure"); return; } - TFMLOGD("Total %d nodes.",node_count); + TFMLOGD("Total %d nodes.", node_count); qsort(info, node_count, sizeof(info[0]), CompareCycle); for (i = 0; i < node_count; i++) { total_cycles += GetCounter(info[i]); @@ -338,19 +345,14 @@ void hexagon_controller_DumpNodeName(uint32_t nn_id) { for (i = 0; i < node_count; i++) { counter = GetCounter(info[i]); cum_cycles += counter; - TFMLOGD("node,0x%x,%s,%s,executions,%d,cycles,%lld,%f %%," - "cum_cycles,%lld,%f %%", - info[i].node_id, - ConvertGraphInfoIdToName(info[i].node_id), - ConvertGraphInfoIdToOpName(info[i].node_id), - info[i].executions, - counter, - 100*((double)counter)/total_cycles, - cum_cycles, - 100*((double)cum_cycles)/total_cycles); + TFMLOGD( + "node,0x%x,%s,%s,executions,%d,cycles,%lld,%f %%," + "cum_cycles,%lld,%f %%", + info[i].node_id, ConvertGraphInfoIdToName(info[i].node_id), + ConvertGraphInfoIdToOpName(info[i].node_id), info[i].executions, + counter, 100 * ((double)counter) / total_cycles, cum_cycles, + 100 * ((double)cum_cycles) / total_cycles); } } -void hexagon_controller_Teardown(uint32_t nn_id) { - hexagon_nn_teardown(nn_id); -} +void hexagon_controller_Teardown(uint32_t nn_id) { hexagon_nn_teardown(nn_id); } diff --git a/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c b/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c index 31caebf8728..6a5d982dc85 100644 --- a/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c +++ b/tensorflow/contrib/hvx/hexagon_controller/src_impl/hexagon_controller.c @@ -24,11 +24,13 @@ limitations under the License. #include "adspmsgd.h" #include "dspCV.h" -#include "rpcmem.h" // helper API's for shared buffer allocation +#include "node_data_float.h" +#include "rpcmem.h" // helper API's for shared buffer allocation #include "soc_interface.h" #include "tfm_log.h" -// if false, use int data as input. This is only for acceleration purpose +// if false, use int data as input. This is only for acceleration purpose. +// Also you may need to change android.min. static const bool USE_FLOAT_DATA = true; // if true, show id for each node @@ -43,27 +45,96 @@ extern uint8_t inception_dummy_int_data_224x224[]; extern uint8_t inception_dummy_int_data_299x299[]; extern float inception_dummy_float_data_299x299[]; -#define HEXAGON_CONTROLLER_VERSION 92 +#define HEXAGON_CONTROLLER_VERSION 101 // allocate print bufsize in advance @MB #define PRINT_BUFSIZE (2 * 1024 * 1024) static unsigned char s_print_buf[PRINT_BUFSIZE]; -// input node data buffer size -// x2 1024 * 1024 * 2 > 299 * 299 * 3 * 4 > 1024 * 1024 -static const int INPUT_NODE_DATA_BUFFER_SIZE = 1024 * 1024 * 2; -// output node data buffer size -// (1008 is enough for inception) -static const int OUTPUT_NODE_DATA_BUFFER_SIZE = 300 * 300 * 3 * 4; +#define MAX_INPUTS 10 +#define MAX_OUTPUTS 10 -static struct NodeDataFloat s_input_node_data_float_buffer; -static float* s_output_node_data_float_buffer; -static int s_output_node_data_float_buffer_byte_size; -static int s_output_node_data_float_array_size; +static struct NodeDataFloat s_input_node_data_buffer[MAX_INPUTS]; +static uint8_t* s_output_node_data_buffer[MAX_OUTPUTS]; +static int s_output_node_data_buffer_max_byte_size[MAX_OUTPUTS]; +static int s_output_node_data_array_byte_size[MAX_OUTPUTS]; static uint32_t s_target_graph_id; static bool s_dbg_use_inception_dummy_data = false; +static int s_dbg_inception_version = 3; + +static int GetInputNodeCount() { + for (int i = 0; i < MAX_INPUTS; ++i) { + if (s_input_node_data_buffer[i].max_buf_byte_size == 0) { + return i; + } + } + return 0; +} + +static int GetOutputNodeCount() { + for (int i = 0; i < MAX_OUTPUTS; ++i) { + if (s_output_node_data_buffer_max_byte_size[i] == 0) { + return i; + } + } + return 0; +} + +static bool SetInputTensorDef(int port, hexagon_nn_tensordef* tensordef) { + if (port >= GetInputNodeCount()) { + TFMLOGE("Error exceeds input count."); + return false; + } + struct NodeDataFloat* input_node_data_buffer = + &s_input_node_data_buffer[port]; + tensordef->batches = input_node_data_buffer->x; + tensordef->height = input_node_data_buffer->y; + tensordef->width = input_node_data_buffer->z; + tensordef->depth = input_node_data_buffer->d; + tensordef->data = input_node_data_buffer->byte_array_data; + tensordef->dataLen = input_node_data_buffer->array_byte_size; + + return true; +} + +bool hexagon_controller_SetAllInputTensorDef(int node_count, + hexagon_nn_tensordef* tensordef) { + bool success = true; + if (node_count != GetInputNodeCount()) { + TFMLOGE("Error invalid input node count."); + return false; + } + for (int i = 0; i < node_count; ++i) { + SetInputTensorDef(i, &tensordef[i]); + } + return success; +} + +static bool SetOutputTensorDef(int port, hexagon_nn_tensordef* tensordef) { + if (port >= GetOutputNodeCount()) { + TFMLOGE("Error exceeds output count."); + return false; + } + tensordef->data = s_output_node_data_buffer[port]; + tensordef->dataLen = s_output_node_data_buffer_max_byte_size[port]; + return true; +} + +bool hexagon_controller_SetAllOutputTensorDef(int node_count, + hexagon_nn_tensordef* tensordef) { + bool success = true; + if (node_count != GetOutputNodeCount()) { + TFMLOGE("Error invalid output node count. %d != %d", node_count, + GetOutputNodeCount()); + return false; + } + for (int i = 0; i < node_count; ++i) { + SetOutputTensorDef(i, &tensordef[i]); + } + return success; +} void hexagon_controller_InitInputNodeDataToInceptionDummyData(int version) { if (version == 1) { @@ -72,44 +143,54 @@ void hexagon_controller_InitInputNodeDataToInceptionDummyData(int version) { return; } hexagon_controller_CopyByteNodeData( - INCEPTION_PARAM_BATCHES, INCEPTION_PARAM_HEIGHT_V1, - INCEPTION_PARAM_WIDTH_V1, INCEPTION_PARAM_DEPTH, - 1, inception_dummy_int_data_224x224); + 0, INCEPTION_PARAM_BATCHES, INCEPTION_PARAM_HEIGHT_V1, + INCEPTION_PARAM_WIDTH_V1, INCEPTION_PARAM_DEPTH, 1, + inception_dummy_int_data_224x224); } else if (version == 3) { if (USE_FLOAT_DATA) { hexagon_controller_CopyByteNodeData( - INCEPTION_PARAM_BATCHES, INCEPTION_PARAM_HEIGHT_V3, - INCEPTION_PARAM_WIDTH_V3, INCEPTION_PARAM_DEPTH, - sizeof(float), (uint8_t*)inception_dummy_float_data_299x299); + 0, INCEPTION_PARAM_BATCHES, INCEPTION_PARAM_HEIGHT_V3, + INCEPTION_PARAM_WIDTH_V3, INCEPTION_PARAM_DEPTH, sizeof(float), + (uint8_t*)inception_dummy_float_data_299x299); } else { hexagon_controller_CopyByteNodeData( - INCEPTION_PARAM_BATCHES, INCEPTION_PARAM_HEIGHT_V3, - INCEPTION_PARAM_WIDTH_V3, INCEPTION_PARAM_DEPTH, - 1, inception_dummy_int_data_299x299); + 0, INCEPTION_PARAM_BATCHES, INCEPTION_PARAM_HEIGHT_V3, + INCEPTION_PARAM_WIDTH_V3, INCEPTION_PARAM_DEPTH, 1, + inception_dummy_int_data_299x299); } } } -bool hexagon_controller_ExecuteGraphWithBuffer( - uint32_t nn_id, bool show_ranking) { - uint32_t out_batches, out_height, out_width, out_depth; - uint32_t out_data_size; - int x = s_input_node_data_float_buffer.x; - int y = s_input_node_data_float_buffer.y; - int z = s_input_node_data_float_buffer.z; - int d = s_input_node_data_float_buffer.d; - uint8_t *byte_data = s_input_node_data_float_buffer.byte_array_data; - int array_size = s_input_node_data_float_buffer.array_size; - const bool success = hexagon_controller_ExecuteGraph( - nn_id, x, y, z, d, byte_data, array_size, - &out_batches, &out_height, &out_width, &out_depth, - (uint8_t *)s_output_node_data_float_buffer, - s_output_node_data_float_buffer_byte_size, - &out_data_size); - s_output_node_data_float_array_size = - out_batches * out_height * out_width * out_depth; +bool hexagon_controller_ExecuteGraphWithBuffer(uint32_t nn_id, + bool show_ranking) { + const int input_node_count = GetInputNodeCount(); + hexagon_nn_tensordef inputs[input_node_count]; + const int output_node_count = GetOutputNodeCount(); + if (output_node_count <= 0) { + TFMLOGI("Error output node count is 0."); + return false; + } + hexagon_nn_tensordef outputs[output_node_count]; + hexagon_controller_SetAllInputTensorDef(input_node_count, inputs); + hexagon_controller_SetAllOutputTensorDef(output_node_count, outputs); + const bool success = hexagon_controller_ExecuteGraphWithMultipleInOut( + nn_id, input_node_count, inputs, output_node_count, outputs); + for (int i = 0; i < output_node_count; ++i) { + s_output_node_data_array_byte_size[i] = outputs[i].data_valid_len; + } + + const hexagon_nn_tensordef* output0 = &outputs[0]; + + const uint32_t out_batches = output0->batches; + const uint32_t out_height = output0->height; + const uint32_t out_width = output0->width; + const uint32_t out_depth = output0->depth; + const uint32_t out_data_size = output0->data_valid_len; + const uint32_t out_buf_byte_size = output0->dataLen; + if (!success) { TFMLOGE("Execution failed"); + DumpNNId(nn_id); return false; } else if (!show_ranking) { return true; @@ -118,15 +199,11 @@ bool hexagon_controller_ExecuteGraphWithBuffer( static const int OUT_RANKING_SIZE = 5; int out_ranking[OUT_RANKING_SIZE]; hexagon_controller_PrintMaxNIdx( - s_output_node_data_float_buffer, - out_batches * out_height * out_width * out_depth, - OUT_RANKING_SIZE, out_ranking); - TFMLOGD("%d x %d x %d x %d, byte size = %d\n", - out_batches, - out_height, - out_width, - out_depth, - out_data_size); + (float*)s_output_node_data_buffer[0], + out_batches * out_height * out_width * out_depth, OUT_RANKING_SIZE, + out_ranking); + TFMLOGD("%d x %d x %d x %d, byte size = %d, buf size = %d\n", out_batches, + out_height, out_width, out_depth, out_data_size, out_buf_byte_size); if (s_dbg_use_inception_dummy_data) { // Check the result of inception with a dummy data. This step shouldn't // be passed when show_ranking != true to avoid adding unnecessary @@ -142,9 +219,7 @@ bool hexagon_controller_ExecuteGraphWithBuffer( return true; } -uint32_t hexagon_controller_GetTargetGraphId() { - return s_target_graph_id; -} +uint32_t hexagon_controller_GetTargetGraphId() { return s_target_graph_id; } void hexagon_controller_SetTargetGraphId(uint32_t graph_id) { s_target_graph_id = graph_id; @@ -168,69 +243,129 @@ int hexagon_controller_GetHexagonBinaryVersion() { return retval; } -bool hexagon_controller_AllocateNodeDataBuffers( - int input_size, int output_size) { - TFMLOGD("Allocate memory for input / output node data float"); - if (s_input_node_data_float_buffer.buf_size != 0) { +bool hexagon_controller_AllocateInputNodeDataBuffers(int port, + int input_buf_byte_size) { + TFMLOGD("Allocate memory for input node data. port = %d, size = %d", port, + input_buf_byte_size); + if (s_input_node_data_buffer[port].max_buf_byte_size != 0) { TFMLOGE("ERROR! input buffer is already allocated!!"); return false; } else { - int byte_array_data_size = USE_FLOAT_DATA ? - input_size * sizeof(float) : input_size; /* sizeof(uint8_t) ? */ - s_input_node_data_float_buffer.buf_size = input_size; - // unused? remove? - s_input_node_data_float_buffer.array_data = - malloc(input_size * sizeof(float)); - s_input_node_data_float_buffer.byte_array_data = - malloc(byte_array_data_size); + s_input_node_data_buffer[port].max_buf_byte_size = input_buf_byte_size; + posix_memalign((void**)&s_input_node_data_buffer[port].byte_array_data, 128, + input_buf_byte_size); + TFMLOGD("allocate input node data buffers done"); + } + return true; +} - s_output_node_data_float_buffer = malloc(output_size * sizeof(float)); - s_output_node_data_float_buffer_byte_size = output_size * sizeof(float); - s_output_node_data_float_array_size = 0; - TFMLOGD("allocate node data buffers"); +bool hexagon_controller_AllocateOutputNodeDataBuffers( + int port, int output_buf_byte_size) { + TFMLOGD("Allocate memory for output node data. port = %d, size = %d", port, + output_buf_byte_size); + if (s_output_node_data_buffer_max_byte_size[port] != 0) { + TFMLOGE("ERROR! input buffer is already allocated!!"); + return false; + } else { + // s_output_node_data_buffer = malloc(output_size * sizeof(float)); + posix_memalign((void**)&s_output_node_data_buffer[port], 128, + output_buf_byte_size); + s_output_node_data_buffer_max_byte_size[port] = output_buf_byte_size; + s_output_node_data_array_byte_size[port] = 0; + TFMLOGD("allocate output node data buffers"); + } + return true; +} + +bool hexagon_controller_AllocateMultipleNodeDataBuffers(int input_count, + int* input_sizes, + int output_count, + int* output_sizes) { + bool success = true; + for (int i = 0; i < input_count; ++i) { + success &= + hexagon_controller_AllocateInputNodeDataBuffers(i, input_sizes[i]); + } + for (int i = 0; i < output_count; ++i) { + success &= + hexagon_controller_AllocateOutputNodeDataBuffers(i, output_sizes[i]); + } + + if (s_dbg_use_inception_dummy_data) { + hexagon_controller_InitInputNodeDataToInceptionDummyData( + s_dbg_inception_version); + } + return success; +} + +bool hexagon_controller_AllocateNodeDataBuffers(int input_size, + int output_size) { + return hexagon_controller_AllocateMultipleNodeDataBuffers(1, &input_size, 1, + &output_size); +} + +bool hexagon_controller_ReleaseInputNodeDataBuffersWithPort(int port) { + struct NodeDataFloat* input_node_data_buffer = + &s_input_node_data_buffer[port]; + if (input_node_data_buffer->max_buf_byte_size == 0) { + TFMLOGE("ERROR! input buffer has not been allocated yet!!"); + return false; + } else { + input_node_data_buffer->max_buf_byte_size = 0; + input_node_data_buffer->array_byte_size = 0; + free(input_node_data_buffer->byte_array_data); + } + return true; +} + +bool hexagon_controller_ReleaseOutputNodeDataBuffersWithPort(int port) { + if (s_output_node_data_buffer_max_byte_size[port] == 0) { + TFMLOGE("ERROR! output buffer has not been allocated yet!!"); + return false; + } else { + s_output_node_data_buffer_max_byte_size[port] = 0; + s_output_node_data_array_byte_size[port] = 0; + free(s_output_node_data_buffer[port]); } return true; } bool hexagon_controller_ReleaseNodeDataBuffers() { - if (s_input_node_data_float_buffer.buf_size == 0) { - TFMLOGE("ERROR! input buffer has not been allocated yet!!"); - return false; - } else { - s_input_node_data_float_buffer.buf_size = 0; - free(s_input_node_data_float_buffer.array_data); + bool success = true; + for (int i = 0; i < GetInputNodeCount(); ++i) { + success &= hexagon_controller_ReleaseInputNodeDataBuffersWithPort(i); } - if (s_output_node_data_float_buffer_byte_size == 0) { - TFMLOGE("ERROR! output buffer has not been allocated yet!!"); - return false; - } else { - s_output_node_data_float_buffer_byte_size = 0; - free(s_input_node_data_float_buffer.byte_array_data); + for (int i = 0; i < GetOutputNodeCount(); ++i) { + success &= hexagon_controller_ReleaseOutputNodeDataBuffersWithPort(i); } - return true; + return success; } -bool hexagon_controller_CopyByteNodeData( - int x, int y, int z, int d, int type_byte_size, uint8_t* array_data) { +bool hexagon_controller_CopyByteNodeData(int port, int x, int y, int z, int d, + int type_byte_size, + uint8_t* array_data) { int array_byte_size = x * y * z * d * type_byte_size; - TFMLOGD("--- %d, %d, %d, %d, %d, %d",x,y,z,d,type_byte_size,array_byte_size); - if (s_input_node_data_float_buffer.buf_size < array_byte_size) { + TFMLOGD("--- %d, %d, %d, %d, %d, %d", x, y, z, d, type_byte_size, + array_byte_size); + struct NodeDataFloat* input_node_data_buffer = &s_input_node_data_buffer[0]; + + if (input_node_data_buffer->max_buf_byte_size < array_byte_size) { TFMLOGE("ERROR! input buffer size is too small! %d < %d", - s_input_node_data_float_buffer.buf_size, array_byte_size); + input_node_data_buffer->max_buf_byte_size, array_byte_size); return false; } - memcpy(s_input_node_data_float_buffer.byte_array_data, - array_data, array_byte_size); - s_input_node_data_float_buffer.array_size = array_byte_size; - s_input_node_data_float_buffer.x = x; - s_input_node_data_float_buffer.y = y; - s_input_node_data_float_buffer.z = z; - s_input_node_data_float_buffer.d = d; + memcpy(input_node_data_buffer->byte_array_data, array_data, array_byte_size); + input_node_data_buffer->array_byte_size = array_byte_size; + input_node_data_buffer->x = x; + input_node_data_buffer->y = y; + input_node_data_buffer->z = z; + input_node_data_buffer->d = d; return true; } -int hexagon_controller_InitHexagonWithMaxAttributes( - int enable_dcvs, int bus_usage, int version) { +int hexagon_controller_InitHexagonWithMaxAttributes(int enable_dcvs, + int bus_usage, + int version) { TFMLOGI("Init hexagon with max attributes (Controller version = %d)", HEXAGON_CONTROLLER_VERSION); const int MCPS = 1000; @@ -239,17 +374,17 @@ int hexagon_controller_InitHexagonWithMaxAttributes( adspmsgd_start(0, RPCMEM_HEAP_DEFAULT, 4096); dspCV_Attribute attrib[] = { - // The below values will result in the maximum aDSP performance, - // at Turbo voltage. - // Slightly more MCPS than are available on current targets - {DSP_TOTAL_MCPS, MCPS}, - // drive the clock to MAX on known targets - {DSP_MCPS_PER_THREAD, MCPS / 2}, - // 12 GB/sec is slightly higher than the max realistic - // max BW on existing targets. - {PEAK_BUS_BANDWIDTH_MBPS, MBPS}, - // This app is non-real time, and constantly reading/writing memory - {BUS_USAGE_PERCENT, bus_usage}, + // The below values will result in the maximum aDSP performance, + // at Turbo voltage. + // Slightly more MCPS than are available on current targets + {DSP_TOTAL_MCPS, MCPS}, + // drive the clock to MAX on known targets + {DSP_MCPS_PER_THREAD, MCPS / 2}, + // 12 GB/sec is slightly higher than the max realistic + // max BW on existing targets. + {PEAK_BUS_BANDWIDTH_MBPS, MBPS}, + // This app is non-real time, and constantly reading/writing memory + {BUS_USAGE_PERCENT, bus_usage}, }; int retval = 0; if (!enable_dcvs) { @@ -263,13 +398,8 @@ int hexagon_controller_InitHexagonWithMaxAttributes( dspCV_initQ6_with_attributes(attrib, sizeof(attrib) / sizeof(attrib[0])); TFMLOGD("Return value from dspCV_initQ6() : %d\n", retval); - hexagon_controller_AllocateNodeDataBuffers( - INPUT_NODE_DATA_BUFFER_SIZE, OUTPUT_NODE_DATA_BUFFER_SIZE); - - if (s_dbg_use_inception_dummy_data) { - hexagon_controller_InitInputNodeDataToInceptionDummyData(version); - } s_target_graph_id = 0; + s_dbg_inception_version = version; return retval; } @@ -285,31 +415,36 @@ int hexagon_controller_DeInitHexagon() { return retval; } -void hexagon_controller_GrowMemorySize() { - hexagon_nn_config(); +void hexagon_controller_GrowMemorySize() { hexagon_nn_config(); } + +struct NodeDataFloat* hexagon_controller_GetInputNodeDataBuffer(int port) { + if (port >= GetInputNodeCount()) { + TFMLOGE("port should be less than 1"); + } + return &s_input_node_data_buffer[port]; } -struct NodeDataFloat* hexagon_controller_GetInputNodeDataFloatBuffer() { - return &s_input_node_data_float_buffer; -} - -float* hexagon_controller_GetOutputNodeDataFloatBuffer( - const char *const node_name, int* out_array_size) { - *out_array_size = s_output_node_data_float_array_size; - return s_output_node_data_float_buffer; +uint8_t* hexagon_controller_GetOutputNodeDataBuffer(int port, + int* out_array_byte_size) { + if (port >= GetOutputNodeCount()) { + TFMLOGE("port should be less than 1"); + } + *out_array_byte_size = s_output_node_data_array_byte_size[port]; + return s_output_node_data_buffer[port]; } // Append const node to the graph -int hexagon_controller_AppendConstNode( - const char* const name, int graph_id, int node_id, - int batch, int height, int width, int depth, - const uint8_t* const data, int data_length) { +int hexagon_controller_AppendConstNode(const char* const name, int graph_id, + int node_id, int batch, int height, + int width, int depth, + const uint8_t* const data, + int data_length) { if (DBG_SHOW_ID) { - TFMLOGV("---(CONST) %s, %d, %d, %d, %d, %d, %d", - name, node_id, batch, height, width, depth, data_length); + TFMLOGV("---(CONST) %s, %d, %d, %d, %d, %d, %d", name, node_id, batch, + height, width, depth, data_length); } else { - TFMLOGV("---(CONST) %s, %d, %d, %d, %d, %d", - name, batch, height, width, depth, data_length); + TFMLOGV("---(CONST) %s, %d, %d, %d, %d, %d", name, batch, height, width, + depth, data_length); } const int retval = hexagon_nn_append_const_node( graph_id, node_id, batch, height, width, depth, data, data_length); @@ -321,11 +456,12 @@ int hexagon_controller_AppendConstNode( } // Append node to the graph -int hexagon_controller_AppendNode( - const char* const name, int graph_id, int node_id, int ops_id, - int padding_id, const hexagon_nn_input* const inputs, - int inputs_count, const hexagon_nn_output* const outputs, - int outputs_count) { +int hexagon_controller_AppendNode(const char* const name, int graph_id, + int node_id, int ops_id, int padding_id, + const hexagon_nn_input* const inputs, + int inputs_count, + const hexagon_nn_output* const outputs, + int outputs_count) { char input_param_buf[OUTPUT_PARAM_MAX_LINE_SIZE]; memset(input_param_buf, 0, OUTPUT_PARAM_MAX_LINE_SIZE); int pos = 0; @@ -335,8 +471,8 @@ int hexagon_controller_AppendNode( pos += snprintf(&input_param_buf[pos], 500, "(%d, %d), ", inputs[i].src_id, inputs[i].output_idx); } else { - pos += snprintf(&input_param_buf[pos], 500, "(%d), ", - inputs[i].output_idx); + pos += + snprintf(&input_param_buf[pos], 500, "(%d), ", inputs[i].output_idx); } } @@ -349,18 +485,16 @@ int hexagon_controller_AppendNode( } if (DBG_SHOW_ID) { - TFMLOGV("---(OP) %s, %d, %d, %d, %d, %d, %s, %s", name, node_id, - ops_id, padding_id, inputs_count, outputs_count, input_param_buf, + TFMLOGV("---(OP) %s, %d, %d, %d, %d, %d, %s, %s", name, node_id, ops_id, + padding_id, inputs_count, outputs_count, input_param_buf, output_param_buf); } else { - TFMLOGV("---(OP) %s, %d, %d, %d, %d, %s, %s", name, - ops_id, padding_id, inputs_count, outputs_count, input_param_buf, - output_param_buf); + TFMLOGV("---(OP) %s, %d, %d, %d, %d, %s, %s", name, ops_id, padding_id, + inputs_count, outputs_count, input_param_buf, output_param_buf); } - const int retval = hexagon_nn_append_node( - graph_id, node_id, ops_id, padding_id, - inputs, inputs_count, - outputs, outputs_count); + const int retval = + hexagon_nn_append_node(graph_id, node_id, ops_id, padding_id, inputs, + inputs_count, outputs, outputs_count); if (retval != 0) { TFMLOGE("Failed to append const node %d", node_id); return retval; @@ -375,13 +509,3 @@ void hexagon_controller_EnableDbgUseInceptionDummyData(bool enable) { bool hexagon_controller_IsDbgUseInceptionDummyDataEnabled() { return s_dbg_use_inception_dummy_data; } - -void hexagon_controller_PrintLog(uint32_t nn_id) { - unsigned char *buf; - if ((buf = malloc(PRINT_BUFSIZE)) == NULL) { - return; - } - hexagon_nn_getlog(nn_id, buf, PRINT_BUFSIZE); - TFMLOGE("DUMP HEXAGON LOG: %s", buf); - free(buf); -} diff --git a/tensorflow/contrib/hvx/hexagon_controller/src_impl/include/hexagon_controller.h b/tensorflow/contrib/hvx/hexagon_controller/src_impl/include/hexagon_controller.h index ab8c80c0f32..fc921ff8b98 100644 --- a/tensorflow/contrib/hvx/hexagon_controller/src_impl/include/hexagon_controller.h +++ b/tensorflow/contrib/hvx/hexagon_controller/src_impl/include/hexagon_controller.h @@ -40,16 +40,37 @@ int hexagon_controller_GetWrapperVersion(); int hexagon_controller_GetHexagonBinaryVersion(); +// Buffer operations +bool hexagon_controller_SetAllInputTensorDef(int node_count, + hexagon_nn_tensordef* tensordef); + +bool hexagon_controller_SetAllInputTensorDef(int node_count, + hexagon_nn_tensordef* tensordef); + // Hexagon perf functions int hexagon_controller_InitHexagonWithMaxAttributes(int enable_dcvs, int bus_usage, int version); +bool hexagon_controller_AllocateInputNodeDataBuffersWithPort(int port, + int input_size); + +bool hexagon_controller_AllocateOutNodeDataBuffersWithPort(int port, + int output_size); + bool hexagon_controller_AllocateNodeDataBuffers(int input_size, int output_size); +bool hexagon_controller_AllocateMultipleNodeDataBuffers(int input_count, + int* input_sizes, + int output_count, + int* output_sizes); + +bool hexagon_controller_ReleaseInputNodeDataBuffersWithPort(int port); +bool hexagon_controller_ReleaseOutputNodeDataBuffersWithPort(int port); + bool hexagon_controller_ReleaseNodeDataBuffers(); -bool hexagon_controller_CopyByteNodeData(int x, int y, int z, int d, +bool hexagon_controller_CopyByteNodeData(int port, int x, int y, int z, int d, int type_byte_size, uint8_t* array_data); @@ -63,10 +84,10 @@ void hexagon_controller_SetTargetGraphId(uint32_t graph_id); void hexagon_controller_GrowMemorySize(); // Graph data transfer functions -struct NodeDataFloat* hexagon_controller_GetInputNodeDataFloatBuffer(); +struct NodeDataFloat* hexagon_controller_GetInputNodeDataBuffer(int port); -float* hexagon_controller_GetOutputNodeDataFloatBuffer( - const char* const node_name, int* out_array_size); +uint8_t* hexagon_controller_GetOutputNodeDataBuffer(int port, + int* out_array_byte_size); // Graph functions uint32_t hexagon_controller_InstantiateGraph(); @@ -79,6 +100,10 @@ uint32_t hexagon_controller_SetupGraph(int version); bool hexagon_controller_ExecuteInceptionDummyData(uint32_t nn_id); +bool hexagon_controller_ExecuteGraphWithMultipleInOut( + const uint32_t nn_id, const int input_count, hexagon_nn_tensordef* inputs, + const int output_count, hexagon_nn_tensordef* outputs); + bool hexagon_controller_ExecuteGraph( const uint32_t nn_id, const uint32_t batches, const uint32_t height, const uint32_t width, const uint32_t depth, uint8_t* int_data, @@ -117,8 +142,6 @@ void hexagon_controller_EnableDbgUseInceptionDummyData(bool enable); bool hexagon_controller_IsDbgUseInceptionDummyDataEnabled(); -void hexagon_controller_PrintLog(uint32_t nn_id); - #ifdef __cplusplus } #endif // __cplusplus diff --git a/tensorflow/contrib/hvx/hexagon_controller/src_log/include/tfm_log.h b/tensorflow/contrib/hvx/hexagon_controller/src_log/include/tfm_log.h index e8615fd4ec0..8d11ee4a340 100644 --- a/tensorflow/contrib/hvx/hexagon_controller/src_log/include/tfm_log.h +++ b/tensorflow/contrib/hvx/hexagon_controller/src_log/include/tfm_log.h @@ -33,6 +33,9 @@ static inline bool IsLogOn(int log_level) { return log_level >= s_log_level; } static inline void SetLogLevel(int log_level) { s_log_level = log_level; } +// Do nothing +static inline void SetExperimentalDebug() {} + #define TFMLOGV(fmt, ...) \ do { \ if (!IsLogOn(TFM_LOG_LEVEL_VERBOSE)) break; \ @@ -71,4 +74,9 @@ static inline void LogDHexagon(const char* fmt, ...) { va_end(ap); } +static inline void DumpNNId(uint32_t nn_id) { + // TODO(satok): Dump more information + TFMLOGI("NN Id = %d", nn_id); +} + #endif diff --git a/tensorflow/contrib/hvx/hexagon_controller/src_soc_interface/include/node_data_float.h b/tensorflow/contrib/hvx/hexagon_controller/src_soc_interface/include/node_data_float.h index a9c3296e9f4..c7034cc3a0d 100644 --- a/tensorflow/contrib/hvx/hexagon_controller/src_soc_interface/include/node_data_float.h +++ b/tensorflow/contrib/hvx/hexagon_controller/src_soc_interface/include/node_data_float.h @@ -28,9 +28,8 @@ struct NodeDataFloat { int y; int z; int d; - int buf_size; - int array_size; - float* array_data; + int max_buf_byte_size; + int array_byte_size; uint8_t* byte_array_data; char node_name[NODE_DATA_FLOAT_NODE_NAME_BUF_SIZE]; }; diff --git a/tensorflow/contrib/hvx/hexagon_controller/src_soc_interface/include/soc_interface.h b/tensorflow/contrib/hvx/hexagon_controller/src_soc_interface/include/soc_interface.h index 6d85e6ce487..30fad13fb5f 100644 --- a/tensorflow/contrib/hvx/hexagon_controller/src_soc_interface/include/soc_interface.h +++ b/tensorflow/contrib/hvx/hexagon_controller/src_soc_interface/include/soc_interface.h @@ -43,13 +43,30 @@ bool soc_interface_Finalize(); bool soc_interface_ExecuteGraph(); // Teardown graph setup bool soc_interface_TeardownGraph(); + +// Allocate buffers for input node and output node +bool soc_interface_AllocateInOutNodeBuffers(int input_count, int* input_sizes, + int output_count, + int* output_sizes); + +// Send input data to SOC with port +bool soc_interface_FillInputNodeWithPort(int port, int x, int y, int z, int d, + const uint8_t* const buf, + uint64_t buf_byte_size); + // Send input data to SOC bool soc_interface_FillInputNodeFloat(int x, int y, int z, int d, const uint8_t* const buf, - uint64_t buf_size); + uint64_t buf_byte_size); + +// Load output data from SOC with port +bool soc_interface_ReadOutputNodeWithPort(int port, uint8_t** buf, + uint64_t* buf_byte_size); + // Load output data from SOC bool soc_interface_ReadOutputNodeFloat(const char* const node_name, - uint8_t** buf, uint64_t* buf_size); + uint8_t** buf, uint64_t* buf_byte_size); + // Setup graph // TODO(satok): Remove and use runtime version bool soc_interface_setupDummyGraph(int version); diff --git a/tensorflow/contrib/hvx/hexagon_controller/src_soc_interface/soc_interface.c b/tensorflow/contrib/hvx/hexagon_controller/src_soc_interface/soc_interface.c index 7db8d4870c7..a1387ee5736 100755 --- a/tensorflow/contrib/hvx/hexagon_controller/src_soc_interface/soc_interface.c +++ b/tensorflow/contrib/hvx/hexagon_controller/src_soc_interface/soc_interface.c @@ -22,7 +22,12 @@ limitations under the License. #include "node_data_float.h" #include "tfm_log.h" +// to demonstrate the performance difference between ION and HLOS memory +// for sharing with ADSP. +#define USE_ION_MEMORY + const int64_t FLAG_ENABLE_INCEPTION_DUMMY_BINARY_INPUT = 0x01; +const int64_t FLAG_ENABLE_EXPERIMENTAL_DEBUG = 0x02; static const int INCEPTION_VERSION = 3; @@ -84,48 +89,62 @@ bool soc_interface_TeardownGraph() { return true; } -bool soc_interface_FillInputNodeFloat( - int x, int y, int z, int d, const uint8_t* const buf, - uint64_t buf_size) { - TFMLOGD("FillInputNodeFloat"); - struct NodeDataFloat* node_data_float = - hexagon_controller_GetInputNodeDataFloatBuffer(); - const int array_size = x * y * z * d; - if (array_size > node_data_float->buf_size) { - TFMLOGE("Array size exceeds buf size %d > %d", - array_size, node_data_float->buf_size); +bool soc_interface_AllocateInOutNodeBuffers(int input_count, int* input_sizes, + int output_count, + int* output_sizes) { + TFMLOGD("AllocateInOutNodeBuffers"); + return hexagon_controller_AllocateMultipleNodeDataBuffers( + input_count, input_sizes, output_count, output_sizes); +} + +bool soc_interface_FillInputNodeWithPort(int port, int x, int y, int z, int d, + const uint8_t* const buf, + uint64_t buf_byte_size) { + TFMLOGD("FillInputNodeWithPort %d", port); + struct NodeDataFloat* node_data = + hexagon_controller_GetInputNodeDataBuffer(port); + if (buf_byte_size > node_data->max_buf_byte_size) { + TFMLOGE("buf size exceeds max buf size"); return false; } - if (buf_size != array_size * sizeof(float)) { - TFMLOGE("Invalid buf size!"); - return false; - } - memcpy(node_data_float->byte_array_data, buf, buf_size); - node_data_float->x = x; - node_data_float->y = y; - node_data_float->z = z; - node_data_float->d = d; - node_data_float->array_size = buf_size; + memcpy(node_data->byte_array_data, buf, buf_byte_size); + node_data->x = x; + node_data->y = y; + node_data->z = z; + node_data->d = d; + node_data->array_byte_size = buf_byte_size; return true; } +bool soc_interface_FillInputNodeFloat(int x, int y, int z, int d, + const uint8_t* const buf, + uint64_t buf_byte_size) { + return soc_interface_FillInputNodeWithPort( + /*port=*/0, x, y, z, d, buf, buf_byte_size); +} + // TODO(satok): Remove and use runtime version -bool soc_interface_ReadOutputNodeFloat( - const char* const node_name, uint8_t** buf, uint64_t *buf_size) { - TFMLOGD("ReadOutputNodeFloat"); - int array_size = -1; - float* output_node_data_float = - hexagon_controller_GetOutputNodeDataFloatBuffer(node_name, &array_size); - if (array_size < 0) { +bool soc_interface_ReadOutputNodeWithPort(int port, uint8_t** buf, + uint64_t* buf_byte_size) { + TFMLOGD("ReadOutputNodeWithPort"); + int array_byte_size = -1; + uint8_t* output_node_data_buffer = + hexagon_controller_GetOutputNodeDataBuffer(port, &array_byte_size); + if (array_byte_size < 0) { TFMLOGE("Failed to read data."); return false; } - *buf = (uint8_t*)output_node_data_float; - *buf_size = array_size * sizeof(float); + *buf = output_node_data_buffer; + *buf_byte_size = array_byte_size; return true; } -bool soc_interface_SetupGraphDummy(int version) { +bool soc_interface_ReadOutputNodeFloat(const char* const node_name, + uint8_t** buf, uint64_t* buf_byte_size) { + return soc_interface_ReadOutputNodeWithPort(/*port=*/0, buf, buf_byte_size); +} + +bool soc_interface_setupDummyGraph(int version) { TFMLOGD("SetupGraphDummy"); const uint32_t graph_id = hexagon_controller_SetupGraph(version); if (graph_id == 0) { @@ -136,12 +155,14 @@ bool soc_interface_SetupGraphDummy(int version) { return true; } -bool soc_interface_AllocateNodeInputAndNodeOutputArray( - int total_input_count, int total_output_count) { +bool soc_interface_AllocateNodeInputAndNodeOutputArray(int total_input_count, + int total_output_count) { TFMLOGD("Allocate node inputs and node outputs array %d, %d", total_input_count, total_output_count); - s_node_inputs_array = malloc(total_input_count * sizeof(hexagon_nn_input)); - s_node_outputs_array = malloc(total_output_count * sizeof(hexagon_nn_output)); + posix_memalign((void**)&s_node_inputs_array, 128, + total_input_count * sizeof(hexagon_nn_input)); + posix_memalign((void**)&s_node_outputs_array, 128, + total_output_count * sizeof(hexagon_nn_output)); s_node_inputs_array_index = 0; s_node_outputs_array_index = 0; s_node_inputs_array_max_count = total_input_count; @@ -188,9 +209,9 @@ void* soc_interface_SetOneNodeOutputs(int output_count, int* max_size) { } // Append const node to the graph -bool soc_interface_AppendConstNode( - const char* const name, int node_id, int batch, int height, int width, int depth, - const uint8_t* const data, int data_length) { +bool soc_interface_AppendConstNode(const char* const name, int node_id, + int batch, int height, int width, int depth, + const uint8_t* const data, int data_length) { const uint32_t graph_id = hexagon_controller_GetTargetGraphId(); const int retval = hexagon_controller_AppendConstNode( name, graph_id, node_id, batch, height, width, depth, data, data_length); @@ -202,14 +223,14 @@ bool soc_interface_AppendConstNode( } // Append node to the graph -bool soc_interface_AppendNode( - const char* const name, int node_id, int ops_id, int padding_id, const void* const inputs, - int inputs_count, const void* const outputs, int outputs_count) { +bool soc_interface_AppendNode(const char* const name, int node_id, int ops_id, + int padding_id, const void* const inputs, + int inputs_count, const void* const outputs, + int outputs_count) { const uint32_t graph_id = hexagon_controller_GetTargetGraphId(); const int retval = hexagon_controller_AppendNode( - name, graph_id, node_id, ops_id, padding_id, - (hexagon_nn_input*) inputs, inputs_count, - (hexagon_nn_output*) outputs, outputs_count); + name, graph_id, node_id, ops_id, padding_id, (hexagon_nn_input*)inputs, + inputs_count, (hexagon_nn_output*)outputs, outputs_count); if (retval != 0) { TFMLOGE("Failed to append const node %d", node_id); return false; @@ -217,7 +238,6 @@ bool soc_interface_AppendNode( return true; } - // Instantiate graph bool soc_interface_InstantiateGraph() { const uint32_t nn_id = hexagon_controller_InstantiateGraph(); @@ -240,5 +260,7 @@ void soc_interface_SetDebugFlag(uint64_t flag) { if ((flag & FLAG_ENABLE_INCEPTION_DUMMY_BINARY_INPUT) != 0) { TFMLOGI("Enable always use panda data"); hexagon_controller_EnableDbgUseInceptionDummyData(true); + } else if ((flag & FLAG_ENABLE_EXPERIMENTAL_DEBUG) != 0) { + SetExperimentalDebug(); } } diff --git a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc index 6ae7c4a7420..6af608396ab 100644 --- a/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc +++ b/tensorflow/contrib/hvx/hvx_ops_support_checker/hvx_ops_support_checker_main.cc @@ -33,10 +33,15 @@ limitations under the License. #include "tensorflow/tools/graph_transforms/transform_utils.h" namespace tensorflow { + namespace { -static int ParseFlags(int argc, char* argv[], string* in_graph) { +static int ParseFlags(int argc, char* argv[], string* in_graph, + bool* dump_all_nodes, bool* dump_shape_and_type) { std::vector flag_list = { - Flag("in_graph", in_graph, "input graph file name"), + Flag("in_graph", in_graph, "Input graph file name to check hvx support."), + Flag("dump_all_nodes", dump_all_nodes, "Dump all nodes in the model."), + Flag("dump_shape_and_type", dump_shape_and_type, + "Dump shape and type of nodes"), }; CHECK(Flags::Parse(&argc, argv, flag_list)); // We need to call this to set up global state for TensorFlow. @@ -48,12 +53,25 @@ static int ParseFlags(int argc, char* argv[], string* in_graph) { return 0; } -static void SummarizeNode(const NodeDef& node_def) { +static void SummarizeNode(const NodeDef& node_def, + const bool dump_shape_and_type) { LOG(INFO) << "Node(" << node_def.name() << ")"; LOG(INFO) << " op: " << node_def.op(); for (const string& input : node_def.input()) { LOG(INFO) << " Input: " << input; } + std::vector data_types; + std::vector shapes; + const Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( + node_def, &data_types, &shapes); + if (data_types.empty() || shapes.empty()) { + return; + } + CHECK_EQ(data_types.size(), shapes.size()); + for (int i = 0; i < data_types.size(); ++i) { + LOG(INFO) << " Output(" << i << "): " << DataType_Name(data_types.at(i)) + << ", " << shapes.at(i).DebugString(); + } } static void DumpRemoteFusedGraph(const NodeDef& node_def) { @@ -89,10 +107,14 @@ static void DumpRemoteFusedGraph(const NodeDef& node_def) { } } -static void CheckOpsSupport(const GraphDef& graph_def) { +static void CheckOpsSupport(const GraphDef& graph_def, + const bool dump_all_nodes, + const bool dump_shape_and_type) { const IGraphTransferOpsDefinitions& ops_definition = HexagonOpsDefinitions::getInstance(); LOG(INFO) << "Checking " << graph_def.node_size() << " nodes"; + LOG(INFO) << "dump_all_nodes = " << dump_all_nodes + << ", dump_shape_and_tpye = " << dump_shape_and_type; std::unordered_set unsupported_ops; bool all_supported = true; @@ -125,9 +147,9 @@ static void CheckOpsSupport(const GraphDef& graph_def) { LOG(INFO) << count << " ops are not supported."; } - if (contains_remote_graph) { + if (contains_remote_graph || dump_all_nodes) { for (const NodeDef& node : graph_def.node()) { - SummarizeNode(node); + SummarizeNode(node, dump_shape_and_type); } } } @@ -137,7 +159,10 @@ static void CheckOpsSupport(const GraphDef& graph_def) { int main(int argc, char** argv) { tensorflow::string in_graph; - const int ret = tensorflow::ParseFlags(argc, argv, &in_graph); + bool dump_all_nodes; + bool dump_shape_and_type; + const int ret = tensorflow::ParseFlags(argc, argv, &in_graph, &dump_all_nodes, + &dump_shape_and_type); if (ret != 0) { return ret; } @@ -146,6 +171,6 @@ int main(int argc, char** argv) { TF_CHECK_OK(tensorflow::graph_transforms::LoadTextOrBinaryGraphFile( in_graph, &graph_def)); - tensorflow::CheckOpsSupport(graph_def); + tensorflow::CheckOpsSupport(graph_def, dump_all_nodes, dump_shape_and_type); return 0; } diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index b396dcea211..aef3e385b57 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -136,7 +136,7 @@ def transform(images, transforms, interpolation="NEAREST"): `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to the transform mapping input points to output points. - interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". + interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". Returns: Image(s) with the same type and shape as `images`, with the given diff --git a/tensorflow/contrib/input_pipeline/BUILD b/tensorflow/contrib/input_pipeline/BUILD index 9aa5763efcc..bb7857eb998 100644 --- a/tensorflow/contrib/input_pipeline/BUILD +++ b/tensorflow/contrib/input_pipeline/BUILD @@ -82,9 +82,10 @@ py_test( srcs_version = "PY2AND3", deps = [ ":input_pipeline_py", - "//tensorflow:tensorflow_py", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", "//tensorflow/python:state_ops", "//tensorflow/python:variables", ], diff --git a/tensorflow/contrib/integrate/python/ops/odes.py b/tensorflow/contrib/integrate/python/ops/odes.py index d13844d6132..b4a99867ed4 100644 --- a/tensorflow/contrib/integrate/python/ops/odes.py +++ b/tensorflow/contrib/integrate/python/ops/odes.py @@ -12,47 +12,55 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """ODE solvers for TensorFlow.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function +import abc import collections +import six + from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import tensor_array_ops - -_ButcherTableau = collections.namedtuple( - '_ButcherTableau', 'alpha beta c_sol c_mid c_error') +_ButcherTableau = collections.namedtuple('_ButcherTableau', + 'alpha beta c_sol c_mid c_error') # Parameters from Shampine (1986), section 4. _DORMAND_PRINCE_TABLEAU = _ButcherTableau( - alpha=[1/5, 3/10, 4/5, 8/9, 1., 1.], - beta=[[1/5], - [3/40, 9/40], - [44/45, -56/15, 32/9], - [19372/6561, -25360/2187, 64448/6561, -212/729], - [9017/3168, -355/33, 46732/5247, 49/176, -5103/18656], - [35/384, 0, 500/1113, 125/192, -2187/6784, 11/84]], - c_sol=[35/384, 0, 500/1113, 125/192, -2187/6784, 11/84, 0], - c_mid=[6025192743/30085553152 / 2, 0, 51252292925/65400821598 / 2, - -2691868925/45128329728 / 2, 187940372067/1594534317056 / 2, - -1776094331/19743644256 / 2, 11237099/235043384 / 2], - c_error=[1951/21600 - 35/384, - 0, - 22642/50085 - 500/1113, - 451/720 - 125/192, - -12231/42400 - -2187/6784, - 649/6300 - 11/84, - 1/60], -) + alpha=[1 / 5, 3 / 10, 4 / 5, 8 / 9, 1., 1.], + beta=[ + [1 / 5], + [3 / 40, 9 / 40], + [44 / 45, -56 / 15, 32 / 9], + [19372 / 6561, -25360 / 2187, 64448 / 6561, -212 / 729], + [9017 / 3168, -355 / 33, 46732 / 5247, 49 / 176, -5103 / 18656], + [35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84], + ], + c_sol=[35 / 384, 0, 500 / 1113, 125 / 192, -2187 / 6784, 11 / 84, 0], + c_mid=[ + 6025192743 / 30085553152 / 2, 0, 51252292925 / 65400821598 / 2, + -2691868925 / 45128329728 / 2, 187940372067 / 1594534317056 / 2, + -1776094331 / 19743644256 / 2, 11237099 / 235043384 / 2 + ], + c_error=[ + 1951 / 21600 - 35 / 384, + 0, + 22642 / 50085 - 500 / 1113, + 451 / 720 - 125 / 192, + -12231 / 42400 - -2187 / 6784, + 649 / 6300 - 11 / 84, + 1 / 60, + ],) def _possibly_nonzero(x): @@ -64,9 +72,10 @@ def _scaled_dot_product(scale, xs, ys, name=None): with ops.name_scope(name, 'scaled_dot_product', [scale, xs, ys]) as scope: # Some of the parameters in our Butcher tableau include zeros. Using # _possibly_nonzero lets us avoid wasted computation. - return math_ops.add_n([(scale * x) * y for x, y in zip(xs, ys) - if _possibly_nonzero(x) or _possibly_nonzero(y)], - name=scope) + return math_ops.add_n( + [(scale * x) * y for x, y in zip(xs, ys) + if _possibly_nonzero(x) or _possibly_nonzero(y)], + name=scope) def _dot_product(xs, ys, name=None): @@ -75,7 +84,12 @@ def _dot_product(xs, ys, name=None): return math_ops.add_n([x * y for x, y in zip(xs, ys)], name=scope) -def _runge_kutta_step(func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_TABLEAU, +def _runge_kutta_step(func, + y0, + f0, + t0, + dt, + tableau=_DORMAND_PRINCE_TABLEAU, name=None): """Take an arbitrary Runge-Kutta step and estimate error. @@ -115,8 +129,8 @@ def _runge_kutta_step(func, y0, f0, t0, dt, tableau=_DORMAND_PRINCE_TABLEAU, y1 = array_ops.identity(yi, name='%s/y1' % scope) f1 = array_ops.identity(k[-1], name='%s/f1' % scope) - y1_error = _scaled_dot_product(dt_cast, tableau.c_error, k, - name='%s/y1_error' % scope) + y1_error = _scaled_dot_product( + dt_cast, tableau.c_error, k, name='%s/y1_error' % scope) return (y1, f1, y1_error, k) @@ -208,15 +222,15 @@ def _optimal_step_size(last_step, order=5, name=None): """Calculate the optimal size for the next Runge-Kutta step.""" - with ops.name_scope( - name, 'optimal_step_size', [last_step, error_ratio]) as scope: + with ops.name_scope(name, 'optimal_step_size', [last_step, + error_ratio]) as scope: error_ratio = math_ops.cast(error_ratio, last_step.dtype) exponent = math_ops.cast(1 / order, last_step.dtype) # this looks more complex than necessary, but importantly it keeps # error_ratio in the numerator so we can't divide by zero: - factor = math_ops.maximum( - 1 / ifactor, - math_ops.minimum(error_ratio ** exponent / safety, 1 / dfactor)) + factor = math_ops.maximum(1 / ifactor, + math_ops.minimum(error_ratio**exponent / safety, + 1 / dfactor)) return math_ops.div(last_step, factor, name=scope) @@ -232,8 +246,9 @@ def _ta_append(tensor_array, value): return tensor_array.write(tensor_array.size(), value) -class _RungeKuttaState(collections.namedtuple( - '_RungeKuttaState', 'y1, f1, t0, t1, dt, interp_coeff')): +class _RungeKuttaState( + collections.namedtuple('_RungeKuttaState', + 'y1, f1, t0, t1, dt, interp_coeff')): """Saved state of the Runge Kutta solver. Attributes: @@ -247,8 +262,8 @@ class _RungeKuttaState(collections.namedtuple( """ -class _History(collections.namedtuple( - '_History', 'integrate_points, error_ratio')): +class _History( + collections.namedtuple('_History', 'integrate_points, error_ratio')): """Saved integration history for use in `info_dict`. Attributes: @@ -258,6 +273,20 @@ class _History(collections.namedtuple( """ +def _assert_increasing(t): + assert_increasing = control_flow_ops.Assert( + math_ops.reduce_all(t[1:] > t[:-1]), ['`t` must be monotonic increasing']) + return ops.control_dependencies([assert_increasing]) + + +def _check_input_types(t, y0): + if not (y0.dtype.is_floating or y0.dtype.is_complex): + raise TypeError('`y0` must have a floating point or complex floating ' + 'point dtype') + if not t.dtype.is_floating: + raise TypeError('`t` must have a floating point dtype') + + def _dopri5(func, y0, t, @@ -277,24 +306,24 @@ def _dopri5(func, # automatically first_step = 1.0 - with ops.name_scope( - name, 'dopri5', - [y0, t, rtol, atol, safety, ifactor, dfactor, max_num_steps]) as scope: + with ops.name_scope(name, 'dopri5', [ + y0, t, rtol, atol, safety, ifactor, dfactor, max_num_steps + ]) as scope: - first_step = ops.convert_to_tensor(first_step, dtype=t.dtype, - name='first_step') + first_step = ops.convert_to_tensor( + first_step, dtype=t.dtype, name='first_step') safety = ops.convert_to_tensor(safety, dtype=t.dtype, name='safety') ifactor = ops.convert_to_tensor(ifactor, dtype=t.dtype, name='ifactor') dfactor = ops.convert_to_tensor(dfactor, dtype=t.dtype, name='dfactor') - max_num_steps = ops.convert_to_tensor(max_num_steps, dtype=dtypes.int32, - name='max_num_steps') + max_num_steps = ops.convert_to_tensor( + max_num_steps, dtype=dtypes.int32, name='max_num_steps') def adaptive_runge_kutta_step(rk_state, history, n_steps): """Take an adaptive Runge-Kutta step to integrate the ODE.""" y0, f0, _, t0, dt, interp_coeff = rk_state with ops.name_scope('assertions'): - check_underflow = control_flow_ops.Assert( - t0 + dt > t0, ['underflow in dt', dt]) + check_underflow = control_flow_ops.Assert(t0 + dt > t0, + ['underflow in dt', dt]) check_max_num_steps = control_flow_ops.Assert( n_steps < max_num_steps, ['max_num_steps exceeded']) check_numerics = control_flow_ops.Assert( @@ -320,16 +349,16 @@ def _dopri5(func, f_next = control_flow_ops.cond(accept_step, lambda: f1, lambda: f0) t_next = control_flow_ops.cond(accept_step, lambda: t0 + dt, lambda: t0) interp_coeff = control_flow_ops.cond( - accept_step, - lambda: _interp_fit_rk(y0, y1, k, dt), + accept_step, lambda: _interp_fit_rk(y0, y1, k, dt), lambda: interp_coeff) dt_next = _optimal_step_size(dt, error_ratio, safety, ifactor, dfactor) - rk_state = _RungeKuttaState( - y_next, f_next, t0, t_next, dt_next, interp_coeff) + rk_state = _RungeKuttaState(y_next, f_next, t0, t_next, dt_next, + interp_coeff) with ops.name_scope('update/history'): - history = _History(_ta_append(history.integrate_points, t0 + dt), - _ta_append(history.error_ratio, error_ratio)) + history = _History( + _ta_append(history.integrate_points, t0 + dt), + _ta_append(history.error_ratio, error_ratio)) return rk_state, history, n_steps + 1 def interpolate(solution, history, rk_state, i): @@ -337,18 +366,14 @@ def _dopri5(func, with ops.name_scope('interpolate'): rk_state, history, _ = control_flow_ops.while_loop( lambda rk_state, *_: t[i] > rk_state.t1, - adaptive_runge_kutta_step, - (rk_state, history, 0), + adaptive_runge_kutta_step, (rk_state, history, 0), name='integrate_loop') - y = _interp_evaluate( - rk_state.interp_coeff, rk_state.t0, rk_state.t1, t[i]) + y = _interp_evaluate(rk_state.interp_coeff, rk_state.t0, rk_state.t1, + t[i]) solution = solution.write(i, y) return solution, history, rk_state, i + 1 - assert_increasing = control_flow_ops.Assert( - math_ops.reduce_all(t[1:] > t[:-1]), - ['`t` must be monotonic increasing']) - with ops.control_dependencies([assert_increasing]): + with _assert_increasing(t): num_times = array_ops.size(t) solution = tensor_array_ops.TensorArray( @@ -363,8 +388,7 @@ def _dopri5(func, solution, history, _, _ = control_flow_ops.while_loop( lambda _, __, ___, i: i < num_times, - interpolate, - (solution, history, rk_state, 1), + interpolate, (solution, history, rk_state, 1), name='interpolate_loop') y = solution.stack(name=scope) @@ -373,9 +397,11 @@ def _dopri5(func, return y else: integrate_points = history.integrate_points.stack() - info_dict = {'num_func_evals': 6 * array_ops.size(integrate_points) + 1, - 'integrate_points': integrate_points, - 'error_ratio': history.error_ratio.stack()} + info_dict = { + 'num_func_evals': 6 * array_ops.size(integrate_points) + 1, + 'integrate_points': integrate_points, + 'error_ratio': history.error_ratio.stack() + } return (y, info_dict) @@ -390,7 +416,7 @@ def odeint(func, name=None): """Integrate a system of ordinary differential equations. - Solves the initial value problem for a non-stiff system of first order ode-s: + Solves the initial value problem for a non-stiff system of first order ODEs: ``` dy/dt = func(y, t), y(t[0]) = y0 @@ -483,21 +509,109 @@ def odeint(func, # arbitrarily nested tuple. This will help performance and usability by # avoiding the need to pack/unpack in user functions. y0 = ops.convert_to_tensor(y0, name='y0') - if not (y0.dtype.is_floating or y0.dtype.is_complex): - raise TypeError('`y0` must have a floating point or complex floating ' - 'point dtype') - t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t') - if not t.dtype.is_floating: - raise TypeError('`t` must have a floating point dtype') + _check_input_types(t, y0) error_dtype = abs(y0).dtype rtol = ops.convert_to_tensor(rtol, dtype=error_dtype, name='rtol') atol = ops.convert_to_tensor(atol, dtype=error_dtype, name='atol') - return _dopri5(func, y0, t, - rtol=rtol, - atol=atol, - full_output=full_output, - name=scope, - **options) + return _dopri5( + func, + y0, + t, + rtol=rtol, + atol=atol, + full_output=full_output, + name=scope, + **options) + + +class _FixedGridIntegrator(six.with_metaclass(abc.ABCMeta)): + """Base class for fixed-grid ODE integrators.""" + + def integrate(self, evol_func, y0, time_grid): + time_delta_grid = time_grid[1:] - time_grid[:-1] + + scan_func = self._make_scan_func(evol_func) + + y_grid = functional_ops.scan(scan_func, (time_grid[:-1], time_delta_grid), + y0) + return array_ops.concat([[y0], y_grid], axis=0) + + def _make_scan_func(self, evol_func): + + def scan_func(y, t_and_dt): + t, dt = t_and_dt + dy = self._step_func(evol_func, t, dt, y) + dy = math_ops.cast(dy, dtype=y.dtype) + return y + dy + + return scan_func + + @abc.abstractmethod + def _step_func(self, evol_func, t, dt, y): + pass + + +class _MidpointFixedGridIntegrator(_FixedGridIntegrator): + + def _step_func(self, evol_func, t, dt, y): + dt_cast = math_ops.cast(dt, y.dtype) + # yn1 = yn + h * f(tn + h/2, yn + f(tn, yn) * h/2) + return dt_cast * evol_func(y + evol_func(y, t) * dt_cast / 2, t + dt / 2) + + +class _RK4FixedGridIntegrator(_FixedGridIntegrator): + + def _step_func(self, evol_func, t, dt, y): + k1 = evol_func(y, t) + half_step = t + dt / 2 + dt_cast = math_ops.cast(dt, y.dtype) + + k2 = evol_func(y + dt_cast * k1 / 2, half_step) + k3 = evol_func(y + dt_cast * k2 / 2, half_step) + k4 = evol_func(y + dt_cast * k3, t + dt) + return math_ops.add_n([k1, 2 * k2, 2 * k3, k4]) * (dt_cast / 6) + + +def odeint_fixed(func, y0, t, method='rk4', name=None): + """ODE integration on a fixed grid (with no step size control). + + Useful in certain scenarios to avoid the overhead of adaptive step size + control, e.g. when differentiation of the integration result is desired and/or + the time grid is known a priori to be sufficient. + + Args: + func: Function that maps a Tensor holding the state `y` and a scalar Tensor + `t` into a Tensor of state derivatives with respect to time. + y0: N-D Tensor giving starting value of `y` at time point `t[0]`. + t: 1-D Tensor holding a sequence of time points for which to solve for + `y`. The initial time point should be the first element of this sequence, + and each time must be larger than the previous time. May have any floating + point dtype. + method: One of 'midpoint' or 'rk4'. + name: Optional name for the resulting operation. + + Returns: + y: (N+1)-D tensor, where the first dimension corresponds to different + time points. Contains the solved value of y for each desired time point in + `t`, with the initial value `y0` being the first element along the first + dimension. + + Raises: + ValueError: Upon caller errors. + """ + with ops.name_scope(name, 'odeint_fixed', [y0, t]): + t = ops.convert_to_tensor(t, preferred_dtype=dtypes.float64, name='t') + y0 = ops.convert_to_tensor(y0, name='y0') + _check_input_types(t, y0) + + with _assert_increasing(t): + with ops.name_scope(method): + if method == 'midpoint': + return _MidpointFixedGridIntegrator().integrate(func, y0, t) + elif method == 'rk4': + return _RK4FixedGridIntegrator().integrate(func, y0, t) + else: + raise ValueError('method not supported: {!s}'.format(method)) diff --git a/tensorflow/contrib/integrate/python/ops/odes_test.py b/tensorflow/contrib/integrate/python/ops/odes_test.py index 009e1d1f77c..3ec01212d25 100644 --- a/tensorflow/contrib/integrate/python/ops/odes_test.py +++ b/tensorflow/contrib/integrate/python/ops/odes_test.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== """Tests for ODE solvers.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -165,11 +166,9 @@ class OdeIntTest(test.TestCase): with self.test_session() as sess: y_solved_0, info_0 = sess.run( - odes.odeint( - self.func, self.y0, times0, full_output=True)) + odes.odeint(self.func, self.y0, times0, full_output=True)) y_solved_1, info_1 = sess.run( - odes.odeint( - self.func, self.y0, times1, full_output=True)) + odes.odeint(self.func, self.y0, times1, full_output=True)) self.assertAllClose(y_solved_0, y_solved_1[::10]) self.assertEqual(info_0['num_func_evals'], info_1['num_func_evals']) @@ -182,11 +181,9 @@ class OdeIntTest(test.TestCase): full_output=True, method='dopri5', options=dict(max_num_steps=2000)) with self.test_session() as sess: _, info_0 = sess.run( - odes.odeint( - self.func, self.y0, t, rtol=0, atol=1e-6, **kwargs)) + odes.odeint(self.func, self.y0, t, rtol=0, atol=1e-6, **kwargs)) _, info_1 = sess.run( - odes.odeint( - self.func, self.y0, t, rtol=0, atol=1e-9, **kwargs)) + odes.odeint(self.func, self.y0, t, rtol=0, atol=1e-9, **kwargs)) self.assertAllClose( info_0['integrate_points'].size * 1000**0.2, float(info_1['integrate_points'].size), @@ -243,5 +240,49 @@ class InterpolationTest(test.TestCase): sess.run(y_invalid) +class OdeIntFixedTest(test.TestCase): + + def _test_integrate_sine(self, method): + + def evol_func(y, t): + del t + return array_ops.stack([y[1], -y[0]]) + + y0 = [0., 1.] + time_grid = np.linspace(0., 10., 200) + y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method) + + with self.test_session() as sess: + y_grid_array = sess.run(y_grid) + + np.testing.assert_allclose( + y_grid_array[:, 0], np.sin(time_grid), rtol=1e-2, atol=1e-2) + + def _test_integrate_gaussian(self, method): + + def evol_func(y, t): + return -math_ops.cast(t, dtype=y.dtype) * y[0] + + y0 = [1.] + time_grid = np.linspace(0., 2., 100) + y_grid = odes.odeint_fixed(evol_func, y0, time_grid, method=method) + + with self.test_session() as sess: + y_grid_array = sess.run(y_grid) + + np.testing.assert_allclose( + y_grid_array[:, 0], np.exp(-time_grid**2 / 2), rtol=1e-2, atol=1e-2) + + def _test_everything(self, method): + self._test_integrate_sine(method) + self._test_integrate_gaussian(method) + + def test_midpoint(self): + self._test_everything('midpoint') + + def test_rk4(self): + self._test_everything('rk4') + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/keras/BUILD b/tensorflow/contrib/keras/BUILD index 71ce6540d62..619ebb7ce07 100644 --- a/tensorflow/contrib/keras/BUILD +++ b/tensorflow/contrib/keras/BUILD @@ -7,6 +7,7 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) +load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "py_test") py_library( @@ -393,12 +394,11 @@ py_test( ], ) -py_test( +cuda_py_test( name = "normalization_test", size = "small", srcs = ["python/keras/layers/normalization_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":keras", ":testing_utils", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/keras/python/keras/backend.py b/tensorflow/contrib/keras/python/keras/backend.py index 7a005603084..324f510301a 100644 --- a/tensorflow/contrib/keras/python/keras/backend.py +++ b/tensorflow/contrib/keras/python/keras/backend.py @@ -269,9 +269,10 @@ def get_uid(prefix=''): def reset_uids(): - layer_name_uids_collection = ops.get_collection_ref('LAYER_NAME_UIDS') - if layer_name_uids_collection: - layer_name_uids_collection.pop() + per_graph_layer_name_uids = tf_base_layers.PER_GRAPH_LAYER_NAME_UIDS + keys = list(per_graph_layer_name_uids.keys()) + for key in keys: + del per_graph_layer_name_uids[key] def clear_session(): diff --git a/tensorflow/contrib/keras/python/keras/backend_test.py b/tensorflow/contrib/keras/python/keras/backend_test.py index 2da5aee58e5..a2bc95e4a10 100644 --- a/tensorflow/contrib/keras/python/keras/backend_test.py +++ b/tensorflow/contrib/keras/python/keras/backend_test.py @@ -105,10 +105,13 @@ class BackendUtilsTest(test.TestCase): self.assertEqual(keras.backend.image_data_format(), image_data_format) keras.backend.set_image_data_format('channels_last') - def test_get_uid(self): + def test_get_reset_uids(self): self.assertEqual(keras.backend.get_uid('foo'), 1) self.assertEqual(keras.backend.get_uid('foo'), 2) + keras.backend.reset_uids() + self.assertEqual(keras.backend.get_uid('foo'), 1) + class BackendVariableTest(test.TestCase): diff --git a/tensorflow/contrib/keras/python/keras/engine/topology.py b/tensorflow/contrib/keras/python/keras/engine/topology.py index 07d708ada3c..637d0c5a487 100644 --- a/tensorflow/contrib/keras/python/keras/engine/topology.py +++ b/tensorflow/contrib/keras/python/keras/engine/topology.py @@ -1176,6 +1176,7 @@ class Container(Layer): # The following properties are not actually used by Keras; # they exist for compatibility with TF. self._updates = [] + self._losses = [] self._scope = None self._reuse = None self._base_name = name diff --git a/tensorflow/contrib/keras/python/keras/integration_test.py b/tensorflow/contrib/keras/python/keras/integration_test.py index bcd844201c1..0f6db097d15 100644 --- a/tensorflow/contrib/keras/python/keras/integration_test.py +++ b/tensorflow/contrib/keras/python/keras/integration_test.py @@ -161,6 +161,73 @@ class KerasIntegrationTest(test.TestCase): verbose=2) self.assertGreater(history.history['val_acc'][-1], 0.70) + def test_vector_classification_shared_sequential(self): + # Test that Sequential models that feature internal updates + # and internal losses can be shared. + with self.test_session(): + np.random.seed(1337) + (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( + train_samples=200, + test_samples=100, + input_shape=(10,), + num_classes=2) + y_train = keras.utils.to_categorical(y_train) + y_test = keras.utils.to_categorical(y_test) + + base_model = keras.models.Sequential([ + keras.layers.Dense(16, + activation='relu', + kernel_regularizer=keras.regularizers.l2(1e-5), + bias_regularizer=keras.regularizers.l2(1e-5), + input_shape=x_train.shape[1:]), + keras.layers.BatchNormalization(), + ]) + x = keras.layers.Input(x_train.shape[1:]) + y = base_model(x) + y = keras.layers.Dense(y_train.shape[-1], activation='softmax')(y) + model = keras.models.Model(x, y) + model.compile(loss='categorical_crossentropy', + optimizer='rmsprop', + metrics=['accuracy']) + history = model.fit(x_train, y_train, epochs=10, batch_size=16, + validation_data=(x_test, y_test), + verbose=2) + self.assertGreater(history.history['val_acc'][-1], 0.85) + + def test_vector_classification_shared_model(self): + # Test that functional models that feature internal updates + # and internal losses can be shared. + with self.test_session(): + np.random.seed(1337) + (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( + train_samples=200, + test_samples=100, + input_shape=(10,), + num_classes=2) + y_train = keras.utils.to_categorical(y_train) + y_test = keras.utils.to_categorical(y_test) + + inputs = keras.layers.Input(x_train.shape[1:]) + x = keras.layers.Dense(16, + activation='relu', + kernel_regularizer=keras.regularizers.l2(1e-5), + bias_regularizer=keras.regularizers.l2(1e-5), + input_shape=x_train.shape[1:])(inputs) + x = keras.layers.BatchNormalization()(x) + base_model = keras.models.Model(inputs, x) + + x = keras.layers.Input(x_train.shape[1:]) + y = base_model(x) + y = keras.layers.Dense(y_train.shape[-1], activation='softmax')(y) + model = keras.models.Model(x, y) + model.compile(loss='categorical_crossentropy', + optimizer='rmsprop', + metrics=['accuracy']) + history = model.fit(x_train, y_train, epochs=10, batch_size=16, + validation_data=(x_test, y_test), + verbose=2) + self.assertGreater(history.history['val_acc'][-1], 0.85) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/keras/python/keras/layers/normalization_test.py b/tensorflow/contrib/keras/python/keras/layers/normalization_test.py index dc410f84d85..1a0686800eb 100644 --- a/tensorflow/contrib/keras/python/keras/layers/normalization_test.py +++ b/tensorflow/contrib/keras/python/keras/layers/normalization_test.py @@ -94,22 +94,23 @@ class NoiseLayersTest(test.TestCase): np.testing.assert_allclose(out.std(), 1.0, atol=1e-1) def test_batchnorm_convnet(self): - with self.test_session(): - model = keras.models.Sequential() - norm = keras.layers.BatchNormalization( - axis=1, input_shape=(3, 4, 4), momentum=0.8) - model.add(norm) - model.compile(loss='mse', optimizer='sgd') + if test.is_gpu_available(cuda_only=True): + with self.test_session(use_gpu=True): + model = keras.models.Sequential() + norm = keras.layers.BatchNormalization( + axis=1, input_shape=(3, 4, 4), momentum=0.8) + model.add(norm) + model.compile(loss='mse', optimizer='sgd') - # centered on 5.0, variance 10.0 - x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4)) - model.fit(x, x, epochs=4, verbose=0) - out = model.predict(x) - out -= np.reshape(keras.backend.eval(norm.beta), (1, 3, 1, 1)) - out /= np.reshape(keras.backend.eval(norm.gamma), (1, 3, 1, 1)) + # centered on 5.0, variance 10.0 + x = np.random.normal(loc=5.0, scale=10.0, size=(1000, 3, 4, 4)) + model.fit(x, x, epochs=4, verbose=0) + out = model.predict(x) + out -= np.reshape(keras.backend.eval(norm.beta), (1, 3, 1, 1)) + out /= np.reshape(keras.backend.eval(norm.gamma), (1, 3, 1, 1)) - np.testing.assert_allclose(np.mean(out, axis=(0, 2, 3)), 0.0, atol=1e-1) - np.testing.assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1) + np.testing.assert_allclose(np.mean(out, axis=(0, 2, 3)), 0.0, atol=1e-1) + np.testing.assert_allclose(np.std(out, axis=(0, 2, 3)), 1.0, atol=1e-1) def test_shared_batchnorm(self): """Test that a BN layer can be shared across different data streams. diff --git a/tensorflow/contrib/keras/python/keras/models.py b/tensorflow/contrib/keras/python/keras/models.py index 0ae373da3cd..8786e0b97ac 100644 --- a/tensorflow/contrib/keras/python/keras/models.py +++ b/tensorflow/contrib/keras/python/keras/models.py @@ -436,6 +436,7 @@ class Sequential(Model): # The following properties are not actually used by Keras; # they exist for compatibility with TF's variable scoping mechanism. self._updates = [] + self._losses = [] self._scope = None self._reuse = None self._base_name = name diff --git a/tensorflow/contrib/kernel_methods/g3doc/tutorial.md b/tensorflow/contrib/kernel_methods/g3doc/tutorial.md index 9877375c2c1..f39a8d80d22 100644 --- a/tensorflow/contrib/kernel_methods/g3doc/tutorial.md +++ b/tensorflow/contrib/kernel_methods/g3doc/tutorial.md @@ -13,7 +13,7 @@ for sparse features is in the works. We will use [tf.contrib.learn](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn) (TensorFlow's high-level Machine Learning API) Estimators for our ML models. The tf.contrib.learn API reduces the boilerplate code one needs to write for configuring, training and evaluating models and will let us focus on the core -ideas. If you are not familiar with this API, [tf.contrib.learn Quickstart](https://www.tensorflow.org/get_started/tflearn) is a good place to start. We +ideas. If you are not familiar with this API, [tf.estimator Quickstart](https://www.tensorflow.org/get_started/estimator) is a good place to start. We will use MNIST, a widely-used dataset containing images of handwritten digits (between 0 and 9). The tutorial consists of the following steps: diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index d3b10949630..f2a904b5211 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -257,27 +257,33 @@ def _fused_batch_norm( 'beta') if not param_initializers: param_initializers = {} - beta_initializer = param_initializers.get('beta', - init_ops.zeros_initializer()) - beta = variables.model_variable( - 'beta', - shape=params_shape, - dtype=dtype, - initializer=beta_initializer, - collections=beta_collections, - trainable=trainable_beta) - trainable_gamma = trainable and scale - gamma_collections = utils.get_variable_collections(variables_collections, - 'gamma') - gamma_initializer = param_initializers.get('gamma', - init_ops.ones_initializer()) - gamma = variables.model_variable( - 'gamma', - shape=params_shape, - dtype=dtype, - initializer=gamma_initializer, - collections=gamma_collections, - trainable=trainable_gamma) + if center: + beta_initializer = param_initializers.get('beta', + init_ops.zeros_initializer()) + beta = variables.model_variable( + 'beta', + shape=params_shape, + dtype=dtype, + initializer=beta_initializer, + collections=beta_collections, + trainable=trainable_beta) + else: + beta = array_ops.constant(0.0, shape=params_shape) + + if scale: + gamma_collections = utils.get_variable_collections( + variables_collections, 'gamma') + gamma_initializer = param_initializers.get('gamma', + init_ops.ones_initializer()) + gamma = variables.model_variable( + 'gamma', + shape=params_shape, + dtype=dtype, + initializer=gamma_initializer, + collections=gamma_collections, + trainable=trainable) + else: + gamma = array_ops.constant(1.0, shape=params_shape) # Create moving_mean and moving_variance variables and add them to the # appropriate collections. @@ -449,7 +455,8 @@ def batch_norm(inputs, then the batch normalization uses weighted mean and variance. (This can be used to correct for bias in training example selection.) - fused: Use nn.fused_batch_norm if True, nn.batch_normalization otherwise. + fused: if `True`, use a faster, fused implementation based on + nn.fused_batch_norm. If `None`, use the fused implementation if possible. data_format: A string. `NHWC` (default) and `NCHW` are supported. zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new pair of variables 'moving_mean/biased' and 'moving_mean/local_step'. @@ -473,7 +480,6 @@ def batch_norm(inputs, Raises: ValueError: If `batch_weights` is not None and `fused` is True. - ValueError: If `param_regularizers` is not None and `fused` is True. ValueError: If `data_format` is neither `NHWC` nor `NCHW`. ValueError: If the rank of `inputs` is undefined. ValueError: If rank or channels dimension of `inputs` is undefined. @@ -487,6 +493,21 @@ def batch_norm(inputs, 'supported for fused batch norm.') if renorm: raise ValueError('Renorm is not supported for fused batch norm.') + + # Only use _fused_batch_norm (1) if fused is set True or if it is + # possible to use (currently it doesn't support batch weights, + # renorm, and the case when rank is neither 2 nor 4), + # and (2) if used with zero_debias_moving_mean, or an input shape of rank 2, + # or non-default updates_collections (not implemented in + # normalization_layers.BatchNormalization yet); otherwise use the fused + # implementation in normalization_layers.BatchNormalization. + inputs = ops.convert_to_tensor(inputs) + rank = inputs.get_shape().ndims + feature_supported = batch_weights is None and not renorm and rank in [2, 4] + possible_to_fuse = fused is None and feature_supported + if (fused or possible_to_fuse) and ( + zero_debias_moving_mean or rank == 2 or + updates_collections is not ops.GraphKeys.UPDATE_OPS): return _fused_batch_norm( inputs, decay=decay, @@ -552,7 +573,8 @@ def batch_norm(inputs, renorm_momentum=renorm_decay, name=sc.name, _scope=sc, - _reuse=reuse) + _reuse=reuse, + fused=fused) outputs = layer.apply(inputs, training=is_training) # Add variables to collections. @@ -560,9 +582,9 @@ def batch_norm(inputs, layer.moving_mean, variables_collections, 'moving_mean') _add_variable_to_collections( layer.moving_variance, variables_collections, 'moving_variance') - if layer.beta: + if layer.beta is not None: _add_variable_to_collections(layer.beta, variables_collections, 'beta') - if layer.gamma: + if layer.gamma is not None: _add_variable_to_collections( layer.gamma, variables_collections, 'gamma') @@ -2145,6 +2167,44 @@ def unit_norm(inputs, dim, epsilon=1e-7, scope=None): return math_ops.div(inputs, array_ops.tile(lengths, multiples)) +def poincare_normalize(x, axis=1, epsilon=1e-5, name=None): + """Project into the Poincare ball with norm <= 1.0 - epsilon. + + https://en.wikipedia.org/wiki/Poincare_ball_model + + Used in + Poincare Embeddings for Learning Hierarchical Representations + Maximilian Nickel, Douwe Kiela + https://arxiv.org/pdf/1705.08039.pdf + + For a 1-D tensor with `axis = 0`, computes + + (x * (1 - epsilon)) / ||x|| if ||x|| > 1 - epsilon + output = + x otherwise + + For `x` with more dimensions, independently normalizes each 1-D slice along + dimension `axis`. + + Args: + x: A `Tensor`. + axis: Axis along which to normalize. A scalar or a vector of + integers. + epsilon: A small deviation from the edge of the unit sphere for numerical + stability. + name: A name for this operation (optional). + + Returns: + A `Tensor` with the same shape as `x`. + """ + with ops.name_scope(name, 'poincare_normalize', [x]) as name: + x = ops.convert_to_tensor(x, name='x') + square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keep_dims=True) + x_inv_norm = math_ops.rsqrt(square_sum) + x_inv_norm = math_ops.minimum((1. - epsilon) * x_inv_norm, 1.) + return math_ops.multiply(x, x_inv_norm, name=name) + + def legacy_fully_connected(x, num_output_units, activation_fn=None, diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index b49c33e9969..d4ee85b5501 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -38,6 +38,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -1702,13 +1703,6 @@ class BatchNormTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'Weighted mean and variance'): _layers.batch_norm(inputs, batch_weights=batch_weights, fused=True) - def testParamRegularizersFused(self): - with ops.Graph().as_default() as g, self.test_session(g): - inputs = array_ops.placeholder(dtype=dtypes.float32, shape=(5, 3, 3, 7)) - with self.assertRaisesRegexp(ValueError, - 'Regularizers are not currently'): - _layers.batch_norm(inputs, param_regularizers={}, fused=True) - def _testCreateOp(self, fused): height, width = 3, 3 with self.test_session(): @@ -1779,7 +1773,8 @@ class BatchNormTest(test.TestCase): height, width = 3, 3 with self.test_session(): images = random_ops.random_uniform((5, height, width, 3), seed=1) - _layers.batch_norm(images, scale=True, zero_debias_moving_mean=True) + _layers.batch_norm( + images, scale=True, zero_debias_moving_mean=True, fused=False) self.assertEqual(len(variables.get_model_variables()), 6) moving_mean = variables.get_variables_by_name('moving_mean')[0] moving_variance = variables.get_variables_by_name('moving_variance')[0] @@ -1873,7 +1868,8 @@ class BatchNormTest(test.TestCase): images, decay=0.1, updates_collections=None, - zero_debias_moving_mean=True) + zero_debias_moving_mean=True, + fused=False) moving_mean = variables.get_variables_by_name('BatchNorm/moving_mean')[0] moving_variance = variables.get_variables_by_name('moving_variance')[0] biased = variables.get_variables_by_name('biased')[0] @@ -2522,7 +2518,7 @@ class BatchNormTest(test.TestCase): def _runBatchNormalizationWithFormat(self, shape, data_format, is_training): channels = shape[-1] - with self.test_session() as sess: + with self.test_session(use_gpu=True) as sess: images = np.arange(np.product(shape), dtype=np.float32).reshape(shape) beta = init_ops.constant_initializer( np.arange( @@ -2560,20 +2556,22 @@ class BatchNormTest(test.TestCase): return sess.run(output) def testNHWCAndNCHWInferenceProduceSameOutput(self): - for shape in [[7, 3, 5], [5, 2, 3, 4], [11, 3, 2, 4, 5]]: - nhwc = self._runBatchNormalizationWithFormat( - data_format='NHWC', shape=shape, is_training=False) - nchw = self._runBatchNormalizationWithFormat( - data_format='NCHW', shape=shape, is_training=False) - self.assertAllClose(nhwc, nchw, atol=1e-4, rtol=1e-4) + if test.is_gpu_available(cuda_only=True): + for shape in [[7, 3, 5], [5, 2, 3, 4], [11, 3, 2, 4, 5]]: + nhwc = self._runBatchNormalizationWithFormat( + data_format='NHWC', shape=shape, is_training=False) + nchw = self._runBatchNormalizationWithFormat( + data_format='NCHW', shape=shape, is_training=False) + self.assertAllClose(nhwc, nchw, atol=1e-4, rtol=1e-4) def testNHWCAndNCHWTrainingProduceSameOutput(self): - for shape in [[7, 3, 5], [5, 2, 3, 4], [11, 3, 2, 4, 5]]: - nhwc = self._runBatchNormalizationWithFormat( - data_format='NHWC', shape=shape, is_training=True) - nchw = self._runBatchNormalizationWithFormat( - data_format='NCHW', shape=shape, is_training=True) - self.assertAllClose(nhwc, nchw, atol=1e-4, rtol=1e-4) + if test.is_gpu_available(cuda_only=True): + for shape in [[7, 3, 5], [5, 2, 3, 4], [11, 3, 2, 4, 5]]: + nhwc = self._runBatchNormalizationWithFormat( + data_format='NHWC', shape=shape, is_training=True) + nchw = self._runBatchNormalizationWithFormat( + data_format='NCHW', shape=shape, is_training=True) + self.assertAllClose(nhwc, nchw, atol=1e-4, rtol=1e-4) class LayerNormTest(test.TestCase): @@ -3231,6 +3229,69 @@ class UnitNormTests(test.TestCase): self.assertAllClose(expected, actual, 1e-4, 1e-4) +class PoincareNormalizeTest(test.TestCase): + + def _PoincareNormalize(self, x, dim, epsilon=1e-5): + if isinstance(dim, list): + norm = np.linalg.norm(x, axis=tuple(dim)) + for d in dim: + norm = np.expand_dims(norm, d) + norm_x = ((1. - epsilon) * x) / norm + else: + norm = np.expand_dims(np.apply_along_axis(np.linalg.norm, dim, x), dim) + norm_x = ((1. - epsilon) * x) / norm + return np.where(norm > 1.0 - epsilon, norm_x, x) + + def testPoincareNormalize(self): + x_shape = [20, 7, 3] + epsilon = 1e-5 + tol = 1e-6 + np.random.seed(1) + x_np = np.random.random_sample(x_shape).astype(np.float32) + for dim in range(len(x_shape)): + y_np = self._PoincareNormalize(x_np, dim, epsilon) + with self.test_session(): + x_tf = constant_op.constant(x_np, name='x') + y_tf = _layers.poincare_normalize(x_tf, dim, epsilon) + y_tf_eval = y_tf.eval() + norm = np.linalg.norm(y_np, axis=dim) + self.assertLessEqual(norm.max(), 1. - epsilon + tol) + norm = np.linalg.norm(y_tf_eval, axis=dim) + self.assertLessEqual(norm.max(), 1. - epsilon + tol) + self.assertAllClose(y_np, y_tf_eval) + + def testPoincareNormalizeDimArray(self): + x_shape = [20, 7, 3] + epsilon = 1e-5 + tol = 1e-6 + np.random.seed(1) + x_np = np.random.random_sample(x_shape).astype(np.float32) + dim = [1, 2] + y_np = self._PoincareNormalize(x_np, dim, epsilon) + with self.test_session(): + x_tf = constant_op.constant(x_np, name='x') + y_tf = _layers.poincare_normalize(x_tf, dim, epsilon) + y_tf_eval = y_tf.eval() + norm = np.linalg.norm(y_np, axis=tuple(dim)) + self.assertLess(norm.max(), 1. - epsilon + tol) + norm = np.linalg.norm(y_tf_eval, axis=tuple(dim)) + self.assertLess(norm.max(), 1. - epsilon + tol) + self.assertAllClose(y_np, y_tf_eval, rtol=1e-6, atol=1e-6) + + def testPoincareNormalizeGradient(self): + x_shape = [20, 7, 3] + np.random.seed(1) + x_np = np.random.random_sample(x_shape).astype(np.float64) + for dim in range(len(x_shape)): + with self.test_session(): + x_tf = constant_op.constant(x_np, name='x') + y_tf = _layers.poincare_normalize(x_tf, dim) + err = gradient_checker.compute_gradient_error(x_tf, x_shape, + y_tf, x_shape) + print('PoinCareNormalize gradient err = %g ' % err) + self.assertLess(err, 1e-4) + + # TODO(b/28426988): Add separate tests for non-legacy versions. class LegacyFullyConnectedTest(test.TestCase): diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 0ab39e35d0b..980f971c50e 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -780,6 +780,28 @@ py_test( ], ) +py_test( + name = "debug_test", + size = "medium", + srcs = ["python/learn/estimators/debug_test.py"], + shard_count = 4, + srcs_version = "PY2AND3", + deps = [ + ":learn", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/learn/python/learn/datasets", + "//tensorflow/contrib/metrics:metrics_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:training", + "//third_party/py/numpy", + ], +) + py_test( name = "composable_model_test", size = "medium", diff --git a/tensorflow/contrib/learn/python/learn/estimators/debug.py b/tensorflow/contrib/learn/python/learn/estimators/debug.py new file mode 100644 index 00000000000..010448be12e --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/estimators/debug.py @@ -0,0 +1,321 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Debug estimators. + +Debug estimators are bias-only estimators that can be used for debugging +and as simple baselines. + +Example: + +``` +# Build DebugClassifier +classifier = DebugClassifier() + +# Input builders +def input_fn_train: # returns x, y (where y represents label's class index). + pass + +def input_fn_eval: # returns x, y (where y represents label's class index). + pass + +# Fit model. +classifier.fit(input_fn=input_fn_train) + +# Evaluate cross entropy between the test and train labels. +loss = classifier.evaluate(input_fn=input_fn_eval)["loss"] + +# predict_classes outputs the most commonly seen class in training. +predicted_label = classifier.predict_classes(new_samples) + +# predict_proba outputs the class distribution from training. +label_distribution = classifier.predict_proba(new_samples) +``` +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.layers.python.layers import optimizers +from tensorflow.contrib.learn.python.learn.estimators import estimator +from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +from tensorflow.contrib.learn.python.learn.estimators import prediction_key +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops + + +def _get_feature_dict(features): + if isinstance(features, dict): + return features + return {"": features} + + +def debug_model_fn(features, labels, mode, params, config=None): + """Model_fn for debug models. + + Args: + features: `Tensor` or dict of `Tensor` (depends on data passed to `fit`). + labels: Labels that are compatible with the `_Head` instance in `params`. + mode: Defines whether this is training, evaluation or prediction. + See `ModeKeys`. + params: A dict of hyperparameters containing: + * head: A `_Head` instance. + config: `RunConfig` object to configure the runtime settings. + + Raises: + KeyError: If weight column is specified but not present. + ValueError: If features is an empty dictionary. + + Returns: + A `ModelFnOps` instance. + """ + del config # Unused. + + features = _get_feature_dict(features) + if not features: + raise ValueError("Features cannot be empty.") + + head = params["head"] + size_checks = [] + batch_size = None + + # The first dimension is assumed to be a batch size and must be consistent + # among all of the features. + for feature in features.values(): + first_dim = array_ops.shape(feature)[0] + if batch_size is None: + batch_size = first_dim + else: + size_checks.append(check_ops.assert_equal(batch_size, first_dim)) + + with ops.control_dependencies(size_checks): + logits = array_ops.zeros([batch_size, head.logits_dimension]) + + def train_op_fn(loss): + return optimizers.optimize_loss( + loss, global_step=None, learning_rate=0.3, optimizer="Adagrad") + + return head.create_model_fn_ops( + features=features, + labels=labels, + mode=mode, + train_op_fn=train_op_fn, + logits=logits) + + +class DebugClassifier(estimator.Estimator): + """A classifier for TensorFlow Debug models. + + Example: + + ```python + + # Build DebugClassifier + classifier = DebugClassifier() + + # Input builders + def input_fn_train: # returns x, y (where y represents label's class index). + pass + + def input_fn_eval: # returns x, y (where y represents label's class index). + pass + + # Fit model. + classifier.fit(input_fn=input_fn_train) + + # Evaluate cross entropy between the test and train labels. + loss = classifier.evaluate(input_fn=input_fn_eval)["loss"] + + # predict_class outputs the most commonly seen class in training. + predicted_label = classifier.predict_class(new_samples) + + # predict_proba outputs the class distribution from training. + label_distribution = classifier.predict_proba(new_samples) + ``` + + Input of `fit` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * if `weight_column_name` is not `None`, a feature with + `key=weight_column_name` whose value is a `Tensor`. + """ + + def __init__(self, + model_dir=None, + n_classes=2, + weight_column_name=None, + config=None, + feature_engineering_fn=None): + """Initializes a DebugClassifier instance. + + Args: + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. + n_classes: number of label classes. Default is binary classification. + It must be greater than 1. Note: Class labels are integers representing + the class index (i.e. values from 0 to n_classes-1). For arbitrary + label values (e.g. string labels), convert to class indices first. + weight_column_name: A string defining feature column name representing + weights. It is used to down weight or boost examples during training. It + will be multiplied by the loss of the example. + config: `RunConfig` object to configure the runtime settings. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns + features and labels which will be fed into the model. + Returns: + A `DebugClassifier` estimator. + + Raises: + ValueError: If `n_classes` < 2. + """ + params = {"head": + head_lib._multi_class_head( # pylint: disable=protected-access + n_classes=n_classes, + weight_column_name=weight_column_name, + enable_centered_bias=True)} + + super(DebugClassifier, self).__init__( + model_fn=debug_model_fn, + model_dir=model_dir, + config=config, + params=params, + feature_engineering_fn=feature_engineering_fn) + + def predict_classes(self, input_fn=None, batch_size=None): + """Returns predicted classes for given features. + + Args: + input_fn: Input function. + batch_size: Override default batch size. + + Returns: + An iterable of predicted classes. Each predicted class is represented by + its class index (i.e. integer from 0 to n_classes-1). + """ + key = prediction_key.PredictionKey.CLASSES + preds = self.predict( + input_fn=input_fn, batch_size=batch_size, outputs=[key]) + return (pred[key] for pred in preds) + + def predict_proba(self, + input_fn=None, + batch_size=None): + """Returns prediction probabilities for given features. + + Args: + input_fn: Input function. + batch_size: Override default batch size. + + Returns: + An iterable of predicted probabilities with shape [batch_size, n_classes]. + """ + key = prediction_key.PredictionKey.PROBABILITIES + preds = self.predict( + input_fn=input_fn, + batch_size=batch_size, + outputs=[key]) + return (pred[key] for pred in preds) + + +class DebugRegressor(estimator.Estimator): + """A regressor for TensorFlow Debug models. + + Example: + + ```python + + # Build DebugRegressor + regressor = DebugRegressor() + + # Input builders + def input_fn_train: # returns x, y (where y represents label's class index). + pass + + def input_fn_eval: # returns x, y (where y represents label's class index). + pass + + # Fit model. + regressor.fit(input_fn=input_fn_train) + + # Evaluate squared-loss between the test and train targets. + loss = regressor.evaluate(input_fn=input_fn_eval)["loss"] + + # predict_scores outputs mean value seen during training. + predicted_targets = regressor.predict_scores(new_samples) + ``` + + Input of `fit` and `evaluate` should have following features, + otherwise there will be a `KeyError`: + + * if `weight_column_name` is not `None`, a feature with + `key=weight_column_name` whose value is a `Tensor`. + """ + + def __init__(self, + model_dir=None, + label_dimension=1, + weight_column_name=None, + config=None, + feature_engineering_fn=None): + """Initializes a DebugRegressor instance. + + Args: + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator to + continue training a previously saved model. + label_dimension: Number of regression targets per example. This is the + size of the last dimension of the labels and logits `Tensor` objects + (typically, these have shape `[batch_size, label_dimension]`). + weight_column_name: A string defining feature column name representing + weights. It is used to down weight or boost examples during training. It + will be multiplied by the loss of the example. + config: `RunConfig` object to configure the runtime settings. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns + features and labels which will be fed into the model. + Returns: + A `DebugRegressor` estimator. + """ + + params = { + "head": + head_lib._regression_head( # pylint: disable=protected-access + weight_column_name=weight_column_name, + label_dimension=label_dimension, + enable_centered_bias=True) + } + + super(DebugRegressor, self).__init__( + model_fn=debug_model_fn, + model_dir=model_dir, + config=config, + params=params, + feature_engineering_fn=feature_engineering_fn) + + def predict_scores(self, input_fn=None, batch_size=None): + """Returns predicted scores for given features. + + Args: + input_fn: Input function. + batch_size: Override default batch size. + + Returns: + An iterable of predicted scores. + """ + key = prediction_key.PredictionKey.SCORES + preds = self.predict( + input_fn=input_fn, batch_size=batch_size, outputs=[key]) + return (pred[key] for pred in preds) diff --git a/tensorflow/contrib/learn/python/learn/estimators/debug_test.py b/tensorflow/contrib/learn/python/learn/estimators/debug_test.py new file mode 100644 index 00000000000..935e66ee8ca --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/estimators/debug_test.py @@ -0,0 +1,840 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Debug estimators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import functools +import operator +import tempfile + +import numpy as np + +from tensorflow.contrib.layers.python.layers import feature_column +from tensorflow.contrib.layers.python.layers import feature_column_ops +from tensorflow.contrib.learn.python.learn import experiment +from tensorflow.contrib.learn.python.learn.datasets import base +from tensorflow.contrib.learn.python.learn.estimators import _sklearn +from tensorflow.contrib.learn.python.learn.estimators import debug +from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils +from tensorflow.contrib.learn.python.learn.estimators import run_config +from tensorflow.contrib.learn.python.learn.estimators import test_data +from tensorflow.contrib.learn.python.learn.metric_spec import MetricSpec +from tensorflow.contrib.metrics.python.ops import metric_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test +from tensorflow.python.training import input as input_lib + + +NUM_EXAMPLES = 100 +N_CLASSES = 5 # Cardinality of multiclass labels. +LABEL_DIMENSION = 3 # Dimensionality of regression labels. + + +def _train_test_split(features_and_labels): + features, labels = features_and_labels + train_set = (features[:len(features) / 2], labels[:len(features) / 2]) + test_set = (features[len(features) / 2:], labels[len(features) / 2:]) + return train_set, test_set + + +def _input_fn_builder(features, labels): + + def input_fn(): + feature_dict = {'features': constant_op.constant(features)} + my_labels = labels + if my_labels is not None: + my_labels = constant_op.constant(my_labels) + return feature_dict, my_labels + + return input_fn + + +class DebugClassifierTest(test.TestCase): + + def setUp(self): + np.random.seed(100) + self.features = np.random.rand(NUM_EXAMPLES, 5) + self.labels = np.random.choice( + range(N_CLASSES), p=[0.1, 0.3, 0.4, 0.1, 0.1], size=NUM_EXAMPLES) + self.binary_labels = np.random.choice( + range(2), p=[0.2, 0.8], size=NUM_EXAMPLES) + self.binary_float_labels = np.random.choice( + range(2), p=[0.2, 0.8], size=NUM_EXAMPLES) + + def testPredict(self): + """Tests that DebugClassifier outputs the majority class.""" + (train_features, train_labels), (test_features, + test_labels) = _train_test_split( + [self.features, self.labels]) + majority_class, _ = max(collections.Counter(train_labels).items(), + key=operator.itemgetter(1)) + expected_prediction = np.vstack( + [[majority_class] for _ in range(test_labels.shape[0])]) + + classifier = debug.DebugClassifier(n_classes=N_CLASSES) + classifier.fit(input_fn=_input_fn_builder(train_features, train_labels), + steps=50) + + pred = classifier.predict_classes(input_fn=_input_fn_builder(test_features, + None)) + self.assertAllEqual(expected_prediction, np.vstack(pred)) + + def testPredictBinary(self): + """Same as above for binary predictions.""" + (train_features, train_labels), (test_features, + test_labels) = _train_test_split( + [self.features, self.binary_labels]) + + majority_class, _ = max(collections.Counter(train_labels).items(), + key=operator.itemgetter(1)) + expected_prediction = np.vstack( + [[majority_class] for _ in range(test_labels.shape[0])]) + + classifier = debug.DebugClassifier(n_classes=2) + classifier.fit(input_fn=_input_fn_builder(train_features, train_labels), + steps=50) + + pred = classifier.predict_classes(input_fn=_input_fn_builder(test_features, + None)) + self.assertAllEqual(expected_prediction, np.vstack(pred)) + + (train_features, train_labels), ( + test_features, test_labels) = _train_test_split( + [self.features, self.binary_float_labels]) + + majority_class, _ = max(collections.Counter(train_labels).items(), + key=operator.itemgetter(1)) + expected_prediction = np.vstack( + [[majority_class] for _ in range(test_labels.shape[0])]) + + classifier = debug.DebugClassifier(n_classes=2) + classifier.fit(input_fn=_input_fn_builder(train_features, train_labels), + steps=50) + + pred = classifier.predict_classes(input_fn=_input_fn_builder(test_features, + None)) + self.assertAllEqual(expected_prediction, np.vstack(pred)) + + def testPredictProba(self): + """Tests that DebugClassifier outputs observed class distribution.""" + (train_features, train_labels), (test_features, + test_labels) = _train_test_split( + [self.features, self.labels]) + + class_distribution = np.zeros((1, N_CLASSES)) + for label in train_labels: + class_distribution[0, label] += 1 + class_distribution /= len(train_labels) + + expected_prediction = np.vstack( + [class_distribution for _ in range(test_labels.shape[0])]) + + classifier = debug.DebugClassifier(n_classes=N_CLASSES) + classifier.fit(input_fn=_input_fn_builder(train_features, train_labels), + steps=50) + + pred = classifier.predict_proba( + input_fn=_input_fn_builder(test_features, None)) + + self.assertAllClose(expected_prediction, np.vstack(pred), atol=0.1) + + def testPredictProbaBinary(self): + """Same as above but for binary classification.""" + (train_features, train_labels), (test_features, + test_labels) = _train_test_split( + [self.features, self.binary_labels]) + + class_distribution = np.zeros((1, 2)) + for label in train_labels: + class_distribution[0, label] += 1 + class_distribution /= len(train_labels) + + expected_prediction = np.vstack( + [class_distribution for _ in range(test_labels.shape[0])]) + + classifier = debug.DebugClassifier(n_classes=2) + classifier.fit(input_fn=_input_fn_builder(train_features, train_labels), + steps=50) + + pred = classifier.predict_proba( + input_fn=_input_fn_builder(test_features, None)) + + self.assertAllClose(expected_prediction, np.vstack(pred), atol=0.1) + + (train_features, train_labels), ( + test_features, test_labels) = _train_test_split( + [self.features, self.binary_float_labels]) + + class_distribution = np.zeros((1, 2)) + for label in train_labels: + class_distribution[0, int(label)] += 1 + class_distribution /= len(train_labels) + + expected_prediction = np.vstack( + [class_distribution for _ in range(test_labels.shape[0])]) + + classifier = debug.DebugClassifier(n_classes=2) + classifier.fit(input_fn=_input_fn_builder(train_features, train_labels), + steps=50) + + pred = classifier.predict_proba( + input_fn=_input_fn_builder(test_features, None)) + + self.assertAllClose(expected_prediction, np.vstack(pred), atol=0.1) + + def testExperimentIntegration(self): + exp = experiment.Experiment( + estimator=debug.DebugClassifier(n_classes=3), + train_input_fn=test_data.iris_input_multiclass_fn, + eval_input_fn=test_data.iris_input_multiclass_fn) + exp.test() + + def _assertInRange(self, expected_min, expected_max, actual): + self.assertLessEqual(expected_min, actual) + self.assertGreaterEqual(expected_max, actual) + + def testEstimatorContract(self): + estimator_test_utils.assert_estimator_contract(self, debug.DebugClassifier) + + def testLogisticRegression_MatrixData(self): + """Tests binary classification using matrix data as input.""" + classifier = debug.DebugClassifier( + config=run_config.RunConfig(tf_random_seed=1)) + input_fn = test_data.iris_input_logistic_fn + classifier.fit(input_fn=input_fn, steps=5) + scores = classifier.evaluate(input_fn=input_fn, steps=1) + self._assertInRange(0.0, 1.0, scores['accuracy']) + self.assertIn('loss', scores) + + def testLogisticRegression_MatrixData_Labels1D(self): + """Same as the last test, but label shape is [100] instead of [100, 1].""" + + def _input_fn(): + iris = test_data.prepare_iris_data_for_logistic_regression() + return { + 'feature': constant_op.constant( + iris.data, dtype=dtypes.float32) + }, constant_op.constant( + iris.target, shape=[100], dtype=dtypes.int32) + + classifier = debug.DebugClassifier(config=run_config.RunConfig( + tf_random_seed=1)) + classifier.fit(input_fn=_input_fn, steps=5) + scores = classifier.evaluate(input_fn=_input_fn, steps=1) + self.assertIn('loss', scores) + + def testLogisticRegression_NpMatrixData(self): + """Tests binary classification using numpy matrix data as input.""" + iris = test_data.prepare_iris_data_for_logistic_regression() + train_x = iris.data + train_y = iris.target + classifier = debug.DebugClassifier( + config=run_config.RunConfig(tf_random_seed=1)) + classifier.fit(x=train_x, y=train_y, steps=5) + scores = classifier.evaluate(x=train_x, y=train_y, steps=1) + self._assertInRange(0.0, 1.0, scores['accuracy']) + + def _assertBinaryPredictions(self, expected_len, predictions): + self.assertEqual(expected_len, len(predictions)) + for prediction in predictions: + self.assertIn(prediction, (0, 1)) + + def _assertProbabilities(self, expected_batch_size, expected_n_classes, + probabilities): + self.assertEqual(expected_batch_size, len(probabilities)) + for b in range(expected_batch_size): + self.assertEqual(expected_n_classes, len(probabilities[b])) + for i in range(expected_n_classes): + self._assertInRange(0.0, 1.0, probabilities[b][i]) + + def testLogisticRegression_TensorData(self): + """Tests binary classification using tensor data as input.""" + + def _input_fn(num_epochs=None): + features = { + 'age': + input_lib.limit_epochs( + constant_op.constant([[.8], [0.2], [.1]]), + num_epochs=num_epochs), + 'language': + sparse_tensor.SparseTensor( + values=input_lib.limit_epochs( + ['en', 'fr', 'zh'], num_epochs=num_epochs), + indices=[[0, 0], [0, 1], [2, 0]], + dense_shape=[3, 2]) + } + return features, constant_op.constant([[1], [0], [0]], dtype=dtypes.int32) + + classifier = debug.DebugClassifier(n_classes=2) + + classifier.fit(input_fn=_input_fn, steps=50) + + scores = classifier.evaluate(input_fn=_input_fn, steps=1) + self._assertInRange(0.0, 1.0, scores['accuracy']) + self.assertIn('loss', scores) + predict_input_fn = functools.partial(_input_fn, num_epochs=1) + predictions = list(classifier.predict_classes(input_fn=predict_input_fn)) + self._assertBinaryPredictions(3, predictions) + + def testLogisticRegression_FloatLabel(self): + """Tests binary classification with float labels.""" + + def _input_fn_float_label(num_epochs=None): + features = { + 'age': + input_lib.limit_epochs( + constant_op.constant([[50], [20], [10]]), + num_epochs=num_epochs), + 'language': + sparse_tensor.SparseTensor( + values=input_lib.limit_epochs( + ['en', 'fr', 'zh'], num_epochs=num_epochs), + indices=[[0, 0], [0, 1], [2, 0]], + dense_shape=[3, 2]) + } + labels = constant_op.constant([[0.8], [0.], [0.2]], dtype=dtypes.float32) + return features, labels + + classifier = debug.DebugClassifier(n_classes=2) + + classifier.fit(input_fn=_input_fn_float_label, steps=50) + + predict_input_fn = functools.partial(_input_fn_float_label, num_epochs=1) + predictions = list(classifier.predict_classes(input_fn=predict_input_fn)) + self._assertBinaryPredictions(3, predictions) + predictions_proba = list( + classifier.predict_proba(input_fn=predict_input_fn)) + self._assertProbabilities(3, 2, predictions_proba) + + def testMultiClass_MatrixData(self): + """Tests multi-class classification using matrix data as input.""" + classifier = debug.DebugClassifier(n_classes=3) + + input_fn = test_data.iris_input_multiclass_fn + classifier.fit(input_fn=input_fn, steps=200) + scores = classifier.evaluate(input_fn=input_fn, steps=1) + self._assertInRange(0.0, 1.0, scores['accuracy']) + self.assertIn('loss', scores) + + def testMultiClass_MatrixData_Labels1D(self): + """Same as the last test, but label shape is [150] instead of [150, 1].""" + + def _input_fn(): + iris = base.load_iris() + return { + 'feature': constant_op.constant( + iris.data, dtype=dtypes.float32) + }, constant_op.constant( + iris.target, shape=[150], dtype=dtypes.int32) + + classifier = debug.DebugClassifier(n_classes=3) + + classifier.fit(input_fn=_input_fn, steps=200) + scores = classifier.evaluate(input_fn=_input_fn, steps=1) + self._assertInRange(0.0, 1.0, scores['accuracy']) + + def testMultiClass_NpMatrixData(self): + """Tests multi-class classification using numpy matrix data as input.""" + iris = base.load_iris() + train_x = iris.data + train_y = iris.target + classifier = debug.DebugClassifier(n_classes=3) + classifier.fit(x=train_x, y=train_y, steps=200) + scores = classifier.evaluate(x=train_x, y=train_y, steps=1) + self._assertInRange(0.0, 1.0, scores['accuracy']) + + def testLoss(self): + """Tests loss calculation.""" + + def _input_fn_train(): + # Create 4 rows, one of them (y = x), three of them (y=Not(x)) + # The logistic prediction should be (y = 0.25). + labels = constant_op.constant([[1], [0], [0], [0]]) + features = {'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32),} + return features, labels + + classifier = debug.DebugClassifier(n_classes=2) + + classifier.fit(input_fn=_input_fn_train, steps=5) + scores = classifier.evaluate(input_fn=_input_fn_train, steps=1) + self.assertIn('loss', scores) + + def testLossWithWeights(self): + """Tests loss calculation with weights.""" + + def _input_fn_train(): + # 4 rows with equal weight, one of them (y = x), three of them (y=Not(x)) + # The logistic prediction should be (y = 0.25). + labels = constant_op.constant([[1.], [0.], [0.], [0.]]) + features = { + 'x': array_ops.ones( + shape=[4, 1], dtype=dtypes.float32), + 'w': constant_op.constant([[1.], [1.], [1.], [1.]]) + } + return features, labels + + def _input_fn_eval(): + # 4 rows, with different weights. + labels = constant_op.constant([[1.], [0.], [0.], [0.]]) + features = { + 'x': array_ops.ones( + shape=[4, 1], dtype=dtypes.float32), + 'w': constant_op.constant([[7.], [1.], [1.], [1.]]) + } + return features, labels + + classifier = debug.DebugClassifier( + weight_column_name='w', + n_classes=2, + config=run_config.RunConfig(tf_random_seed=1)) + + classifier.fit(input_fn=_input_fn_train, steps=5) + scores = classifier.evaluate(input_fn=_input_fn_eval, steps=1) + self.assertIn('loss', scores) + + def testTrainWithWeights(self): + """Tests training with given weight column.""" + + def _input_fn_train(): + # Create 4 rows, one of them (y = x), three of them (y=Not(x)) + # First row has more weight than others. Model should fit (y=x) better + # than (y=Not(x)) due to the relative higher weight of the first row. + labels = constant_op.constant([[1], [0], [0], [0]]) + features = { + 'x': array_ops.ones( + shape=[4, 1], dtype=dtypes.float32), + 'w': constant_op.constant([[100.], [3.], [2.], [2.]]) + } + return features, labels + + def _input_fn_eval(): + # Create 4 rows (y = x) + labels = constant_op.constant([[1], [1], [1], [1]]) + features = { + 'x': array_ops.ones( + shape=[4, 1], dtype=dtypes.float32), + 'w': constant_op.constant([[1.], [1.], [1.], [1.]]) + } + return features, labels + + classifier = debug.DebugClassifier(weight_column_name='w') + + classifier.fit(input_fn=_input_fn_train, steps=5) + scores = classifier.evaluate(input_fn=_input_fn_eval, steps=1) + self._assertInRange(0.0, 1.0, scores['accuracy']) + + def testCustomMetrics(self): + """Tests custom evaluation metrics.""" + + def _input_fn(num_epochs=None): + # Create 4 rows, one of them (y = x), three of them (y=Not(x)) + labels = constant_op.constant([[1], [0], [0], [0]]) + features = { + 'x': + input_lib.limit_epochs( + array_ops.ones( + shape=[4, 1], dtype=dtypes.float32), + num_epochs=num_epochs), + } + return features, labels + + def _my_metric_op(predictions, labels): + # For the case of binary classification, the 2nd column of "predictions" + # denotes the model predictions. + labels = math_ops.to_float(labels) + predictions = array_ops.strided_slice( + predictions, [0, 1], [-1, 2], end_mask=1) + labels = math_ops.cast(labels, predictions.dtype) + return math_ops.reduce_sum(math_ops.multiply(predictions, labels)) + + classifier = debug.DebugClassifier( + config=run_config.RunConfig(tf_random_seed=1)) + + classifier.fit(input_fn=_input_fn, steps=5) + scores = classifier.evaluate( + input_fn=_input_fn, + steps=5, + metrics={ + 'my_accuracy': + MetricSpec( + metric_fn=metric_ops.streaming_accuracy, + prediction_key='classes'), + 'my_precision': + MetricSpec( + metric_fn=metric_ops.streaming_precision, + prediction_key='classes'), + 'my_metric': + MetricSpec( + metric_fn=_my_metric_op, prediction_key='probabilities') + }) + self.assertTrue( + set(['loss', 'my_accuracy', 'my_precision', 'my_metric']).issubset( + set(scores.keys()))) + predict_input_fn = functools.partial(_input_fn, num_epochs=1) + predictions = np.array( + list(classifier.predict_classes(input_fn=predict_input_fn))) + self.assertEqual( + _sklearn.accuracy_score([1, 0, 0, 0], predictions), + scores['my_accuracy']) + + # Test the case where the 2nd element of the key is neither "classes" nor + # "probabilities". + with self.assertRaisesRegexp(KeyError, 'bad_type'): + classifier.evaluate( + input_fn=_input_fn, + steps=5, + metrics={ + 'bad_name': + MetricSpec( + metric_fn=metric_ops.streaming_auc, + prediction_key='bad_type') + }) + + def testTrainSaveLoad(self): + """Tests that insures you can save and reload a trained model.""" + + def _input_fn(num_epochs=None): + features = { + 'age': + input_lib.limit_epochs( + constant_op.constant([[.8], [.2], [.1]]), + num_epochs=num_epochs), + 'language': + sparse_tensor.SparseTensor( + values=input_lib.limit_epochs( + ['en', 'fr', 'zh'], num_epochs=num_epochs), + indices=[[0, 0], [0, 1], [2, 0]], + dense_shape=[3, 2]) + } + return features, constant_op.constant([[1], [0], [0]], dtype=dtypes.int32) + + model_dir = tempfile.mkdtemp() + classifier = debug.DebugClassifier( + model_dir=model_dir, + n_classes=3, + config=run_config.RunConfig(tf_random_seed=1)) + + classifier.fit(input_fn=_input_fn, steps=5) + predict_input_fn = functools.partial(_input_fn, num_epochs=1) + predictions1 = classifier.predict_classes(input_fn=predict_input_fn) + del classifier + + classifier2 = debug.DebugClassifier( + model_dir=model_dir, + n_classes=3, + config=run_config.RunConfig(tf_random_seed=1)) + predictions2 = classifier2.predict_classes(input_fn=predict_input_fn) + self.assertEqual(list(predictions1), list(predictions2)) + + def testExport(self): + """Tests export model for servo.""" + + def input_fn(): + return { + 'age': + constant_op.constant([1]), + 'language': + sparse_tensor.SparseTensor( + values=['english'], indices=[[0, 0]], dense_shape=[1, 1]) + }, constant_op.constant([[1]]) + + language = feature_column.sparse_column_with_hash_bucket('language', 100) + feature_columns = [ + feature_column.real_valued_column('age'), + feature_column.embedding_column( + language, dimension=1) + ] + + classifier = debug.DebugClassifier(config=run_config.RunConfig( + tf_random_seed=1)) + classifier.fit(input_fn=input_fn, steps=5) + + def default_input_fn(unused_estimator, examples): + return feature_column_ops.parse_feature_columns_from_examples( + examples, feature_columns) + + export_dir = tempfile.mkdtemp() + classifier.export(export_dir, input_fn=default_input_fn) + + +class DebugRegressorTest(test.TestCase): + + def setUp(self): + np.random.seed(100) + self.features = np.random.rand(NUM_EXAMPLES, 5) + self.targets = np.random.rand(NUM_EXAMPLES, LABEL_DIMENSION) + + def testPredictScores(self): + """Tests that DebugRegressor outputs the mean target.""" + (train_features, train_labels), (test_features, + test_labels) = _train_test_split( + [self.features, self.targets]) + mean_target = np.mean(train_labels, 0) + expected_prediction = np.vstack( + [mean_target for _ in range(test_labels.shape[0])]) + + classifier = debug.DebugRegressor(label_dimension=LABEL_DIMENSION) + classifier.fit( + input_fn=_input_fn_builder(train_features, train_labels), steps=50) + + pred = classifier.predict_scores(input_fn=_input_fn_builder(test_features, + None)) + self.assertAllClose(expected_prediction, np.vstack(pred), atol=0.1) + + def testExperimentIntegration(self): + exp = experiment.Experiment( + estimator=debug.DebugRegressor(), + train_input_fn=test_data.iris_input_logistic_fn, + eval_input_fn=test_data.iris_input_logistic_fn) + exp.test() + + def testEstimatorContract(self): + estimator_test_utils.assert_estimator_contract(self, debug.DebugRegressor) + + def testRegression_MatrixData(self): + """Tests regression using matrix data as input.""" + regressor = debug.DebugRegressor( + config=run_config.RunConfig(tf_random_seed=1)) + input_fn = test_data.iris_input_logistic_fn + regressor.fit(input_fn=input_fn, steps=200) + scores = regressor.evaluate(input_fn=input_fn, steps=1) + self.assertIn('loss', scores) + + def testRegression_MatrixData_Labels1D(self): + """Same as the last test, but label shape is [100] instead of [100, 1].""" + + def _input_fn(): + iris = test_data.prepare_iris_data_for_logistic_regression() + return { + 'feature': constant_op.constant(iris.data, dtype=dtypes.float32) + }, constant_op.constant( + iris.target, shape=[100], dtype=dtypes.int32) + + regressor = debug.DebugRegressor( + config=run_config.RunConfig(tf_random_seed=1)) + + regressor.fit(input_fn=_input_fn, steps=200) + scores = regressor.evaluate(input_fn=_input_fn, steps=1) + self.assertIn('loss', scores) + + def testRegression_NpMatrixData(self): + """Tests binary classification using numpy matrix data as input.""" + iris = test_data.prepare_iris_data_for_logistic_regression() + train_x = iris.data + train_y = iris.target + regressor = debug.DebugRegressor( + config=run_config.RunConfig(tf_random_seed=1)) + regressor.fit(x=train_x, y=train_y, steps=200) + scores = regressor.evaluate(x=train_x, y=train_y, steps=1) + self.assertIn('loss', scores) + + def testRegression_TensorData(self): + """Tests regression using tensor data as input.""" + + def _input_fn(num_epochs=None): + features = { + 'age': + input_lib.limit_epochs( + constant_op.constant([[.8], [.15], [0.]]), + num_epochs=num_epochs), + 'language': + sparse_tensor.SparseTensor( + values=input_lib.limit_epochs( + ['en', 'fr', 'zh'], num_epochs=num_epochs), + indices=[[0, 0], [0, 1], [2, 0]], + dense_shape=[3, 2]) + } + return features, constant_op.constant([1., 0., 0.2], dtype=dtypes.float32) + + regressor = debug.DebugRegressor( + config=run_config.RunConfig(tf_random_seed=1)) + + regressor.fit(input_fn=_input_fn, steps=200) + + scores = regressor.evaluate(input_fn=_input_fn, steps=1) + self.assertIn('loss', scores) + + def testLoss(self): + """Tests loss calculation.""" + + def _input_fn_train(): + # Create 4 rows, one of them (y = x), three of them (y=Not(x)) + # The algorithm should learn (y = 0.25). + labels = constant_op.constant([[1.], [0.], [0.], [0.]]) + features = {'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32),} + return features, labels + + regressor = debug.DebugRegressor( + config=run_config.RunConfig(tf_random_seed=1)) + + regressor.fit(input_fn=_input_fn_train, steps=5) + scores = regressor.evaluate(input_fn=_input_fn_train, steps=1) + self.assertIn('loss', scores) + + def testLossWithWeights(self): + """Tests loss calculation with weights.""" + + def _input_fn_train(): + # 4 rows with equal weight, one of them (y = x), three of them (y=Not(x)) + # The algorithm should learn (y = 0.25). + labels = constant_op.constant([[1.], [0.], [0.], [0.]]) + features = { + 'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32), + 'w': constant_op.constant([[1.], [1.], [1.], [1.]]) + } + return features, labels + + def _input_fn_eval(): + # 4 rows, with different weights. + labels = constant_op.constant([[1.], [0.], [0.], [0.]]) + features = { + 'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32), + 'w': constant_op.constant([[7.], [1.], [1.], [1.]]) + } + return features, labels + + regressor = debug.DebugRegressor( + weight_column_name='w', config=run_config.RunConfig(tf_random_seed=1)) + + regressor.fit(input_fn=_input_fn_train, steps=5) + scores = regressor.evaluate(input_fn=_input_fn_eval, steps=1) + self.assertIn('loss', scores) + + def testTrainWithWeights(self): + """Tests training with given weight column.""" + + def _input_fn_train(): + # Create 4 rows, one of them (y = x), three of them (y=Not(x)) + # First row has more weight than others. Model should fit (y=x) better + # than (y=Not(x)) due to the relative higher weight of the first row. + labels = constant_op.constant([[1.], [0.], [0.], [0.]]) + features = { + 'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32), + 'w': constant_op.constant([[100.], [3.], [2.], [2.]]) + } + return features, labels + + def _input_fn_eval(): + # Create 4 rows (y = x) + labels = constant_op.constant([[1.], [1.], [1.], [1.]]) + features = { + 'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32), + 'w': constant_op.constant([[1.], [1.], [1.], [1.]]) + } + return features, labels + + regressor = debug.DebugRegressor( + weight_column_name='w', config=run_config.RunConfig(tf_random_seed=1)) + + regressor.fit(input_fn=_input_fn_train, steps=5) + scores = regressor.evaluate(input_fn=_input_fn_eval, steps=1) + self.assertIn('loss', scores) + + def testCustomMetrics(self): + """Tests custom evaluation metrics.""" + + def _input_fn(num_epochs=None): + # Create 4 rows, one of them (y = x), three of them (y=Not(x)) + labels = constant_op.constant([[1.], [0.], [0.], [0.]]) + features = { + 'x': + input_lib.limit_epochs( + array_ops.ones(shape=[4, 1], dtype=dtypes.float32), + num_epochs=num_epochs), + } + return features, labels + + def _my_metric_op(predictions, labels): + return math_ops.reduce_sum(math_ops.multiply(predictions, labels)) + + regressor = debug.DebugRegressor( + config=run_config.RunConfig(tf_random_seed=1)) + + regressor.fit(input_fn=_input_fn, steps=5) + scores = regressor.evaluate( + input_fn=_input_fn, + steps=1, + metrics={ + 'my_error': + MetricSpec( + metric_fn=metric_ops.streaming_mean_squared_error, + prediction_key='scores'), + 'my_metric': + MetricSpec(metric_fn=_my_metric_op, prediction_key='scores') + }) + self.assertIn('loss', set(scores.keys())) + self.assertIn('my_error', set(scores.keys())) + self.assertIn('my_metric', set(scores.keys())) + predict_input_fn = functools.partial(_input_fn, num_epochs=1) + predictions = np.array( + list(regressor.predict_scores(input_fn=predict_input_fn))) + self.assertAlmostEqual( + _sklearn.mean_squared_error(np.array([1, 0, 0, 0]), predictions), + scores['my_error']) + + # Tests the case where the prediction_key is not "scores". + with self.assertRaisesRegexp(KeyError, 'bad_type'): + regressor.evaluate( + input_fn=_input_fn, + steps=1, + metrics={ + 'bad_name': + MetricSpec( + metric_fn=metric_ops.streaming_auc, + prediction_key='bad_type') + }) + + def testTrainSaveLoad(self): + """Tests that insures you can save and reload a trained model.""" + + def _input_fn(num_epochs=None): + features = { + 'age': + input_lib.limit_epochs( + constant_op.constant([[0.8], [0.15], [0.]]), + num_epochs=num_epochs), + 'language': + sparse_tensor.SparseTensor( + values=input_lib.limit_epochs( + ['en', 'fr', 'zh'], num_epochs=num_epochs), + indices=[[0, 0], [0, 1], [2, 0]], + dense_shape=[3, 2]) + } + return features, constant_op.constant([1., 0., 0.2], dtype=dtypes.float32) + + model_dir = tempfile.mkdtemp() + regressor = debug.DebugRegressor( + model_dir=model_dir, config=run_config.RunConfig(tf_random_seed=1)) + + regressor.fit(input_fn=_input_fn, steps=5) + predict_input_fn = functools.partial(_input_fn, num_epochs=1) + predictions = list(regressor.predict_scores(input_fn=predict_input_fn)) + del regressor + + regressor2 = debug.DebugRegressor( + model_dir=model_dir, config=run_config.RunConfig(tf_random_seed=1)) + predictions2 = list(regressor2.predict_scores(input_fn=predict_input_fn)) + self.assertAllClose(predictions, predictions2) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py index 61208ba24e1..bdb88b89bb3 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py @@ -75,10 +75,10 @@ def read_batch_examples(file_pattern, `tf.local_variables_initializer()` and run the op in a session. queue_capacity: Capacity for input queue. num_threads: The number of threads enqueuing examples. In order to have - predicted and repeatable order of reading and enqueueing, such as in + predictable and repeatable order of reading and enqueueing, such as in prediction and evaluation mode, `num_threads` should be 1. read_batch_size: An int or scalar `Tensor` specifying the number of - records to read at once + records to read at once. parse_fn: Parsing function, takes `Example` Tensor returns parsed representation. If `None`, no parsing is done. name: Name of resulting op. @@ -142,10 +142,10 @@ def read_keyed_batch_examples(file_pattern, `tf.local_variables_initializer()` and run the op in a session. queue_capacity: Capacity for input queue. num_threads: The number of threads enqueuing examples. In order to have - predicted and repeatable order of reading and enqueueing, such as in + predictable and repeatable order of reading and enqueueing, such as in prediction and evaluation mode, `num_threads` should be 1. read_batch_size: An int or scalar `Tensor` specifying the number of - records to read at once + records to read at once. parse_fn: Parsing function, takes `Example` Tensor returns parsed representation. If `None`, no parsing is done. name: Name of resulting op. @@ -217,7 +217,7 @@ def read_keyed_batch_examples_shared_queue(file_pattern, queue_capacity: Capacity for input queue. num_threads: The number of threads enqueuing examples. read_batch_size: An int or scalar `Tensor` specifying the number of - records to read at once + records to read at once. parse_fn: Parsing function, takes `Example` Tensor returns parsed representation. If `None`, no parsing is done. name: Name of resulting op. @@ -335,7 +335,7 @@ def _read_keyed_batch_examples_helper(file_pattern, queue_capacity: Capacity for input queue. num_threads: The number of threads enqueuing examples. read_batch_size: An int or scalar `Tensor` specifying the number of - records to read at once + records to read at once. filter_fn: Filtering function, takes both keys as well `Example` Tensors and returns a boolean mask of the same shape as the input Tensors to be applied for filtering. If `None`, no filtering is done. @@ -470,13 +470,15 @@ def read_keyed_batch_features(file_pattern, tf.local_variables_initializer() and run the op in a session. queue_capacity: Capacity for input queue. reader_num_threads: The number of threads to read examples. In order to have - predicted and repeatable order of reading and enqueueing, such as in + predictable and repeatable order of reading and enqueueing, such as in prediction and evaluation mode, `reader_num_threads` should be 1. feature_queue_capacity: Capacity of the parsed features queue. num_enqueue_threads: Number of threads to enqueue the parsed example queue. Using multiple threads to enqueue the parsed example queue helps maintain a full queue when the subsequent computations overall are cheaper than - parsing. + parsing. In order to have predictable and repeatable order of reading and + enqueueing, such as in prediction and evaluation mode, + `num_enqueue_threads` should be 1. parse_fn: Parsing function, takes `Example` Tensor returns parsed representation. If `None`, no parsing is done. name: Name of resulting op. @@ -617,7 +619,9 @@ def queue_parsed_features(parsed_features, num_enqueue_threads: Number of threads to enqueue the parsed example queue. Using multiple threads to enqueue the parsed example queue helps maintain a full queue when the subsequent computations overall are cheaper than - parsing. + parsing. In order to have predictable and repeatable order of reading and + enqueueing, such as in prediction and evaluation mode, + `num_enqueue_threads` should be 1. name: Name of resulting op. Returns: @@ -721,6 +725,7 @@ def read_batch_features(file_pattern, queue_capacity=10000, feature_queue_capacity=100, reader_num_threads=1, + num_enqueue_threads=2, parse_fn=None, name=None): """Adds operations to read, queue, batch and parse `Example` protos. @@ -752,8 +757,14 @@ def read_batch_features(file_pattern, feature_queue_capacity: Capacity of the parsed features queue. Set this value to a small number, for example 5 if the parsed features are large. reader_num_threads: The number of threads to read examples. In order to have - predicted and repeatable order of reading and enqueueing, such as in + predictable and repeatable order of reading and enqueueing, such as in prediction and evaluation mode, `reader_num_threads` should be 1. + num_enqueue_threads: Number of threads to enqueue the parsed example queue. + Using multiple threads to enqueue the parsed example queue helps maintain + a full queue when the subsequent computations overall are cheaper than + parsing. In order to have predictable and repeatable order of reading and + enqueueing, such as in prediction and evaluation mode, + `num_enqueue_threads` should be 1. parse_fn: Parsing function, takes `Example` Tensor returns parsed representation. If `None`, no parsing is done. name: Name of resulting op. @@ -772,8 +783,9 @@ def read_batch_features(file_pattern, randomize_input=randomize_input, num_epochs=num_epochs, queue_capacity=queue_capacity, - feature_queue_capacity=feature_queue_capacity, reader_num_threads=reader_num_threads, + feature_queue_capacity=feature_queue_capacity, + num_enqueue_threads=num_enqueue_threads, parse_fn=parse_fn, name=name) return features @@ -804,7 +816,7 @@ def read_batch_record_features(file_pattern, tf.local_variables_initializer() and run the op in a session. queue_capacity: Capacity for input queue. reader_num_threads: The number of threads to read examples. In order to have - predicted and repeatable order of reading and enqueueing, such as in + predictable and repeatable order of reading and enqueueing, such as in prediction and evaluation mode, `reader_num_threads` should be 1. name: Name of resulting op. diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py index f25f7caf615..6f0fd9a2976 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py @@ -350,6 +350,16 @@ class GraphIOTest(test.TestCase): coord.request_stop() coord.join(threads) + def _create_file_from_list_of_features(self, lines): + json_lines = [ + "".join([ + '{"features": { "feature": { "sequence": {', + '"bytes_list": { "value": ["', base64.b64encode(l).decode("ascii"), + '"]}}}}}\n' + ]) for l in lines + ] + return self._create_temp_file("".join(json_lines)) + def test_read_text_lines_large(self): gfile.Glob = self._orig_glob sequence_prefix = "abcdefghijklmnopqrstuvwxyz123456789" @@ -358,14 +368,7 @@ class GraphIOTest(test.TestCase): "".join([sequence_prefix, str(l)]).encode("ascii") for l in xrange(num_records) ] - json_lines = [ - "".join([ - '{"features": { "feature": { "sequence": {', - '"bytes_list": { "value": ["', base64.b64encode(l).decode("ascii"), - '"]}}}}}\n' - ]) for l in lines - ] - filename = self._create_temp_file("".join(json_lines)) + filename = self._create_file_from_list_of_features(lines) batch_size = 10000 queue_capacity = 100000 name = "my_large_batch" @@ -410,6 +413,61 @@ class GraphIOTest(test.TestCase): self.assertEqual(len(parsed_records), num_records) self.assertEqual(set(parsed_records), set(lines)) + def test_read_batch_features_maintains_order(self): + """Make sure that examples are read in the right order. + + When randomize_input=False, num_enqueue_threads=1 and reader_num_threads=1 + read_keyed_batch_features() should read the examples in the same order as + they appear in the file. + """ + gfile.Glob = self._orig_glob + num_records = 1000 + lines = ["".join(str(l)).encode("ascii") for l in xrange(num_records)] + filename = self._create_file_from_list_of_features(lines) + batch_size = 10 + queue_capacity = 1000 + name = "my_large_batch" + + features = {"sequence": parsing_ops.FixedLenFeature([], dtypes_lib.string)} + + with ops.Graph().as_default() as g, self.test_session(graph=g) as session: + result = graph_io.read_batch_features( + filename, + batch_size, + features, + io_ops.TextLineReader, + randomize_input=False, + num_epochs=1, + queue_capacity=queue_capacity, + reader_num_threads=1, + num_enqueue_threads=1, + parse_fn=parsing_ops.decode_json_example, + name=name) + self.assertEqual(1, len(result)) + self.assertAllEqual((None,), result["sequence"].get_shape().as_list()) + session.run(variables.local_variables_initializer()) + coord = coordinator.Coordinator() + threads = queue_runner_impl.start_queue_runners(session, coord=coord) + + data = [] + try: + while not coord.should_stop(): + data.append(session.run(result)) + except errors.OutOfRangeError: + pass + finally: + coord.request_stop() + + coord.join(threads) + + parsed_records = [ + item for sublist in [d["sequence"] for d in data] for item in sublist + ] + # Check that the number of records matches expected and all records + # are present in the right order. + self.assertEqual(len(parsed_records), num_records) + self.assertEqual(parsed_records, lines) + def test_read_text_lines_multifile(self): gfile.Glob = self._orig_glob filenames = self._create_sorted_temp_files(["ABC\n", "DEF\nGHK\n"]) diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator.py b/tensorflow/contrib/linalg/python/ops/linear_operator.py index 6cdfa861893..91c0938e395 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator.py @@ -126,7 +126,8 @@ class LinearOperator(object): This `LinearOperator` is initialized with boolean flags of the form `is_X`, for `X = non_singular, self_adjoint, positive_definite, square`. - These have the following meaning + These have the following meaning: + * If `is_X == True`, callers should expect the operator to have the property `X`. This is a promise that should be fulfilled, but is *not* a runtime assert. For example, finite floating point precision may result @@ -893,6 +894,23 @@ class LinearOperator(object): with self._name_scope(name): return self._diag_part() + def _trace(self): + return math_ops.reduce_sum(self.diag_part(), axis=-1) + + def trace(self, name="trace"): + """Trace of the linear operator, equal to sum of `self.diag_part()`. + + If the operator is square, this is also the sum of the eigenvalues. + + Args: + name: A name for this `Op`. + + Returns: + Shape `[B1,...,Bb]` `Tensor` of same `dtype` as `self`. + """ + with self._name_scope(name): + return self._trace() + def _add_to_tensor(self, x): # Override if a more efficient implementation is available. return self._get_cached_dense_matrix() + x diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py b/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py index 0853ea03af0..0a71a73a9c5 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_composition.py @@ -97,7 +97,8 @@ class LinearOperatorComposition(linear_operator.LinearOperator): This `LinearOperator` is initialized with boolean flags of the form `is_X`, for `X = non_singular, self_adjoint, positive_definite, square`. - These have the following meaning + These have the following meaning: + * If `is_X == True`, callers should expect the operator to have the property `X`. This is a promise that should be fulfilled, but is *not* a runtime assert. For example, finite floating point precision may result diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py b/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py index 56bc967706a..29184483bf8 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py @@ -98,7 +98,8 @@ class LinearOperatorDiag(linear_operator.LinearOperator): This `LinearOperator` is initialized with boolean flags of the form `is_X`, for `X = non_singular, self_adjoint, positive_definite, square`. - These have the following meaning + These have the following meaning: + * If `is_X == True`, callers should expect the operator to have the property `X`. This is a promise that should be fulfilled, but is *not* a runtime assert. For example, finite floating point precision may result diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_full_matrix.py b/tensorflow/contrib/linalg/python/ops/linear_operator_full_matrix.py index 67889511cbf..52b40eaf8d0 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_full_matrix.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_full_matrix.py @@ -92,7 +92,8 @@ class LinearOperatorFullMatrix(linear_operator.LinearOperator): This `LinearOperator` is initialized with boolean flags of the form `is_X`, for `X = non_singular, self_adjoint, positive_definite, square`. - These have the following meaning + These have the following meaning: + * If `is_X == True`, callers should expect the operator to have the property `X`. This is a promise that should be fulfilled, but is *not* a runtime assert. For example, finite floating point precision may result diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py b/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py index acba1c7035d..b9ac90ff337 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_identity.py @@ -44,17 +44,15 @@ class BaseLinearOperatorIdentity(linear_operator.LinearOperator): """Static check of init arg `num_rows`, possibly add asserts.""" # Possibly add asserts. if self._assert_proper_shapes: - self._num_rows = control_flow_ops.with_dependencies( - [ - check_ops.assert_rank( - self._num_rows, - 0, - message="Argument num_rows must be a 0-D Tensor."), - check_ops.assert_non_negative( - self._num_rows, - message="Argument num_rows must be non-negative."), - ], - self._num_rows) + self._num_rows = control_flow_ops.with_dependencies([ + check_ops.assert_rank( + self._num_rows, + 0, + message="Argument num_rows must be a 0-D Tensor."), + check_ops.assert_non_negative( + self._num_rows, + message="Argument num_rows must be non-negative."), + ], self._num_rows) # Static checks. if not self._num_rows.dtype.is_integer: @@ -74,15 +72,26 @@ class BaseLinearOperatorIdentity(linear_operator.LinearOperator): raise ValueError("Argument num_rows must be non-negative. Found:" " %s" % num_rows_static) + def _min_matrix_dim(self): + """Minimum of domain/range dimension, if statically available, else None.""" + domain_dim = self.domain_dimension.value + range_dim = self.range_dimension.value + if domain_dim is None or range_dim is None: + return None + return min(domain_dim, range_dim) + + def _min_matrix_dim_tensor(self): + """Minimum of domain/range dimension, as a tensor.""" + return math_ops.reduce_min(self.shape_tensor()[-2:]) + def _ones_diag(self): """Returns the diagonal of this operator as all ones.""" if self.shape.is_fully_defined(): - d_shape = self.batch_shape.concatenate( - [min(self.domain_dimension.value, self.range_dimension.value)]) + d_shape = self.batch_shape.concatenate([self._min_matrix_dim()]) else: d_shape = array_ops.concat( [self.batch_shape_tensor(), - [math_ops.reduce_min(self.shape_tensor()[-2:])]], axis=0) + [self._min_matrix_dim_tensor()]], axis=0) return array_ops.ones(shape=d_shape, dtype=self.dtype) @@ -181,7 +190,8 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity): This `LinearOperator` is initialized with boolean flags of the form `is_X`, for `X = non_singular, self_adjoint, positive_definite, square`. - These have the following meaning + These have the following meaning: + * If `is_X == True`, callers should expect the operator to have the property `X`. This is a promise that should be fulfilled, but is *not* a runtime assert. For example, finite floating point precision may result @@ -276,8 +286,8 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity): self._check_batch_shape_possibly_add_asserts() def _shape(self): - matrix_shape = tensor_shape.TensorShape( - (self._num_rows_static, self._num_rows_static)) + matrix_shape = tensor_shape.TensorShape((self._num_rows_static, + self._num_rows_static)) if self._batch_shape_arg is None: return matrix_shape @@ -285,8 +295,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity): return batch_shape.concatenate(matrix_shape) def _shape_tensor(self): - matrix_shape = array_ops.stack( - (self._num_rows, self._num_rows), axis=0) + matrix_shape = array_ops.stack((self._num_rows, self._num_rows), axis=0) if self._batch_shape_arg is None: return matrix_shape @@ -338,8 +347,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity): # Note that adjoint has no effect since this matrix is self-adjoint. x = linear_operator_util.matrix_adjoint(x) if adjoint_arg else x if self._assert_proper_shapes: - aps = linear_operator_util.assert_compatible_matrix_dimensions( - self, x) + aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) x = control_flow_ops.with_dependencies([aps], x) return self._possibly_broadcast_batch_shape(x) @@ -352,6 +360,20 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity): def _solve(self, rhs, adjoint=False, adjoint_arg=False): return self._matmul(rhs, adjoint_arg=adjoint_arg) + def _trace(self): + # Get Tensor of all ones of same shape as self.batch_shape. + if self.batch_shape.is_fully_defined(): + batch_of_ones = array_ops.ones(shape=self.batch_shape, dtype=self.dtype) + else: + batch_of_ones = array_ops.ones( + shape=self.batch_shape_tensor(), dtype=self.dtype) + + if self._min_matrix_dim() is not None: + return self._min_matrix_dim() * batch_of_ones + else: + return (math_ops.cast(self._min_matrix_dim_tensor(), self.dtype) * + batch_of_ones) + def _diag_part(self): return self._ones_diag() @@ -375,17 +397,15 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity): """Static check of init arg `num_rows`, possibly add asserts.""" # Possibly add asserts. if self._assert_proper_shapes: - self._num_rows = control_flow_ops.with_dependencies( - [ - check_ops.assert_rank( - self._num_rows, - 0, - message="Argument num_rows must be a 0-D Tensor."), - check_ops.assert_non_negative( - self._num_rows, - message="Argument num_rows must be non-negative."), - ], - self._num_rows) + self._num_rows = control_flow_ops.with_dependencies([ + check_ops.assert_rank( + self._num_rows, + 0, + message="Argument num_rows must be a 0-D Tensor."), + check_ops.assert_non_negative( + self._num_rows, + message="Argument num_rows must be non-negative."), + ], self._num_rows) # Static checks. if not self._num_rows.dtype.is_integer: @@ -412,17 +432,15 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity): # Possibly add asserts if self._assert_proper_shapes: - self._batch_shape_arg = control_flow_ops.with_dependencies( - [ - check_ops.assert_rank( - self._batch_shape_arg, - 1, - message="Argument batch_shape must be a 1-D Tensor."), - check_ops.assert_non_negative( - self._batch_shape_arg, - message="Argument batch_shape must be non-negative."), - ], - self._batch_shape_arg) + self._batch_shape_arg = control_flow_ops.with_dependencies([ + check_ops.assert_rank( + self._batch_shape_arg, + 1, + message="Argument batch_shape must be a 1-D Tensor."), + check_ops.assert_non_negative( + self._batch_shape_arg, + message="Argument batch_shape must be non-negative."), + ], self._batch_shape_arg) # Static checks if not self._batch_shape_arg.dtype.is_integer: @@ -585,8 +603,7 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity): # Shape [B1,...Bb, 1, 1] self._multiplier_matrix = array_ops.expand_dims( array_ops.expand_dims(self.multiplier, -1), -1) - self._multiplier_matrix_conj = math_ops.conj( - self._multiplier_matrix) + self._multiplier_matrix_conj = math_ops.conj(self._multiplier_matrix) self._abs_multiplier = math_ops.abs(self.multiplier) self._num_rows = linear_operator_util.shape_tensor( @@ -594,27 +611,25 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity): self._num_rows_static = tensor_util.constant_value(self._num_rows) self._check_num_rows_possibly_add_asserts() self._num_rows_cast_to_dtype = math_ops.cast(self._num_rows, self.dtype) - self._num_rows_cast_to_real_dtype = math_ops.cast( - self._num_rows, self.dtype.real_dtype) + self._num_rows_cast_to_real_dtype = math_ops.cast(self._num_rows, + self.dtype.real_dtype) def _shape(self): - matrix_shape = tensor_shape.TensorShape( - (self._num_rows_static, self._num_rows_static)) + matrix_shape = tensor_shape.TensorShape((self._num_rows_static, + self._num_rows_static)) batch_shape = self.multiplier.get_shape() return batch_shape.concatenate(matrix_shape) def _shape_tensor(self): - matrix_shape = array_ops.stack( - (self._num_rows, self._num_rows), axis=0) + matrix_shape = array_ops.stack((self._num_rows, self._num_rows), axis=0) batch_shape = array_ops.shape(self.multiplier) return array_ops.concat((batch_shape, matrix_shape), 0) def _assert_non_singular(self): return check_ops.assert_positive( - math_ops.abs(self.multiplier), - message="LinearOperator was singular") + math_ops.abs(self.multiplier), message="LinearOperator was singular") def _assert_positive_definite(self): return check_ops.assert_positive( @@ -635,13 +650,12 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity): else: matrix = self._multiplier_matrix if self._assert_proper_shapes: - aps = linear_operator_util.assert_compatible_matrix_dimensions( - self, x) + aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) x = control_flow_ops.with_dependencies([aps], x) return x * matrix def _determinant(self): - return self.multiplier ** self._num_rows_cast_to_dtype + return self.multiplier**self._num_rows_cast_to_dtype def _log_abs_determinant(self): return self._num_rows_cast_to_real_dtype * math_ops.log( @@ -654,11 +668,24 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity): else: matrix = self._multiplier_matrix if self._assert_proper_shapes: - aps = linear_operator_util.assert_compatible_matrix_dimensions( - self, rhs) + aps = linear_operator_util.assert_compatible_matrix_dimensions(self, rhs) rhs = control_flow_ops.with_dependencies([aps], rhs) return rhs / matrix + def _trace(self): + # Get Tensor of all ones of same shape as self.batch_shape. + if self.batch_shape.is_fully_defined(): + batch_of_ones = array_ops.ones(shape=self.batch_shape, dtype=self.dtype) + else: + batch_of_ones = array_ops.ones( + shape=self.batch_shape_tensor(), dtype=self.dtype) + + if self._min_matrix_dim() is not None: + return self.multiplier * self._min_matrix_dim() * batch_of_ones + else: + return (self.multiplier * math_ops.cast(self._min_matrix_dim_tensor(), + self.dtype) * batch_of_ones) + def _diag_part(self): return self._ones_diag() * self.multiplier[..., array_ops.newaxis] diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py index b2d7b10157b..3d316450d80 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py @@ -263,6 +263,23 @@ class LinearOperatorDerivedClassTest(test.TestCase): feed_dict=feed_dict) self.assertAC(op_solve_v, mat_solve_v) + def test_trace(self): + self._skip_if_tests_to_skip_contains("trace") + for use_placeholder in False, True: + for shape in self._shapes_to_test: + for dtype in self._dtypes_to_test: + with self.test_session(graph=ops.Graph()) as sess: + sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED + operator, mat, feed_dict = self._operator_and_mat_and_feed_dict( + shape, dtype, use_placeholder=use_placeholder) + op_trace = operator.trace() + mat_trace = math_ops.trace(mat) + if not use_placeholder: + self.assertAllEqual(op_trace.get_shape(), mat_trace.get_shape()) + op_trace_v, mat_trace_v = sess.run([op_trace, mat_trace], + feed_dict=feed_dict) + self.assertAC(op_trace_v, mat_trace_v) + def test_add_to_tensor(self): self._skip_if_tests_to_skip_contains("add_to_tensor") for use_placeholder in False, True: diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py b/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py index 8a152a9b475..22ccf6f1310 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py @@ -91,7 +91,8 @@ class LinearOperatorTriL(linear_operator.LinearOperator): This `LinearOperator` is initialized with boolean flags of the form `is_X`, for `X = non_singular, self_adjoint, positive_definite, square`. - These have the following meaning + These have the following meaning: + * If `is_X == True`, callers should expect the operator to have the property `X`. This is a promise that should be fulfilled, but is *not* a runtime assert. For example, finite floating point precision may result diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_udvh_update.py b/tensorflow/contrib/linalg/python/ops/linear_operator_udvh_update.py index 546d899e74e..9c9c3595746 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_udvh_update.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_udvh_update.py @@ -112,9 +112,9 @@ class LinearOperatorUDVHUpdate(linear_operator.LinearOperator): #### Matrix property hints This `LinearOperator` is initialized with boolean flags of the form `is_X`, - for `X = non_singular, self_adjoint, positive_definite, diag_update_positive` - and `square` - These have the following meaning + for `X = non_singular`, `self_adjoint`, `positive_definite`, + `diag_update_positive` and `square`. These have the following meaning: + * If `is_X == True`, callers should expect the operator to have the property `X`. This is a promise that should be fulfilled, but is *not* a runtime assert. For example, finite floating point precision may result diff --git a/tensorflow/contrib/lookup/lookup_ops.py b/tensorflow/contrib/lookup/lookup_ops.py index f0f1c14fcaa..f53f38f3cf8 100644 --- a/tensorflow/contrib/lookup/lookup_ops.py +++ b/tensorflow/contrib/lookup/lookup_ops.py @@ -86,7 +86,7 @@ def index_table_from_tensor(mapping, Any lookup of an out-of-vocabulary token will return a bucket ID based on its hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the `default_value`. - The bucket ID range is `[mapping size, mapping size + num_oov_buckets]`. + The bucket ID range is `[mapping size, mapping size + num_oov_buckets - 1]`. The underlying table must be initialized by calling `tf.tables_initializer.run()` or `table.init.run()` once. diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index 305ed0d11ec..2150cfe9ea8 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -279,6 +279,16 @@ ifeq ($(TARGET),ANDROID) LIBS += -lhexagon_controller LDFLAGS += -L$(HEXAGON_LIBS) CXXFLAGS += -DUSE_HEXAGON_LIBS + +# CAVEAT: We should disable TENSORFLOW_DISABLE_META while running +# quantized_matmul on Android because it crashes in +# MultiThreadGemm in tensorflow/core/kernels/meta_support.cc +# See http://b/33270149 +# TODO(satok): Remove once it's fixed + CXXFLAGS += -DTENSORFLOW_DISABLE_META + +# Declare __ANDROID_TYPES_FULL__ to enable required types for hvx + CXXFLAGS += -D__ANDROID_TYPES_FULL__ endif ifdef ENABLE_EXPERIMENTAL_HEXNN_OPS @@ -500,6 +510,18 @@ tensorflow/core/util/reporter.cc \ tensorflow/tools/benchmark/benchmark_model.cc \ tensorflow/tools/benchmark/benchmark_model_main.cc +ifdef HEXAGON_LIBS + TF_CC_SRCS += \ +tensorflow/cc/framework/scope.cc \ +tensorflow/cc/framework/ops.cc \ +tensorflow/cc/ops/const_op.cc \ +tensorflow/core/kernels/hexagon/graph_transfer_utils.cc \ +tensorflow/core/kernels/hexagon/graph_transferer.cc \ +tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc \ +tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc \ +tensorflow/core/kernels/hexagon/hexagon_remote_fused_graph_executor_build.cc +endif + # File names of the intermediate files target compilation generates. TF_CC_OBJS := $(addprefix $(OBJDIR), $(TF_CC_SRCS:.cc=.o)) PBT_GEN_FILES := $(addprefix $(PBTGENDIR), $(PBT_CC_SRCS)) diff --git a/tensorflow/contrib/makefile/sub_makefiles/hexagon_graph_execution/Makefile.in b/tensorflow/contrib/makefile/sub_makefiles/hexagon_graph_execution/Makefile.in index 2a6f66edcb7..9aa81144fd2 100644 --- a/tensorflow/contrib/makefile/sub_makefiles/hexagon_graph_execution/Makefile.in +++ b/tensorflow/contrib/makefile/sub_makefiles/hexagon_graph_execution/Makefile.in @@ -34,27 +34,7 @@ $(wildcard $(GTEST_DIR)/src/*.cc) \ $(wildcard $(GTEST_DIR)/src/*.h) \ $(GTEST_HEADERS) -# CAVEAT: We should disable TENSORFLOW_DISABLE_META while running -# quantized_matmul on Android because it crashes in -# MultiThreadGemm in tensorflow/core/kernels/meta_support.cc -# TODO(satok): Remove once it's fixed -CXXFLAGS += -DTENSORFLOW_DISABLE_META - -# Declare __ANDROID_TYPES_FULL__ to enable required types for hvx -CXXFLAGS += -D__ANDROID_TYPES_FULL__ - GRAPH_TRANSFER_SRCS := \ -tensorflow/cc/framework/scope.cc \ -tensorflow/cc/framework/ops.cc \ -tensorflow/cc/ops/const_op.cc \ -tensorflow/core/kernels/hexagon/graph_transfer_utils.cc \ -tensorflow/core/kernels/hexagon/graph_transferer.cc \ -tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc \ -tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc \ -tensorflow/core/kernels/hexagon/hexagon_remote_fused_graph_executor_build.cc \ -tensorflow/core/kernels/remote_fused_graph_execute_op.cc \ -tensorflow/core/kernels/remote_fused_graph_execute_utils.cc \ -tensorflow/core/ops/remote_fused_graph_ops.cc \ tensorflow/core/platform/posix/test.cc GRAPH_EXECUTION_SRCS := \ diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index 857d6fa21bc..541d13aed1b 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -138,6 +138,7 @@ tensorflow/core/kernels/cwise_op_logical_and.cc tensorflow/core/kernels/cwise_op_log.cc tensorflow/core/kernels/cwise_op_less.cc tensorflow/core/kernels/cwise_op_isfinite.cc +tensorflow/core/kernels/cwise_op_invert.cc tensorflow/core/kernels/cwise_op_greater_equal.cc tensorflow/core/kernels/cwise_op_greater.cc tensorflow/core/kernels/cwise_op_floor_div.cc @@ -146,6 +147,9 @@ tensorflow/core/kernels/cwise_op_exp.cc tensorflow/core/kernels/cwise_op_equal_to_2.cc tensorflow/core/kernels/cwise_op_equal_to_1.cc tensorflow/core/kernels/cwise_op_div.cc +tensorflow/core/kernels/cwise_op_bitwise_xor.cc +tensorflow/core/kernels/cwise_op_bitwise_or.cc +tensorflow/core/kernels/cwise_op_bitwise_and.cc tensorflow/core/kernels/cwise_op_add_2.cc tensorflow/core/kernels/cwise_op_add_1.cc tensorflow/core/kernels/cwise_op_abs.cc @@ -202,12 +206,15 @@ tensorflow/core/kernels/quantized_reshape_op.cc tensorflow/core/kernels/quantized_resize_bilinear_op.cc tensorflow/core/kernels/requantization_range_op.cc tensorflow/core/kernels/requantize.cc +tensorflow/core/kernels/remote_fused_graph_execute_op.cc +tensorflow/core/kernels/remote_fused_graph_execute_utils.cc tensorflow/core/ops/training_ops.cc tensorflow/core/ops/string_ops.cc tensorflow/core/ops/state_ops.cc tensorflow/core/ops/sparse_ops.cc tensorflow/core/ops/sendrecv_ops.cc tensorflow/core/ops/script_ops.cc +tensorflow/core/ops/remote_fused_graph_ops.cc tensorflow/core/ops/random_ops.cc tensorflow/core/ops/random_grad.cc tensorflow/core/ops/parsing_ops.cc diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD index 8b792a0f685..c0741f2997a 100644 --- a/tensorflow/contrib/metrics/BUILD +++ b/tensorflow/contrib/metrics/BUILD @@ -31,7 +31,6 @@ py_library( "//tensorflow/python:check_ops", "//tensorflow/python:confusion_matrix", "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:histogram_ops", "//tensorflow/python:init_ops", @@ -40,11 +39,9 @@ py_library( "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:sets", - "//tensorflow/python:sparse_ops", "//tensorflow/python:state_ops", "//tensorflow/python:util", "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", ], ) @@ -58,9 +55,6 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - "//third_party/py/numpy", ], ) @@ -74,8 +68,6 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", "//tensorflow/python:variables", "//third_party/py/numpy", ], @@ -98,6 +90,7 @@ py_test( "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:random_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:variables", "//third_party/py/numpy", ], diff --git a/tensorflow/contrib/nn/BUILD b/tensorflow/contrib/nn/BUILD index 13a05bf3787..dbac049d833 100644 --- a/tensorflow/contrib/nn/BUILD +++ b/tensorflow/contrib/nn/BUILD @@ -17,7 +17,10 @@ py_library( ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = ["//tensorflow/python:nn"], + deps = [ + "//tensorflow/python:nn", + "//tensorflow/python:util", + ], ) filegroup( diff --git a/tensorflow/contrib/opt/python/training/external_optimizer.py b/tensorflow/contrib/opt/python/training/external_optimizer.py index 0909760b383..ff87a95f72f 100644 --- a/tensorflow/contrib/opt/python/training/external_optimizer.py +++ b/tensorflow/contrib/opt/python/training/external_optimizer.py @@ -12,20 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """TensorFlow interface for third-party optimizers.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging - __all__ = ['ExternalOptimizerInterface', 'ScipyOptimizerInterface'] @@ -43,19 +43,41 @@ class ExternalOptimizerInterface(object): @@minimize """ - def __init__(self, loss, var_list=None, equalities=None, inequalities=None, + def __init__(self, + loss, + var_list=None, + equalities=None, + inequalities=None, + var_to_bounds=None, **optimizer_kwargs): """Initialize a new interface instance. Args: loss: A scalar `Tensor` to be minimized. - var_list: Optional list of `Variable` objects to update to minimize + var_list: Optional `list` of `Variable` objects to update to minimize `loss`. Defaults to the list of variables collected in the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. - equalities: Optional list of equality constraint scalar `Tensor`s to be + equalities: Optional `list` of equality constraint scalar `Tensor`s to be held equal to zero. - inequalities: Optional list of inequality constraint scalar `Tensor`s - to be kept nonnegative. + inequalities: Optional `list` of inequality constraint scalar `Tensor`s + to be held nonnegative. + var_to_bounds: Optional `dict` where each key is an optimization + `Variable` and each corresponding value is a length-2 tuple of + `(low, high)` bounds. Although enforcing this kind of simple constraint + could be accomplished with the `inequalities` arg, not all optimization + algorithms support general inequality constraints, e.g. L-BFGS-B. Both + `low` and `high` can either be numbers or anything convertible to a + NumPy array that can be broadcast to the shape of `var` (using + `np.broadcast_to`). To indicate that there is no bound, use `None` (or + `+/- np.infty`). For example, if `var` is a 2x3 matrix, then any of + the following corresponding `bounds` could be supplied: + * `(0, np.infty)`: Each element of `var` held positive. + * `(-np.infty, [1, 2])`: First column less than 1, second column less + than 2. + * `(-np.infty, [[1], [2], [3]])`: First row less than 1, second row less + than 2, etc. + * `(-np.infty, [[1, 2, 3], [4, 5, 6]])`: Entry `var[0, 0]` less than 1, + `var[0, 1]` less than 2, etc. **optimizer_kwargs: Other subclass-specific keyword arguments. """ self._loss = loss @@ -67,37 +89,55 @@ class ExternalOptimizerInterface(object): else: self._vars = list(var_list) - self._update_placeholders = [array_ops.placeholder(var.dtype) - for var in self._vars] - self._var_updates = [var.assign(array_ops.reshape(placeholder, - _get_shape_tuple(var))) - for var, placeholder in - zip(self._vars, self._update_placeholders)] + packed_bounds = None + if var_to_bounds is not None: + left_packed_bounds = [] + right_packed_bounds = [] + for var in self._vars: + shape = var.get_shape().as_list() + bounds = (-np.infty, np.infty) + if var in var_to_bounds: + bounds = var_to_bounds[var] + left_packed_bounds.extend(list(np.broadcast_to(bounds[0], shape).flat)) + right_packed_bounds.extend(list(np.broadcast_to(bounds[1], shape).flat)) + packed_bounds = list(zip(left_packed_bounds, right_packed_bounds)) + self._packed_bounds = packed_bounds + + self._update_placeholders = [ + array_ops.placeholder(var.dtype) for var in self._vars + ] + self._var_updates = [ + var.assign(array_ops.reshape(placeholder, _get_shape_tuple(var))) + for var, placeholder in zip(self._vars, self._update_placeholders) + ] loss_grads = _compute_gradients(loss, self._vars) - equalities_grads = [_compute_gradients(equality, self._vars) - for equality in self._equalities] - inequalities_grads = [_compute_gradients(inequality, self._vars) - for inequality in self._inequalities] + equalities_grads = [ + _compute_gradients(equality, self._vars) + for equality in self._equalities + ] + inequalities_grads = [ + _compute_gradients(inequality, self._vars) + for inequality in self._inequalities + ] self.optimizer_kwargs = optimizer_kwargs self._packed_var = self._pack(self._vars) self._packed_loss_grad = self._pack(loss_grads) self._packed_equality_grads = [ - self._pack(equality_grads) - for equality_grads in equalities_grads + self._pack(equality_grads) for equality_grads in equalities_grads ] self._packed_inequality_grads = [ - self._pack(inequality_grads) - for inequality_grads in inequalities_grads + self._pack(inequality_grads) for inequality_grads in inequalities_grads ] dims = [_prod(_get_shape_tuple(var)) for var in self._vars] accumulated_dims = list(_accumulate(dims)) self._packing_slices = [ - slice(start, end) for start, end in zip(accumulated_dims[:-1], - accumulated_dims[1:])] + slice(start, end) + for start, end in zip(accumulated_dims[:-1], accumulated_dims[1:]) + ] def minimize(self, session=None, @@ -135,35 +175,39 @@ class ExternalOptimizerInterface(object): step_callback = step_callback or (lambda xk: None) # Construct loss function and associated gradient. - loss_grad_func = self._make_eval_func( - [self._loss, self._packed_loss_grad], - session, feed_dict, fetches, loss_callback) + loss_grad_func = self._make_eval_func([self._loss, + self._packed_loss_grad], session, + feed_dict, fetches, loss_callback) # Construct equality constraint functions and associated gradients. - equality_funcs = self._make_eval_funcs( - self._equalities, session, feed_dict, fetches) - equality_grad_funcs = self._make_eval_funcs( - self._packed_equality_grads, session, feed_dict, fetches) + equality_funcs = self._make_eval_funcs(self._equalities, session, feed_dict, + fetches) + equality_grad_funcs = self._make_eval_funcs(self._packed_equality_grads, + session, feed_dict, fetches) # Construct inequality constraint functions and associated gradients. - inequality_funcs = self._make_eval_funcs( - self._inequalities, session, feed_dict, fetches) - inequality_grad_funcs = self._make_eval_funcs( - self._packed_inequality_grads, session, feed_dict, fetches) + inequality_funcs = self._make_eval_funcs(self._inequalities, session, + feed_dict, fetches) + inequality_grad_funcs = self._make_eval_funcs(self._packed_inequality_grads, + session, feed_dict, fetches) # Get initial value from TF session. initial_packed_var_val = session.run(self._packed_var) # Perform minimization. packed_var_val = self._minimize( - initial_val=initial_packed_var_val, loss_grad_func=loss_grad_func, + initial_val=initial_packed_var_val, + loss_grad_func=loss_grad_func, equality_funcs=equality_funcs, equality_grad_funcs=equality_grad_funcs, inequality_funcs=inequality_funcs, inequality_grad_funcs=inequality_grad_funcs, - step_callback=step_callback, optimizer_kwargs=self.optimizer_kwargs) - var_vals = [packed_var_val[packing_slice] - for packing_slice in self._packing_slices] + packed_bounds=self._packed_bounds, + step_callback=step_callback, + optimizer_kwargs=self.optimizer_kwargs) + var_vals = [ + packed_var_val[packing_slice] for packing_slice in self._packing_slices + ] # Set optimization variables to their new values. session.run( @@ -173,7 +217,7 @@ class ExternalOptimizerInterface(object): def _minimize(self, initial_val, loss_grad_func, equality_funcs, equality_grad_funcs, inequality_funcs, inequality_grad_funcs, - step_callback, optimizer_kwargs): + packed_bounds, step_callback, optimizer_kwargs): """Wrapper for a particular optimization algorithm implementation. It would be appropriate for a subclass implementation of this method to @@ -191,6 +235,7 @@ class ExternalOptimizerInterface(object): inequality_funcs: A list of functions each of which specifies a scalar quantity that an optimizer should hold >= 0. inequality_grad_funcs: A list of gradients of inequality_funcs. + packed_bounds: A list of bounds for each index, or `None`. step_callback: A callback function to execute at each optimization step, supplied with the current value of the packed variable vector. optimizer_kwargs: Other key-value arguments available to the optimizer. @@ -239,7 +284,11 @@ class ExternalOptimizerInterface(object): return eval_func - def _make_eval_funcs(self, tensors, session, feed_dict, fetches, + def _make_eval_funcs(self, + tensors, + session, + feed_dict, + fetches, callback=None): return [ self._make_eval_func(tensor, session, feed_dict, fetches, callback) @@ -266,7 +315,24 @@ class ScipyOptimizerInterface(ExternalOptimizerInterface): # The value of vector should now be [0., 0.]. ``` - Example with constraints: + Example with simple bound constraints: + + ```python + vector = tf.Variable([7., 7.], 'vector') + + # Make vector norm as small as possible. + loss = tf.reduce_sum(tf.square(vector)) + + optimizer = ScipyOptimizerInterface( + loss, var_to_bounds={vector: ([1, 2], np.infty)}) + + with tf.Session() as session: + optimizer.minimize(session) + + # The value of vector should now be [1., 2.]. + ``` + + Example with more complicated constraints: ```python vector = tf.Variable([7., 7.], 'vector') @@ -294,7 +360,8 @@ class ScipyOptimizerInterface(ExternalOptimizerInterface): def _minimize(self, initial_val, loss_grad_func, equality_funcs, equality_grad_funcs, inequality_funcs, inequality_grad_funcs, - step_callback, optimizer_kwargs): + packed_bounds, step_callback, optimizer_kwargs): + def loss_grad_func_wrapper(x): # SciPy's L-BFGS-B Fortran implementation requires gradients as doubles. loss, gradient = loss_grad_func(x) @@ -314,7 +381,20 @@ class ScipyOptimizerInterface(ExternalOptimizerInterface): 'callback': step_callback, 'method': method, 'constraints': constraints, + 'bounds': packed_bounds, } + + for kwarg in minimize_kwargs: + if kwarg in optimizer_kwargs: + if kwarg == 'bounds': + # Special handling for 'bounds' kwarg since ability to specify bounds + # was added after this module was already publicly released. + raise ValueError( + 'Bounds must be set using the var_to_bounds argument') + raise ValueError( + 'Optimizer keyword arg \'{}\' is set ' + 'automatically and cannot be injected manually'.format(kwarg)) + minimize_kwargs.update(optimizer_kwargs) if method == 'SLSQP': # SLSQP doesn't support step callbacks. Obviate associated warning @@ -327,8 +407,8 @@ class ScipyOptimizerInterface(ExternalOptimizerInterface): ' Message: %s\n' ' Objective function value: %f\n' ' Number of iterations: %d\n' - ' Number of functions evaluations: %d', - result.message, result.fun, result.nit, result.nfev) + ' Number of functions evaluations: %d', result.message, + result.fun, result.nit, result.nfev) return result['x'] @@ -355,5 +435,7 @@ def _prod(array): def _compute_gradients(tensor, var_list): grads = gradients.gradients(tensor, var_list) # tf.gradients sometimes returns `None` when it should return 0. - return [grad if grad is not None else array_ops.zeros_like(var) - for var, grad in zip(var_list, grads)] + return [ + grad if grad is not None else array_ops.zeros_like(var) + for var, grad in zip(var_list, grads) + ] diff --git a/tensorflow/contrib/opt/python/training/external_optimizer_test.py b/tensorflow/contrib/opt/python/training/external_optimizer_test.py index c9f5a2ca3f1..f39134936f9 100644 --- a/tensorflow/contrib/opt/python/training/external_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/external_optimizer_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import numpy as np + from tensorflow.contrib.opt.python.training import external_optimizer from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -74,13 +75,13 @@ class ExternalOptimizerInterfaceTest(TestCase): minimum_location = constant_op.constant(np.arange(9), dtype=dtypes.float32) - loss = math_ops.reduce_sum(math_ops.square(vector - - minimum_location[:2])) / 2. - loss += math_ops.reduce_sum(math_ops.square(scalar - minimum_location[ - 2])) / 2. + loss = math_ops.reduce_sum( + math_ops.square(vector - minimum_location[:2])) / 2. loss += math_ops.reduce_sum( - math_ops.square(matrix - array_ops.reshape(minimum_location[3:], - [2, 3]))) / 2. + math_ops.square(scalar - minimum_location[2])) / 2. + loss += math_ops.reduce_sum( + math_ops.square( + matrix - array_ops.reshape(minimum_location[3:], [2, 3]))) / 2. optimizer = MockOptimizerInterface(loss) @@ -184,6 +185,41 @@ class ScipyOptimizerInterfaceTest(TestCase): optimizer.minimize(sess) self.assertAllClose(np.ones(2), sess.run(vector)) + def test_scalar_bounds(self): + vector_initial_value = [7., 7.] + vector = variables.Variable(vector_initial_value, 'vector') + + # Make norm as small as possible. + loss = math_ops.reduce_sum(math_ops.square(vector)) + + # Make the minimum value of each component be 1. + var_to_bounds = {vector: (1., np.infty)} + + optimizer = external_optimizer.ScipyOptimizerInterface( + loss, var_to_bounds=var_to_bounds) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + optimizer.minimize(sess) + self.assertAllClose(np.ones(2), sess.run(vector)) + + def test_vector_bounds(self): + vector_initial_value = [7., 7.] + vector = variables.Variable(vector_initial_value, 'vector') + + # Make norm as small as possible. + loss = math_ops.reduce_sum(math_ops.square(vector)) + + var_to_bounds = {vector: ([None, 2.], None)} + + optimizer = external_optimizer.ScipyOptimizerInterface( + loss, var_to_bounds=var_to_bounds) + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + optimizer.minimize(sess) + self.assertAllClose([0., 2.], sess.run(vector)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD new file mode 100644 index 00000000000..c4b46551c12 --- /dev/null +++ b/tensorflow/contrib/predictor/BUILD @@ -0,0 +1,163 @@ +# `Predictor` classes provide an interface for efficient, repeated inference. + +package(default_visibility = ["//third_party/tensroflow/contrib/predictor:__subpackages__"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "predictor", + srcs = [ + "__init__.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":predictor_factories", + ], +) + +py_library( + name = "predictor_factories", + srcs = [ + "predictor_factories.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":contrib_estimator_predictor", + ":core_estimator_predictor", + ":saved_model_predictor", + "//tensorflow/contrib/learn", + ], +) + +py_library( + name = "base_predictor", + srcs = [ + "predictor.py", + ], + srcs_version = "PY2AND3", +) + +py_library( + name = "saved_model_predictor", + srcs = [ + "saved_model_predictor.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":base_predictor", + "//tensorflow/python/tools:saved_model_cli", + ], +) + +py_library( + name = "core_estimator_predictor", + srcs = [ + "core_estimator_predictor.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":base_predictor", + "//tensorflow/contrib/learn", + ], +) + +py_library( + name = "contrib_estimator_predictor", + srcs = [ + "contrib_estimator_predictor.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":base_predictor", + "//tensorflow/contrib/learn", + ], +) + +py_library( + name = "testing_common", + srcs = [ + "testing_common.py", + ], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ], +) + +# Transitive dependencies of this target will be included in the pip package. +py_library( + name = "predictor_pip", + visibility = ["//visibility:public"], + deps = [ + ":contrib_estimator_predictor", + ":core_estimator_predictor", + ":saved_model_predictor", + ], +) + +py_test( + name = "saved_model_predictor_test", + srcs = [ + "saved_model_predictor_test.py", + ], + data = [":test_export_dir"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":saved_model_predictor", + ":testing_common", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "core_estimator_predictor_test", + srcs = [ + "core_estimator_predictor_test.py", + ], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":core_estimator_predictor", + ":testing_common", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "contrib_estimator_predictor_test", + srcs = [ + "contrib_estimator_predictor_test.py", + ], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":contrib_estimator_predictor", + ":testing_common", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + +filegroup( + name = "test_export_dir", + srcs = glob(["test_export_dir/**/*"]), + tags = ["nopip"], +) diff --git a/tensorflow/contrib/predictor/README.md b/tensorflow/contrib/predictor/README.md new file mode 100644 index 00000000000..16cdcf3e706 --- /dev/null +++ b/tensorflow/contrib/predictor/README.md @@ -0,0 +1,96 @@ +# Predictors + +The `Predictor` classes provide a simple interface for performing repeated, +efficient inference. A `Predictor` can be constructed from a `SavedModel` on +disk, a `tf.Estimator` or a `tf.contrib.Estimator`. + +To facilitate the examples below, let's define a trivial `Estimator` that just +calculates a sum: + +```python +def model_fn(features, labels, mode): + z = tf.add(features['x'], features['y'], name='z') + return tf.contrib.learn.ModelFnOps( + mode, {'z': z}, loss=tf.constant(0.0), train_op=tf.no_op()) + +estimator = tf.contrib.learn.Estimator(model_fn=model_fn) +``` + +We can then construct a `Predictor` in two different ways. + +## `Predictor` from a `SavedModel` + +Given a trained `Estimator`, we first export a `SavedModel`: + +```python +def serving_input_fn(): + x = tf.placeholder(dtype=tf.float32, shape=[None], name='x') + y = tf.placeholder(dtype=tf.float32, shape=[None], name='y') + + features = {'x': x, 'y': y} + return tf.contrib.learn.utils.input_fn_utils.InputFnOps( + features, None, default_inputs=features) + +saved_model_dir = estimator.export_savedmodel(my_export_dir, serving_input_fn) +``` + +We can then construct a `Predictor` as follows: + +```python +saved_model_predictor = predictor.from_saved_model(export_dir='test_export_dir') +output_dict = saved_model_predictor({'x': [1.0], 'y': [5.2]}) +# output_dict == {'sum': [6.2]} +``` + +By specifying a signature definition, we can feed and fetch any `Tensor`s in +the `Graph`. In this example, we feed and fetch the same `Tensor`, `z`: + +```python +inputs = outputs = {'z': tf.TensorInfo( + name='z:0', + dtype=types_pb2.DT_FLOAT, + tensor_shape=tensor_shape_pb2.TensorShapeProto())} + +signature_def = tf.saved_model.signature_def_utils.build_signature_def( + inputs=inputs, + outputs=outputs, + method_name='tensorflow/serving/regress') + +trivial_predictor = predictor.from_saved_model( + export_dir=saved_model_dir, + signature_def=signature_def) + +output_dict = trivial_predictor({'z': [32.]}) +# output_dict == {'z': [32.]} +``` + +You can also specify input and output `Tensor`s by name using the `input_names` +and `output_names` keywords: + +```python +saved_model_predictor = predictor.from_saved_model( + export_dir=saved_model_dir, + input_names={'x': 'x:0', 'y': 'y:0'}, + outputs={'z': 'z:0'}) + +output_dict = saved_model_predictor({'x': [6.], 'y': [11.]}) +# output_dict == {'z': [17.]} +``` + +This functionality is particularly useful for performing encoding once, but +doing multiple decoding iterations with e.g. seq2seq models. + +## `Predictor` from an `Estimator` + +We can also construct a `Predictor` directly from an `Estimator`. Defining +`serving_input_fn` as above, + +```python +estimator_predictor = predictor.from_contrib_estimator( + estimator, serving_input_fn) +output_dict = sum_predictor({'x': [1., 2.], 'y': [3., 4.]}) +# output_dict == {'z': [4., 6.]} +``` + +Construction from a `tf.Estimator` is almost identical. + diff --git a/tensorflow/tools/ci_build/builds/tensorboard.sh b/tensorflow/contrib/predictor/__init__.py old mode 100755 new mode 100644 similarity index 60% rename from tensorflow/tools/ci_build/builds/tensorboard.sh rename to tensorflow/contrib/predictor/__init__.py index 77bd29c09f8..d270c3f7983 --- a/tensorflow/tools/ci_build/builds/tensorboard.sh +++ b/tensorflow/contrib/predictor/__init__.py @@ -1,5 +1,4 @@ -#!/usr/bin/env bash -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,16 +13,12 @@ # limitations under the License. # ============================================================================== -set -e +"""Modules for `Predictor`s.""" -export LAUNCHPAD_CHROME=${LAUNCHPAD_CHROME:-$(which chromium-browser)} +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function -cd tensorflow/tensorboard - -# Install all js dependencies (tooling via npm, frontend assets via bower) -npm run prepare - -npm run compile - -# Run wct in headless chrome using xvfb -xvfb-run ./node_modules/web-component-tester/bin/wct --skip-plugin=sauce +from tensorflow.contrib.predictor import from_contrib_estimator +from tensorflow.contrib.predictor import from_estimator +from tensorflow.contrib.predictor import from_saved_model diff --git a/tensorflow/contrib/predictor/contrib_estimator_predictor.py b/tensorflow/contrib/predictor/contrib_estimator_predictor.py new file mode 100644 index 00000000000..b7a98c68e23 --- /dev/null +++ b/tensorflow/contrib/predictor/contrib_estimator_predictor.py @@ -0,0 +1,74 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""A `Predictor constructed from a `tf.contrib.learn.Estimator`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils +from tensorflow.contrib.predictor import predictor +from tensorflow.python.framework import ops +from tensorflow.python.training import monitored_session +from tensorflow.python.training import saver + + +class ContribEstimatorPredictor(predictor.Predictor): + """A `Predictor constructed from a `tf.contrib.learn.Estimator`.""" + + def __init__(self, + estimator, + prediction_input_fn, + input_alternative_key=None, + output_alternative_key=None, + graph=None): + """Initialize a `ContribEstimatorPredictor`. + + Args: + estimator: an instance of `tf.contrib.learn.Estimator`. + prediction_input_fn: a function that takes no arguments and returns an + instance of `InputFnOps`. + input_alternative_key: Optional. Specify the input alternative used for + prediction. + output_alternative_key: Specify the output alternative used for + prediction. Not needed for single-headed models but required for + multi-headed models. + graph: Optional. The Tensorflow `graph` in which prediction should be + done. + """ + self._graph = graph or ops.Graph() + with self._graph.as_default(): + input_fn_ops = prediction_input_fn() + # pylint: disable=protected-access + model_fn_ops = estimator._get_predict_ops(input_fn_ops.features) + # pylint: enable=protected-access + checkpoint_path = saver.latest_checkpoint(estimator.model_dir) + self._session = monitored_session.MonitoredSession( + session_creator=monitored_session.ChiefSessionCreator( + checkpoint_filename_with_path=checkpoint_path)) + + input_alternative_key = ( + input_alternative_key or + saved_model_export_utils.DEFAULT_INPUT_ALTERNATIVE_KEY) + input_alternatives, _ = saved_model_export_utils.get_input_alternatives( + input_fn_ops) + self._feed_tensors = input_alternatives[input_alternative_key] + + (output_alternatives, + output_alternative_key) = saved_model_export_utils.get_output_alternatives( + model_fn_ops, output_alternative_key) + _, fetch_tensors = output_alternatives[output_alternative_key] + self._fetch_tensors = fetch_tensors diff --git a/tensorflow/contrib/predictor/contrib_estimator_predictor_test.py b/tensorflow/contrib/predictor/contrib_estimator_predictor_test.py new file mode 100644 index 00000000000..4b97a52b1a3 --- /dev/null +++ b/tensorflow/contrib/predictor/contrib_estimator_predictor_test.py @@ -0,0 +1,70 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for predictor.contrib_estimator_predictor.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile +import numpy as np + +from tensorflow.contrib.predictor import contrib_estimator_predictor +from tensorflow.contrib.predictor import testing_common +from tensorflow.python.platform import test + + +KEYS_AND_OPS = (('sum', lambda x, y: x + y), + ('product', lambda x, y: x * y,), + ('difference', lambda x, y: x - y)) + + +class ContribEstimatorPredictorTest(test.TestCase): + """Test fixture for `ContribEstimatorPredictor`.""" + + def setUp(self): + model_dir = tempfile.mkdtemp() + self._estimator = testing_common.get_arithmetic_estimator( + core=False, model_dir=model_dir) + self._prediction_input_fn = testing_common.get_arithmetic_input_fn( + core=False, train=False) + + def testSpecifiedSignatureKey(self): + """Test prediction with spedicified signatures.""" + np.random.seed(1234) + for key, op in KEYS_AND_OPS: + x = np.random.rand() + y = np.random.rand() + expected_output = op(x, y) + + predictor = contrib_estimator_predictor.ContribEstimatorPredictor( + estimator=self._estimator, + prediction_input_fn=self._prediction_input_fn, + output_alternative_key=key) + output_tensor_name = predictor.fetch_tensors[key].name + self.assertRegexpMatches( + output_tensor_name, + key, + msg='Unexpected fetch tensor.') + output = predictor({'x': x, 'y': y})[key] + self.assertAlmostEqual( + expected_output, output, places=3, + msg='Failed for output key "{}." ' + 'Got output {} for x = {} and y = {}'.format( + key, output, x, y)) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/predictor/core_estimator_predictor.py b/tensorflow/contrib/predictor/core_estimator_predictor.py new file mode 100644 index 00000000000..5557ef51017 --- /dev/null +++ b/tensorflow/contrib/predictor/core_estimator_predictor.py @@ -0,0 +1,80 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""A `Predictor` constructed from an `learn.python.estimator.Estimator`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.predictor import predictor +from tensorflow.python.estimator import model_fn +from tensorflow.python.framework import ops +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.training import monitored_session + + +def _get_signature_def( + serving_input_receiver, estimator, output_key=None): + """Construct a `SignatureDef` proto.""" + if output_key is None: + output_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + # pylint: disable=protected-access + estimator_spec = estimator._call_model_fn( + serving_input_receiver.features, None, model_fn.ModeKeys.PREDICT) + # pylint: enable=protected-access + export_outputs = estimator_spec.export_outputs + export_output = export_outputs.get(output_key) + if export_output is None: + raise KeyError('output_key must be one of {}; got {}'.format( + export_outputs.keys(), output_key)) + return export_output.as_signature_def(serving_input_receiver.receiver_tensors) + + +class CoreEstimatorPredictor(predictor.Predictor): + """A `Predictor` constructed from an `learn.python.estimator.Estimator`.""" + + def __init__(self, + estimator, + serving_input_receiver_fn, + output_key=None, + graph=None): + """Initialize a `CoreEstimatorPredictor`. + + Args: + estimator: an instance of `learn.python.estimator.Estimator`. + serving_input_receiver_fn: a function that takes no arguments and returns + an instance of `ServingInputReceiver` compatible with `estimator`. + output_key: Optional string specifying the export output to use. If + `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. + graph: Optional. The Tensorflow `graph` in which prediction should be + done. + """ + self._graph = graph or ops.Graph() + with self._graph.as_default(): + serving_input_receiver = serving_input_receiver_fn() + signature_def = _get_signature_def( + serving_input_receiver, estimator, output_key) + checkpoint_path = estimator.model_dir + self._session = monitored_session.MonitoredSession( + session_creator=monitored_session.ChiefSessionCreator( + checkpoint_filename_with_path=checkpoint_path)) + + feed_tensor_info = signature_def.inputs + self._feed_tensors = {k: self._graph.get_tensor_by_name(v.name) + for k, v in feed_tensor_info.items()} + fetch_tensor_info = signature_def.outputs + self._fetch_tensors = {k: self._graph.get_tensor_by_name(v.name) + for k, v in fetch_tensor_info.items()} diff --git a/tensorflow/contrib/predictor/core_estimator_predictor_test.py b/tensorflow/contrib/predictor/core_estimator_predictor_test.py new file mode 100644 index 00000000000..42210867944 --- /dev/null +++ b/tensorflow/contrib/predictor/core_estimator_predictor_test.py @@ -0,0 +1,81 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for predictor.core_estimator_predictor.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile +import numpy as np + +from tensorflow.contrib.predictor import core_estimator_predictor +from tensorflow.contrib.predictor import testing_common +from tensorflow.python.platform import test + + +KEYS_AND_OPS = (('sum', lambda x, y: x + y), + ('product', lambda x, y: x * y,), + ('difference', lambda x, y: x - y)) + + +class CoreEstimatorPredictorTest(test.TestCase): + """Test fixture for `CoreEstimatorPredictor`.""" + + def setUp(self): + model_dir = tempfile.mkdtemp() + self._estimator = testing_common.get_arithmetic_estimator( + core=True, model_dir=model_dir) + self._serving_input_receiver_fn = testing_common.get_arithmetic_input_fn( + core=True, train=False) + + def testDefault(self): + """Test prediction with default signature.""" + np.random.seed(1111) + x = np.random.rand() + y = np.random.rand() + predictor = core_estimator_predictor.CoreEstimatorPredictor( + estimator=self._estimator, + serving_input_receiver_fn=self._serving_input_receiver_fn) + output = predictor({'x': x, 'y': y})['sum'] + self.assertAlmostEqual(output, x + y, places=3) + + def testSpecifiedSignatureKey(self): + """Test prediction with spedicified signatures.""" + np.random.seed(1234) + for output_key, op in KEYS_AND_OPS: + x = np.random.rand() + y = np.random.rand() + expected_output = op(x, y) + + predictor = core_estimator_predictor.CoreEstimatorPredictor( + estimator=self._estimator, + serving_input_receiver_fn=self._serving_input_receiver_fn, + output_key=output_key) + output_tensor_name = predictor.fetch_tensors[output_key].name + self.assertRegexpMatches( + output_tensor_name, + output_key, + msg='Unexpected fetch tensor.') + output = predictor({'x': x, 'y': y})[output_key] + self.assertAlmostEqual( + expected_output, output, places=3, + msg='Failed for output key "{}." ' + 'Got output {} for x = {} and y = {}'.format( + output_key, output, x, y)) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/predictor/predictor.py b/tensorflow/contrib/predictor/predictor.py new file mode 100644 index 00000000000..dbc0028259e --- /dev/null +++ b/tensorflow/contrib/predictor/predictor.py @@ -0,0 +1,77 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Abstract base class for all predictors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import six + + +@six.add_metaclass(abc.ABCMeta) +class Predictor(object): + """Abstract base class for all predictors.""" + + @property + def graph(self): + return self._graph + + @property + def session(self): + return self._session + + @property + def feed_tensors(self): + return self._feed_tensors + + @property + def fetch_tensors(self): + return self._fetch_tensors + + def __repr__(self): + return '{} with feed tensors {} and fetch_tensors {}'.format( + type(self).__name__, self._feed_tensors, self._fetch_tensors) + + def __call__(self, input_dict): + """Returns predictions based on `input_dict`. + + Args: + input_dict: a `dict` mapping strings to numpy arrays. These keys + must match `self._feed_tensors.keys()`. + + Returns: + A `dict` mapping strings to numpy arrays. The keys match + `self.fetch_tensors.keys()`. + + Raises: + ValueError: `input_dict` does not match `feed_tensors`. + """ + # TODO(jamieas): make validation optional? + input_keys = set(input_dict.keys()) + expected_keys = set(self.feed_tensors.keys()) + unexpected_keys = input_keys - expected_keys + if unexpected_keys: + raise ValueError('Got unexpected keys in input_dict: {}'.format( + unexpected_keys)) + + feed_dict = {} + for key in self.feed_tensors.keys(): + value = input_dict.get(key) + if value is not None: + feed_dict[self.feed_tensors[key]] = value + return self._session.run(fetches=self.fetch_tensors, feed_dict=feed_dict) diff --git a/tensorflow/contrib/predictor/predictor_factories.py b/tensorflow/contrib/predictor/predictor_factories.py new file mode 100644 index 00000000000..e3f30d917d6 --- /dev/null +++ b/tensorflow/contrib/predictor/predictor_factories.py @@ -0,0 +1,132 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Factory functions for `Predictor`s.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.predictor import contrib_estimator_predictor +from tensorflow.contrib.predictor import core_estimator_predictor +from tensorflow.contrib.predictor import saved_model_predictor +from tensorflow.python.estimator import estimator as core_estimator + + +def from_contrib_estimator(estimator, + prediction_input_fn, + input_alternative_key=None, + output_alternative_key=None, + graph=None): + """Constructs a `Predictor` from a `tf.contrib.learn.Estimator`. + + Args: + estimator: an instance of `tf.contrib.learn.Estimator`. + prediction_input_fn: a function that takes no arguments and returns an + instance of `InputFnOps`. + input_alternative_key: Optional. Specify the input alternative used for + prediction. + output_alternative_key: Specify the output alternative used for + prediction. Not needed for single-headed models but required for + multi-headed models. + graph: Optional. The Tensorflow `graph` in which prediction should be + done. + + Returns: + An initialized `Predictor`. + + Raises: + TypeError: if `estimator` is a core `Estimator` instead of a contrib + `Estimator`. + """ + if isinstance(estimator, core_estimator.Estimator): + raise TypeError('Espected estimator to be of type ' + 'tf.contrib.learn.Estimator, but got type ' + 'tf.python.estimator.Estimator. You likely want to call ' + 'from_estimator.') + return contrib_estimator_predictor.ContribEstimatorPredictor( + estimator, + prediction_input_fn, + input_alternative_key, + output_alternative_key, + graph) + + +def from_estimator(estimator, + serving_input_receiver_fn, + output_key=None, + graph=None): + """Constructs a `Predictor` from a `tf.python.estimator.Estimator`. + + Args: + estimator: an instance of `learn.python.estimator.Estimator`. + serving_input_receiver_fn: a function that takes no arguments and returns + an instance of `ServingInputReceiver` compatible with `estimator`. + output_key: Optional string specifying the export output to use. If + `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. + graph: Optional. The Tensorflow `graph` in which prediction should be + done. + + Returns: + An initialized `Predictor`. + + Raises: + TypeError: if `estimator` is a contrib `Estimator` instead of a core + `Estimator`. + """ + if isinstance(estimator, estimator.Estimator): + raise TypeError('Espected estimator to be of type ' + 'tf.python.estimator.Estimator, but got type ' + 'tf.contrib.learn.Estimator. You likely want to call ' + 'from_contrib_estimator.') + return core_estimator_predictor.CoreEstimatorPredictor( + estimator, + serving_input_receiver_fn, + output_key, + graph) + + +def from_saved_model(export_dir, + signature_def_key=None, + signature_def=None, + tags=None, + graph=None): + """Constructs a `Predictor` from a `SavedModel` on disk. + + Args: + export_dir: a path to a directory containing a `SavedModel`. + signature_def_key: Optional string specifying the signature to use. If + `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. Only one of + `signature_def_key` and `signature_def` + signature_def: A `SignatureDef` proto specifying the inputs and outputs + for prediction. Only one of `signature_def_key` and `signature_def` + should be specified. + tags: Optional. Tags that will be used to retrieve the correct + `SignatureDef`. Defaults to `DEFAULT_TAGS`. + graph: Optional. The Tensorflow `graph` in which prediction should be + done. + + Returns: + An initialized `Predictor`. + + Raises: + ValueError: More than one of `signature_def_key` and `signature_def` is + specified. + """ + return saved_model_predictor.SavedModelPredictor(export_dir, + signature_def_key, + signature_def, + tags, + graph) diff --git a/tensorflow/contrib/predictor/saved_model_predictor.py b/tensorflow/contrib/predictor/saved_model_predictor.py new file mode 100644 index 00000000000..ab2bafa0c86 --- /dev/null +++ b/tensorflow/contrib/predictor/saved_model_predictor.py @@ -0,0 +1,143 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""A `Predictor` constructed from a `SavedModel`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging + +from tensorflow.contrib.predictor import predictor +from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils +from tensorflow.python.client import session +from tensorflow.python.framework import ops +from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.tools import saved_model_cli + + +DEFAULT_TAGS = 'serve' + +_DEFAULT_INPUT_ALTERNATIVE_FORMAT = 'default_input_alternative:{}' + + +def _get_signature_def(signature_def_key, export_dir, tags): + """Construct a `SignatureDef` proto.""" + signature_def_key = ( + signature_def_key or + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) + + metagraph_def = saved_model_cli.get_meta_graph_def(export_dir, tags) + + try: + signature_def = signature_def_utils.get_signature_def_by_key( + metagraph_def, + signature_def_key) + except ValueError as e: + try: + formatted_key = _DEFAULT_INPUT_ALTERNATIVE_FORMAT.format( + signature_def_key) + signature_def = signature_def_utils.get_signature_def_by_key( + metagraph_def, formatted_key) + + logging.warning('Could not find signature def "%s". ' + 'Using "%s" instead', signature_def_key, formatted_key) + except ValueError: + raise ValueError( + 'Got signature_def_key "{}". Available signatures are {}. ' + 'Original error:\n{}'.format( + signature_def_key, list(metagraph_def.signature_def), e)) + return signature_def + + +def _check_signature_arguments(signature_def_key, + signature_def, + input_names, + output_names): + """Validates signature arguments for `SavedModelPredictor`.""" + signature_def_key_specified = signature_def_key is not None + signature_def_specified = signature_def is not None + input_names_specified = input_names is not None + output_names_specified = output_names is not None + if input_names_specified != output_names_specified: + raise ValueError( + 'input_names and output_names must both be specified or both be ' + 'unspecified.' + ) + + if (signature_def_key_specified + signature_def_specified + + input_names_specified > 1): + raise ValueError( + 'You must specify at most one of signature_def_key OR signature_def OR' + '(input_names AND output_names).' + ) + + +class SavedModelPredictor(predictor.Predictor): + """A `Predictor` constructed from a `SavedModel`.""" + + def __init__(self, + export_dir, + signature_def_key=None, + signature_def=None, + input_names=None, + output_names=None, + tags=None, + graph=None): + """Initialize a `CoreEstimatorPredictor`. + + Args: + export_dir: a path to a directory containing a `SavedModel`. + signature_def_key: Optional string specifying the signature to use. If + `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. Only one of + `signature_def_key` and `signature_def` should be specified. + signature_def: A `SignatureDef` proto specifying the inputs and outputs + for prediction. Only one of `signature_def_key` and `signature_def` + should be specified. + input_names: A dictionary mapping strings to `Tensor`s in the `SavedModel` + that represent the input. The keys can be any string of the user's + choosing. + output_names: A dictionary mapping strings to `Tensor`s in the + `SavedModel` that represent the output. The keys can be any string of + the user's choosing. + tags: Optional. Tags that will be used to retrieve the correct + `SignatureDef`. Defaults to `DEFAULT_TAGS`. + graph: Optional. The Tensorflow `graph` in which prediction should be + done. + Raises: + ValueError: If more than one of signature_def_key OR signature_def OR + (input_names AND output_names) is specified. + """ + _check_signature_arguments( + signature_def_key, signature_def, input_names, output_names) + tags = tags or DEFAULT_TAGS + self._graph = graph or ops.Graph() + + with self._graph.as_default(): + self._session = session.Session() + loader.load(self._session, tags.split(','), export_dir) + + if input_names is None: + if signature_def is None: + signature_def = _get_signature_def(signature_def_key, export_dir, tags) + input_names = {k: v.name for k, v in signature_def.inputs.items()} + output_names = {k: v.name for k, v in signature_def.outputs.items()} + + self._feed_tensors = {k: self._graph.get_tensor_by_name(v) + for k, v in input_names.items()} + self._fetch_tensors = {k: self._graph.get_tensor_by_name(v) + for k, v in output_names.items()} diff --git a/tensorflow/contrib/predictor/saved_model_predictor_test.py b/tensorflow/contrib/predictor/saved_model_predictor_test.py new file mode 100644 index 00000000000..f40e2e73d99 --- /dev/null +++ b/tensorflow/contrib/predictor/saved_model_predictor_test.py @@ -0,0 +1,170 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for predictor.saved_model_predictor.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.predictor import saved_model_predictor +from tensorflow.core.framework import tensor_shape_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.saved_model import signature_def_utils + + +KEYS_AND_OPS = (('sum', lambda x, y: x + y), + ('product', lambda x, y: x * y,), + ('difference', lambda x, y: x - y)) + +MODEL_DIR_NAME = 'contrib/predictor/test_export_dir' + + +class SavedModelPredictorTest(test.TestCase): + + @classmethod + def setUpClass(cls): + # Load a saved model exported from the arithmetic `Estimator`. + # See `testing_common.py`. + cls._export_dir = test.test_src_dir_path(MODEL_DIR_NAME) + + def testDefault(self): + """Test prediction with default signature.""" + np.random.seed(1111) + x = np.random.rand() + y = np.random.rand() + predictor = saved_model_predictor.SavedModelPredictor( + export_dir=self._export_dir) + output = predictor({'x': x, 'y': y})['outputs'] + self.assertAlmostEqual(output, x + y, places=3) + + def testSpecifiedSignatureKey(self): + """Test prediction with spedicified signature key.""" + np.random.seed(1234) + for signature_def_key, op in KEYS_AND_OPS: + x = np.random.rand() + y = np.random.rand() + expected_output = op(x, y) + + predictor = saved_model_predictor.SavedModelPredictor( + export_dir=self._export_dir, + signature_def_key=signature_def_key) + + output_tensor_name = predictor.fetch_tensors['outputs'].name + self.assertRegexpMatches( + output_tensor_name, + signature_def_key, + msg='Unexpected fetch tensor.') + + output = predictor({'x': x, 'y': y})['outputs'] + self.assertAlmostEqual( + expected_output, output, places=3, + msg='Failed for signature "{}." ' + 'Got output {} for x = {} and y = {}'.format( + signature_def_key, output, x, y)) + + def testSpecifiedSignature(self): + """Test prediction with spedicified signature definition.""" + np.random.seed(4444) + for key, op in KEYS_AND_OPS: + x = np.random.rand() + y = np.random.rand() + expected_output = op(x, y) + + inputs = { + 'x': meta_graph_pb2.TensorInfo( + name='inputs/x:0', + dtype=types_pb2.DT_FLOAT, + tensor_shape=tensor_shape_pb2.TensorShapeProto()), + 'y': meta_graph_pb2.TensorInfo( + name='inputs/y:0', + dtype=types_pb2.DT_FLOAT, + tensor_shape=tensor_shape_pb2.TensorShapeProto())} + outputs = { + key: meta_graph_pb2.TensorInfo( + name='outputs/{}:0'.format(key), + dtype=types_pb2.DT_FLOAT, + tensor_shape=tensor_shape_pb2.TensorShapeProto())} + signature_def = signature_def_utils.build_signature_def( + inputs=inputs, + outputs=outputs, + method_name='tensorflow/serving/regress') + predictor = saved_model_predictor.SavedModelPredictor( + export_dir=self._export_dir, + signature_def=signature_def) + + output_tensor_name = predictor.fetch_tensors[key].name + self.assertRegexpMatches( + output_tensor_name, + key, + msg='Unexpected fetch tensor.') + + output = predictor({'x': x, 'y': y})[key] + self.assertAlmostEqual( + expected_output, output, places=3, + msg='Failed for signature "{}". ' + 'Got output {} for x = {} and y = {}'.format(key, output, x, y)) + + def testSpecifiedTensors(self): + """Test prediction with spedicified `Tensor`s.""" + np.random.seed(987) + for key, op in KEYS_AND_OPS: + x = np.random.rand() + y = np.random.rand() + expected_output = op(x, y) + input_names = {'x': 'inputs/x:0', + 'y': 'inputs/y:0'} + output_names = {key: 'outputs/{}:0'.format(key)} + predictor = saved_model_predictor.SavedModelPredictor( + export_dir=self._export_dir, + input_names=input_names, + output_names=output_names) + + output_tensor_name = predictor.fetch_tensors[key].name + self.assertRegexpMatches( + output_tensor_name, + key, + msg='Unexpected fetch tensor.') + + output = predictor({'x': x, 'y': y})[key] + self.assertAlmostEqual( + expected_output, output, places=3, + msg='Failed for signature "{}". ' + 'Got output {} for x = {} and y = {}'.format(key, output, x, y)) + + def testBadTagsFail(self): + """Test that predictor construction fails for bad tags.""" + bad_tags_regex = ('.* could not be found in SavedModel') + with self.assertRaisesRegexp(RuntimeError, bad_tags_regex): + _ = saved_model_predictor.SavedModelPredictor( + export_dir=self._export_dir, + tags=('zomg, bad, tags')) + + def testSpecifiedGraph(self): + """Test that the predictor remembers a specified `Graph`.""" + g = ops.Graph() + predictor = saved_model_predictor.SavedModelPredictor( + export_dir=self._export_dir, + graph=g) + self.assertEqual(predictor.graph, g) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/predictor/test_export_dir/saved_model.pb b/tensorflow/contrib/predictor/test_export_dir/saved_model.pb new file mode 100644 index 00000000000..9100fefb720 Binary files /dev/null and b/tensorflow/contrib/predictor/test_export_dir/saved_model.pb differ diff --git a/tensorflow/contrib/predictor/test_export_dir/variables/variables.data-00000-of-00001 b/tensorflow/contrib/predictor/test_export_dir/variables/variables.data-00000-of-00001 new file mode 100644 index 00000000000..1b1cb4d44c5 Binary files /dev/null and b/tensorflow/contrib/predictor/test_export_dir/variables/variables.data-00000-of-00001 differ diff --git a/tensorflow/contrib/predictor/test_export_dir/variables/variables.index b/tensorflow/contrib/predictor/test_export_dir/variables/variables.index new file mode 100644 index 00000000000..dd32e9b71b3 Binary files /dev/null and b/tensorflow/contrib/predictor/test_export_dir/variables/variables.index differ diff --git a/tensorflow/contrib/predictor/testing_common.py b/tensorflow/contrib/predictor/testing_common.py new file mode 100644 index 00000000000..1767704b993 --- /dev/null +++ b/tensorflow/contrib/predictor/testing_common.py @@ -0,0 +1,102 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Common code used for testing `Predictor`s.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.learn.python.learn.estimators import constants +from tensorflow.contrib.learn.python.learn.estimators import estimator as contrib_estimator +from tensorflow.contrib.learn.python.learn.estimators import model_fn as contrib_model_fn +from tensorflow.contrib.learn.python.learn.utils import input_fn_utils +from tensorflow.python.estimator import estimator as core_estimator +from tensorflow.python.estimator import model_fn +from tensorflow.python.estimator.export import export_lib +from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.saved_model import signature_constants + + +def get_arithmetic_estimator(core=True, model_dir=None): + """Returns an `Estimator` that performs basic arithmetic. + + Args: + core: if `True`, returns a `tensorflow.python.estimator.Estimator`. + Otherwise, returns a `tensorflow.contrib.learn.Estimator`. + model_dir: directory in which to export checkpoints and saved models. + Returns: + An `Estimator` that performs arithmetic operations on its inputs. + """ + def _model_fn(features, labels, mode): + _ = labels + x = features['x'] + y = features['y'] + with ops.name_scope('outputs'): + predictions = {'sum': math_ops.add(x, y, name='sum'), + 'product': math_ops.multiply(x, y, name='product'), + 'difference': math_ops.subtract(x, y, name='difference')} + if core: + export_outputs = {k: export_output.PredictOutput({k: v}) + for k, v in predictions.items()} + export_outputs[signature_constants. + DEFAULT_SERVING_SIGNATURE_DEF_KEY] = export_outputs['sum'] + return model_fn.EstimatorSpec(mode=mode, + predictions=predictions, + export_outputs=export_outputs, + loss=constant_op.constant(0), + train_op=control_flow_ops.no_op()) + else: + output_alternatives = {k: (constants.ProblemType.UNSPECIFIED, {k: v}) + for k, v in predictions.items()} + return contrib_model_fn.ModelFnOps( + mode=mode, + predictions=predictions, + output_alternatives=output_alternatives, + loss=constant_op.constant(0), + train_op=control_flow_ops.no_op()) + if core: + return core_estimator.Estimator(_model_fn) + else: + return contrib_estimator.Estimator(_model_fn, model_dir=model_dir) + + +def get_arithmetic_input_fn(core=True, train=False): + """Returns a input functions or serving input receiver function.""" + def _input_fn(): + with ops.name_scope('inputs'): + x = array_ops.placeholder_with_default(0.0, shape=[], name='x') + y = array_ops.placeholder_with_default(0.0, shape=[], name='y') + label = constant_op.constant(0.0) + features = {'x': x, 'y': y} + if core: + if train: + return features, label + return export_lib.ServingInputReceiver( + features=features, + receiver_tensors=features) + else: + if train: + return features, label + return input_fn_utils.InputFnOps( + features=features, + labels={}, + default_inputs=features) + return _input_fn diff --git a/tensorflow/contrib/quantization/BUILD b/tensorflow/contrib/quantization/BUILD index b1d12cc510a..c19a31afb2a 100644 --- a/tensorflow/contrib/quantization/BUILD +++ b/tensorflow/contrib/quantization/BUILD @@ -35,13 +35,10 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:array_ops", "//tensorflow/python:array_ops_gen", - "//tensorflow/python:framework", + "//tensorflow/python:common_shapes", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", - "//tensorflow/python:nn_ops", "//tensorflow/python:nn_ops_gen", ], ) diff --git a/tensorflow/contrib/remote_fused_graph/README.md b/tensorflow/contrib/remote_fused_graph/README.md new file mode 100644 index 00000000000..267cfa10192 --- /dev/null +++ b/tensorflow/contrib/remote_fused_graph/README.md @@ -0,0 +1,8 @@ +# Remote Fused Graph + +## Description + +This module contains libraries for remote fused graph utilities + +Maintainers: +- Satoshi Kataoka (satok@google.com, github.com/satok16) diff --git a/tensorflow/contrib/remote_fused_graph/pylib/BUILD b/tensorflow/contrib/remote_fused_graph/pylib/BUILD new file mode 100644 index 00000000000..c7ed6631315 --- /dev/null +++ b/tensorflow/contrib/remote_fused_graph/pylib/BUILD @@ -0,0 +1,56 @@ +# Description: +# Contains ops for remote fused graph + +package(default_visibility = ["//tensorflow:__subpackages__"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") + +tf_gen_op_wrapper_py( + name = "gen_remote_fused_graph_ops", + out = "python/ops/gen_remote_fused_graph_ops.py", + deps = [ + "//tensorflow/core:remote_fused_graph_ops_op_lib", + ], +) + +py_library( + name = "remote_fused_graph_ops_py", + srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + srcs_version = "PY2AND3", + deps = [ + ":gen_remote_fused_graph_ops", + "//tensorflow/python:framework_for_generated_wrappers", + ], +) + +py_test( + name = "remote_fused_graph_ops_test", + size = "small", + srcs = ["python/ops/remote_fused_graph_ops_test.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + ":remote_fused_graph_ops_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:nn_ops", + "//third_party/py/numpy", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/remote_fused_graph/pylib/__init__.py b/tensorflow/contrib/remote_fused_graph/pylib/__init__.py new file mode 100644 index 00000000000..4d23c38932e --- /dev/null +++ b/tensorflow/contrib/remote_fused_graph/pylib/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Remote fused graph ops python library. + +## This package provides classes for remote fused graph ops. + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import, line-too-long +from tensorflow.contrib.remote_fused_graph.pylib.python.ops.remote_fused_graph_ops import * +# pylint: enable=unused-import,wildcard-import,line-too-long + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ['remote_fused_graph_execute'] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/tensorboard/__main__.py b/tensorflow/contrib/remote_fused_graph/pylib/python/__init__.py similarity index 80% rename from tensorflow/tensorboard/__main__.py rename to tensorflow/contrib/remote_fused_graph/pylib/python/__init__.py index f172583d7c5..b66091f8759 100644 --- a/tensorflow/tensorboard/__main__.py +++ b/tensorflow/contrib/remote_fused_graph/pylib/python/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,14 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Remote fused graph ops python library.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function - -import sys - -from tensorflow.tensorboard.tensorboard import main - -if __name__ == '__main__': - sys.exit(main()) diff --git a/tensorflow/tools/ci_build/install/install_tensorboard_packages.sh b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/__init__.py old mode 100755 new mode 100644 similarity index 65% rename from tensorflow/tools/ci_build/install/install_tensorboard_packages.sh rename to tensorflow/contrib/remote_fused_graph/pylib/python/ops/__init__.py index ca5092cd475..b66091f8759 --- a/tensorflow/tools/ci_build/install/install_tensorboard_packages.sh +++ b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/__init__.py @@ -1,5 +1,4 @@ -#!/usr/bin/env bash -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,17 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +"""Remote fused graph ops python library.""" -set -e - -# Install dependencies from ubuntu deb repository. -apt-get update -apt-get install -y --no-install-recommends \ - chromium-browser \ - nodejs \ - nodejs-legacy \ - npm \ - python-numpy \ - xvfb -apt-get clean -rm -rf /var/lib/apt/lists/* +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py new file mode 100644 index 00000000000..2054367f0d1 --- /dev/null +++ b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py @@ -0,0 +1,66 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Operations to execute a subgraph on a remote processor.""" + +# pylint: disable=g-bad-name +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import, line-too-long +from tensorflow.contrib.remote_fused_graph.pylib.python.ops import gen_remote_fused_graph_ops +from tensorflow.core.framework import remote_fused_graph_execute_info_pb2 as info_pb2 +# pylint: enable=unused-import,wildcard-import,line-too-long + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops + +# RemoteFusedGraphExecute is not differenciable op. +ops.NotDifferentiable("RemoteFusedGraphExecute") + + +def remote_fused_graph_execute(inputs, + output_types, + graph_def, + graph_input_node_names, + graph_output_node_names, + executor_name, + serialized_executor_parameters, + default_graph_input_tensor_type_shapes=None, + default_graph_output_tensor_type_shapes=None): + """A wrapper for remote_fused_graph_execute.""" + info_proto = info_pb2.RemoteFusedGraphExecuteInfo() + info_proto.remote_graph.CopyFrom(graph_def) + info_proto.graph_input_node_name.extend(graph_input_node_names) + info_proto.graph_output_node_name.extend(graph_output_node_names) + info_proto.executor_name = executor_name + info_proto.serialized_executor_parameters = serialized_executor_parameters + if default_graph_input_tensor_type_shapes: + for type_shape in default_graph_input_tensor_type_shapes: + type_shape_proto = info_proto.default_graph_input_tensor_shape.add() + type_shape_proto.dtype = int(dtypes.as_dtype(type_shape[0])) + for dim in type_shape[1]: + type_shape_proto.shape.dim.add().size = dim + if default_graph_output_tensor_type_shapes: + for type_shape in default_graph_output_tensor_type_shapes: + type_shape_proto = info_proto.default_graph_output_tensor_shape.add() + type_shape_proto.dtype = int(dtypes.as_dtype(type_shape[0])) + for dim in type_shape[1]: + type_shape_proto.shape.dim.add().size = dim + + serialized_info = info_proto.SerializeToString() + + return gen_remote_fused_graph_ops.remote_fused_graph_execute( + inputs, output_types, serialized_info) diff --git a/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops_test.py b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops_test.py new file mode 100644 index 00000000000..45df9091482 --- /dev/null +++ b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops_test.py @@ -0,0 +1,66 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.remote_fused_graph_ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +# pylint: disable=unused-import,wildcard-import, line-too-long +from tensorflow.contrib.remote_fused_graph.pylib.python.ops import remote_fused_graph_ops +# pylint: enable=unused-import,wildcard-import,line-too-long + +from tensorflow.core.framework import graph_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest + + +class RemoteFusedGraphExecuteTest(test_util.TensorFlowTestCase): + """Tests for RemoteFusedGraphExecute op.""" + + def testBuild(self): + graph = graph_pb2.GraphDef() + node = graph.node.add() + node.name = "a" + node.op = "op0" + node = graph.node.add() + node.name = "b" + node.op = "op1" + inputs = [ops.convert_n_to_tensor([1], dtypes.int64)] + output_types = [np.int64, np.int64] + graph_input_node_names = ["a"] + graph_output_node_names = ["a", "b"] + executor_name = "" + serialized_executor_parameters = b"" + default_graph_input_tensor_type_shapes = [[dtypes.int64, [1]]] + default_graph_output_tensor_type_shapes = [[dtypes.int64, [1]], + [dtypes.int64, [1]]] + + output_nodes = remote_fused_graph_ops.remote_fused_graph_execute( + inputs, output_types, graph, graph_input_node_names, + graph_output_node_names, executor_name, serialized_executor_parameters, + default_graph_input_tensor_type_shapes, + default_graph_output_tensor_type_shapes) + self.assertEqual(2, len(output_nodes)) + for output_node in output_nodes: + with self.test_session(use_gpu=False): + output_node.eval() + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 8c71977d5ac..835f7df8b2a 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -65,7 +65,7 @@ tf_custom_op_py_library( cuda_py_tests( name = "rnn_cell_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/rnn_cell_test.py"], additional_deps = [ ":rnn_py", @@ -360,7 +360,10 @@ py_binary( srcs_version = "PY2AND3", deps = [ "//tensorflow/core:protos_all_py", + "//tensorflow/python:client", + "//tensorflow/python:framework_ops", "//tensorflow/python:platform", + "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:training", "//tensorflow/python:variables", ], @@ -374,6 +377,10 @@ py_test( tags = ["no_pip"], deps = [ ":checkpoint_convert", + "//tensorflow/python:client", "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", ], ) diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.h b/tensorflow/contrib/rnn/kernels/lstm_ops.h index d9ed9e3ab71..6317f32ac3b 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops.h +++ b/tensorflow/contrib/rnn/kernels/lstm_ops.h @@ -279,9 +279,6 @@ struct LSTMBlockCellBprop : public LSTMBlockCell { cs_prev_grad + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) + df * wcf.reshape(p_shape).broadcast(p_broadcast_shape); - } - - if (use_peephole) { wci_grad.device(d) = (di * cs_prev).sum(Eigen::array({0})); wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array({0})); wco_grad.device(d) = (do_ * cs).sum(Eigen::array({0})); diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index c41b5793fc9..97b9dcc905d 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -58,7 +58,7 @@ def _lstm_block_cell(x, ```python xh = [x, h_prev] - [i, f, ci, o] = xh * w + b + [i, ci, f, o] = xh * w + b f = f + forget_bias if not use_peephole: @@ -93,7 +93,7 @@ def _lstm_block_cell(x, The weight matrix for output gate peephole connection. forget_bias: An optional `float`. Defaults to `1`. The forget gate bias. cell_clip: An optional `float`. Defaults to `3`. - Value to clip the 'cs' value to. + Value to clip the 'cs' value to. Disable by setting to negative value. use_peephole: An optional `bool`. Defaults to `False`. Whether to use peephole weights. name: A name for the operation (optional). @@ -341,17 +341,24 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): def __init__(self, num_units, forget_bias=1.0, + clip_cell=True, use_peephole=False): """Initialize the basic LSTM cell. Args: num_units: int, The number of units in the LSTM cell. forget_bias: float, The bias added to forget gates (see above). + clip_cell: boolean, whether to apply cell clipping. See + `_lstm_block_cell()` for details. use_peephole: Whether to use peephole connections or not. + + When restoring from CudnnLSTM-trained checkpoints, must set the following: + forget_bias, clip_cell, use_peephole = 0, False, False """ self._num_units = num_units self._forget_bias = forget_bias self._use_peephole = use_peephole + self._clip_cell = clip_cell self._names = { "W": "kernel", "b": "bias", @@ -400,6 +407,7 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): wco=wco, wcf=wcf, forget_bias=self._forget_bias, + cell_clip=None if self._clip_cell else -1, use_peephole=self._use_peephole) new_state = rnn_cell_impl.LSTMStateTuple(cs, h) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py index cb12bc9450c..c99562555a1 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/basic_decoder_test.py @@ -111,6 +111,8 @@ class BasicDecoderTest(test.TestCase): sess_results["first_finished"]) self.assertAllEqual([False, False, False, True, True], sess_results["step_finished"]) + self.assertEqual(output_dtype.sample_id, + sess_results["step_outputs"].sample_id.dtype) self.assertAllEqual( np.argmax(sess_results["step_outputs"].rnn_output, -1), sess_results["step_outputs"].sample_id) @@ -186,6 +188,8 @@ class BasicDecoderTest(test.TestCase): self.assertAllEqual([False, False, False, False, False], sess_results["first_finished"]) self.assertAllEqual(expected_step_finished, sess_results["step_finished"]) + self.assertEqual(output_dtype.sample_id, + sess_results["step_outputs"].sample_id.dtype) self.assertAllEqual(expected_sample_ids, sess_results["step_outputs"].sample_id) self.assertAllEqual(expected_step_next_inputs, @@ -254,6 +258,7 @@ class BasicDecoderTest(test.TestCase): }) sample_ids = sess_results["step_outputs"].sample_id + self.assertEqual(output_dtype.sample_id, sample_ids.dtype) expected_step_finished = (sample_ids == end_token) expected_step_next_inputs = embeddings[sample_ids] self.assertAllEqual(expected_step_finished, @@ -337,6 +342,7 @@ class BasicDecoderTest(test.TestCase): self.assertAllEqual([False, False, False, True, True], sess_results["step_finished"]) sample_ids = sess_results["step_outputs"].sample_id + self.assertEqual(output_dtype.sample_id, sample_ids.dtype) batch_where_not_sampling = np.where(sample_ids == -1) batch_where_sampling = np.where(sample_ids > -1) self.assertAllClose( @@ -441,6 +447,7 @@ class BasicDecoderTest(test.TestCase): sess_results["step_finished"]) sample_ids = sess_results["step_outputs"].sample_id + self.assertEqual(output_dtype.sample_id, sample_ids.dtype) batch_where_not_sampling = np.where(np.logical_not(sample_ids)) batch_where_sampling = np.where(sample_ids) diff --git a/tensorflow/contrib/seq2seq/python/ops/helper.py b/tensorflow/contrib/seq2seq/python/ops/helper.py index bee75479357..6b8cad7fd79 100644 --- a/tensorflow/contrib/seq2seq/python/ops/helper.py +++ b/tensorflow/contrib/seq2seq/python/ops/helper.py @@ -382,9 +382,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper): with ops.name_scope(name, "ScheduledOutputTrainingHelperSample", [time, outputs, state]): sampler = bernoulli.Bernoulli(probs=self._sampling_probability) - return math_ops.cast( - sampler.sample(sample_shape=self.batch_size, seed=self._seed), - dtypes.bool) + return sampler.sample(sample_shape=self.batch_size, seed=self._seed) def next_inputs(self, time, outputs, state, sample_ids, name=None): with ops.name_scope(name, "ScheduledOutputTrainingHelperNextInputs", @@ -396,6 +394,7 @@ class ScheduledOutputTrainingHelper(TrainingHelper): state=state, sample_ids=sample_ids, name=name)) + sample_ids = math_ops.cast(sample_ids, dtypes.bool) def maybe_sample(): """Perform scheduled sampling.""" diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD index 5b65a6ae05e..1d9cce239a6 100644 --- a/tensorflow/contrib/signal/BUILD +++ b/tensorflow/contrib/signal/BUILD @@ -12,8 +12,10 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:array_ops", - "//tensorflow/python:framework", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/slim/python/slim/learning.py b/tensorflow/contrib/slim/python/slim/learning.py index f7dddc46c36..8f690fb5490 100644 --- a/tensorflow/contrib/slim/python/slim/learning.py +++ b/tensorflow/contrib/slim/python/slim/learning.py @@ -603,9 +603,9 @@ def train(train_op, saver: Saver to save checkpoints. If None, a default one will be created and used. save_interval_secs: How often, in seconds, to save the model to `logdir`. - sync_optimizer: an instance of tf.train.SyncReplicasOptimizer. If the - argument is supplied, gradient updates will be synchronous. If left as - `None`, gradient updates will be asynchronous. + sync_optimizer: an instance of tf.train.SyncReplicasOptimizer, or a list of + them. If the argument is supplied, gradient updates will be synchronous. + If left as `None`, gradient updates will be asynchronous. session_config: An instance of `tf.ConfigProto` that will be used to configure the `Session`. If left as `None`, the default will be used. trace_every_n_steps: produce and save a `Timeline` in Chrome trace format @@ -633,6 +633,8 @@ def train(train_op, raise ValueError('Cannot provide trace_every_n_steps because ' 'logdir=None') + if isinstance(sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): + sync_optimizer = [sync_optimizer] if sync_optimizer is not None and startup_delay_steps > 0: raise ValueError( 'startup_delay_steps must be zero when sync_optimizer is supplied.') @@ -647,6 +649,12 @@ def train(train_op, global_step = variables.get_or_create_global_step() saver = saver or tf_saver.Saver() + if sync_optimizer is not None: + for opt in sync_optimizer: + if not isinstance(opt, sync_replicas_optimizer.SyncReplicasOptimizer): + raise ValueError( + '`sync_optimizer` must be a tf.train.SyncReplicasOptimizer.') + with ops.name_scope('init_ops'): if init_op == _USE_DEFAULT: init_op = tf_variables.global_variables_initializer() @@ -659,15 +667,17 @@ def train(train_op, tf_variables.local_variables_initializer(), lookup_ops.tables_initializer()) - if sync_optimizer is not None and isinstance( - sync_optimizer, sync_replicas_optimizer.SyncReplicasOptimizer): + if sync_optimizer is not None and isinstance(sync_optimizer, list): with ops.control_dependencies([local_init_op] if local_init_op is not None else []): if is_chief: - local_init_op = sync_optimizer.chief_init_op + local_init_op = control_flow_ops.group( + *[opt.chief_init_op for opt in sync_optimizer]) else: - local_init_op = sync_optimizer.local_step_init_op - ready_for_local_init_op = sync_optimizer.ready_for_local_init_op + local_init_op = control_flow_ops.group( + *[opt.local_step_init_op for opt in sync_optimizer]) + ready_for_local_init_op = control_flow_ops.group( + *[opt.ready_for_local_init_op for opt in sync_optimizer]) else: ready_for_local_init_op = None @@ -678,14 +688,10 @@ def train(train_op, summary_writer = supervisor.Supervisor.USE_DEFAULT if is_chief and sync_optimizer is not None: - if not isinstance(sync_optimizer, - (sync_replicas_optimizer.SyncReplicasOptimizer)): - raise ValueError( - '`sync_optimizer` must be a tf.train.SyncReplicasOptimizer.') - # Need to create these BEFORE the supervisor finalizes the graph: - init_tokens_op = sync_optimizer.get_init_tokens_op() - chief_queue_runner = sync_optimizer.get_chief_queue_runner() + init_tokens_op = [opt.get_init_tokens_op() for opt in sync_optimizer] + chief_queue_runner = [ + opt.get_chief_queue_runner() for opt in sync_optimizer] if train_step_kwargs == _USE_DEFAULT: with ops.name_scope('train_step'): @@ -741,7 +747,7 @@ def train(train_op, threads = sv.start_queue_runners(sess) logging.info('Starting Queues.') if is_chief and sync_optimizer is not None: - sv.start_queue_runners(sess, [chief_queue_runner]) + sv.start_queue_runners(sess, chief_queue_runner) sess.run(init_tokens_op) try: while not sv.should_stop(): diff --git a/tensorflow/contrib/slim/python/slim/learning_test.py b/tensorflow/contrib/slim/python/slim/learning_test.py index 83d45f6f5ad..69061460eb6 100644 --- a/tensorflow/contrib/slim/python/slim/learning_test.py +++ b/tensorflow/contrib/slim/python/slim/learning_test.py @@ -220,7 +220,7 @@ def LogisticClassifier(inputs): def BatchNormClassifier(inputs): - inputs = layers.batch_norm(inputs, decay=0.1) + inputs = layers.batch_norm(inputs, decay=0.1, fused=None) return layers.fully_connected(inputs, 1, activation_fn=math_ops.sigmoid) @@ -267,6 +267,11 @@ class CreateTrainOpTest(test.TestCase): self._inputs = np.random.rand(16, 4).astype(np.float32) self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32) + def _addBesselsCorrection(self, sample_size, expected_var): + correction_factor = sample_size / (sample_size - 1) + expected_var *= correction_factor + return expected_var + def testUseUpdateOps(self): with ops.Graph().as_default(): random_seed.set_random_seed(0) @@ -275,6 +280,7 @@ class CreateTrainOpTest(test.TestCase): expected_mean = np.mean(self._inputs, axis=(0)) expected_var = np.var(self._inputs, axis=(0)) + expected_var = self._addBesselsCorrection(16, expected_var) tf_predictions = BatchNormClassifier(tf_inputs) loss_ops.log_loss(tf_predictions, tf_labels) diff --git a/tensorflow/contrib/specs/BUILD b/tensorflow/contrib/specs/BUILD index dfdbb61dccf..b60a1ef61e8 100644 --- a/tensorflow/contrib/specs/BUILD +++ b/tensorflow/contrib/specs/BUILD @@ -25,15 +25,12 @@ py_library( "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/ndlstm", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:logging_ops", "//tensorflow/python:math_ops", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", - "//tensorflow/python:ops", - "//tensorflow/python:platform", - "//tensorflow/python:training", + "//tensorflow/python:util", "//tensorflow/python:variable_scope", ], ) diff --git a/tensorflow/contrib/stateless/BUILD b/tensorflow/contrib/stateless/BUILD index 1d9c1ffa50d..598e6513aeb 100644 --- a/tensorflow/contrib/stateless/BUILD +++ b/tensorflow/contrib/stateless/BUILD @@ -21,6 +21,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":stateless_random_ops", + "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index 40d23183112..7b5f9472e76 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -21,6 +21,8 @@ filegroup( exclude = [ "**/METADATA", "**/OWNERS", + "kernels/v4/*", + "proto/*", ], ), visibility = ["//tensorflow:__subpackages__"], diff --git a/tensorflow/contrib/tensor_forest/client/eval_metrics.py b/tensorflow/contrib/tensor_forest/client/eval_metrics.py index 17269863542..a0ae083fdba 100644 --- a/tensorflow/contrib/tensor_forest/client/eval_metrics.py +++ b/tensorflow/contrib/tensor_forest/client/eval_metrics.py @@ -117,7 +117,13 @@ def _recall_at_thresholds(predictions, targets, weights=None): weights=weights) +def _auc(probs, targets, weights=None): + return metric_ops.streaming_auc(array_ops.slice(probs, [0, 1], [-1, 1]), + targets, weights=weights) + + _EVAL_METRICS = { + 'auc': _auc, 'sigmoid_entropy': _sigmoid_entropy, 'softmax_entropy': _softmax_entropy, 'accuracy': _accuracy, @@ -132,6 +138,7 @@ _EVAL_METRICS = { } _PREDICTION_KEYS = { + 'auc': INFERENCE_PROB_NAME, 'sigmoid_entropy': INFERENCE_PROB_NAME, 'softmax_entropy': INFERENCE_PROB_NAME, 'accuracy': INFERENCE_PRED_NAME, diff --git a/tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc b/tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc index 555674ca69e..9d5e1400a58 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc +++ b/tensorflow/contrib/tensor_forest/hybrid/core/ops/unpack_path_op.cc @@ -52,7 +52,7 @@ REGISTER_OP("UnpackPath") auto tree_depth = c->Dim(params, 1); int64 num_nodes = InferenceContext::kUnknownDim; if (c->ValueKnown(tree_depth)) { - num_nodes = (1 << c->Value(tree_depth)) - 1; + num_nodes = (static_cast(1) << c->Value(tree_depth)) - 1; } c->set_output(0, c->Matrix(num_points, num_nodes)); diff --git a/tensorflow/contrib/tensor_forest/kernels/sample_inputs_op.cc b/tensorflow/contrib/tensor_forest/kernels/sample_inputs_op.cc index 6bfc29d96fe..3ddc72f216e 100644 --- a/tensorflow/contrib/tensor_forest/kernels/sample_inputs_op.cc +++ b/tensorflow/contrib/tensor_forest/kernels/sample_inputs_op.cc @@ -284,6 +284,8 @@ class SampleInputs : public OpKernel { index = rand_feature; val = inputs(*it, rand_feature); } else { + CHECK(sparse_input) << rand_feature << " selected, and dense is " + << input_spec_.dense_features_size(); const auto indices = sparse_input_indices.matrix(); const auto values = sparse_input_values.vec(); const int32 sparse_index = sparse_input_start + rand_feature - diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/BUILD b/tensorflow/contrib/tensor_forest/kernels/v4/BUILD new file mode 100644 index 00000000000..0542508a8e9 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/BUILD @@ -0,0 +1,232 @@ +# TensorFlow code for training random forests. +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + ), +) + +cc_library( + name = "decision-tree-resource", + srcs = ["decision-tree-resource.cc"], + hdrs = ["decision-tree-resource.h"], + deps = [ + ":decision_node_evaluator", + ":input_data", + ":leaf_model_operators", + "//tensorflow/contrib/decision_trees/proto:generic_tree_model_cc", + "//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_cc", + "//tensorflow/core:framework_headers_lib", + ], +) + +cc_library( + name = "fertile-stats-resource", + srcs = ["fertile-stats-resource.cc"], + hdrs = ["fertile-stats-resource.h"], + deps = [ + ":decision_node_evaluator", + ":input_data", + ":input_target", + ":leaf_model_operators", + ":split_collection_operators", + "//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_cc", + "//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_cc", + "//tensorflow/core:framework", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "input_data", + srcs = ["input_data.cc"], + hdrs = ["input_data.h"], + deps = [ + "//tensorflow/contrib/decision_trees/proto:generic_tree_model_cc", + "//tensorflow/contrib/decision_trees/proto:generic_tree_model_extensions_cc", + "//tensorflow/contrib/tensor_forest:tree_utils", + "//tensorflow/core:framework_headers_lib", + ], +) + +cc_library( + name = "input_target", + hdrs = ["input_target.h"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + ], +) + +cc_library( + name = "leaf_model_operators", + srcs = ["leaf_model_operators.cc"], + hdrs = ["leaf_model_operators.h"], + deps = [ + ":input_target", + ":params", + "//tensorflow/contrib/decision_trees/proto:generic_tree_model_cc", + "//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_cc", + "//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_cc", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + ], +) + +cc_test( + name = "leaf_model_operators_test", + srcs = ["leaf_model_operators_test.cc"], + deps = [ + ":leaf_model_operators", + ":test_utils", + "//tensorflow/contrib/decision_trees/proto:generic_tree_model_cc", + "//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_cc", + "//tensorflow/core", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "grow_stats", + srcs = ["grow_stats.cc"], + hdrs = ["grow_stats.h"], + deps = [ + ":decision_node_evaluator", + ":input_data", + ":input_target", + ":params", + ":stat_utils", + "//tensorflow/contrib/decision_trees/proto:generic_tree_model_cc", + "//tensorflow/contrib/tensor_forest:tree_utils", + "//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_cc", + "//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_cc", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "grow_stats_test", + srcs = ["grow_stats_test.cc"], + deps = [ + ":grow_stats", + ":test_utils", + "//tensorflow/contrib/decision_trees/proto:generic_tree_model_cc", + "//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_cc", + "//tensorflow/core", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "candidate_graph_runner", + srcs = ["candidate_graph_runner.cc"], + hdrs = ["candidate_graph_runner.h"], + deps = [ + ":input_data", + ":input_target", + "//tensorflow/contrib/decision_trees/proto:generic_tree_model_cc", + "//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_cc", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework_headers_lib", + ], +) + +cc_library( + name = "decision_node_evaluator", + srcs = ["decision_node_evaluator.cc"], + hdrs = ["decision_node_evaluator.h"], + deps = [ + ":input_data", + "//tensorflow/contrib/decision_trees/proto:generic_tree_model_cc", + "//tensorflow/contrib/decision_trees/proto:generic_tree_model_extensions_cc", + "//tensorflow/core:lib", + ], +) + +cc_test( + name = "decision_node_evaluator_test", + srcs = ["decision_node_evaluator_test.cc"], + deps = [ + ":decision_node_evaluator", + ":test_utils", + "//tensorflow/contrib/decision_trees/proto:generic_tree_model_cc", + "//tensorflow/core", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "split_collection_operators", + srcs = ["split_collection_operators.cc"], + hdrs = ["split_collection_operators.h"], + deps = [ + ":candidate_graph_runner", + ":grow_stats", + ":input_data", + ":input_target", + ":leaf_model_operators", + ":params", + ":stat_utils", + "//tensorflow/contrib/decision_trees/proto:generic_tree_model_cc", + "//tensorflow/contrib/decision_trees/proto:generic_tree_model_extensions_cc", + "//tensorflow/contrib/tensor_forest:tree_utils", + "//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_cc", + "//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_cc", + ], +) + +cc_library( + name = "stat_utils", + srcs = ["stat_utils.cc"], + hdrs = ["stat_utils.h"], + deps = [ + "//tensorflow/contrib/decision_trees/proto:generic_tree_model_cc", + "//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_cc", + "//tensorflow/core:framework_headers_lib", + ], +) + +cc_library( + name = "test_utils", + hdrs = ["test_utils.h"], + deps = [ + ":input_data", + ":input_target", + ], +) + +cc_library( + name = "params", + srcs = ["params.cc"], + hdrs = ["params.h"], + deps = [ + "//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_cc", + "//tensorflow/core:framework_headers_lib", + ], +) + +cc_test( + name = "params_test", + srcs = ["params_test.cc"], + deps = [ + ":params", + "//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc b/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc new file mode 100644 index 00000000000..81e2a1b2a1b --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.cc @@ -0,0 +1,137 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h" + +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { +namespace tensorforest { + +// Names of ops in the graph to run. +constexpr char kInitializeOp[] = "init"; +constexpr char kAddExampleOp[] = "add_example"; +constexpr char kSplitScoreName[] = "split_score"; +constexpr char kGetSplitName[] = "get_split"; +constexpr char kGetLeftStatsName[] = "get_left_stats"; +constexpr char kGetRightStatsName[] = "get_right_stats"; + +// Names of files written by python graph builder. +constexpr char kGraphFilename[] = "graph"; +constexpr char kSaverDefFilename[] = "saver"; +constexpr char kMetaDefFilename[] = "meta"; + +// Names of Tensor inputs. +constexpr char kFeaturesName[] = "features"; +constexpr char kInputDataName[] = "input_data"; +constexpr char kTargetsName[] = "targets"; +constexpr char kExamplesName[] = "examples"; + +constexpr char kNoOp[] = "none"; + +CandidateGraphRunner::CandidateGraphRunner( + const string& graph_dir, const decision_trees::BinaryNode& split) + : split_(split) { + // read graph from file. + GraphDef graph_def; + TF_CHECK_OK(ReadBinaryProto( + Env::Default(), io::JoinPath(graph_dir, kGraphFilename), &graph_def)) + << "Could not read graph def."; + + // create session. + session_.reset(::tensorflow::NewSession(SessionOptions())); + TF_CHECK_OK(session_->Create(graph_def)) << "Failed to create session"; + + // Features don't change, store them in a tensor. + const auto& oblique = split.inequality_left_child_test().oblique(); + const int32 feat_size = oblique.features_size(); + features_.reset( + new Tensor(tensorflow::DT_INT32, TensorShape({feat_size}))); + auto feat = features_->flat(); + int i = 0; + for (const auto& id : oblique.features()) { + safe_strto32(id.id().value(), &feat(i++)); + } +} + +void CandidateGraphRunner::RunOp( + const string& name, const TensorNameValueList& inputs, + const std::vector& output_tensor_names, + std::vector* outputs) { + std::vector op_name; + if (name != kNoOp) { + op_name.push_back(name); + } + TF_CHECK_OK(session_->Run(inputs, output_tensor_names, op_name, outputs)) + << "Failed to run: " << name; +} + +void CandidateGraphRunner::Init() { + RunOp(kInitializeOp, TensorNameValueList(), std::vector(), nullptr); +} + +void CandidateGraphRunner::AddExample(const Tensor& input_data, + const Tensor& target, + const Tensor& examples) { + TensorNameValueList inputs; + inputs.emplace_back(kFeaturesName, *features_); + inputs.emplace_back(kExamplesName, examples); + inputs.emplace_back(kInputDataName, input_data); + inputs.emplace_back(kTargetsName, target); + + RunOp(kAddExampleOp, inputs, std::vector(), nullptr); +} + +float CandidateGraphRunner::SplitScore() { + std::vector outputs; + RunOp(kNoOp, TensorNameValueList(), {kSplitScoreName}, &outputs); + return outputs[0].unaligned_flat()(0); +} + +void CandidateGraphRunner::GetSplit(decision_trees::BinaryNode* node) { + std::vector outputs; + RunOp(kNoOp, TensorNameValueList(), {kGetSplitName}, &outputs); + ParseProtoUnlimited(node, outputs[0].unaligned_flat()(0)); + const auto& oblique = split_.inequality_left_child_test().oblique(); + auto* new_split = + node->mutable_inequality_left_child_test()->mutable_oblique(); + for (const auto& id : oblique.features()) { + *new_split->add_features() = id; + } +} + +void CandidateGraphRunner::GetLeftStats(LeafStat* stats) { + std::vector outputs; + RunOp(kNoOp, TensorNameValueList(), {kGetLeftStatsName}, &outputs); + const auto& counts = outputs[0].unaligned_flat(); + auto* dense = stats->mutable_classification()->mutable_dense_counts(); + for (int i = 0; i < counts.size(); ++i) { + dense->add_value()->set_float_value(counts(i)); + } +} + +void CandidateGraphRunner::GetRightStats(LeafStat* stats) { + std::vector outputs; + RunOp(kNoOp, TensorNameValueList(), {kGetRightStatsName}, &outputs); + const auto& counts = outputs[0].unaligned_flat(); + auto* dense = stats->mutable_classification()->mutable_dense_counts(); + for (int i = 0; i < counts.size(); ++i) { + dense->add_value()->set_float_value(counts(i)); + } +} + +} // namespace tensorforest +} // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h b/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h new file mode 100644 index 00000000000..4bd1f06c729 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h @@ -0,0 +1,73 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_CANDIDATE_GRAPH_RUNNER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_CANDIDATE_GRAPH_RUNNER_H_ +#include +#include + +#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h" +#include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace tensorforest { + +typedef std::vector> + TensorNameValueList; + +// Class that represents one split candidate, and can perform operations +// on a session created from a graph. +class CandidateGraphRunner { + public: + // split should contain the features that are being used. + CandidateGraphRunner(const string& graph_dir, + const decision_trees::BinaryNode& split); + + // Input the given data and target Tensors to the add_example op. + void AddExample(const Tensor& input_data, const Tensor& target, + const Tensor& examples); + + // Get the candidates' split score with the split_score op. + float SplitScore(); + + // Fills in the split in node with weights and threshold. + void GetSplit(decision_trees::BinaryNode* node); + + // Fills in the stats for the left-branch taken. + void GetLeftStats(LeafStat* stats); + + // Fills in the stats for the right-branch taken. + void GetRightStats(LeafStat* stats); + + // Initializes variables, must be run before other ops. + void Init(); + + protected: + void RunOp(const string& name, const TensorNameValueList& inputs, + const std::vector& output_tensor_names, + std::vector* outputs); + + std::unique_ptr session_; + decision_trees::BinaryNode split_; + std::unique_ptr features_; +}; + +} // namespace tensorforest +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_CANDIDATE_GRAPH_RUNNER_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc new file mode 100644 index 00000000000..165685ca53b --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.cc @@ -0,0 +1,88 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h" + +namespace tensorflow { +namespace tensorforest { + +using decision_trees::DecisionTree; +using decision_trees::TreeNode; + +int32 DecisionTreeResource::TraverseTree( + const std::unique_ptr& input_data, int example, + int32* leaf_depth) const { + const DecisionTree& tree = decision_tree_->decision_tree(); + int32 current_id = 0; + int32 depth = 0; + while (true) { + const TreeNode& current = tree.nodes(current_id); + if (current.has_leaf()) { + if (leaf_depth != nullptr) { + *leaf_depth = depth; + } + return current_id; + } + ++depth; + const int32 next_id = + node_evaluators_[current_id]->Decide(input_data, example); + current_id = tree.nodes(next_id).node_id().value(); + } +} + +void DecisionTreeResource::SplitNode(int32 node_id, SplitCandidate* best, + std::vector* new_children) { + DecisionTree* tree = decision_tree_->mutable_decision_tree(); + TreeNode* node = tree->mutable_nodes(node_id); + int32 newid = tree->nodes_size(); + + // left + new_children->push_back(newid); + TreeNode* new_left = tree->add_nodes(); + new_left->mutable_node_id()->set_value(newid++); + new_left->mutable_leaf(); + + // right + new_children->push_back(newid); + TreeNode* new_right = tree->add_nodes(); + new_right->mutable_node_id()->set_value(newid); + new_right->mutable_leaf(); + + node->clear_leaf(); + node->mutable_binary_node()->Swap(best->mutable_split()); + node->mutable_binary_node()->mutable_left_child_id()->set_value(newid - 1); + node->mutable_binary_node()->mutable_right_child_id()->set_value(newid); + while (node_evaluators_.size() <= node_id) { + node_evaluators_.emplace_back(nullptr); + } + node_evaluators_[node_id] = CreateDecisionNodeEvaluator(*node); +} + +void DecisionTreeResource::MaybeInitialize() { + DecisionTree* tree = decision_tree_->mutable_decision_tree(); + if (tree->nodes_size() == 0) { + tree->add_nodes()->mutable_leaf(); + } else if (node_evaluators_.empty()) { // reconstruct evaluators + for (const auto& node : tree->nodes()) { + if (node.has_leaf()) { + node_evaluators_.emplace_back(nullptr); + } else { + node_evaluators_.push_back(CreateDecisionNodeEvaluator(node)); + } + } + } +} + +} // namespace tensorforest +} // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h new file mode 100644 index 00000000000..c8f09d8e075 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h @@ -0,0 +1,90 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_ + +#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h" +#include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace tensorforest { + + +// Keep a tree ensemble in memory for efficient evaluation and mutation. +class DecisionTreeResource : public ResourceBase { + public: + // Constructor. + explicit DecisionTreeResource() + : decision_tree_(new decision_trees::Model()) {} + + string DebugString() override { + return strings::StrCat("DecisionTree[size=", + decision_tree_->decision_tree().nodes_size(), + "]"); + } + + void MaybeInitialize(); + + const decision_trees::Model& decision_tree() const { + return *decision_tree_; + } + + decision_trees::Model* mutable_decision_tree() { + return decision_tree_.get(); + } + + const decision_trees::Leaf& get_leaf(int32 id) const { + return decision_tree_->decision_tree().nodes(id).leaf(); + } + + decision_trees::TreeNode* get_mutable_tree_node(int32 id) { + return decision_tree_->mutable_decision_tree()->mutable_nodes(id); + } + + // Resets the resource and frees the proto. + // Caller needs to hold the mutex lock while calling this. + void Reset() { + decision_tree_.reset(new decision_trees::Model()); + } + + mutex* get_mutex() { return &mu_; } + + // Return the TreeNode for the leaf that the example ends up at according + // to decsion_tree_. Also fill in that leaf's depth if it isn't nullptr. + int32 TraverseTree(const std::unique_ptr& input_data, + int example, int32* depth) const; + + // Split the given node_id, turning it from a Leaf to a BinaryNode and + // setting it's split to the given best. Add new children ids to + // new_children. + void SplitNode(int32 node_id, SplitCandidate* best, + std::vector* new_children); + + private: + mutex mu_; + std::unique_ptr decision_tree_; + std::vector> node_evaluators_; +}; + + +} // namespace tensorforest +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_TREE_RESOURCE_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc new file mode 100644 index 00000000000..7e25579070e --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.cc @@ -0,0 +1,120 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h" +#include "tensorflow/core/lib/strings/numbers.h" + +namespace tensorflow { +namespace tensorforest { + +std::unique_ptr CreateDecisionNodeEvaluator( + const decision_trees::TreeNode& node) { + const decision_trees::BinaryNode& bnode = node.binary_node(); + return CreateBinaryDecisionNodeEvaluator(bnode, bnode.left_child_id().value(), + bnode.right_child_id().value()); +} + +std::unique_ptr CreateBinaryDecisionNodeEvaluator( + const decision_trees::BinaryNode& bnode, int32 left, int32 right) { + if (bnode.has_inequality_left_child_test()) { + const auto& test = bnode.inequality_left_child_test(); + if (test.has_oblique()) { + return std::unique_ptr( + new ObliqueInequalityDecisionNodeEvaluator(test, left, right)); + } else { + return std::unique_ptr( + new InequalityDecisionNodeEvaluator(test, left, right)); + } + } else { + decision_trees::MatchingValuesTest test; + if (bnode.custom_left_child_test().UnpackTo(&test)) { + return std::unique_ptr( + new MatchingValuesDecisionNodeEvaluator(test, left, right)); + } else { + LOG(ERROR) << "Unknown split test: " << bnode.DebugString(); + return nullptr; + } + } +} + +InequalityDecisionNodeEvaluator::InequalityDecisionNodeEvaluator( + const decision_trees::InequalityTest& test, int32 left, int32 right) + : BinaryDecisionNodeEvaluator(left, right) { + safe_strto32(test.feature_id().id().value(), &feature_num_); + threshold_ = test.threshold().float_value(); + include_equals_ = + test.type() == decision_trees::InequalityTest::LESS_OR_EQUAL; +} + +int32 InequalityDecisionNodeEvaluator::Decide( + const std::unique_ptr& dataset, int example) const { + const float val = dataset->GetExampleValue(example, feature_num_); + if (val < threshold_ || (include_equals_ && val == threshold_)) { + return left_child_id_; + } else { + return right_child_id_; + } +} + +ObliqueInequalityDecisionNodeEvaluator::ObliqueInequalityDecisionNodeEvaluator( + const decision_trees::InequalityTest& test, int32 left, int32 right) + : BinaryDecisionNodeEvaluator(left, right) { + for (int i = 0; i < test.oblique().features_size(); ++i) { + int32 val; + safe_strto32(test.oblique().features(i).id().value(), &val); + feature_num_.push_back(val); + feature_weights_.push_back(test.oblique().weights(i)); + } + threshold_ = test.threshold().float_value(); +} + +int32 ObliqueInequalityDecisionNodeEvaluator::Decide( + const std::unique_ptr& dataset, int example) const { + float val = 0; + for (int i = 0; i < feature_num_.size(); ++i) { + val += feature_weights_[i] * + dataset->GetExampleValue(example, feature_num_[i]); + } + + if (val <= threshold_) { + return left_child_id_; + } else { + return right_child_id_; + } +} + +MatchingValuesDecisionNodeEvaluator::MatchingValuesDecisionNodeEvaluator( + const decision_trees::MatchingValuesTest& test, int32 left, int32 right) + : BinaryDecisionNodeEvaluator(left, right) { + safe_strto32(test.feature_id().id().value(), &feature_num_); + for (const auto& val : test.value()) { + values_.push_back(val.float_value()); + } + inverse_ = test.inverse(); +} + +int32 MatchingValuesDecisionNodeEvaluator::Decide( + const std::unique_ptr& dataset, int example) const { + const float val = dataset->GetExampleValue(example, feature_num_); + for (float testval : values_) { + if (val == testval) { + return inverse_ ? right_child_id_ : left_child_id_; + } + } + + return inverse_ ? left_child_id_ : right_child_id_; +} + +} // namespace tensorforest +} // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h new file mode 100644 index 00000000000..3f03c2d05bb --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h @@ -0,0 +1,107 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_ + +#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" +#include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" + +namespace tensorflow { +namespace tensorforest { + + +// Base class for evaluators of decision nodes that effectively copy proto +// contents into C++ structures for faster execution. +class DecisionNodeEvaluator { + public: + virtual ~DecisionNodeEvaluator() {} + + // Returns the index of the child node. + virtual int32 Decide(const std::unique_ptr& dataset, + int example) const = 0; +}; + +// An evaluator for binary decisions with left and right children. +class BinaryDecisionNodeEvaluator : public DecisionNodeEvaluator { + protected: + BinaryDecisionNodeEvaluator(int32 left, int32 right) + : left_child_id_(left), right_child_id_(right) {} + + int32 left_child_id_; + int32 right_child_id_; +}; + +// Evaluator for basic inequality decisions (f[x] <= T). +class InequalityDecisionNodeEvaluator : public BinaryDecisionNodeEvaluator { + public: + InequalityDecisionNodeEvaluator(const decision_trees::InequalityTest& test, + int32 left, int32 right); + + int32 Decide(const std::unique_ptr& dataset, + int example) const override; + + protected: + int32 feature_num_; + float threshold_; + + // If decision is '<=' as opposed to '<'. + bool include_equals_; +}; + +// Evalutor for splits with multiple weighted features. +class ObliqueInequalityDecisionNodeEvaluator + : public BinaryDecisionNodeEvaluator { + public: + ObliqueInequalityDecisionNodeEvaluator( + const decision_trees::InequalityTest& test, int32 left, int32 right); + + int32 Decide(const std::unique_ptr& dataset, + int example) const override; + + protected: + std::vector feature_num_; + std::vector feature_weights_; + float threshold_; +}; + +// Evaluator for contains-in-set decisions. Also supports inverse (not-in-set). +class MatchingValuesDecisionNodeEvaluator : public BinaryDecisionNodeEvaluator { + public: + MatchingValuesDecisionNodeEvaluator( + const decision_trees::MatchingValuesTest& test, int32 left, int32 right); + + int32 Decide(const std::unique_ptr& dataset, + int example) const override; + + protected: + int32 feature_num_; + std::vector values_; + bool inverse_; +}; + +std::unique_ptr CreateDecisionNodeEvaluator( + const decision_trees::TreeNode& node); +std::unique_ptr CreateBinaryDecisionNodeEvaluator( + const decision_trees::BinaryNode& node, int32 left, int32 right); + +struct CandidateEvalatorCollection { + std::vector> splits; +}; + +} // namespace tensorforest +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_DECISION_NODE_EVALUATOR_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc new file mode 100644 index 00000000000..5c49b87443e --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator_test.cc @@ -0,0 +1,127 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h" +#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +using tensorflow::tensorforest::InequalityDecisionNodeEvaluator; +using tensorflow::tensorforest::MatchingValuesDecisionNodeEvaluator; +using tensorflow::tensorforest::ObliqueInequalityDecisionNodeEvaluator; +using tensorflow::decision_trees::InequalityTest; +using tensorflow::decision_trees::MatchingValuesTest; + +TEST(InequalityDecisionNodeEvaluatorTest, TestLessOrEqual) { + InequalityTest test; + test.mutable_feature_id()->mutable_id()->set_value("0"); + test.mutable_threshold()->set_float_value(3.0); + test.set_type(InequalityTest::LESS_OR_EQUAL); + std::unique_ptr eval( + new InequalityDecisionNodeEvaluator(test, 0, 1)); + + std::unique_ptr dataset( + new tensorflow::tensorforest::TestableDataSet( + {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1)); + + ASSERT_EQ(eval->Decide(dataset, 2), 0); + ASSERT_EQ(eval->Decide(dataset, 3), 0); + ASSERT_EQ(eval->Decide(dataset, 4), 1); +} + +TEST(InequalityDecisionNodeEvaluatorTest, TestStrictlyLess) { + InequalityTest test; + test.mutable_feature_id()->mutable_id()->set_value("0"); + test.mutable_threshold()->set_float_value(3.0); + test.set_type(InequalityTest::LESS_THAN); + std::unique_ptr eval( + new InequalityDecisionNodeEvaluator(test, 0, 1)); + + std::unique_ptr dataset( + new tensorflow::tensorforest::TestableDataSet( + {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1)); + + ASSERT_EQ(eval->Decide(dataset, 2), 0); + ASSERT_EQ(eval->Decide(dataset, 3), 1); + ASSERT_EQ(eval->Decide(dataset, 4), 1); +} + +TEST(MatchingDecisionNodeEvaluatorTest, Basic) { + MatchingValuesTest test; + test.mutable_feature_id()->mutable_id()->set_value("0"); + test.add_value()->set_float_value(3.0); + test.add_value()->set_float_value(5.0); + + std::unique_ptr eval( + new MatchingValuesDecisionNodeEvaluator(test, 0, 1)); + + std::unique_ptr dataset( + new tensorflow::tensorforest::TestableDataSet( + {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1)); + + ASSERT_EQ(eval->Decide(dataset, 2), 1); + ASSERT_EQ(eval->Decide(dataset, 3), 0); + ASSERT_EQ(eval->Decide(dataset, 4), 1); + ASSERT_EQ(eval->Decide(dataset, 5), 0); +} + +TEST(MatchingDecisionNodeEvaluatorTest, Inverse) { + MatchingValuesTest test; + test.mutable_feature_id()->mutable_id()->set_value("0"); + test.add_value()->set_float_value(3.0); + test.add_value()->set_float_value(5.0); + test.set_inverse(true); + + std::unique_ptr eval( + new MatchingValuesDecisionNodeEvaluator(test, 0, 1)); + + std::unique_ptr dataset( + new tensorflow::tensorforest::TestableDataSet( + {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 1)); + + ASSERT_EQ(eval->Decide(dataset, 2), 0); + ASSERT_EQ(eval->Decide(dataset, 3), 1); + ASSERT_EQ(eval->Decide(dataset, 4), 0); + ASSERT_EQ(eval->Decide(dataset, 5), 1); +} + +TEST(ObliqueDecisionNodeEvaluatorTest, Basic) { + InequalityTest test; + auto* feat1 = test.mutable_oblique()->add_features(); + feat1->mutable_id()->set_value("0"); + test.mutable_oblique()->add_weights(1.0); + auto* feat2 = test.mutable_oblique()->add_features(); + feat2->mutable_id()->set_value("1"); + test.mutable_oblique()->add_weights(1.0); + + test.mutable_threshold()->set_float_value(3.0); + test.set_type(InequalityTest::LESS_OR_EQUAL); + + std::unique_ptr eval( + new ObliqueInequalityDecisionNodeEvaluator(test, 0, 1)); + + std::unique_ptr dataset( + new tensorflow::tensorforest::TestableDataSet( + {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}, 2)); + + ASSERT_EQ(eval->Decide(dataset, 0), 0); + ASSERT_EQ(eval->Decide(dataset, 1), 1); +} + +} // namespace +} // namespace tensorflow + diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.cc b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.cc new file mode 100644 index 00000000000..9f5d9485143 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.cc @@ -0,0 +1,98 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h" + +#include + +namespace tensorflow { +namespace tensorforest { + +void FertileStatsResource::AddExampleToStatsAndInitialize( + const std::unique_ptr& input_data, + const InputTarget* target, const std::vector& examples, + int32 node_id, int32 node_depth, bool* is_finished) { + // Set leaf's counts for calculating probabilities. + for (int example : examples) { + model_op_->UpdateModel(&leaf_stats_[node_id], target, example); + } + + // Update stats or initialize if needed. + if (collection_op_->IsInitialized(node_id)) { + collection_op_->AddExample(input_data, target, examples, node_id); + } else { + // This throws away any extra examples, which is more inefficient towards + // the top but gradually becomes less of an issue as the tree grows. + for (int example : examples) { + collection_op_->CreateAndInitializeCandidateWithExample( + input_data, example, node_id); + if (collection_op_->IsInitialized(node_id)) { + break; + } + } + } + + *is_finished = collection_op_->IsFinished(node_id); +} + +void FertileStatsResource::AllocateNode(int32 node_id, int32 depth) { + leaf_stats_[node_id] = LeafStat(); + model_op_->InitModel(&leaf_stats_[node_id]); + collection_op_->InitializeSlot(node_id, depth); +} + +void FertileStatsResource::Allocate(int32 parent_depth, + const std::vector& new_children) { + const int32 children_depth = parent_depth + 1; + for (const int32 child : new_children) { + AllocateNode(child, children_depth); + } +} + +void FertileStatsResource::Clear(int32 node) { + collection_op_->ClearSlot(node); + leaf_stats_.erase(node); +} + +bool FertileStatsResource::BestSplit(int32 node_id, SplitCandidate* best, + int32* depth) { + return collection_op_->BestSplit(node_id, best, depth); +} + +void FertileStatsResource::MaybeInitialize() { + if (leaf_stats_.empty()) { + AllocateNode(0, 0); + } +} + +void FertileStatsResource::ExtractFromProto(const FertileStats& stats) { + collection_op_ = + SplitCollectionOperatorFactory::CreateSplitCollectionOperator(params_); + collection_op_->ExtractFromProto(stats); + for (int i = 0; i < stats.node_to_slot_size(); ++i) { + const auto& slot = stats.node_to_slot(i); + leaf_stats_[slot.node_id()] = slot.leaf_stats(); + } +} + +void FertileStatsResource::PackToProto(FertileStats* stats) const { + for (const auto& entry : leaf_stats_) { + auto* slot = stats->add_node_to_slot(); + *slot->mutable_leaf_stats() = entry.second; + slot->set_node_id(entry.first); + } + collection_op_->PackToProto(stats); +} +} // namespace tensorforest +} // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h new file mode 100644 index 00000000000..34ec945e846 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h @@ -0,0 +1,110 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_ + +#include + +#include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h" +#include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h" +#include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace tensorforest { + +// Stores a FertileStats proto and implements operations on it. +class FertileStatsResource : public ResourceBase { + public: + // Constructor. + explicit FertileStatsResource(const TensorForestParams& params) + : params_(params) { + model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(params_); + } + + string DebugString() override { + return "FertileStats"; + } + + void ExtractFromProto(const FertileStats& stats); + + void PackToProto(FertileStats* stats) const; + + // Resets the resource and frees the proto. + // Caller needs to hold the mutex lock while calling this. + void Reset() { + leaf_stats_.clear(); + } + + // Reset the stats for a node, but leave the leaf_stats intact. + void ResetSplitStats(int32 node_id, int32 depth) { + collection_op_->ClearSlot(node_id); + collection_op_->InitializeSlot(node_id, depth); + } + + mutex* get_mutex() { return &mu_; } + + void MaybeInitialize(); + + // Applies the example to the given leaf's statistics. Also applies it to the + // node's fertile slot's statistics if or initializes a split candidate, + // where applicable. Returns if the node is finished or if it's ready to + // allocate to a fertile slot. + void AddExampleToStatsAndInitialize( + const std::unique_ptr& input_data, + const InputTarget* target, const std::vector& examples, + int32 node_id, int32 node_depth, bool* is_finished); + + // Allocate a fertile slot for each ready node, then new children up to + // max_fertile_nodes_. + void Allocate(int32 parent_depth, const std::vector& new_children); + + // Remove a node's fertile slot. Should only be called when the node is + // no longer a leaf. + void Clear(int32 node); + + // Return the best SplitCandidate for a node, or NULL if no suitable split + // was found. + bool BestSplit(int32 node_id, SplitCandidate* best, int32* depth); + + const LeafStat& leaf_stat(int32 node_id) { + return leaf_stats_[node_id]; + } + + void set_leaf_stat(const LeafStat& stat, int32 node_id) { + leaf_stats_[node_id] = stat; + } + + private: + mutex mu_; + std::shared_ptr model_op_; + std::unique_ptr collection_op_; + std::unordered_map leaf_stats_; + const TensorForestParams params_; + + void AllocateNode(int32 node_id, int32 depth); +}; + + +} // namespace tensorforest +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_FERTILE_STATS_RESOURCE_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc new file mode 100644 index 00000000000..fe7b24d36c9 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.cc @@ -0,0 +1,762 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h" + +#include +#include +#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h" +#include "tensorflow/core/lib/random/distribution_sampler.h" + + +namespace tensorflow { +namespace tensorforest { + +// When creating evaluators for the split candidates, use these +// for the left and right return values. +static const int32 LEFT_INDEX = 0; +static const int32 RIGHT_INDEX = 1; + +GrowStats::GrowStats(const TensorForestParams& params, int32 depth) + : depth_(depth), + params_(params), + split_after_samples_(ResolveParam(params.split_after_samples(), depth)), + num_splits_to_consider_( + ResolveParam(params.num_splits_to_consider(), depth)), + num_outputs_(params.num_outputs()) {} + +void GrowStats::AddSplit(const decision_trees::BinaryNode& split) { + splits_.push_back(split); + evaluators_.emplace_back( + CreateBinaryDecisionNodeEvaluator(split, LEFT_INDEX, RIGHT_INDEX)); + AddSplitStats(); +} + +void GrowStats::RemoveSplit(int split_num) { + splits_.erase(splits_.begin() + split_num); + evaluators_.erase(evaluators_.begin() + split_num); + RemoveSplitStats(split_num); +} + +// ------------------------ Classification --------------------------- // + +ClassificationStats::ClassificationStats(const TensorForestParams& params, + int32 depth) + : GrowStats(params, depth), finish_early_(false) { + // Early splitting params. + if (params.finish_type().type() == SPLIT_FINISH_BASIC) { + min_split_samples_ = split_after_samples_; + } else { + if (!params.has_dominate_fraction() || !params.has_min_split_samples()) { + LOG(FATAL) << "dominate_fraction and min_split_samples " + << "required for early-finish strategy."; + } else { + min_split_samples_ = ResolveParam(params.min_split_samples(), depth); + finish_check_every_ = + ResolveParam(params.finish_type().check_every_steps(), depth); + finish_sample_epoch_ = min_split_samples_ / finish_check_every_; + + dominate_fraction_ = ResolveParam(params.dominate_fraction(), depth_); + if (dominate_fraction_ <= 0 || dominate_fraction_ > 1.0) { + LOG(FATAL) << "Invalid dominate fraction " << dominate_fraction_; + } + } + } + + // Pruning params. + if (params.pruning_type().type() != SPLIT_PRUNE_NONE) { + prune_check_every_ = + ResolveParam(params.pruning_type().prune_every_samples(), depth); + prune_sample_epoch_ = 1; + prune_fraction_ = 0.0; + switch (params_.pruning_type().type()) { + case SPLIT_PRUNE_HALF: + prune_fraction_ = 0.5; + break; + case SPLIT_PRUNE_QUARTER: + prune_fraction_ = 0.25; + break; + case SPLIT_PRUNE_10_PERCENT: + prune_fraction_ = 0.10; + break; + case SPLIT_PRUNE_HOEFFDING: + dominate_fraction_ = ResolveParam(params.dominate_fraction(), depth_); + half_ln_dominate_frac_ = 0.5 * log(1.0 / (1.0 - dominate_fraction_)); + break; + default: + LOG(WARNING) << "Unknown pruning type"; + } + } + + if (params.use_running_stats_method()) { + left_gini_.reset(new RunningGiniScores()); + right_gini_.reset(new RunningGiniScores()); + } + + uint64 time_seed = static_cast(std::clock()); + single_rand_ = std::unique_ptr( + new random::PhiloxRandom(time_seed)); + rng_ = std::unique_ptr( + new random::SimplePhilox(single_rand_.get())); +} + +bool ClassificationStats::IsFinished() const { + bool basic = weight_sum_ >= split_after_samples_ && num_outputs_seen() > 1; + return basic || finish_early_; +} + +float ClassificationStats::MaybeCachedGiniScore(int split, float* left_sum, + float* right_sum) const { + if (left_gini_ == nullptr) { + return GiniScore(split, left_sum, right_sum); + } else { + *left_sum = left_gini_->sum(split); + const float left = WeightedSmoothedGini( + *left_sum, left_gini_->square(split), num_outputs_); + + *right_sum = right_gini_->sum(split); + const float right = WeightedSmoothedGini( + *right_sum, right_gini_->square(split), num_outputs_); + + return left + right; + } +} + +void ClassificationStats::AddExample( + const std::unique_ptr& input_data, const InputTarget* target, + int example) { + const int64 int_label = target->GetTargetAsClassIndex(example, 0); + const float weight = target->GetTargetWeight(example); + + for (int i = 0; i < num_splits(); ++i) { + auto& eval = evaluators_[i]; + if (eval->Decide(input_data, example) == LEFT_INDEX) { + if (left_gini_ != nullptr) { + left_gini_->update(i, left_count(i, int_label), weight); + } + ClassificationAddLeftExample(i, int_label, weight); + } else if (right_gini_ != nullptr) { + right_gini_->update(i, right_count(i, int_label), weight); + } + } + + ClassificationAddTotalExample(int_label, weight); + + weight_sum_ += weight; + + CheckFinishEarly(); + CheckPrune(); +} + +void ClassificationStats::CheckPrune() { + if (IsFinished() || weight_sum_ < prune_sample_epoch_ * prune_check_every_) { + return; + } + ++prune_sample_epoch_; + + if (params_.pruning_type().type() == SPLIT_PRUNE_HOEFFDING) { + CheckPruneHoeffding(); + return; + } + + const int to_remove = num_splits() * prune_fraction_; + if (to_remove <= 0) { + return; + } + + // pair ordering is first-then-second by default, no need for custom + // comparison. Use std::greater to make it a min-heap. + std::priority_queue, std::vector>, + std::greater>> + worst; + + // Track indices that are in the heap so we can iterate over them + // by largest-first later. + std::set indices; + + for (int i = 0; i < num_splits(); ++i) { + float left, right; + const float split_score = MaybeCachedGiniScore(i, &left, &right); + if (worst.size() < to_remove) { + worst.push(std::pair(split_score, i)); + indices.insert(i); + } else if (worst.top().first < split_score) { + indices.erase(worst.top().second); + worst.pop(); + worst.push(std::pair(split_score, i)); + indices.insert(i); + } + } + + // traverse indices from the back so that they are removed correctly. + for (auto it = indices.rbegin(); it != indices.rend(); ++it) { + RemoveSplit(*it); + } +} + +void ClassificationStats::CheckPruneHoeffding() { + std::vector split_scores(num_splits()); + // Find best split score + float best_split_score = FLT_MAX; + for (int i = 0; i < num_splits(); ++i) { + float left, right; + split_scores[i] = MaybeCachedGiniScore(i, &left, &right); + if (split_scores[i] < best_split_score) { + best_split_score = split_scores[i]; + } + } + + // We apply the Hoeffding bound to the difference between the best split + // score and the i-th split score. + // Raw Gini ranges from 0 to 1 - (1/n), but our gini score is weighted. + const float num_classes = params_.num_outputs(); + const float gini_diff_range = weight_sum_ * (1.0 - 1.0 / num_classes); + float epsilon = gini_diff_range * sqrt(half_ln_dominate_frac_ / weight_sum_); + for (int i = num_splits() - 1; i >= 0; i--) { + if (split_scores[i] - best_split_score > epsilon) { + RemoveSplit(i); + } + } +} + +void ClassificationStats::CheckFinishEarly() { + if (weight_sum_ < min_split_samples_ || + weight_sum_ < finish_sample_epoch_ * finish_check_every_) { + return; + } + ++finish_sample_epoch_; + + if (params_.finish_type().type() == SPLIT_FINISH_DOMINATE_HOEFFDING) { + CheckFinishEarlyHoeffding(); + } else if (params_.finish_type().type() == SPLIT_FINISH_DOMINATE_BOOTSTRAP) { + CheckFinishEarlyBootstrap(); + } +} + +void ClassificationStats::CheckFinishEarlyHoeffding() { + // Each term in the Gini impurity can range from 0 to 0.5 * 0.5. + float range = 0.25 * static_cast(params_.num_outputs()) * weight_sum_; + + float hoeffding_bound = + range * sqrt(log(1.0 / (1.0 - dominate_fraction_)) / (2.0 * weight_sum_)); + + float unused_left_sum, unused_right_sum; + std::function score_fn = + std::bind(&ClassificationStats::MaybeCachedGiniScore, this, + std::placeholders::_1, &unused_left_sum, &unused_right_sum); + + float best_score; + int32 best_index; + float second_best_score; + int32 second_best_index; + GetTwoBest(num_splits(), score_fn, &best_score, &best_index, + &second_best_score, &second_best_index); + + finish_early_ = (second_best_score - best_score) > hoeffding_bound; +} + +void ClassificationStats::MakeBootstrapWeights(int index, + std::vector* weights) { + int n = weight_sum_; + float denom = static_cast(n) + static_cast(num_outputs_); + for (int i = 0; i < num_outputs_; ++i) { + // Use the Laplace smoothed per-class probabilities when generating the + // bootstrap samples. + (*weights)[i] = (left_count(index, i) + 1.0) / denom; + (*weights)[num_outputs_ + i] = (right_count(index, i) + 1.0) / denom; + } +} + +int ClassificationStats::NumBootstrapSamples() const { + float p = 1.0 - dominate_fraction_; + int bootstrap_samples = 1; + while (p < 1.0) { + ++bootstrap_samples; + p = p * 2; + } + return bootstrap_samples; +} + +void ClassificationStats::CheckFinishEarlyBootstrap() { + float unused_left_sum, unused_right_sum; + std::function score_fn = + std::bind(&ClassificationStats::MaybeCachedGiniScore, this, + std::placeholders::_1, &unused_left_sum, &unused_right_sum); + + float best_score; + int32 best_index; + float second_best_score; + int32 second_best_index; + GetTwoBest(num_splits(), score_fn, &best_score, &best_index, + &second_best_score, &second_best_index); + + std::vector weights1(num_outputs_ * 2); + MakeBootstrapWeights(best_index, &weights1); + random::DistributionSampler ds1(weights1); + + std::vector weights2(num_outputs_ * 2); + MakeBootstrapWeights(second_best_index, &weights2); + random::DistributionSampler ds2(weights2); + + const int bootstrap_samples = NumBootstrapSamples(); + + int worst_g1 = 0; + for (int i = 0; i < bootstrap_samples; i++) { + int g1 = BootstrapGini(weight_sum_, 2 * num_outputs_, ds1, rng_.get()); + worst_g1 = std::max(worst_g1, g1); + } + + int best_g2 = 99; + for (int i = 0; i < bootstrap_samples; i++) { + int g2 = BootstrapGini(weight_sum_, 2 * num_outputs_, ds2, rng_.get()); + best_g2 = std::min(best_g2, g2); + } + + finish_early_ = worst_g1 < best_g2; +} + +// ------------------------ Dense Classification --------------------------- // +void DenseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) { + Initialize(); + if (!slot.has_post_init_leaf_stats()) { + return; + } + const int32 num_classes = params_.num_outputs(); + weight_sum_ = slot.post_init_leaf_stats().weight_sum(); + const auto& class_stats = + slot.post_init_leaf_stats().classification().dense_counts(); + + // Total counts. + for (int i = 0; i < num_classes; ++i) { + total_counts_[i] = class_stats.value(i).float_value(); + num_outputs_seen_ += total_counts_[i] != 0; + } + + // Candidate counts and splits. + int split_num = 0; + for (const auto& cand : slot.candidates()) { + AddSplit(cand.split()); + const auto& left_stats = cand.left_stats().classification().dense_counts(); + for (int i = 0; i < num_classes; ++i) { + const float val = left_stats.value(i).float_value(); + mutable_left_count(split_num, i) = val; + MaybeInitializeRunningCount(split_num, val); + } + ++split_num; + } +} + +void DenseClassificationGrowStats::PackToProto(FertileSlot* slot) const { + auto* slot_stats = slot->mutable_post_init_leaf_stats(); + slot_stats->set_weight_sum(weight_sum_); + + auto* class_stats = slot->mutable_post_init_leaf_stats() + ->mutable_classification() + ->mutable_dense_counts(); + for (int i = 0; i < num_outputs_; ++i) { + class_stats->add_value()->set_float_value(total_counts_[i]); + } + + for (int split_num = 0; split_num < num_splits(); ++split_num) { + auto* cand = slot->add_candidates(); + *cand->mutable_split() = splits_[split_num]; + auto* left_stats = cand->mutable_left_stats() + ->mutable_classification() + ->mutable_dense_counts(); + for (int i = 0; i < num_outputs_; ++i) { + left_stats->add_value()->set_float_value(left_count(split_num, i)); + } + } +} + +float DenseClassificationGrowStats::GiniScore(int split, float* left_sum, + float* right_sum) const { + float left_square = 0, right_square = 0; + *left_sum = 0; + *right_sum = 0; + for (int j = 0; j < num_outputs_; ++j) { + const float left = left_count(split, j); + *left_sum += left; + left_square += left * left; + const float right = right_count(split, j); + *right_sum += right; + right_square += right * right; + } + + const float left_score = + WeightedSmoothedGini(*left_sum, left_square, num_outputs_); + const float right_score = + WeightedSmoothedGini(*right_sum, right_square, num_outputs_); + return left_score + right_score; +} + +bool DenseClassificationGrowStats::BestSplit(SplitCandidate* best) const { + float min_score = FLT_MAX; + int best_index = -1; + float best_left_sum, best_right_sum; + + // Calculate sums. + for (int i = 0; i < num_splits(); ++i) { + float left_sum, right_sum; + const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum); + // Find the lowest gini. + if (left_sum > 0 && right_sum > 0 && + split_score < min_score) { // useless check + min_score = split_score; + best_index = i; + best_left_sum = left_sum; + best_right_sum = right_sum; + } + } + + // This could happen if all the splits are useless. + if (best_index < 0) { + return false; + } + + // Fill in stats to be used for leaf model. + *best->mutable_split() = splits_[best_index]; + // Left + auto* left = best->mutable_left_stats(); + auto* left_class_stats = left->mutable_classification(); + left->set_weight_sum(best_left_sum); + auto* left_counts = left_class_stats->mutable_dense_counts(); + for (int i = 0; i < params_.num_outputs(); ++i) { + left_counts->add_value()->set_float_value( + left_count(best_index, i)); + } + + // Right + auto* right = best->mutable_right_stats(); + auto* right_class_stats = right->mutable_classification(); + right->set_weight_sum(best_right_sum); + auto* right_counts = right_class_stats->mutable_dense_counts(); + for (int i = 0; i < params_.num_outputs(); ++i) { + right_counts->add_value()->set_float_value( + total_counts_[i] - left_count(best_index, i)); + } + return true; +} + +// ------------------------ Sparse Classification --------------------------- // +void SparseClassificationGrowStats::ExtractFromProto(const FertileSlot& slot) { + Initialize(); + if (!slot.has_post_init_leaf_stats()) { + return; + } + weight_sum_ = slot.post_init_leaf_stats().weight_sum(); + const auto& class_stats = + slot.post_init_leaf_stats().classification().sparse_counts(); + + // Total counts. + for (auto const& entry : class_stats.sparse_value()) { + total_counts_[entry.first] = entry.second.float_value(); + } + + // Candidate counts and splits. + int split_num = 0; + for (const auto& cand : slot.candidates()) { + AddSplit(cand.split()); + const auto& left_stats = cand.left_stats().classification().sparse_counts(); + for (auto const& entry : left_stats.sparse_value()) { + const float val = entry.second.float_value(); + left_counts_[split_num][entry.first] = val; + MaybeInitializeRunningCount(split_num, val); + } + ++split_num; + } +} + +void SparseClassificationGrowStats::PackToProto(FertileSlot* slot) const { + auto* slot_stats = slot->mutable_post_init_leaf_stats(); + slot_stats->set_weight_sum(weight_sum_); + + auto* class_stats = slot->mutable_post_init_leaf_stats() + ->mutable_classification() + ->mutable_sparse_counts() + ->mutable_sparse_value(); + for (const auto& entry : total_counts_) { + decision_trees::Value val; + val.set_float_value(entry.second); + (*class_stats)[entry.first] = val; + } + + for (int split_num = 0; split_num < num_splits(); ++split_num) { + auto* cand = slot->add_candidates(); + *cand->mutable_split() = splits_[split_num]; + auto* left_stats = cand->mutable_left_stats() + ->mutable_classification() + ->mutable_sparse_counts() + ->mutable_sparse_value(); + for (const auto& entry : left_counts_[split_num]) { + decision_trees::Value val; + val.set_float_value(entry.second); + (*left_stats)[entry.first] = val; + } + } +} + +float SparseClassificationGrowStats::GiniScore( + int split, float* left_sum, float* right_sum) const { + float left_square = 0, right_square = 0; + *left_sum = 0; + *right_sum = 0; + for (const auto& entry : total_counts_) { + const int label = entry.first; + float left = 0; + float right = 0; + auto it = left_counts_[split].find(label); + if (it == left_counts_[split].end()) { + right = entry.second; + } else { + left = it->second; + right = entry.second - it->second; + } + *left_sum += left; + left_square += left * left; + *right_sum += right; + right_square += right * right; + } + const int32 num_classes = params_.num_outputs(); + const float left_score = + WeightedSmoothedGini(*left_sum, left_square, num_classes); + const float right_score = + WeightedSmoothedGini(*right_sum, right_square, num_classes); + return left_score + right_score; +} + +bool SparseClassificationGrowStats::BestSplit(SplitCandidate* best) const { + float min_score = FLT_MAX; + int best_index = -1; + float best_left_sum = -1; + float best_right_sum = -1; + + // Find the lowest gini. + for (int i = 0; i < num_splits(); ++i) { + float left_sum, right_sum; + const float split_score = MaybeCachedGiniScore(i, &left_sum, &right_sum); + if (left_sum > 0 && right_sum > 0 && + split_score < min_score) { // useless check + min_score = split_score; + best_index = i; + best_left_sum = left_sum; + best_right_sum = right_sum; + } + } + + // This could happen if all the splits are useless. + if (best_index < 0) { + return false; + } + + // Fill in stats to be used for leaf model. + *best->mutable_split() = splits_[best_index]; + // Left + auto* left = best->mutable_left_stats(); + auto* left_class_stats = left->mutable_classification(); + left->set_weight_sum(best_left_sum); + auto* left_counts = + left_class_stats->mutable_sparse_counts()->mutable_sparse_value(); + + // Right + auto* right = best->mutable_right_stats(); + auto* right_class_stats = right->mutable_classification(); + right->set_weight_sum(best_right_sum); + auto* right_counts = + right_class_stats->mutable_sparse_counts()->mutable_sparse_value(); + + for (const auto& entry : total_counts_) { + auto it = left_counts_[best_index].find(entry.first); + if (it == left_counts_[best_index].end()) { + (*right_counts)[entry.first].set_float_value(entry.second); + } else { + const float left = it->second; + const float right = entry.second - it->second; + (*left_counts)[entry.first].set_float_value(left); + if (right > 0) { + (*right_counts)[entry.first].set_float_value(right); + } + } + } + return true; +} + +// --------------------- Least Squares Regression --------------------------- // +void LeastSquaresRegressionGrowStats::ExtractFromProto( + const FertileSlot& slot) { + const int32 num_outputs = params_.num_outputs(); + Initialize(); + if (!slot.has_post_init_leaf_stats()) { + return; + } + weight_sum_ = slot.post_init_leaf_stats().weight_sum(); + const auto& total_sums = + slot.post_init_leaf_stats().regression().mean_output(); + const auto& total_squares = + slot.post_init_leaf_stats().regression().mean_output_squares(); + + // Total counts. + for (int i = 0; i < num_outputs; ++i) { + total_sum_[i] = total_sums.value(i).float_value(); + total_sum_squares_[i] = total_squares.value(i).float_value(); + } + + // Candidate counts and splits. + int split_num = 0; + for (const auto& cand : slot.candidates()) { + AddSplit(cand.split()); + const auto& sums = cand.left_stats().regression().mean_output(); + const auto& squares = cand.left_stats().regression().mean_output_squares(); + for (int i = 0; i < num_outputs; ++i) { + left_sum(split_num, i) = sums.value(i).float_value(); + left_square(split_num, i) = squares.value(i).float_value(); + } + left_counts_[split_num] = cand.left_stats().weight_sum(); + ++split_num; + } +} + +void LeastSquaresRegressionGrowStats::PackToProto(FertileSlot* slot) const { + const int32 num_outputs = params_.num_outputs(); + auto* slot_stats = slot->mutable_post_init_leaf_stats(); + slot_stats->set_weight_sum(weight_sum_); + + auto* total_sums = slot->mutable_post_init_leaf_stats() + ->mutable_regression() + ->mutable_mean_output(); + auto* total_squares = slot->mutable_post_init_leaf_stats() + ->mutable_regression() + ->mutable_mean_output_squares(); + + for (int i = 0; i < total_sum_.size(); ++i) { + total_sums->add_value()->set_float_value(total_sum_[i]); + total_squares->add_value()->set_float_value(total_sum_squares_[i]); + } + + for (int split_num = 0; split_num < num_splits(); ++split_num) { + auto* cand = slot->add_candidates(); + *cand->mutable_split() = splits_[split_num]; + auto* sums = cand->mutable_left_stats() + ->mutable_regression() + ->mutable_mean_output(); + auto* squares = cand->mutable_left_stats() + ->mutable_regression() + ->mutable_mean_output_squares(); + for (int i = 0; i < num_outputs; ++i) { + sums->add_value()->set_float_value(left_sum(split_num, i)); + squares->add_value()->set_float_value(left_square(split_num, i)); + } + cand->mutable_left_stats()->set_weight_sum(left_counts_[split_num]); + } +} + +void LeastSquaresRegressionGrowStats::AddExample( + const std::unique_ptr& input_data, const InputTarget* target, + int example) { + const int32 num_outputs = params_.num_outputs(); + // Update splits. + for (int i = 0; i < num_splits(); ++i) { + auto& eval = evaluators_[i]; + if (eval->Decide(input_data, example) == LEFT_INDEX) { + for (int j = 0; j < num_outputs; ++j) { + const float output = target->GetTargetAsContinuous(example, j); + left_sum(i, j) += output; + left_square(i, j) += output * output; + } + ++left_counts_[i]; + } + } + + // Update totals. + for (int i = 0; i < num_outputs; ++i) { + const float output = target->GetTargetAsContinuous(example, i); + total_sum_[i] += output; + total_sum_squares_[i] += output * output; + } + weight_sum_ += 1.0; +} + +float LeastSquaresRegressionGrowStats::SplitVariance(int split) const { + float total_variance = 0; + for (int i = 0; i < params_.num_outputs(); ++i) { + // Left side + const float le_x = + left_sum(split, i) / left_counts_[split]; + + const float le_x2 = + left_square(split, i) / left_counts_[split]; + total_variance += le_x2 - le_x * le_x; + + // Right side + const float re_x = (total_sum_[i] - left_sum(split, i)) / + (weight_sum_ - left_counts_[split]); + + const float re_x2 = + (total_sum_squares_[i] - left_square(split, i)) / + (weight_sum_ - left_counts_[split]); + total_variance += re_x2 - re_x * re_x; + } + return total_variance; +} + +bool LeastSquaresRegressionGrowStats::BestSplit(SplitCandidate* best) const { + float min_score = FLT_MAX; + int best_index = -1; + const int32 num_outputs = params_.num_outputs(); + for (int i = 0; i < num_splits(); ++i) { + if (left_counts_[i] > 0 && weight_sum_ - left_counts_[i] > 0) { + const float split_score = SplitVariance(i); + if (split_score < min_score) { + min_score = split_score; + best_index = i; + } + } + } + + // This could happen if all the splits are useless. + if (best_index < 0) { + return false; + } + + // Fill in right stats to be used for leaf model. + *best->mutable_split() = splits_[best_index]; + // Left + auto* left = best->mutable_left_stats(); + auto* left_reg_stats = left->mutable_regression(); + left->set_weight_sum(left_counts_[best_index]); + auto* left_output_sum = left_reg_stats->mutable_mean_output(); + for (int i = 0; i < num_outputs; ++i) { + left_output_sum->add_value()->set_float_value( + left_sum(best_index, i)); + } + + // Right + auto* right = best->mutable_right_stats(); + auto* right_reg_stats = right->mutable_regression(); + right->set_weight_sum(weight_sum_ - left_counts_[best_index]); + auto* right_output_sum = right_reg_stats->mutable_mean_output(); + for (int i = 0; i < num_outputs; ++i) { + right_output_sum->add_value()->set_float_value( + total_sum_[i] - left_sum(best_index, i)); + } + return true; +} + +bool LeastSquaresRegressionGrowStats::IsFinished() const { + return weight_sum_ >= split_after_samples_; +} + +} // namespace tensorforest +} // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h new file mode 100644 index 00000000000..8d32b4961b1 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h @@ -0,0 +1,470 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_ +#include +#include + +#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/decision_node_evaluator.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/params.h" +#include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h" +#include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/simple_philox.h" + +namespace tensorflow { +namespace tensorforest { + +// Base class for tracking stats necessary to split a leaf. +// Holds and tracks stats for every candidate split. +class GrowStats { + public: + virtual ~GrowStats() {} + // Perform any initialization. + virtual void Initialize() = 0; + + // Add an example to any stats being collected. + virtual void AddExample(const std::unique_ptr& input_data, + const InputTarget* target, int example) = 0; + + // Fill in the best split, return false if none were valid. + virtual bool BestSplit(SplitCandidate* best) const = 0; + + // Return true if this leaf is finished splitting. + virtual bool IsFinished() const = 0; + + // Get the split_num BinaryNode. + const decision_trees::BinaryNode& Split(int split_num) const { + return splits_[split_num]; + } + + // Clear all state. + virtual void Clear() { + weight_sum_ = 0; + splits_.clear(); + evaluators_.clear(); + ClearInternal(); + } + + virtual void ExtractFromProto(const FertileSlot& slot) = 0; + virtual void PackToProto(FertileSlot* slot) const = 0; + + // Add split to the list of candidate splits. + void AddSplit(const decision_trees::BinaryNode& split); + void RemoveSplit(int split_num); + + int num_splits() const { + return splits_.size(); + } + + float weight_sum() const { + return weight_sum_; + } + + bool IsInitialized() const { + return weight_sum_ > 0 || splits_.size() == num_splits_to_consider_; + } + + int32 depth() const { + return depth_; + } + + protected: + GrowStats(const TensorForestParams& params, int32 depth); + + // Function called by AddSplit for subclasses to initialize stats for a split. + virtual void AddSplitStats() = 0; + + virtual void RemoveSplitStats(int split_num) = 0; + + // Function called by Clear for subclasses to clear their state. + virtual void ClearInternal() = 0; + + std::vector splits_; + std::vector> evaluators_; + + float weight_sum_; + + const int32 depth_; + + const TensorForestParams& params_; + + // We cache these beacuse they're used often. + const int split_after_samples_; + const int num_splits_to_consider_; + + const int32 num_outputs_; +}; + +// Don't track anything, useful for systems that want to track split +// candidates but train the model in some other way. +class SimpleStats : public GrowStats { + public: + SimpleStats(const TensorForestParams& params, int32 depth) + : GrowStats(params, depth) {} + void Initialize() override {} + + void ExtractFromProto(const FertileSlot& slot) override {} + void PackToProto(FertileSlot* slot) const override {} + + void AddExample(const std::unique_ptr& input_data, + const InputTarget* target, int example) override { + weight_sum_ += target->GetTargetWeight(example); + } + + bool BestSplit(SplitCandidate* best) const override { return false; } + + bool IsFinished() const override { + return weight_sum_ >= split_after_samples_; + } + + protected: + void AddSplitStats() override {} + void RemoveSplitStats(int split_num) override {} + void ClearInternal() override {} +}; + +// Tracks the sum and square of one side of a split for each Gini calculation. +class RunningGiniScores { + public: + float sum(int split) const { return sum_[split]; } + float square(int split) const { return square_[split]; } + + void update(int split, float old_val, float weight) { + sum_[split] += weight; + const float new_val = old_val + weight; + square_[split] = square_[split] - old_val * old_val + new_val * new_val; + } + + void add_split() { + sum_.push_back(0); + square_.push_back(0); + } + + void remove_split(int i) { + sum_.erase(sum_.begin() + i); + square_.erase(square_.begin() + i); + } + + private: + std::vector sum_; + std::vector square_; +}; + +class ClassificationStats : public GrowStats { + public: + ClassificationStats(const TensorForestParams& params, int32 depth); + + bool IsFinished() const override; + + void AddExample(const std::unique_ptr& input_data, + const InputTarget* target, int example) override; + + protected: + virtual float GiniScore(int split, float* left_sum, + float* right_sum) const = 0; + virtual int num_outputs_seen() const = 0; + virtual float left_count(int split, int class_num) const = 0; + virtual float right_count(int split, int class_num) const = 0; + + virtual void ClassificationAddLeftExample( + int split, int64 int_label, float weight) = 0; + virtual void ClassificationAddTotalExample(int64 int_label, float weight) = 0; + + virtual void ClassificationAddSplitStats() = 0; + virtual void ClassificationRemoveSplitStats(int split) = 0; + + void AddSplitStats() override { + if (left_gini_ != nullptr) { + left_gini_->add_split(); + right_gini_->add_split(); + } + ClassificationAddSplitStats(); + } + void RemoveSplitStats(int split) override { + if (left_gini_ != nullptr) { + left_gini_->remove_split(split); + right_gini_->remove_split(split); + } + ClassificationRemoveSplitStats(split); + } + + // Virtual so we can override these to test. + virtual void CheckFinishEarly(); + virtual void CheckFinishEarlyHoeffding(); + virtual void CheckFinishEarlyBootstrap(); + + virtual void CheckPrune(); + + // Implement SplitPruningStrategyType::SPLIT_PRUNE_HOEFFDING. + void CheckPruneHoeffding(); + + // Return the gini score, possibly being calculated from sums and squares + // saved in left_gini_ and right_gini_, otherwise calculated from raw counts. + float MaybeCachedGiniScore(int split, float* left_sum, + float* right_sum) const; + + // Initialize the sum and squares of left_gini_ and right_gini_ for given + // split and value (being extracted from a proto), if left_gini_ isn't null. + void MaybeInitializeRunningCount(int split, float val) { + if (left_gini_ != nullptr) { + left_gini_->update(split, 0, val); + right_gini_->update(split, 0, val); + } + } + + int NumBootstrapSamples() const; + + // Populate *weights with the smoothed per-class frequencies needed to + // initialize a DistributionSampler. + void MakeBootstrapWeights(int index, std::vector* weights); + + // Accessors for RunningGiniScores objects, for testing. + virtual const std::unique_ptr& get_left_gini() const { + return left_gini_; + } + virtual const std::unique_ptr& get_right_gini() const { + return right_gini_; + } + + private: + // Tracks how many check_every_samples epochs we've seen go by in weight_sum. + int32 finish_sample_epoch_; + int32 finish_check_every_; + int32 prune_sample_epoch_; + int32 prune_check_every_; + bool finish_early_; + int32 min_split_samples_; + float dominate_fraction_; + float prune_fraction_; + + // When using SPLIT_PRUNE_HOEFFDING, we precompute and store + // 0.5 * ln(1 / (1.0 - dominate_fraction_)). + float half_ln_dominate_frac_; + + std::unique_ptr single_rand_; + std::unique_ptr rng_; + + std::unique_ptr left_gini_; + std::unique_ptr right_gini_; +}; + +// Tracks classification stats by storing class counts densely. +class DenseClassificationGrowStats : public ClassificationStats { + public: + DenseClassificationGrowStats(const TensorForestParams& params, int32 depth) + : ClassificationStats(params, depth) {} + + void Initialize() override { + Clear(); + total_counts_.resize(num_outputs_); + } + + void ExtractFromProto(const FertileSlot& slot) override; + void PackToProto(FertileSlot* slot) const override; + + bool BestSplit(SplitCandidate* best) const override; + + protected: + void ClassificationAddSplitStats() override { + left_counts_.resize(num_outputs_ * num_splits()); + } + void ClassificationRemoveSplitStats(int split_num) override { + left_counts_.erase(left_counts_.begin() + num_outputs_ * split_num, + left_counts_.begin() + num_outputs_ * (split_num + 1)); + } + void ClearInternal() override { + total_counts_.clear(); + left_counts_.clear(); + num_outputs_seen_ = 0; + } + + int num_outputs_seen() const override { + return num_outputs_seen_; + } + + void ClassificationAddLeftExample(int split, int64 int_label, + float weight) override { + mutable_left_count(split, int_label) += weight; + } + void ClassificationAddTotalExample(int64 int_label, float weight) override { + num_outputs_seen_ += total_counts_[int_label] == 0 && weight > 0; + total_counts_[int_label] += weight; + } + + float GiniScore(int split, float* left_sum, float* right_sum) const override; + + float left_count(int split, int class_num) const override { + return left_counts_[split * num_outputs_ + class_num]; + } + float right_count(int split, int class_num) const override { + return total_counts_[class_num] - + left_counts_[split * num_outputs_ + class_num]; + } + + private: + inline float& mutable_left_count(int split, int class_num) { + return left_counts_[split * num_outputs_ + class_num]; + } + // Total class counts seen at this leaf + std::vector total_counts_; + + // Also track the number of classes seen for not splitting pure leaves. + int num_outputs_seen_; + + // Left-branch taken class counts at this leaf for each split. + // This is a flat vector for memory-performance reasons. + // left_counts_[i * num_outputs_ + j] has the j-th class count for split i. + std::vector left_counts_; +}; + +// Tracks classification stats by storing class counts sparsely. +class SparseClassificationGrowStats : public ClassificationStats { + public: + SparseClassificationGrowStats(const TensorForestParams& params, int32 depth) + : ClassificationStats(params, depth) {} + + void Initialize() override { + Clear(); + } + + void ExtractFromProto(const FertileSlot& slot) override; + void PackToProto(FertileSlot* slot) const override; + + bool BestSplit(SplitCandidate* best) const override; + + protected: + void ClassificationAddSplitStats() override { + left_counts_.resize(num_splits()); + } + void ClassificationRemoveSplitStats(int split_num) override { + left_counts_.erase(left_counts_.begin() + split_num, + left_counts_.begin() + (split_num + 1)); + } + void ClearInternal() override { + total_counts_.clear(); + left_counts_.clear(); + } + + int num_outputs_seen() const override { return total_counts_.size(); } + + void ClassificationAddLeftExample(int split, int64 int_label, + float weight) override { + left_counts_[split][int_label] += weight; + } + void ClassificationAddTotalExample(int64 int_label, float weight) override { + total_counts_[int_label] += weight; + } + + float GiniScore(int split, float* left_sum, float* right_sum) const override; + + float left_count(int split, int class_num) const override { + return left_counts_[split].at(class_num); + } + float right_count(int split, int class_num) const override { + return total_counts_.at(class_num) - left_counts_[split].at(class_num); + } + + private: + // Total class counts seen at this leaf + std::unordered_map total_counts_; + + // Left-branch taken class counts at this leaf for each split. + // left_counts_[i][j] has the j-th class count for split i. + std::vector> left_counts_; +}; + +// Tracks regression stats using least-squares minimization. +class LeastSquaresRegressionGrowStats : public GrowStats { + public: + LeastSquaresRegressionGrowStats(const TensorForestParams& params, int32 depth) + : GrowStats(params, depth) {} + + void Initialize() override { + Clear(); + total_sum_.resize(num_outputs_); + total_sum_squares_.resize(num_outputs_); + } + + void ExtractFromProto(const FertileSlot& slot) override; + void PackToProto(FertileSlot* slot) const override; + + void AddExample(const std::unique_ptr& input_data, + const InputTarget* target, int example) override; + bool BestSplit(SplitCandidate* best) const override; + bool IsFinished() const override; + + protected: + // Returns the variance of split. + float SplitVariance(int split) const; + + void AddSplitStats() override { + left_sums_.resize(num_outputs_ * num_splits()); + left_squares_.resize(num_outputs_ * num_splits()); + left_counts_.push_back(0); + } + void RemoveSplitStats(int split_num) override { + left_sums_.erase(left_sums_.begin() + num_outputs_ * split_num, + left_sums_.begin() + num_outputs_ * (split_num + 1)); + left_squares_.erase(left_squares_.begin() + num_outputs_ * split_num, + left_squares_.begin() + num_outputs_ * (split_num + 1)); + left_counts_.erase(left_counts_.begin() + split_num, + left_counts_.begin() + (split_num + 1)); + } + + void ClearInternal() override { + total_sum_.clear(); + total_sum_squares_.clear(); + left_sums_.clear(); + left_squares_.clear(); + } + + private: + // Convenience methods for accessing the flat count vectors. + inline const float& left_sum(int split, int output_num) const { + return left_sums_[split * num_outputs_ + output_num]; + } + inline float& left_sum(int split, int output_num) { + return left_sums_[split * num_outputs_ + output_num]; + } + inline const float& left_square(int split, int output_num) const { + return left_squares_[split * num_outputs_ + output_num]; + } + inline float& left_square(int split, int output_num) { + return left_squares_[split * num_outputs_ + output_num]; + } + + // Total sums and squares seen at this leaf. + // sum[i] is the sum of the i-th output. + std::vector total_sum_; + std::vector total_sum_squares_; + + // Per-split sums and squares, stored flat for performance. + // left_sums_[i * num_outputs_ + j] has the j-th sum for split i. + std::vector left_sums_; + std::vector left_squares_; + + // The number of example seen at each split. + std::vector left_counts_; +}; + + +} // namespace tensorforest +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_GROW_STATS_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc new file mode 100644 index 00000000000..8d51456e303 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats_test.cc @@ -0,0 +1,364 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h" + +#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h" +#include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +using tensorflow::tensorforest::GrowStats; +using tensorflow::tensorforest::TestableInputTarget; +using tensorflow::tensorforest::FertileSlot; +using tensorflow::tensorforest::DenseClassificationGrowStats; +using tensorflow::tensorforest::SparseClassificationGrowStats; +using tensorflow::tensorforest::LeastSquaresRegressionGrowStats; +using tensorflow::tensorforest::TensorForestParams; +using tensorflow::tensorforest::SPLIT_FINISH_BASIC; +using tensorflow::tensorforest::SPLIT_FINISH_DOMINATE_HOEFFDING; +using tensorflow::tensorforest::SPLIT_PRUNE_HOEFFDING; +using tensorflow::decision_trees::BinaryNode; +using tensorflow::decision_trees::InequalityTest; +using tensorflow::decision_trees::FeatureId; + +BinaryNode MakeSplit(const string& feat, float val) { + BinaryNode split; + InequalityTest* test = split.mutable_inequality_left_child_test(); + FeatureId feature_id; + feature_id.mutable_id()->set_value(feat); + *test->mutable_feature_id() = feature_id; + test->mutable_threshold()->set_float_value(val); + test->set_type(InequalityTest::LESS_OR_EQUAL); + + return split; +} + +void RunBatch(GrowStats* stats, + const TestableInputTarget* target) { + stats->AddSplit(MakeSplit("0", 10.0)); + stats->AddSplit(MakeSplit("1", 4.0)); + + std::unique_ptr dataset( + new tensorflow::tensorforest::TestableDataSet( + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, 2)); + + for (int i = 0; i < target->NumItems(); ++i) { + stats->AddExample(dataset, target, i); + } +} + +TEST(GrowStatsDenseClassificationTest, Basic) { + TensorForestParams params; + params.set_num_outputs(2); + params.mutable_split_after_samples()->set_constant_value(2); + params.mutable_num_splits_to_consider()->set_constant_value(2); + std::unique_ptr stat( + new DenseClassificationGrowStats(params, 1)); + stat->Initialize(); + + std::vector labels = {1, 0, 1}; + std::vector weights = {2.3, 20.3, 1.1}; + std::unique_ptr target( + new TestableInputTarget(&labels, &weights, 1)); + + RunBatch(stat.get(), target.get()); + CHECK(stat->IsFinished()); + + FertileSlot slot; + stat->PackToProto(&slot); + + string serialized = slot.DebugString(); + + std::unique_ptr new_stat( + new DenseClassificationGrowStats(params, 1)); + new_stat->ExtractFromProto(slot); + FertileSlot second_one; + new_stat->PackToProto(&second_one); + string serialized_again = second_one.DebugString(); + ASSERT_EQ(serialized_again, serialized); +} + +class TestableRunningStats : public DenseClassificationGrowStats { + public: + TestableRunningStats(const TensorForestParams& params, int32 depth) + : DenseClassificationGrowStats(params, depth) {} + + float test_left_sum(int split) { + return get_left_gini()->sum(split); + } + float test_left_square(int split) { + return get_left_gini()->square(split); + } + float test_right_sum(int split) { + return get_right_gini()->sum(split); + } + float test_right_square(int split) { + return get_right_gini()->square(split); + } +}; + +TEST(GrowStatsDenseClassificationTest, BasicRunningStats) { + TensorForestParams params; + params.set_num_outputs(2); + params.mutable_split_after_samples()->set_constant_value(2); + params.mutable_num_splits_to_consider()->set_constant_value(2); + params.set_use_running_stats_method(true); + std::unique_ptr stat( + new TestableRunningStats(params, 1)); + stat->Initialize(); + + std::vector labels = {1, 0, 1}; + std::vector weights = {2.3, 20.3, 1.1}; + std::unique_ptr target( + new TestableInputTarget(&labels, &weights, 1)); + + RunBatch(stat.get(), target.get()); + CHECK(stat->IsFinished()); + + ASSERT_FLOAT_EQ(stat->test_left_sum(0), 2.3 + 20.3 + 1.1); + ASSERT_FLOAT_EQ(stat->test_left_square(0), 3.4 * 3.4 + 20.3 * 20.3); + ASSERT_FLOAT_EQ(stat->test_right_sum(0), 0.0); + ASSERT_FLOAT_EQ(stat->test_right_square(0), 0.0); + + ASSERT_FLOAT_EQ(stat->test_left_sum(1), 2.3 + 20.3); + ASSERT_FLOAT_EQ(stat->test_left_square(1), 2.3 * 2.3 + 20.3 * 20.3); + ASSERT_FLOAT_EQ(stat->test_right_sum(1), 1.1); + ASSERT_FLOAT_EQ(stat->test_right_square(1), 1.1 * 1.1); + + FertileSlot slot; + stat->PackToProto(&slot); + + string serialized = slot.DebugString(); + + std::unique_ptr new_stat( + new DenseClassificationGrowStats(params, 1)); + new_stat->ExtractFromProto(slot); + FertileSlot second_one; + new_stat->PackToProto(&second_one); + string serialized_again = second_one.DebugString(); + ASSERT_EQ(serialized_again, serialized); +} + +class TestableFinishEarly : public DenseClassificationGrowStats { + public: + TestableFinishEarly(const TensorForestParams& params, int32 depth) + : DenseClassificationGrowStats(params, depth), num_times_called_(0) {} + + int num_times_called_; + + protected: + void CheckFinishEarlyHoeffding() override { + ++num_times_called_; + } +}; + +TEST(GrowStatsDenseClassificationTest, TestFinishEarly) { + TensorForestParams params; + params.set_num_outputs(2); + params.mutable_split_after_samples()->set_constant_value(2); + params.mutable_num_splits_to_consider()->set_constant_value(2); + params.mutable_min_split_samples()->set_constant_value(15); + params.mutable_dominate_fraction()->set_constant_value(0.99); + auto* finish = params.mutable_finish_type(); + finish->set_type(SPLIT_FINISH_DOMINATE_HOEFFDING); + finish->mutable_check_every_steps()->set_constant_value(5); + std::unique_ptr stat(new TestableFinishEarly(params, 1)); + stat->Initialize(); + + std::vector labels = {1, 0, 1}; + std::vector weights = {1, 1, 1}; + std::unique_ptr target( + new TestableInputTarget(&labels, &weights, 1)); + std::unique_ptr dataset( + new tensorflow::tensorforest::TestableDataSet( + {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, 2)); + + // Run through the 3 examples + RunBatch(stat.get(), target.get()); + + ASSERT_EQ(stat->num_times_called_, 0); + + // Go over min_split_samples. + for (int i = 0; i < 13; ++i) { + stat->AddExample(dataset, target.get(), 0); + } + + ASSERT_EQ(stat->num_times_called_, 1); + + // More examples up to 55. + for (int i = 0; i < 39; ++i) { + stat->AddExample(dataset, target.get(), 0); + } + + ASSERT_EQ(stat->num_times_called_, 9); +} + + +TEST(GrowStatsDenseClassificationTest, TestCheckPruneHoeffding) { + TensorForestParams params; + params.set_num_outputs(2); + params.mutable_split_after_samples()->set_constant_value(2000); + params.mutable_num_splits_to_consider()->set_constant_value(2); + params.mutable_min_split_samples()->set_constant_value(15); + params.mutable_dominate_fraction()->set_constant_value(0.99); + auto* finish = params.mutable_finish_type(); + finish->set_type(SPLIT_FINISH_BASIC); + finish->mutable_check_every_steps()->set_constant_value(100); + params.mutable_pruning_type()->set_type(SPLIT_PRUNE_HOEFFDING); + params.mutable_pruning_type()->mutable_prune_every_samples() + ->set_constant_value(1); + + DenseClassificationGrowStats stats(params, 1); + stats.Initialize(); + stats.AddSplit(MakeSplit("0", 0.0)); + stats.AddSplit(MakeSplit("1", 0.0)); + + // On each iteration, we add two examples, one of class 0 and one + // of class 1. Split #0 classifies them perfectly, while split #1 + // sends them both to the left. + std::vector labels = {0, 1}; + std::vector weights = {1, 1}; + TestableInputTarget target(&labels, &weights, 1); + std::unique_ptr dataset( + new tensorflow::tensorforest::TestableDataSet( + {-1.0, -1.0, 1.0, -1.0}, 2)); + + // Math time! + // After 2n samples, + // split 0 has smoothed counts (n+1,1);(1,n+1) and + // split 1 has smoothed counts (n+1,n+1);(1,1) + // split 0 smoothed ginis are both 1 - (n+1)^2/(n+2)^2 - 1/(n+2)^2 and + // split 1 smoothed ginis are 1 - 2 (n+1)^2 / (2n+2)^2 and 1 - 2 (1/4) = 1/2 + // split 0 weighted smoothed ginis are both n (1 - (n^2 + 2n + 2) / (n+2)^2) + // split 1 weighted smoothed ginis are 0 and 2n (1 - 2(n+1)^2 / (2n+2)^2) + // split 0 split score = 2n (1 - (n^2 + 2n + 2) / (n+2)^2) + // split 1 spilt score = 2n (1 - 2(n+1)^2 / (2n+2)^2) + // split 1 score - split 0 score = + // 2n ( (n^2 + 2n + 2) / (n+2)^2 - 2(n+1)^2 / (2n+2)^2 ) + // = 2n ( (n^2 + 2n + 2) (2n+2)^2 - 2(n+1)^2 (n+2)^2 ) / ((n+2)^2 (2n+2)^2 ) + // = 2n ((n^2+2n+2)(4n^2+8n+4) - 2(n^2+2n+1)(n^2+4n+4)) / ((n+2)^2 (2n+2)^2) + // = 2n (4n^4+8n^3+4n^2+8n^3+16n^2+8n+8n^2+16n+8 + // - (2n^4+8n^3+8n^2+4n^3+16n^2+16n+2n^2+8n+8)) / ((n+2)^2 (2n+2)^2) + // = 2n (2n^4 + 4n^3 + 2n^2) / ((n+2)^2 (2n+2)^2) + // = 4n^3 (n^2 + 2n + 1) / ((n+2)^2 (2n+2)^2) + // = n^3 / (n+2)^2 + // Meanwhile, after 2n samples, + // epsilon = 2n (1 - 1/2) sqrt(0.5 ln(1/0.01) / 2n) + // = n sqrt( ln(10) / 2n) + // Graphical comparison says that epsilon is greater between 0 and 4.5, + // and then the split score difference is greater for n >= 5. + // n = 1 + stats.AddExample(dataset, &target, 0); + stats.AddExample(dataset, &target, 1); + ASSERT_EQ(stats.num_splits(), 2); + + // n = 2 + stats.AddExample(dataset, &target, 0); + stats.AddExample(dataset, &target, 1); + ASSERT_EQ(stats.num_splits(), 2); + + // n = 3 + stats.AddExample(dataset, &target, 0); + stats.AddExample(dataset, &target, 1); + ASSERT_EQ(stats.num_splits(), 2); + + // n = 4 + stats.AddExample(dataset, &target, 0); + stats.AddExample(dataset, &target, 1); + ASSERT_EQ(stats.num_splits(), 2); + + // n = 5 + stats.AddExample(dataset, &target, 0); + stats.AddExample(dataset, &target, 1); + ASSERT_EQ(stats.num_splits(), 1); + + // n = 6 + stats.AddExample(dataset, &target, 0); + stats.AddExample(dataset, &target, 1); + ASSERT_EQ(stats.num_splits(), 1); +} + +TEST(GrowStatsLeastSquaresRegressionTest, Basic) { + TensorForestParams params; + params.set_num_outputs(1); + params.mutable_split_after_samples()->set_constant_value(2); + params.mutable_num_splits_to_consider()->set_constant_value(2); + std::unique_ptr stat( + new LeastSquaresRegressionGrowStats(params, 1)); + stat->Initialize(); + + std::vector labels = {2.3, 5.6, 1.1}; + std::unique_ptr target( + new TestableInputTarget(&labels, {}, 1)); + std::vector branches = {1, 0, 1, 1, 0, 0}; + + RunBatch(stat.get(), target.get()); + CHECK(stat->IsFinished()); + + FertileSlot slot; + stat->PackToProto(&slot); + + string serialized = slot.DebugString(); + + std::unique_ptr new_stat( + new LeastSquaresRegressionGrowStats(params, 1)); + new_stat->ExtractFromProto(slot); + FertileSlot second_one; + new_stat->PackToProto(&second_one); + string serialized_again = second_one.DebugString(); + + ASSERT_EQ(serialized_again, serialized); +} + + +TEST(GrowStatsSparseClassificationTest, Basic) { + TensorForestParams params; + params.set_num_outputs(2); + params.mutable_split_after_samples()->set_constant_value(2); + params.mutable_num_splits_to_consider()->set_constant_value(2); + std::unique_ptr stat( + new SparseClassificationGrowStats(params, 1)); + stat->Initialize(); + + std::vector labels = {100, 1000, 1}; + std::vector weights = {2.3, 20.3, 1.1}; + std::unique_ptr target( + new TestableInputTarget(&labels, &weights, 1)); + std::vector branches = {1, 0, 1, 1, 0, 0}; + + RunBatch(stat.get(), target.get()); + CHECK(stat->IsFinished()); + + FertileSlot slot; + stat->PackToProto(&slot); + + string serialized = slot.DebugString(); + + std::unique_ptr new_stat( + new SparseClassificationGrowStats(params, 1)); + new_stat->ExtractFromProto(slot); + FertileSlot second_one; + new_stat->PackToProto(&second_one); + string serialized_again = second_one.DebugString(); + ASSERT_EQ(serialized_again, serialized); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc new file mode 100644 index 00000000000..f5f07bea5c7 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.cc @@ -0,0 +1,154 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" +#include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h" +#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h" +#include "tensorflow/core/lib/strings/numbers.h" + +namespace tensorflow { +namespace tensorforest { +namespace { + +const int32 SPARSE_DEFAULT = 0; + +bool DecideInequalityTest(const decision_trees::InequalityTest& test, + float value) { + float bias = test.threshold().float_value(); + switch (test.type()) { + case decision_trees::InequalityTest::LESS_OR_EQUAL: + return value <= bias; + + case decision_trees::InequalityTest::LESS_THAN: + return value < bias; + + case decision_trees::InequalityTest::GREATER_OR_EQUAL: + return value >= bias; + + case decision_trees::InequalityTest::GREATER_THAN: + return value > bias; + + default: + return false; + } +} + +bool DecideMatchingValuesTest(const decision_trees::MatchingValuesTest& test, + float value) { + for (const decision_trees::Value& test_value : test.value()) { + if (test_value.float_value() == value) { + return true; + } + } + return false; +} + +} // namespace + +bool TensorDataSet::Decide(const decision_trees::BinaryNode& node, + int example) const { + // TODO(gilberth): Support missing values. + float val = 0; + const auto& test = node.inequality_left_child_test(); + + if (test.has_oblique()) { + for (int i = 0; i < test.oblique().features_size(); ++i) { + val += test.oblique().weights(i) * + GetExampleValue(example, test.oblique().features(i)); + } + } else { + val = GetExampleValue(example, test.feature_id()); + } + + if (node.has_inequality_left_child_test()) { + return DecideInequalityTest(node.inequality_left_child_test(), val); + } else { + decision_trees::MatchingValuesTest test; + if (node.custom_left_child_test().UnpackTo(&test)) { + return DecideMatchingValuesTest(test, val); + } else { + return false; + } + } +} + +float TensorDataSet::GetExampleValue( + int example, const decision_trees::FeatureId& feature_id) const { + int32 feature; + safe_strto32(feature_id.id().value(), &feature); + if (feature >= input_spec_.dense_features_size()) { + return FindSparseValue(*sparse_indices_, *sparse_values_, example, feature); + } else { + return (*dense_data_)(example, feature); + } +} + +float TensorDataSet::GetExampleValue(int example, int32 feature_id) const { + if (feature_id >= input_spec_.dense_features_size()) { + return FindSparseValue(*sparse_indices_, *sparse_values_, example, + feature_id); + } else { + return (*dense_data_)(example, feature_id); + } +} + +void TensorDataSet::set_input_tensors(const Tensor& dense, + const Tensor& sparse_indices, + const Tensor& sparse_values) { + if (dense.shape().dims() == 2) { + dense_data_.reset(new DenseStorageType(dense.tensor())); + } + if (sparse_indices.shape().dims() == 2) { + sparse_indices_.reset(new SparseIndicesStorageType( + sparse_indices.tensor())); + sparse_values_.reset(new SparseValuesStorageType( + sparse_values.tensor())); + } + original_dense_tensor_ = dense; +} + +void TensorDataSet::RandomSample(int example, + decision_trees::FeatureId* feature_id, + float* bias, int* type) const { + int32 num_total_features = input_spec_.dense_features_size(); + int64 sparse_input_start; + if (sparse_indices_ != nullptr) { + const int32 num_sparse = tensorforest::GetNumSparseFeatures( + *sparse_indices_, example, &sparse_input_start); + if (sparse_input_start >= 0) { + num_total_features += num_sparse; + } + } + int rand_feature = rng_->Uniform(num_total_features); + if (rand_feature < available_features_.size()) { // it's dense. + *feature_id = available_features_[rand_feature]; + *type = input_spec_.GetDenseFeatureType(rand_feature); + } else { + const int32 sparse_index = + sparse_input_start + rand_feature - input_spec_.dense_features_size(); + const int32 saved_index = + (*sparse_indices_)(sparse_index, 1) + input_spec_.dense_features_size(); + *feature_id = decision_trees::FeatureId(); + feature_id->mutable_id()->set_value(strings::StrCat(saved_index)); + + // TODO(gilberth): Remove this shortcut when different sparse types are + // allowed. + *type = input_spec_.sparse(0).original_type(); + } + + *bias = GetExampleValue(example, *feature_id); +} + +} // namespace tensorforest +} // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h new file mode 100644 index 00000000000..261a1f2d5e4 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_data.h @@ -0,0 +1,124 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_ +#include +#include +#include "google/protobuf/any.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" +#include "tensorflow/contrib/tensor_forest/kernels/data_spec.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/simple_philox.h" + +namespace tensorflow { +namespace tensorforest { + +typedef TTypes::ConstTensor DenseStorageType; +typedef TTypes::ConstTensor SparseIndicesStorageType; +typedef TTypes::ConstTensor SparseValuesStorageType; + +class TensorDataSet { + public: + TensorDataSet(const tensorforest::TensorForestDataSpec& input_spec, + int32 seed) + : dense_data_(nullptr), + sparse_indices_(nullptr), + sparse_values_(nullptr), + input_spec_(input_spec), + split_sampling_random_seed_(seed) { + int column_count = 0; + for (int i = 0; i < input_spec_.dense_size(); ++i) { + for (int j = 0; j < input_spec_.dense(i).size(); ++j) { + decision_trees::FeatureId id; + id.mutable_id()->set_value(strings::StrCat(column_count)); + available_features_.push_back(id); + ++column_count; + } + } + + // Set up the random number generator. + if (split_sampling_random_seed_ == 0) { + uint64 time_seed = static_cast(std::clock()); + single_rand_ = std::unique_ptr( + new random::PhiloxRandom(time_seed)); + } else { + single_rand_ = std::unique_ptr( + new random::PhiloxRandom(split_sampling_random_seed_)); + } + + rng_ = std::unique_ptr( + new random::SimplePhilox(single_rand_.get())); + } + virtual ~TensorDataSet() {} + + void set_input_tensors(const Tensor& dense, const Tensor& sparse_indices, + const Tensor& sparse_values); + + float get_input_value(int offset, int col) { + return (*dense_data_)(offset, col); + } + + int NumItems() const { + if (dense_data_ != nullptr) { + return dense_data_->dimensions()[0]; + } else if (sparse_indices_ != nullptr) { + return sparse_indices_->dimensions()[0]; + } else { + return 0; + } + } + + // This looks up a value by example and int32_id, which is much faster than + // GetFeature. + float GetExampleValue(int example, + const decision_trees::FeatureId& feature_id) const; + + // Same as overload with FeatureId, but if you already have the feature as + // an int32 you can avoid the atoi32. + virtual float GetExampleValue(int example, int32 feature_id) const; + + int num_features() { + return available_features_.size(); + } + + const Tensor& original_tensor() const { return original_dense_tensor_; } + + bool Decide(const decision_trees::BinaryNode& node, int example) const; + + // Randomly samples a feature from example, returns its id in feature_name, + // the value in bias, and it's type from input_spec in type. + void RandomSample(int example, decision_trees::FeatureId* feature_name, + float* bias, int* type) const; + + private: + std::unique_ptr dense_data_; + std::unique_ptr sparse_indices_; + std::unique_ptr sparse_values_; + + Tensor original_dense_tensor_; + const tensorforest::TensorForestDataSpec input_spec_; + std::vector available_features_; + + int32 split_sampling_random_seed_; + std::unique_ptr single_rand_; + std::unique_ptr rng_; +}; +} // namespace tensorforest +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h b/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h new file mode 100644 index 00000000000..97b2314f0fb --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/input_target.h @@ -0,0 +1,91 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_ +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" + +namespace tensorflow { +namespace tensorforest { + +typedef Eigen::TensorMap< + Eigen::Tensor, 0> // NOLINT(runtime/int) + SingleDimStorageType; + +// Base class for classes that hold labels and weights. Mostly for testing +// purposes, because it's inconvenient to construct nasty Eigen::things. +class InputTarget { + public: + virtual ~InputTarget() {} + virtual int32 GetTargetAsClassIndex(int example_index, + int target_index) const = 0; + + virtual float GetTargetWeight(int example_index) const = 0; + + virtual float GetTargetAsContinuous(int example_index, + int target_index) const = 0; +}; + +template +class StoredInputTarget : public InputTarget { + protected: + StoredInputTarget(const T* t, const T* w, int num_targets) + : target_(t), weight_(w), num_targets_(num_targets) {} + + const T* target_; + const T* weight_; + int num_targets_; +}; + +// Holds labels/targets and weights. Assumes that tensors are passed as +// t.unaligned_flat(). For multi-output, specifying the number of +// outputs will correctly index the flattened data. +class TensorInputTarget : public StoredInputTarget { + public: + TensorInputTarget(const SingleDimStorageType* t, + const SingleDimStorageType* w, const Tensor& tensor, + int num_targets) + : StoredInputTarget(t, w, num_targets), original_tensor_(tensor) {} + + int32 GetTargetAsClassIndex(int example_index, + int target_index) const override { + return static_cast( + GetTargetAsContinuous(example_index, target_index)); + } + + float GetTargetWeight(int example_index) const override { + const size_t num_weights = weight_->size(); + return num_weights > 0 && example_index < num_weights + ? (*weight_)(example_index) + : 1.0; + } + + float GetTargetAsContinuous(int example_index, + int target_index) const override { + QCHECK_LT(target_index, num_targets_); + return (*target_)(example_index * num_targets_ + target_index); + } + + const Tensor& original_tensor() const { + return original_tensor_; + } + + protected: + Tensor original_tensor_; +}; +} // namespace tensorforest +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc new file mode 100644 index 00000000000..4b9bb0f9c97 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.cc @@ -0,0 +1,160 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h" + +namespace tensorflow { +namespace tensorforest { + +std::unique_ptr +LeafModelOperatorFactory::CreateLeafModelOperator( + const TensorForestParams& params) { + switch (params.leaf_type()) { + case MODEL_DENSE_CLASSIFICATION: + return std::unique_ptr( + new DenseClassificationLeafModelOperator(params)); + + case MODEL_SPARSE_CLASSIFICATION: + return std::unique_ptr( + new SparseClassificationLeafModelOperator(params)); + + case MODEL_SPARSE_OR_DENSE_CLASSIFICATION: + return std::unique_ptr( + new SparseOrDenseClassificationLeafModelOperator(params)); + + case MODEL_REGRESSION: + return std::unique_ptr( + new RegressionLeafModelOperator(params)); + + default: + LOG(ERROR) << "Unknown model operator: " << params.leaf_type(); + return nullptr; + } +} + +// ------------------------ Dense ----------------------------- // +float DenseClassificationLeafModelOperator::GetOutputValue( + const decision_trees::Leaf& leaf, int32 o) const { + return leaf.vector().value(o).float_value(); +} + +void DenseClassificationLeafModelOperator::UpdateModel( + LeafStat* leaf, const InputTarget* target, + int example) const { + const int32 int_label = target->GetTargetAsClassIndex(example, 0); + auto* val = leaf->mutable_classification()->mutable_dense_counts() + ->mutable_value(int_label); + float weight = target->GetTargetWeight(example); + val->set_float_value(val->float_value() + weight); + leaf->set_weight_sum(leaf->weight_sum() + weight); +} + +void DenseClassificationLeafModelOperator::InitModel( + LeafStat* leaf) const { + for (int i = 0; i < params_.num_outputs(); ++i) { + leaf->mutable_classification()->mutable_dense_counts()->add_value(); + } +} + +void DenseClassificationLeafModelOperator::ExportModel( + const LeafStat& stat, decision_trees::Leaf* leaf) const { + *leaf->mutable_vector() = stat.classification().dense_counts(); +} + +// ------------------------- Sparse -------------------------- // +float SparseClassificationLeafModelOperator::GetOutputValue( + const decision_trees::Leaf& leaf, int32 o) const { + const auto it = leaf.sparse_vector().sparse_value().find(o); + if (it == leaf.sparse_vector().sparse_value().end()) { + return 0; // default value + } else { + return it->second.float_value(); + } +} + +void SparseClassificationLeafModelOperator::UpdateModel( + LeafStat* leaf, const InputTarget* target, + int example) const { + const int32 int_label = target->GetTargetAsClassIndex(example, 0); + const float weight = target->GetTargetWeight(example); + leaf->set_weight_sum(leaf->weight_sum() + weight); + auto value_map = leaf->mutable_classification()->mutable_sparse_counts() + ->mutable_sparse_value(); + auto it = value_map->find(int_label); + if (it == value_map->end()) { + (*value_map)[int_label].set_float_value(weight); + } else { + it->second.set_float_value(it->second.float_value() + weight); + } +} + +void SparseClassificationLeafModelOperator::ExportModel( + const LeafStat& stat, decision_trees::Leaf* leaf) const { + *leaf->mutable_sparse_vector() = stat.classification().sparse_counts(); +} + +// ------------------------- SparseOrDense -------------------------- // +float SparseOrDenseClassificationLeafModelOperator::GetOutputValue( + const decision_trees::Leaf& leaf, int32 o) const { + if (leaf.has_vector()) { + return dense_->GetOutputValue(leaf, o); + } else { + return sparse_->GetOutputValue(leaf, o); + } +} + +void SparseOrDenseClassificationLeafModelOperator::UpdateModel( + LeafStat* leaf, const InputTarget* target, int example) const { + if (leaf->classification().has_dense_counts()) { + return dense_->UpdateModel(leaf, target, example); + } else { + return sparse_->UpdateModel(leaf, target, example); + } +} + +void SparseOrDenseClassificationLeafModelOperator::ExportModel( + const LeafStat& stat, decision_trees::Leaf* leaf) const { + if (stat.classification().has_dense_counts()) { + return dense_->ExportModel(stat, leaf); + } else { + return sparse_->ExportModel(stat, leaf); + } +} + +// ------------------------ Regression ----------------------------- // +float RegressionLeafModelOperator::GetOutputValue( + const decision_trees::Leaf& leaf, int32 o) const { + return leaf.vector().value(o).float_value(); +} + +void RegressionLeafModelOperator::InitModel( + LeafStat* leaf) const { + for (int i = 0; i < params_.num_outputs(); ++i) { + leaf->mutable_regression()->mutable_mean_output()->add_value(); + } +} + +void RegressionLeafModelOperator::ExportModel( + const LeafStat& stat, decision_trees::Leaf* leaf) const { + for (int i = 0; i < params_.num_outputs(); ++i) { + const float new_val = + stat.regression().mean_output().value(i).float_value() / + stat.weight_sum(); + leaf->mutable_vector()->add_value()->set_float_value(new_val); + } +} + + +} // namespace tensorforest +} // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h new file mode 100644 index 00000000000..8aadefc4033 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h @@ -0,0 +1,150 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_ + +#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/params.h" +#include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h" +#include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h" + +namespace tensorflow { +namespace tensorforest { + +// Abstract base class for classes that can initialize, get, and update leaf +// models. +class LeafModelOperator { + public: + // Number of outputs is interpreted differently for classification and + // regression. For classification, it's the number of possible classes. + // For regression, it's the target dimensions. + explicit LeafModelOperator(const TensorForestParams& params) + : params_(params) {} + virtual ~LeafModelOperator() {} + + // Returns the value of the requested output, which should be + // in [0, num_outputs_). For classification, it's the class count (weighted + // number of instances seen). For regression, it's e.g. the average value. + virtual float GetOutputValue(const decision_trees::Leaf& leaf, + int32 o) const = 0; + + // Update the given Leaf's model with the given example. + virtual void UpdateModel(LeafStat* leaf, + const InputTarget* target, + int example) const = 0; + + // Initialize an empty Leaf model. + virtual void InitModel(LeafStat* leaf) const = 0; + + virtual void ExportModel(const LeafStat& stat, + decision_trees::Leaf* leaf) const = 0; + + protected: + const TensorForestParams& params_; +}; + +// LeafModelOperator that stores class counts in a dense vector. +class DenseClassificationLeafModelOperator : public LeafModelOperator { + public: + explicit DenseClassificationLeafModelOperator( + const TensorForestParams& params) + : LeafModelOperator(params) {} + float GetOutputValue(const decision_trees::Leaf& leaf, + int32 o) const override; + + void UpdateModel(LeafStat* leaf, const InputTarget* target, + int example) const override; + + void InitModel(LeafStat* leaf) const override; + + void ExportModel(const LeafStat& stat, + decision_trees::Leaf* leaf) const override; +}; + +// LeafModelOperator that stores class counts sparsely in a map. Assumes default +// value for yet-unseen classes is 0. +class SparseClassificationLeafModelOperator : public LeafModelOperator { + public: + explicit SparseClassificationLeafModelOperator( + const TensorForestParams& params) + : LeafModelOperator(params) {} + float GetOutputValue(const decision_trees::Leaf& leaf, + int32 o) const override; + + void UpdateModel(LeafStat* leaf, const InputTarget* target, + int example) const override; + + void InitModel(LeafStat* leaf) const override {} + + void ExportModel(const LeafStat& stat, + decision_trees::Leaf* leaf) const override; +}; + +class SparseOrDenseClassificationLeafModelOperator : public LeafModelOperator { + public: + explicit SparseOrDenseClassificationLeafModelOperator( + const TensorForestParams& params) + : LeafModelOperator(params), + dense_(new DenseClassificationLeafModelOperator(params)), + sparse_(new SparseClassificationLeafModelOperator(params)) {} + float GetOutputValue(const decision_trees::Leaf& leaf, + int32 o) const override; + + void UpdateModel(LeafStat* leaf, const InputTarget* target, + int example) const override; + + void InitModel(LeafStat* leaf) const override {} + + void ExportModel(const LeafStat& stat, + decision_trees::Leaf* leaf) const override; + + protected: + std::unique_ptr dense_; + std::unique_ptr sparse_; +}; + +// LeafModelOperator that stores regression leaf models with constant-value +// prediction. +class RegressionLeafModelOperator : public LeafModelOperator { + public: + explicit RegressionLeafModelOperator(const TensorForestParams& params) + : LeafModelOperator(params) {} + float GetOutputValue(const decision_trees::Leaf& leaf, + int32 o) const override; + + // TODO(gilberth): Quick experimentation suggests it's not even worth + // updating model and just using the seeded values. Can add this in + // with additional_data, though protobuf::Any is slow. Maybe make it + // optional. Maybe make any update optional. + void UpdateModel(LeafStat* leaf, const InputTarget* target, + int example) const override {} + + void InitModel(LeafStat* leaf) const override; + + void ExportModel(const LeafStat& stat, + decision_trees::Leaf* leaf) const override; +}; + +class LeafModelOperatorFactory { + public: + static std::unique_ptr CreateLeafModelOperator( + const TensorForestParams& params); +}; + +} // namespace tensorforest +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_LEAF_MODEL_OPERATORS_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc new file mode 100644 index 00000000000..037e4e244c4 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators_test.cc @@ -0,0 +1,223 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h" +#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h" +#include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +using tensorflow::decision_trees::Leaf; +using tensorflow::tensorforest::DenseClassificationLeafModelOperator; +using tensorflow::tensorforest::LeafModelOperator; +using tensorflow::tensorforest::SparseClassificationLeafModelOperator; +using tensorflow::tensorforest::SparseOrDenseClassificationLeafModelOperator; +using tensorflow::tensorforest::LeafStat; +using tensorflow::tensorforest::RegressionLeafModelOperator; +using tensorflow::tensorforest::TestableInputTarget; +using tensorflow::tensorforest::TensorForestParams; + +const int32 kNumClasses = 3; + +constexpr char kRegressionStatProto[] = + "weight_sum: 3 " + "regression { " + "mean_output { " + "value { " + " float_value: 27 " + "} " + "value { " + " float_value: 282 " + "} " + "value { " + " float_value: 10 " + "} " + "} " + "mean_output_squares { " + "value {" + " float_value: 245" + "}" + "value {" + " float_value: 26564" + "}" + "value {" + " float_value: 46" + "}" + "}" +"}"; + +void TestClassificationNormalUse(const std::unique_ptr& op) { + std::unique_ptr leaf(new LeafStat); + op->InitModel(leaf.get()); + + Leaf l; + op->ExportModel(*leaf, &l); + + // Make sure it was initialized correctly. + for (int i = 0; i < kNumClasses; ++i) { + EXPECT_EQ(op->GetOutputValue(l, i), 0); + } + + std::vector labels = {1, 0, 1}; + std::vector weights = {2.3, 20.3, 1.1}; + std::unique_ptr target( + new TestableInputTarget(&labels, &weights, 1)); + + // Update and check value. + op->UpdateModel(leaf.get(), target.get(), 0); + op->UpdateModel(leaf.get(), target.get(), 1); + op->UpdateModel(leaf.get(), target.get(), 2); + + op->ExportModel(*leaf, &l); + EXPECT_FLOAT_EQ(op->GetOutputValue(l, 1), 3.4); +} + + +TEST(DenseLeafModelOperatorsTest, NormalUse) { + TensorForestParams params; + params.set_num_outputs(kNumClasses); + std::unique_ptr op( + new DenseClassificationLeafModelOperator(params)); + TestClassificationNormalUse(op); +} + +TEST(SparseLeafModelOperatorsTest, NormalUse) { + TensorForestParams params; + params.set_num_outputs(kNumClasses); + std::unique_ptr op( + new SparseClassificationLeafModelOperator(params)); + TestClassificationNormalUse(op); +} + +TEST(DenseLeafModelOperatorsTest, InitWithExisting) { + TensorForestParams params; + params.set_num_outputs(kNumClasses); + std::unique_ptr op( + new DenseClassificationLeafModelOperator(params)); + + std::unique_ptr stat(new LeafStat); + stat->mutable_classification() + ->mutable_dense_counts() + ->add_value() + ->set_float_value(1.1); + stat->mutable_classification() + ->mutable_dense_counts() + ->add_value() + ->set_float_value(2.2); + stat->mutable_classification() + ->mutable_dense_counts() + ->add_value() + ->set_float_value(3.3); + + std::unique_ptr leaf(new Leaf); + + op->ExportModel(*stat, leaf.get()); + + // Make sure it was initialized correctly. + EXPECT_EQ(leaf->vector().value_size(), kNumClasses); + EXPECT_FLOAT_EQ(op->GetOutputValue(*leaf, 0), 1.1); + EXPECT_FLOAT_EQ(op->GetOutputValue(*leaf, 1), 2.2); + EXPECT_FLOAT_EQ(op->GetOutputValue(*leaf, 2), 3.3); +} + +TEST(SparseOrDenseClassificationLeafModelOperator, InitWithExisting) { + TensorForestParams params; + params.set_num_outputs(kNumClasses); + std::unique_ptr op( + new SparseOrDenseClassificationLeafModelOperator(params)); + + std::unique_ptr stat(new LeafStat); + (*stat->mutable_classification() + ->mutable_sparse_counts() + ->mutable_sparse_value())[0] + .set_float_value(1.1); + (*stat->mutable_classification() + ->mutable_sparse_counts() + ->mutable_sparse_value())[1] + .set_float_value(2.2); + (*stat->mutable_classification() + ->mutable_sparse_counts() + ->mutable_sparse_value())[2] + .set_float_value(3.3); + + std::unique_ptr leaf(new Leaf); + + op->ExportModel(*stat, leaf.get()); + + // Make sure it was initialized correctly. + EXPECT_FLOAT_EQ(op->GetOutputValue(*leaf, 0), 1.1); + EXPECT_FLOAT_EQ(op->GetOutputValue(*leaf, 1), 2.2); + EXPECT_FLOAT_EQ(op->GetOutputValue(*leaf, 2), 3.3); +} + +TEST(SparseLeafModelOperatorsTest, InitWithExisting) { + TensorForestParams params; + params.set_num_outputs(kNumClasses); + std::unique_ptr op( + new SparseClassificationLeafModelOperator(params)); + std::unique_ptr stat(new LeafStat); + (*stat->mutable_classification() + ->mutable_sparse_counts() + ->mutable_sparse_value())[0] + .set_float_value(1.1); + (*stat->mutable_classification() + ->mutable_sparse_counts() + ->mutable_sparse_value())[1] + .set_float_value(2.2); + (*stat->mutable_classification() + ->mutable_sparse_counts() + ->mutable_sparse_value())[2] + .set_float_value(3.3); + + std::unique_ptr leaf(new Leaf); + + op->ExportModel( *stat, leaf.get()); + + // Make sure it was initialized correctly. + EXPECT_FLOAT_EQ(op->GetOutputValue(*leaf, 0), 1.1); + EXPECT_FLOAT_EQ(op->GetOutputValue(*leaf, 1), 2.2); + EXPECT_FLOAT_EQ(op->GetOutputValue(*leaf, 2), 3.3); + + // check default value. + EXPECT_FLOAT_EQ(op->GetOutputValue(*leaf, 100), 0); + EXPECT_EQ(leaf->sparse_vector().sparse_value().size(), kNumClasses); +} + + +TEST(RegressionLeafModelOperatorsTest, NormalUse) { + TensorForestParams params; + params.set_num_outputs(kNumClasses); + std::unique_ptr op( + new RegressionLeafModelOperator(params)); + + std::unique_ptr stat(new LeafStat()); + const string contents(kRegressionStatProto); + ::tensorflow::protobuf::TextFormat::ParseFromString(contents, stat.get()); + + std::unique_ptr leaf(new Leaf); + op->ExportModel(*stat, leaf.get()); + + // Make sure it was initialized correctly. + EXPECT_FLOAT_EQ(op->GetOutputValue(*leaf, 0), 9); + EXPECT_FLOAT_EQ(op->GetOutputValue(*leaf, 1), 94); + EXPECT_FLOAT_EQ(op->GetOutputValue(*leaf, 2), 3.3333333); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/params.cc b/tensorflow/contrib/tensor_forest/kernels/v4/params.cc new file mode 100644 index 00000000000..a3b09c17d51 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/params.cc @@ -0,0 +1,54 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/tensor_forest/kernels/v4/params.h" +#include +#include +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace tensorforest { + +float ResolveParam(const DepthDependentParam& param, int32 depth) { + float val; + switch (param.ParamType_case()) { + case DepthDependentParam::kConstantValue: + return param.constant_value(); + + case DepthDependentParam::kLinear: + val = depth * param.linear().slope() + param.linear().y_intercept(); + return std::min(std::max(val, param.linear().min_val()), + param.linear().max_val()); + + case DepthDependentParam::kExponential: + return param.exponential().bias() + + param.exponential().multiplier() * + static_cast( + pow(param.exponential().base(), + param.exponential().depth_multiplier() * depth)); + + case DepthDependentParam::kThreshold: + if (depth >= param.threshold().threshold()) { + return param.threshold().on_value(); + } else { + return param.threshold().off_value(); + } + + default: + LOG(FATAL) << "unknown parameter type"; + } +} + +} // namespace tensorforest +} // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/params.h b/tensorflow/contrib/tensor_forest/kernels/v4/params.h new file mode 100644 index 00000000000..97a9d8d0963 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/params.h @@ -0,0 +1,32 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_PARAMS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_PARAMS_H_ + +#include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace tensorforest { + +// Return the value of the given depth-dependent parameter given a leaf's depth. +float ResolveParam(const DepthDependentParam& param, int32 depth); + + +} // namespace tensorforest +} // namespace tensorflow + + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_PARAMS_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/params_test.cc b/tensorflow/contrib/tensor_forest/kernels/v4/params_test.cc new file mode 100644 index 00000000000..801881af136 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/params_test.cc @@ -0,0 +1,75 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/tensor_forest/kernels/v4/params.h" +#include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace { + +using tensorflow::tensorforest::DepthDependentParam; +using tensorflow::tensorforest::ResolveParam; + +TEST(ParamsTest, TestConstant) { + DepthDependentParam param; + param.set_constant_value(10.0); + + ASSERT_EQ(ResolveParam(param, 0), 10.0); + ASSERT_EQ(ResolveParam(param, 100), 10.0); +} + +TEST(ParamsTest, TestLinear) { + DepthDependentParam param; + auto* linear = param.mutable_linear(); + linear->set_y_intercept(100.0); + linear->set_slope(-10.0); + linear->set_min_val(23.0); + linear->set_max_val(90.0); + + ASSERT_EQ(ResolveParam(param, 0), 90); + ASSERT_EQ(ResolveParam(param, 1), 90); + ASSERT_EQ(ResolveParam(param, 2), 80); + + ASSERT_EQ(ResolveParam(param, 30), 23); +} + +TEST(ParamsTest, TestExponential) { + DepthDependentParam param; + auto* expo = param.mutable_exponential(); + expo->set_bias(100.0); + expo->set_base(10.0); + expo->set_multiplier(-1.0); + expo->set_depth_multiplier(1.0); + + ASSERT_EQ(ResolveParam(param, 0), 99); + ASSERT_EQ(ResolveParam(param, 1), 90); + ASSERT_EQ(ResolveParam(param, 2), 0); +} + +TEST(ParamsTest, TestThreshold) { + DepthDependentParam param; + auto* threshold = param.mutable_threshold(); + threshold->set_on_value(100.0); + threshold->set_off_value(10.0); + threshold->set_threshold(5.0); + + ASSERT_EQ(ResolveParam(param, 0), 10); + ASSERT_EQ(ResolveParam(param, 4), 10); + ASSERT_EQ(ResolveParam(param, 5), 100); + ASSERT_EQ(ResolveParam(param, 6), 100); +} + +} // namespace + + diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc new file mode 100644 index 00000000000..ddf4be87996 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.cc @@ -0,0 +1,257 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h" + +#include + +#include "tensorflow/contrib/decision_trees/proto/generic_tree_model_extensions.pb.h" +#include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h" + +namespace tensorflow { +namespace tensorforest { + +std::unique_ptr +SplitCollectionOperatorFactory::CreateSplitCollectionOperator( + const TensorForestParams& params) { + switch (params.collection_type()) { + case COLLECTION_BASIC: + return std::unique_ptr( + new SplitCollectionOperator(params)); + + case GRAPH_RUNNER_COLLECTION: + return std::unique_ptr( + new GraphRunnerSplitCollectionOperator(params)); + + default: + LOG(ERROR) << "Unknown split collection operator: " + << params.collection_type(); + return nullptr; + } +} + +std::unique_ptr SplitCollectionOperator::CreateGrowStats( + int32 node_id, int32 depth) const { + switch (params_.stats_type()) { + case STATS_DENSE_GINI: + return std::unique_ptr( + new DenseClassificationGrowStats(params_, depth)); + + case STATS_SPARSE_GINI: + return std::unique_ptr( + new SparseClassificationGrowStats(params_, depth)); + + case STATS_LEAST_SQUARES_REGRESSION: + return std::unique_ptr(new LeastSquaresRegressionGrowStats( + params_, depth)); + + default: + LOG(ERROR) << "Unknown grow stats type: " << params_.stats_type(); + return nullptr; + } +} + +void SplitCollectionOperator::ExtractFromProto( + const FertileStats& stats_proto) { + for (int i = 0; i < stats_proto.node_to_slot_size(); ++i) { + const auto& slot = stats_proto.node_to_slot(i); + stats_[slot.node_id()] = CreateGrowStats(slot.node_id(), slot.depth()); + stats_[slot.node_id()]->ExtractFromProto(slot); + } +} + +void SplitCollectionOperator::PackToProto(FertileStats* stats_proto) const { + for (int i = 0; i < stats_proto->node_to_slot_size(); ++i) { + auto* new_slot = stats_proto->mutable_node_to_slot(i); + const auto& stats = stats_.at(new_slot->node_id()); + if (params_.checkpoint_stats()) { + stats->PackToProto(new_slot); + } + new_slot->set_depth(stats->depth()); + } +} + +void SplitCollectionOperator::InitializeSlot(int32 node_id, int32 depth) { + stats_[node_id] = std::unique_ptr(CreateGrowStats(node_id, depth)); + stats_[node_id]->Initialize(); +} + +void SplitCollectionOperator::AddExample( + const std::unique_ptr& input_data, const InputTarget* target, + const std::vector& examples, int32 node_id) const { + auto* slot = stats_.at(node_id).get(); + for (int example : examples) { + slot->AddExample(input_data, target, example); + } +} + +bool SplitCollectionOperator::IsInitialized(int32 node_id) const { + return stats_.at(node_id)->IsInitialized(); +} + +void SplitCollectionOperator::CreateAndInitializeCandidateWithExample( + const std::unique_ptr& input_data, int example, + int32 node_id) const { + // Assumes split_initializations_per_input == 1. + decision_trees::BinaryNode split; + float bias; + int type; + decision_trees::FeatureId feature_id; + input_data->RandomSample(example, &feature_id, &bias, &type); + + if (type == kDataFloat) { + decision_trees::InequalityTest* test = + split.mutable_inequality_left_child_test(); + *test->mutable_feature_id() = feature_id; + test->mutable_threshold()->set_float_value(bias); + test->set_type(params_.inequality_test_type()); + } else if (type == kDataCategorical) { + decision_trees::MatchingValuesTest test; + *test.mutable_feature_id() = feature_id; + test.add_value()->set_float_value(bias); + split.mutable_custom_left_child_test()->PackFrom(test); + } else { + LOG(ERROR) << "Unknown feature type " << type << ", not sure which " + << "node type to use."; + } + stats_.at(node_id)->AddSplit(split); +} + +bool SplitCollectionOperator::BestSplit(int32 node_id, + SplitCandidate* best, + int32* depth) const { + auto* slot = stats_.at(node_id).get(); + *depth = slot->depth(); + return slot->BestSplit(best); +} + +// -------------------------------- GraphRunner ------------------ // + +std::unique_ptr GraphRunnerSplitCollectionOperator::CreateGrowStats( + int32 node_id, int32 depth) const { + return std::unique_ptr(new SimpleStats(params_, depth)); +} + +int64 GraphRunnerSplitCollectionOperator::UniqueId(int32 node_id, + int32 split_id) const { + return node_id * num_splits_to_consider_ + split_id; +} + +bool GraphRunnerSplitCollectionOperator::BestSplit(int32 node_id, + SplitCandidate* best, + int32* depth) const { + float min_score = FLT_MAX; + int best_index = -1; + auto* slot = stats_.at(node_id).get(); + *depth = slot->depth(); + for (int i = 0; i < slot->num_splits(); ++i) { + // TODO(gilberth): Support uselessness. + auto& runner = runners_[UniqueId(node_id, i)]; + const float split_score = runner->SplitScore(); + if (split_score < min_score) { + min_score = split_score; + best_index = i; + } + } + + // This could happen if all the splits are useless. + if (best_index < 0) { + return false; + } + + // Fill in split info and left/right stats to initialize models with. + *best = SplitCandidate(); + auto& runner = runners_[UniqueId(node_id, best_index)]; + runner->GetLeftStats(best->mutable_left_stats()); + runner->GetRightStats(best->mutable_right_stats()); + runner->GetSplit(best->mutable_split()); + return true; +} + +void GraphRunnerSplitCollectionOperator::AddExample( + const std::unique_ptr& input_data, const InputTarget* target, + const std::vector& examples, int32 node_id) const { + // Build input Tensors. + int size = examples.size(); + Tensor examples_t(tensorflow::DT_INT32, TensorShape({size})); + auto ex_data = examples_t.flat(); + std::copy(examples.begin(), examples.end(), ex_data.data()); + + const TensorInputTarget* tensor_target = + dynamic_cast(target); + CHECK_NOTNULL(tensor_target); + + const Tensor& data_t = input_data->original_tensor(); + const Tensor& target_t = tensor_target->original_tensor(); + + // Add to candidates. + auto* slot = stats_.at(node_id).get(); + for (int i = 0; i < slot->num_splits(); ++i) { + auto& runner = runners_[UniqueId(node_id, i)]; + runner->AddExample(data_t, target_t, examples_t); + } + + // Update simple weight sums so we know when we're done. + for (int example : examples) { + slot->AddExample(input_data, target, example); + } +} + +void GraphRunnerSplitCollectionOperator:: + CreateAndInitializeCandidateWithExample( + const std::unique_ptr& input_data, int example, + int32 node_id) const { + auto* slot = stats_.at(node_id).get(); + int cand_num = slot->num_splits(); + const int64 unique_id = UniqueId(node_id, cand_num); + + decision_trees::BinaryNode split; + + decision_trees::InequalityTest* test = + split.mutable_inequality_left_child_test(); + auto* oblique = test->mutable_oblique(); + for (int i = 0; i < features_per_node_; ++i) { + float bias; + int type; + // This is really just a way to select a list of random features. + // Also a way to warn the user that categoricals don't make sense here. + input_data->RandomSample(example, oblique->add_features(), &bias, &type); + + if (type == kDataFloat) { + test->set_type(decision_trees::InequalityTest::LESS_OR_EQUAL); + + // The comparison bias is assumed to be zero. + test->mutable_threshold()->set_float_value(0); + } else { + LOG(ERROR) << "Categorical features not supported with this system."; + return; + } + } + + slot->AddSplit(split); + + runners_[unique_id].reset(new CandidateGraphRunner(graph_dir_, split)); + runners_[unique_id]->Init(); +} + +void GraphRunnerSplitCollectionOperator::ClearSlot(int32 node_id) { + SplitCollectionOperator::ClearSlot(node_id); + for (int i = 0; i < num_splits_to_consider_; ++i) { + runners_.erase(UniqueId(node_id, i)); + } +} + +} // namespace tensorforest +} // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h new file mode 100644 index 00000000000..d0ea33612aa --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/split_collection_operators.h @@ -0,0 +1,148 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_ + +#include +#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/candidate_graph_runner.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/leaf_model_operators.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/params.h" +#include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h" +#include "tensorflow/contrib/tensor_forest/proto/tensor_forest_params.pb.h" + +namespace tensorflow { +namespace tensorforest { + +// Class that can initialize and update split collections, and +// report if one is finished and ready to split. Designed to be inherited +// from to implement techniques such as pruning and early/delayed finishing. +class SplitCollectionOperator { + public: + explicit SplitCollectionOperator(const TensorForestParams& params) + : params_(params) {} + virtual ~SplitCollectionOperator() {} + + // Return a new GrowStats object according to stats_type_; + virtual std::unique_ptr CreateGrowStats(int32 node_id, + int32 depth) const; + + // Initialize from a previously serialized proto. + virtual void ExtractFromProto(const FertileStats& stats); + + // Serialize contents to the given proto. + virtual void PackToProto(FertileStats* stats) const; + + // Updates the slot's candidates with the new example. + // Assumes slot has been initialized. + virtual void AddExample(const std::unique_ptr& input_data, + const InputTarget* target, + const std::vector& examples, + int32 node_id) const; + + // Create a new candidate and initialize it with the given example. + virtual void CreateAndInitializeCandidateWithExample( + const std::unique_ptr& input_data, int example, + int32 node_id) const; + + // Create a new GrowStats for the given node id and initialize it. + virtual void InitializeSlot(int32 node_id, int32 depth); + + // Perform any necessary cleanup for any tracked state for the slot. + virtual void ClearSlot(int32 node_id) { + stats_.erase(node_id); + } + + // Return true if slot is fully initialized. + virtual bool IsInitialized(int32 node_id) const; + + // Return true if slot is finished. + virtual bool IsFinished(int32 node_id) const { + return stats_.at(node_id)->IsFinished(); + } + + // Fill in best with the best split that node_id has, return true if this + // was successful, false if no good split was found. + virtual bool BestSplit(int32 node_id, SplitCandidate* best, + int32* depth) const; + + protected: + const TensorForestParams& params_; + std::unordered_map> stats_; +}; + + +class GraphRunnerSplitCollectionOperator : public SplitCollectionOperator { + public: + explicit GraphRunnerSplitCollectionOperator(const TensorForestParams& params) + : SplitCollectionOperator(params) { + if (params.num_splits_to_consider().ParamType_case() == + DepthDependentParam::PARAMTYPE_NOT_SET) { + LOG(FATAL) << "GRAPH_RUNNER_COLLECTION must specify a constant value for " + << " num_splits_to_consider"; + } else { + num_splits_to_consider_ = + params.num_splits_to_consider().constant_value(); + } + } + + std::unique_ptr CreateGrowStats(int32 node_id, + int32 depth) const override; + + // Updates the slot's candidates with the new example. + // Assumes slot has been initialized. + void AddExample(const std::unique_ptr& input_data, + const InputTarget* target, const std::vector& examples, + int32 node_id) const override; + + // Create a new candidate and initialize it with the given example. + void CreateAndInitializeCandidateWithExample( + const std::unique_ptr& input_data, int example, + int32 node_id) const override; + + bool BestSplit(int32 node_id, SplitCandidate* best, + int32* depth) const override; + + void ClearSlot(int32 node_id) override; + + protected: + int64 UniqueId(int32 node_id, int32 split_id) const; + + mutable std::unordered_map> + runners_; + int features_per_node_; + string graph_dir_; + // Must have a constant value because of how we make unique ids right now. + int32 num_splits_to_consider_; +}; + +// Creates a type of SplitCollectionOperator depending on the type passed, +// which is SplitCollectionType in fertile_stats.proto. +// Can create a SplitCollectionOperator itself, known as "basic". +class SplitCollectionOperatorFactory { + public: + static std::unique_ptr CreateSplitCollectionOperator( + const TensorForestParams& params); +}; + +} // namespace tensorforest +} // namespace tensorflow + + + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_SPLIT_COLLECTION_OPERATORS_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.cc b/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.cc new file mode 100644 index 00000000000..0bec198e97e --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.cc @@ -0,0 +1,87 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h" +#include + +#include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" + +namespace tensorflow { +namespace tensorforest { + +// When using smoothing but only tracking sum and squares, and we're adding +// num_classes for smoothing each class, then Gini looks more like this: +// Gini = 1 - \sum_i (c_i + 1)^2 / C^2 +// = 1 - (1 / C^2) ( (\sum_i c_i)^2 + 2 (\sum_i c_i) + (\sum_i 1)) +// = 1 - (1 / C^2) ( stats.square() + 2 stats.sum() + #_classes) +// = 1 - ( stats.square() + 2 stats.sum() + #_classes) / (smoothed_sum * +// smoothed_sum) +// +// where +// smoothed_sum = stats.sum() + #_classes +float GiniImpurity(const LeafStat& stats, int32 num_classes) { + const float smoothed_sum = num_classes + stats.weight_sum(); + return 1.0 - ( + (stats.classification().gini().square() + + 2 * stats.weight_sum() + num_classes) / (smoothed_sum * smoothed_sum)); +} + +float WeightedGiniImpurity(const LeafStat& stats, int32 num_classes) { + return stats.weight_sum() * GiniImpurity(stats, num_classes); +} + +void UpdateGini(LeafStat* stats, float old_val, float weight) { + stats->set_weight_sum(stats->weight_sum() + weight); + // Equivalent to stats->square() - old_val * old_val + new_val * new_val, + // (for new_val = old_val + weight), but more numerically stable. + stats->mutable_classification()->mutable_gini()->set_square( + stats->classification().gini().square() + + weight * weight + 2 * old_val * weight); +} + + +float Variance(const LeafStat& stats, int output) { + if (stats.weight_sum() == 0) { + return 0; + } + const float e_x = + stats.regression().mean_output().value(output).float_value() + / stats.weight_sum(); + const auto e_x2 = + stats.regression().mean_output_squares().value(output).float_value() + / stats.weight_sum(); + return e_x2 - e_x * e_x; +} + +float TotalVariance(const LeafStat& stats) { + float sum = 0; + for (int i = 0; i < stats.regression().mean_output().value_size(); ++i) { + sum += Variance(stats, i); + } + return sum; +} + +float SmoothedGini(float sum, float square, int num_classes) { + // See comments for GiniImpurity above. + const float smoothed_sum = num_classes + sum; + return 1.0 - + (square + 2 * sum + num_classes) / (smoothed_sum * smoothed_sum); +} + +float WeightedSmoothedGini(float sum, float square, int num_classes) { + return sum * SmoothedGini(sum, square, num_classes); +} + +} // namespace tensorforest +} // namespace tensorflow diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h b/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h new file mode 100644 index 00000000000..8e002d0414f --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/stat_utils.h @@ -0,0 +1,50 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_ +#include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace tensorforest { + +// Returns the smoothed, unweighted Gini impurity. +float GiniImpurity(const LeafStat& stats, int32 num_classes); + +// Returns the smoothed, weighted Gini impurity +float WeightedGiniImpurity(const LeafStat& stats, int32 num_classes); + +// Updates the GiniStats given the old and new values of a class count that +// was updated. +void UpdateGini(LeafStat* stats, float old_val, float weight); + +// Returns the variance in stats for the given output. +float Variance(const LeafStat& stats, int output); + +// Returns the variance sum for all outputs. +float TotalVariance(const LeafStat& stats); + +// ------- functions used by C++ stats classes -------- // +// Returns the smoothed gini score given the sum and sum of the squares of the +// class counts. +float SmoothedGini(float sum, float square, int num_classes); + +// Returns the smoothed gini score weighted by the sum. +float WeightedSmoothedGini(float sum, float square, int num_classes); + +} // namespace tensorforest +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_STAT_UTILS_H_ diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h b/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h new file mode 100644 index 00000000000..4ac23ceb3e5 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/kernels/v4/test_utils.h @@ -0,0 +1,73 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_TEST_UTILS_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_TEST_UTILS_H_ +#include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h" +#include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h" + +namespace tensorflow { +namespace tensorforest { + +class TestableInputTarget : public StoredInputTarget> { + public: + TestableInputTarget(const std::vector* t, const std::vector* w, + int num_t) + : StoredInputTarget(t, w, num_t) {} + + int NumItems() const { + return target_->size(); + } + + int32 GetTargetAsClassIndex(int example_index, + int target_index) const override { + return static_cast( + GetTargetAsContinuous(example_index, target_index)); + } + + float GetTargetWeight(int example_index) const override { + const size_t num_weights = weight_->size(); + return num_weights > 0 && example_index < num_weights + ? (*weight_)[example_index] + : 1.0; + } + + float GetTargetAsContinuous(int example_index, + int target_index) const override { + QCHECK_LT(target_index, num_targets_); + return (*target_)[example_index * num_targets_ + target_index]; + } +}; + + +class TestableDataSet : public TensorDataSet { + public: + TestableDataSet(const std::vector& data, int num_features) + : TensorDataSet(TensorForestDataSpec(), 11), + num_features_(num_features), + data_(data) {} + + float GetExampleValue(int example, int32 feature_id) const override { + return data_[example * num_features_ + feature_id]; + } + + protected: + int num_features_; + std::vector data_; +}; + +} // namespace tensorforest +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_TEST_UTILS_H_ diff --git a/tensorflow/contrib/tensor_forest/proto/BUILD b/tensorflow/contrib/tensor_forest/proto/BUILD new file mode 100644 index 00000000000..1cfef44af1a --- /dev/null +++ b/tensorflow/contrib/tensor_forest/proto/BUILD @@ -0,0 +1,31 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") + +package(default_visibility = ["//visibility:public"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +tf_proto_library( + name = "fertile_stats_proto", + srcs = ["fertile_stats.proto"], + cc_api_version = 2, + protodeps = ["//tensorflow/contrib/decision_trees/proto:generic_tree_model"], + visibility = ["//visibility:public"], +) + +tf_proto_library( + name = "tensor_forest_params_proto", + srcs = ["tensor_forest_params.proto"], + cc_api_version = 2, + protodeps = ["//tensorflow/contrib/decision_trees/proto:generic_tree_model"], + visibility = ["//visibility:public"], +) diff --git a/tensorflow/contrib/tensor_forest/proto/fertile_stats.proto b/tensorflow/contrib/tensor_forest/proto/fertile_stats.proto new file mode 100644 index 00000000000..0ded04ad75a --- /dev/null +++ b/tensorflow/contrib/tensor_forest/proto/fertile_stats.proto @@ -0,0 +1,92 @@ +syntax = "proto3"; +option cc_enable_arenas = true; + +package tensorflow.tensorforest; + +import "tensorflow/contrib/decision_trees/proto/generic_tree_model.proto"; + + +message FertileStats { + // Tracks stats for each node. node_to_slot[i] is the FertileSlot for node i. + // This may be sized to max_nodes initially, or grow dynamically as needed. + repeated FertileSlot node_to_slot = 1; +} + + +message GiniStats { + // This allows us to quickly track and calculate impurity (classification) + // by storing the sum of input weights and the sum of the squares of the + // input weights. Weighted gini is then: 1 - (square / sum * sum). + // Updates to these numbers are: + // old_i = leaf->value(label) + // new_i = old_i + incoming_weight + // sum -> sum + incoming_weight + // square -> square - (old_i ^ 2) + (new_i ^ 2) + // total_left_sum -> total_left_sum - old_left_i * old_total_i + + // new_left_i * new_total_i + float square = 2; +} + +message LeafStat { + // The sum of the weights of the training examples that we have seen. + // This is here, outside of the leaf_stat oneof, because almost all + // types will want it. + float weight_sum = 3; + + // TODO(thomaswc): Move the GiniStats out of LeafStats and into something + // that only tracks them for splits. + message GiniImpurityClassificationStats { + oneof counts { + decision_trees.Vector dense_counts = 1; + decision_trees.SparseVector sparse_counts = 2; + } + GiniStats gini = 3; + } + + // This is the info needed for calculating variance for regression. + // Variance will still have to be summed over every output, but the + // number of outputs in regression problems is almost always 1. + message LeastSquaresRegressionStats { + decision_trees.Vector mean_output = 1; + decision_trees.Vector mean_output_squares = 2; + } + + oneof leaf_stat { + GiniImpurityClassificationStats classification = 1; + LeastSquaresRegressionStats regression = 2; + // TODO(thomaswc): Add in v5's SparseClassStats. + } +} + +message FertileSlot { + // The statistics for *all* the examples seen at this leaf. + LeafStat leaf_stats = 4; + + repeated SplitCandidate candidates = 1; + + // The statistics for the examples seen at this leaf after all the + // splits have been initialized. If post_init_leaf_stats.weight_sum + // is > 0, then all candidates have been initialized. We need to track + // both leaf_stats and post_init_leaf_stats because the first is used + // to create the decision_tree::Leaf and the second is used to infer + // the statistics for the right side of a split (given the leaf side + // stats). + LeafStat post_init_leaf_stats = 6; + + int32 node_id = 5; + int32 depth = 7; +} + +message SplitCandidate { + // proto representing the potential node. + decision_trees.BinaryNode split = 1; + + // Right counts are inferred from FertileSlot.leaf_stats and left. + LeafStat left_stats = 4; + + // Right stats (not full counts) are kept here. + LeafStat right_stats = 5; + + // Fields used when training with a graph runner. + string unique_id = 6; +} diff --git a/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto b/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto new file mode 100644 index 00000000000..49b19e0b623 --- /dev/null +++ b/tensorflow/contrib/tensor_forest/proto/tensor_forest_params.proto @@ -0,0 +1,146 @@ +syntax = "proto3"; + +package tensorflow.tensorforest; + +import "tensorflow/contrib/decision_trees/proto/generic_tree_model.proto"; + +// Leaf models specify what is returned at inference time, and how it is +// stored in the decision_trees.Leaf protos. +enum LeafModelType { + MODEL_DENSE_CLASSIFICATION = 0; + MODEL_SPARSE_CLASSIFICATION = 1; + MODEL_REGRESSION = 2; + MODEL_SPARSE_OR_DENSE_CLASSIFICATION = 3; +} + +// Stats models generally specify information that is collected which is +// necessary to choose a split at a node. Specifically, they operate on +// a SplitCandidate::LeafStat proto. +enum StatsModelType { + STATS_DENSE_GINI = 0; + STATS_SPARSE_GINI = 1; + STATS_LEAST_SQUARES_REGRESSION = 2; + STATS_SPARSE_THEN_DENSE_GINI = 3; +} + +// Allows selection of operations on the collection of split candidates. +// Basic infers right split stats from the leaf stats and each candidate's +// left stats. +enum SplitCollectionType { + COLLECTION_BASIC = 0; + GRAPH_RUNNER_COLLECTION = 1; +} + +// Pruning strategies define how candidates are pruned over time. +// SPLIT_PRUNE_HALF prunes the worst half of splits every prune_ever_samples, +// etc. Note that prune_every_samples plays against the depth-dependent +// split_after_samples, so they should be set together. +enum SplitPruningStrategyType { + SPLIT_PRUNE_NONE = 0; + SPLIT_PRUNE_HALF = 1; + SPLIT_PRUNE_QUARTER = 2; + SPLIT_PRUNE_10_PERCENT = 3; + // SPLIT_PRUNE_HOEFFDING prunes splits whose Gini impurity is worst than + // the best split's by more than the Hoeffding bound. + SPLIT_PRUNE_HOEFFDING = 4; +} + +message SplitPruningConfig { + DepthDependentParam prune_every_samples = 1; + SplitPruningStrategyType type = 2; +} + +// Finish strategies define when slots are considered finished. +// Basic requires at least split_after_samples, and doesn't allow slots to +// finish until the leaf has received more than one class. Hoeffding splits +// early after min_split_samples if one split is dominating the rest according +// to hoeffding bounds. Bootstrap does the same but compares gini's calculated +// with sampled smoothed counts. +enum SplitFinishStrategyType { + SPLIT_FINISH_BASIC = 0; + SPLIT_FINISH_DOMINATE_HOEFFDING = 2; + SPLIT_FINISH_DOMINATE_BOOTSTRAP = 3; +} + +message SplitFinishConfig { + // Configure how often we check for finish, because some finish methods + // are expensive to perform. + DepthDependentParam check_every_steps = 1; + SplitFinishStrategyType type = 2; +} + +// A parameter that changes linearly with depth, with upper and lower bounds. +message LinearParam { + float slope = 1; + float y_intercept = 2; + float min_val = 3; + float max_val = 4; +} + +// A parameter that changes expoentially with the form +// f = c + mb^(k*d) +// where: +// c: constant bias +// b: base +// m: multiplier +// k: depth multiplier +// d: depth +message ExponentialParam { + float bias = 1; + float base = 2; + float multiplier = 3; + float depth_multiplier = 4; +} + +// A parameter that is 'off' until depth >= a threshold, then is 'on'. +message ThresholdParam { + float on_value = 1; + float off_value = 2; + float threshold = 3; +} + +// A parameter that may change with node depth. +message DepthDependentParam { + oneof ParamType { + float constant_value = 1; + LinearParam linear = 2; + ExponentialParam exponential = 3; + ThresholdParam threshold = 4; + } +} + +message TensorForestParams { + // ------------ Types that control training subsystems ------ // + LeafModelType leaf_type = 1; + StatsModelType stats_type = 2; + SplitCollectionType collection_type = 3; + SplitPruningConfig pruning_type = 4; + SplitFinishConfig finish_type = 5; + + // --------- Parameters that can't change by definition --------------- // + int32 num_trees = 6; + int32 max_nodes = 7; + int32 num_features = 21; + + decision_trees.InequalityTest.Type inequality_test_type = 19; + + // Some booleans controlling execution + bool is_regression = 8; + bool drop_final_class = 9; + bool collate_examples = 10; + bool checkpoint_stats = 11; + bool use_running_stats_method = 20; + + // Number of classes (classification) or targets (regression) + int32 num_outputs = 12; + + // --------- Parameters that could be depth-dependent --------------- // + DepthDependentParam num_splits_to_consider = 13; + DepthDependentParam split_after_samples = 14; + DepthDependentParam dominate_fraction = 15; + DepthDependentParam min_split_samples = 18; + + // --------- Parameters for experimental features ---------------------- // + string graph_dir = 16; + int32 num_select_features = 17; +} diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index 177783c207e..bdfe981baf3 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -435,7 +435,7 @@ class RandomForestGraphs(object): if processed_sparse_features is not None: raise NotImplementedError( 'Feature bagging not supported with sparse features.') - tree_data = self._bag_features(i, input_data) + tree_data = self._bag_features(i, tree_data) probabilities.append(self.trees[i].inference_graph( tree_data, data_spec, diff --git a/tensorflow/contrib/tensorboard/BUILD b/tensorflow/contrib/tensorboard/BUILD index 13de7fb39d9..2e0a46ffe43 100644 --- a/tensorflow/contrib/tensorboard/BUILD +++ b/tensorflow/contrib/tensorboard/BUILD @@ -42,8 +42,8 @@ py_library( srcs = ["plugins/projector/__init__.py"], srcs_version = "PY2AND3", deps = [ + ":protos_all_py", "//tensorflow/python:lib", - "//tensorflow/tensorboard/plugins/projector:protos_all_py", ], ) @@ -54,10 +54,10 @@ py_test( srcs_version = "PY2AND3", deps = [ ":projector", + ":protos_all_py", "//tensorflow/python:client_testlib", "//tensorflow/python:platform", "//tensorflow/python:summary", - "//tensorflow/tensorboard/plugins/projector:protos_all_py", ], ) diff --git a/tensorflow/contrib/tensorboard/plugins/projector/__init__.py b/tensorflow/contrib/tensorboard/plugins/projector/__init__.py index be2398cdc0c..7b9be76757c 100644 --- a/tensorflow/contrib/tensorboard/plugins/projector/__init__.py +++ b/tensorflow/contrib/tensorboard/plugins/projector/__init__.py @@ -28,11 +28,11 @@ from __future__ import print_function import os from google.protobuf import text_format -from tensorflow.python.lib.io import file_io -from tensorflow.tensorboard.plugins.projector import projector_config_pb2 +from tensorflow.contrib.tensorboard.plugins.projector import projector_config_pb2 # pylint: disable=wildcard-import -from tensorflow.tensorboard.plugins.projector.projector_config_pb2 import * +from tensorflow.contrib.tensorboard.plugins.projector.projector_config_pb2 import * # pylint: enable=wildcard-import +from tensorflow.python.lib.io import file_io def visualize_embeddings(summary_writer, config): diff --git a/tensorflow/contrib/tensorboard/plugins/projector/projector_api_test.py b/tensorflow/contrib/tensorboard/plugins/projector/projector_api_test.py index 5f86f57a1c6..9ad42bff47f 100644 --- a/tensorflow/contrib/tensorboard/plugins/projector/projector_api_test.py +++ b/tensorflow/contrib/tensorboard/plugins/projector/projector_api_test.py @@ -24,10 +24,10 @@ import shutil from google.protobuf import text_format from tensorflow.contrib.tensorboard.plugins import projector +from tensorflow.contrib.tensorboard.plugins.projector import projector_config_pb2 from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary.writer import writer as writer_lib -from tensorflow.tensorboard.plugins.projector import projector_config_pb2 class ProjectorApiTest(test.TestCase): diff --git a/tensorflow/tensorboard/plugins/projector/projector_config.proto b/tensorflow/contrib/tensorboard/plugins/projector/projector_config.proto similarity index 100% rename from tensorflow/tensorboard/plugins/projector/projector_config.proto rename to tensorflow/contrib/tensorboard/plugins/projector/projector_config.proto diff --git a/tensorflow/contrib/testing/BUILD b/tensorflow/contrib/testing/BUILD index 225a1ccd126..0be6aa755be 100644 --- a/tensorflow/contrib/testing/BUILD +++ b/tensorflow/contrib/testing/BUILD @@ -16,6 +16,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/core:protos_all_py", "//tensorflow/python:summary", "//tensorflow/python:training", "//third_party/py/numpy", diff --git a/tensorflow/contrib/text/BUILD b/tensorflow/contrib/text/BUILD index 6bcb03238cc..8a2cb28684f 100644 --- a/tensorflow/contrib/text/BUILD +++ b/tensorflow/contrib/text/BUILD @@ -101,8 +101,6 @@ py_test( "//tensorflow/python:errors", "//tensorflow/python:lookup_ops", "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:platform_test", "//tensorflow/python:random_seed", "//tensorflow/python:training", ], diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD b/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD index 8040c791ee4..32403c3af45 100644 --- a/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD +++ b/tensorflow/contrib/tfprof/python/tools/tfprof/BUILD @@ -2,6 +2,7 @@ package(default_visibility = ["//visibility:public"]) licenses(["notice"]) # Apache 2.0 +load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "tf_py_test") @@ -16,12 +17,10 @@ py_library( ], ) -py_test( +cuda_py_test( name = "model_analyzer_test", srcs = ["model_analyzer_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ + additional_deps = [ ":model_analyzer", "//tensorflow/contrib/tfprof/python/tools/tfprof/internal:model_analyzer_testlib", "//tensorflow/python:client", @@ -30,14 +29,13 @@ py_test( "//tensorflow/python:platform", "//tensorflow/python:variables", ], + tags = ["no_pip"], ) -py_test( +cuda_py_test( name = "profiler_test", srcs = ["profiler_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ + additional_deps = [ ":model_analyzer", "//tensorflow/contrib/tfprof/python/tools/tfprof/internal:model_analyzer_testlib", "//tensorflow/python:client", @@ -46,6 +44,7 @@ py_test( "//tensorflow/python:platform", "//tensorflow/python:variables", ], + tags = ["no_pip"], ) py_library( diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/internal/BUILD b/tensorflow/contrib/tfprof/python/tools/tfprof/internal/BUILD index 3fa5b7867d4..e8440468119 100644 --- a/tensorflow/contrib/tfprof/python/tools/tfprof/internal/BUILD +++ b/tensorflow/contrib/tfprof/python/tools/tfprof/internal/BUILD @@ -2,6 +2,7 @@ package(default_visibility = ["//tensorflow/contrib/tfprof/python/tools/tfprof:_ licenses(["notice"]) # Apache 2.0 +load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "tf_py_test") load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") @@ -60,6 +61,19 @@ py_test( ], ) +cuda_py_test( + name = "run_metadata_test", + srcs = ["run_metadata_test.py"], + additional_deps = [ + "//tensorflow/contrib/tfprof/python/tools/tfprof:model_analyzer", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/tools/tfprof:protos_all_py", + ], +) + # ----------------------------------------------------------------------------- # Google-internal targets. These must be at the end for syncrepo. diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/internal/print_model_analysis_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/internal/print_model_analysis_test.py index 76e7d627cea..091fcdef885 100644 --- a/tensorflow/contrib/tfprof/python/tools/tfprof/internal/print_model_analysis_test.py +++ b/tensorflow/contrib/tfprof/python/tools/tfprof/internal/print_model_analysis_test.py @@ -153,6 +153,10 @@ class PrintModelAnalysisTest(test.TestCase): } } } + accelerator_exec_micros: 0 + cpu_exec_micros: 0 + total_accelerator_exec_micros: 0 + total_cpu_exec_micros: 0 } children { name: "DW" @@ -205,6 +209,10 @@ class PrintModelAnalysisTest(test.TestCase): } } } + accelerator_exec_micros: 0 + cpu_exec_micros: 0 + total_accelerator_exec_micros: 0 + total_cpu_exec_micros: 0 } children { name: "DW/Initializer" @@ -237,6 +245,10 @@ class PrintModelAnalysisTest(test.TestCase): } } } + accelerator_exec_micros: 0 + cpu_exec_micros: 0 + total_accelerator_exec_micros: 0 + total_cpu_exec_micros: 0 } children { name: "DW/Initializer/random_normal/mean" @@ -247,6 +259,10 @@ class PrintModelAnalysisTest(test.TestCase): total_parameters: 0 float_ops: 0 total_float_ops: 0 + accelerator_exec_micros: 0 + cpu_exec_micros: 0 + total_accelerator_exec_micros: 0 + total_cpu_exec_micros: 0 } children { name: "DW/Initializer/random_normal/mul" @@ -282,6 +298,10 @@ class PrintModelAnalysisTest(test.TestCase): } } } + accelerator_exec_micros: 0 + cpu_exec_micros: 0 + total_accelerator_exec_micros: 0 + total_cpu_exec_micros: 0 } children { name: "DW/Initializer/random_normal/shape" @@ -292,6 +312,10 @@ class PrintModelAnalysisTest(test.TestCase): total_parameters: 0 float_ops: 0 total_float_ops: 0 + accelerator_exec_micros: 0 + cpu_exec_micros: 0 + total_accelerator_exec_micros: 0 + total_cpu_exec_micros: 0 } children { name: "DW/Initializer/random_normal/stddev" @@ -302,6 +326,10 @@ class PrintModelAnalysisTest(test.TestCase): total_parameters: 0 float_ops: 0 total_float_ops: 0 + accelerator_exec_micros: 0 + cpu_exec_micros: 0 + total_accelerator_exec_micros: 0 + total_cpu_exec_micros: 0 } float_ops: 0 total_float_ops: 0 @@ -330,9 +358,17 @@ class PrintModelAnalysisTest(test.TestCase): } } } + accelerator_exec_micros: 0 + cpu_exec_micros: 0 + total_accelerator_exec_micros: 0 + total_cpu_exec_micros: 0 } float_ops: 0 total_float_ops: 0 + accelerator_exec_micros: 0 + cpu_exec_micros: 0 + total_accelerator_exec_micros: 0 + total_cpu_exec_micros: 0 } children { name: "DW/read" @@ -360,9 +396,17 @@ class PrintModelAnalysisTest(test.TestCase): } } } + accelerator_exec_micros: 0 + cpu_exec_micros: 0 + total_accelerator_exec_micros: 0 + total_cpu_exec_micros: 0 } float_ops: 0 total_float_ops: 0 + accelerator_exec_micros: 0 + cpu_exec_micros: 0 + total_accelerator_exec_micros: 0 + total_cpu_exec_micros: 0 } children { name: "zeros" @@ -373,9 +417,17 @@ class PrintModelAnalysisTest(test.TestCase): total_parameters: 0 float_ops: 0 total_float_ops: 0 + accelerator_exec_micros: 0 + cpu_exec_micros: 0 + total_accelerator_exec_micros: 0 + total_cpu_exec_micros: 0 } float_ops: 0 - total_float_ops: 0""", expected_pb) + total_float_ops: 0 + accelerator_exec_micros: 0 + cpu_exec_micros: 0 + total_accelerator_exec_micros: 0 + total_cpu_exec_micros: 0""", expected_pb) self.assertEqual(expected_pb, tfprof_pb) diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/internal/run_metadata_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/internal/run_metadata_test.py new file mode 100644 index 00000000000..feb20dc0f4b --- /dev/null +++ b/tensorflow/contrib/tfprof/python/tools/tfprof/internal/run_metadata_test.py @@ -0,0 +1,111 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""test the RunMetadata proto.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import defaultdict + +import six + +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test + +# pylint: disable=g-bad-import-order +# XXX: this depends on pywrap_tensorflow and must come later +from tensorflow.contrib.tfprof.python.tools.tfprof import model_analyzer + +SIZE = 1300 + + +def _extract_node(run_meta, node_name): + ret = defaultdict(list) + for dev_stat in run_meta.step_stats.dev_stats: + dev = dev_stat.device + for node_stat in dev_stat.node_stats: + if node_stat.node_name == node_name: + ret[dev].append(node_stat) + return ret + + +def _run_model(): + x = random_ops.random_normal(shape=[1, SIZE]) + w = random_ops.random_normal(shape=[SIZE, 2 * SIZE]) + y = math_ops.matmul(x, w) + + with session.Session() as sess: + run_metadata = config_pb2.RunMetadata() + opts = model_analyzer.PRINT_ALL_TIMING_MEMORY + opts['min_micros'] = 0 + opts['min_bytes'] = 0 + _ = sess.run(y, + options=config_pb2.RunOptions( + trace_level=config_pb2.RunOptions.FULL_TRACE), + run_metadata=run_metadata) + tfprof_node = model_analyzer.print_model_analysis( + sess.graph, + run_meta=run_metadata, + tfprof_options=opts) + + return tfprof_node, run_metadata + + +class RunMetadataTest(test.TestCase): + + def testGPU(self): + if not test.is_gpu_available(): + return + + with ops.device('/gpu:0'): + tfprof_node, run_meta = _run_model() + + self.assertEqual(tfprof_node.children[0].name, 'MatMul') + self.assertGreater(tfprof_node.children[0].exec_micros, 10) + + ret = _extract_node(run_meta, 'MatMul') + self.assertEqual(len(ret), 1) + self.assertTrue('/job:localhost/replica:0/task:0/gpu:0' in ret) + + ret = _extract_node(run_meta, 'MatMul:MatMul') + self.assertEqual(len(ret), 2) + has_all_stream = False + for k, _ in six.iteritems(ret): + self.assertTrue('gpu:0/stream' in k) + if 'gpu:0/stream:all' in k: + has_all_stream = True + self.assertTrue(has_all_stream) + + def testCPU(self): + with ops.device('/cpu:0'): + tfprof_node, run_meta = _run_model() + + self.assertEqual(tfprof_node.children[0].name, 'MatMul') + self.assertGreater(tfprof_node.children[0].exec_micros, 10) + + ret = _extract_node(run_meta, 'MatMul') + self.assertEqual(len(ret), 1) + self.assertTrue('/job:localhost/replica:0/task:0/cpu:0' in ret) + + ret = _extract_node(run_meta, 'MatMul:MatMul') + self.assertEqual(len(ret), 0) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py index 913971afaf1..1b5041441f9 100644 --- a/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py +++ b/tensorflow/contrib/tfprof/python/tools/tfprof/model_analyzer_test.py @@ -38,7 +38,7 @@ class PrintModelAnalysisTest(test.TestCase): outfile = os.path.join(test.get_temp_dir(), 'dump') opts['output'] = 'file:outfile=' + outfile - with session.Session() as sess, ops.device('/cpu:0'): + with session.Session() as sess: _ = lib.BuildSmallModel() model_analyzer.print_model_analysis(sess.graph, tfprof_options=opts) @@ -57,7 +57,7 @@ class PrintModelAnalysisTest(test.TestCase): opts['output'] = 'file:outfile=' + outfile opts['account_type_regexes'] = ['.*'] opts['select'] = [ - 'bytes', 'params', 'float_ops', 'occurrence', 'device', 'op_types', + 'params', 'float_ops', 'occurrence', 'device', 'op_types', 'input_shapes' ] @@ -77,7 +77,7 @@ class PrintModelAnalysisTest(test.TestCase): with gfile.Open(outfile, 'r') as f: # pylint: disable=line-too-long self.assertEqual( - 'node name | # parameters | # float_ops | output bytes | assigned devices | op types | input shapes\n_TFProfRoot (--/451 params, --/10.44k flops, --/5.28KB, _kTFScopeParent, )\n Conv2D (0/0 params, 5.83k/5.83k flops, 432B/432B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D, 0:2x6x6x3|1:3x3x3x6)\n Conv2D_1 (0/0 params, 4.61k/4.61k flops, 384B/384B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D, 0:2x3x3x6|1:2x2x6x12)\n DW (3x3x3x6, 162/162 params, 0/0 flops, 648B/1.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables, )\n DW/Assign (0/0 params, 0/0 flops, 0B/0B, Assign, 0:3x3x3x6|1:3x3x3x6)\n DW/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent, )\n DW/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add, 0:3x3x3x6|1:1)\n DW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal, 0:4)\n DW/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const, )\n DW/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul, 0:3x3x3x6|1:1)\n DW/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const, )\n DW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const, )\n DW/read (0/0 params, 0/0 flops, 648B/648B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity, 0:3x3x3x6)\n DW2 (2x2x6x12, 288/288 params, 0/0 flops, 1.15KB/2.30KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables, )\n DW2/Assign (0/0 params, 0/0 flops, 0B/0B, Assign, 0:2x2x6x12|1:2x2x6x12)\n DW2/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent, )\n DW2/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add, 0:2x2x6x12|1:1)\n DW2/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal, 0:4)\n DW2/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const, )\n DW2/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul, 0:2x2x6x12|1:1)\n DW2/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const, )\n DW2/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const, )\n DW2/read (0/0 params, 0/0 flops, 1.15KB/1.15KB, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity, 0:2x2x6x12)\n ScalarW (1, 1/1 params, 0/0 flops, 0B/0B, VariableV2|_trainable_variables, )\n ScalarW/Assign (0/0 params, 0/0 flops, 0B/0B, Assign, 0:1|1:1)\n ScalarW/Initializer (0/0 params, 0/0 flops, 0B/0B, _kTFScopeParent, )\n ScalarW/Initializer/random_normal (0/0 params, 0/0 flops, 0B/0B, Add, 0:1|1:1)\n ScalarW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, 0B/0B, RandomStandardNormal, 0:0)\n ScalarW/Initializer/random_normal/mean (0/0 params, 0/0 flops, 0B/0B, Const, )\n ScalarW/Initializer/random_normal/mul (0/0 params, 0/0 flops, 0B/0B, Mul, 0:1|1:1)\n ScalarW/Initializer/random_normal/shape (0/0 params, 0/0 flops, 0B/0B, Const, )\n ScalarW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, 0B/0B, Const, )\n ScalarW/read (0/0 params, 0/0 flops, 0B/0B, Identity, 0:1)\n init (0/0 params, 0/0 flops, 0B/0B, NoOp, 0:1|1:3x3x3x6|2:2x2x6x12)\n zeros (0/0 params, 0/0 flops, 864B/864B, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Const, )\n', + 'node name | # parameters | # float_ops | assigned devices | op types | input shapes\n_TFProfRoot (--/451 params, --/10.44k flops, _kTFScopeParent, )\n Conv2D (0/0 params, 5.83k/5.83k flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D, 0:2x6x6x3|1:3x3x3x6)\n Conv2D_1 (0/0 params, 4.61k/4.61k flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Conv2D, 0:2x3x3x6|1:2x2x6x12)\n DW (3x3x3x6, 162/162 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables, )\n DW/Assign (0/0 params, 0/0 flops, Assign, 0:3x3x3x6|1:3x3x3x6)\n DW/Initializer (0/0 params, 0/0 flops, _kTFScopeParent, )\n DW/Initializer/random_normal (0/0 params, 0/0 flops, Add, 0:3x3x3x6|1:1)\n DW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0:4)\n DW/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, )\n DW/Initializer/random_normal/mul (0/0 params, 0/0 flops, Mul, 0:3x3x3x6|1:1)\n DW/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, )\n DW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, )\n DW/read (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity, 0:3x3x3x6)\n DW2 (2x2x6x12, 288/288 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|VariableV2|_trainable_variables, )\n DW2/Assign (0/0 params, 0/0 flops, Assign, 0:2x2x6x12|1:2x2x6x12)\n DW2/Initializer (0/0 params, 0/0 flops, _kTFScopeParent, )\n DW2/Initializer/random_normal (0/0 params, 0/0 flops, Add, 0:2x2x6x12|1:1)\n DW2/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0:4)\n DW2/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, )\n DW2/Initializer/random_normal/mul (0/0 params, 0/0 flops, Mul, 0:2x2x6x12|1:1)\n DW2/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, )\n DW2/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, )\n DW2/read (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Identity, 0:2x2x6x12)\n ScalarW (1, 1/1 params, 0/0 flops, VariableV2|_trainable_variables, )\n ScalarW/Assign (0/0 params, 0/0 flops, Assign, 0:1|1:1)\n ScalarW/Initializer (0/0 params, 0/0 flops, _kTFScopeParent, )\n ScalarW/Initializer/random_normal (0/0 params, 0/0 flops, Add, 0:1|1:1)\n ScalarW/Initializer/random_normal/RandomStandardNormal (0/0 params, 0/0 flops, RandomStandardNormal, 0:0)\n ScalarW/Initializer/random_normal/mean (0/0 params, 0/0 flops, Const, )\n ScalarW/Initializer/random_normal/mul (0/0 params, 0/0 flops, Mul, 0:1|1:1)\n ScalarW/Initializer/random_normal/shape (0/0 params, 0/0 flops, Const, )\n ScalarW/Initializer/random_normal/stddev (0/0 params, 0/0 flops, Const, )\n ScalarW/read (0/0 params, 0/0 flops, Identity, 0:1)\n init (0/0 params, 0/0 flops, NoOp, 0:1|1:3x3x3x6|2:2x2x6x12)\n zeros (0/0 params, 0/0 flops, /job:localhost/replica:0/task:0/cpu:0, /job:localhost/replica:0/task:0/cpu:0|Const, )\n', f.read()) # pylint: enable=line-too-long @@ -96,7 +96,7 @@ class PrintModelAnalysisTest(test.TestCase): 'input_shapes' ] - with session.Session() as sess, ops.device('/cpu:0'): + with session.Session() as sess: x = lib.BuildSmallModel() sess.run(variables.global_variables_initializer()) @@ -177,7 +177,7 @@ class PrintModelAnalysisTest(test.TestCase): 'bytes', 'params', 'float_ops', 'device' ] - with session.Session() as sess, ops.device('/cpu:0'): + with session.Session() as sess: x = lib.BuildSmallModel() sess.run(variables.global_variables_initializer()) @@ -205,7 +205,7 @@ class PrintModelAnalysisTest(test.TestCase): opts['max_depth'] = 100000 opts['step'] = 0 - with session.Session() as sess, ops.device('/cpu:0'): + with session.Session() as sess: x = lib.BuildFullModel() sess.run(variables.global_variables_initializer()) @@ -233,7 +233,7 @@ class PrintModelAnalysisTest(test.TestCase): opts['select'] = ['params', 'micros', 'occurrence', 'input_shapes'] opts['order_by'] = 'occurrence' - with session.Session() as sess, ops.device('/cpu:0'): + with session.Session() as sess: x = lib.BuildFullModel() sess.run(variables.global_variables_initializer()) @@ -247,9 +247,11 @@ class PrintModelAnalysisTest(test.TestCase): sess.graph, run_meta, tfprof_cmd='op', tfprof_options=opts) with gfile.Open(outfile, 'r') as f: + # pylint: disable=line-too-long self.assertEqual( - 'nodename|executiontime|#parameters|opoccurrence|inputshapes\n', - f.read().replace('\t', '').replace(' ', '')[0:60]) + 'nodename|totalexecutiontime|acceleratorexecutiontime|cpuexecutiontime|#parameters|opoccurrence|input', + f.read().replace('\t', '').replace(' ', '')[0:100]) + # pylint: enable=line-too-long total_children = 0 last_occurrence = 1e32 diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD new file mode 100644 index 00000000000..0a83070c345 --- /dev/null +++ b/tensorflow/contrib/tpu/BUILD @@ -0,0 +1,224 @@ +# Description: Operations defined for Cloud TPUs + +package( + default_visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", + ], +) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_custom_op_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", +) +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +cc_library( + name = "all_ops", + deps = [ + ":cross_replica_ops_op_lib", + ":infeed_ops_op_lib", + ":outfeed_ops_op_lib", + ":replication_ops_op_lib", + ":tpu_configuration_ops_op_lib", + ":tpu_sendrecv_ops_op_lib", + ], +) + +py_library( + name = "tpu_estimator", + srcs = [ + "python/tpu/tpu_config.py", + "python/tpu/tpu_estimator.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":tpu", + ":tpu_py", + ":training_loop", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + ], +) + +tf_gen_op_libs( + op_lib_names = [ + "cross_replica_ops", + "infeed_ops", + "outfeed_ops", + "replication_ops", + "tpu_configuration_ops", + "tpu_sendrecv_ops", + ], + deps = [ + "//tensorflow/core:lib", + ], +) + +tf_custom_op_library( + name = "python/ops/_tpu_ops.so", + srcs = [ + "ops/cross_replica_ops.cc", + "ops/infeed_ops.cc", + "ops/outfeed_ops.cc", + "ops/replication_ops.cc", + "ops/tpu_configuration_ops.cc", + "ops/tpu_sendrecv_ops.cc", + ], +) + +tf_gen_op_wrapper_py( + name = "tpu_ops", + deps = [ + ":cross_replica_ops_op_lib", + ":infeed_ops_op_lib", + ":outfeed_ops_op_lib", + ":replication_ops_op_lib", + ":tpu_configuration_ops_op_lib", + ":tpu_sendrecv_ops_op_lib", + ], +) + +tf_custom_op_py_library( + name = "tpu_py", + srcs = glob(["python/ops/*.py"]) + ["__init__.py"], + dso = [":python/ops/_tpu_ops.so"], + kernels = [ + ":all_ops", + ], + srcs_version = "PY2AND3", + deps = [ + ":tpu_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ], +) + +py_library( + name = "tpu_helper_library", + srcs_version = "PY2AND3", + deps = [ + ":tpu", + ":tpu_feed", + ":tpu_function", + ":tpu_py", + ":tpu_sharding", + ":training_loop", + ], +) + +py_library( + name = "tpu_function", + srcs = ["python/tpu/tpu_function.py"], + srcs_version = "PY2AND3", + deps = [ + ":tpu_feed", + ":tpu_py", + "//tensorflow/python:framework", + ], +) + +py_library( + name = "tpu", + srcs = [ + "python/tpu/__init__.py", + "python/tpu/tpu.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":tpu_py", + ":training_loop", + "//tensorflow/python:framework", + ], +) + +py_library( + name = "tpu_sharding", + srcs = ["python/tpu/tpu_sharding.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework", + ], +) + +py_library( + name = "tpu_feed", + srcs = ["python/tpu/tpu_feed.py"], + srcs_version = "PY2AND3", + deps = [ + ":tpu_py", + ":tpu_sharding", + "//tensorflow/python:framework", + ], +) + +py_library( + name = "training_loop", + srcs = [ + "python/tpu/tpu_optimizer.py", + "python/tpu/training_loop.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":tpu_function", + "//tensorflow/python:framework", + ], +) + +tf_py_test( + name = "tpu_sharding_test", + size = "small", + srcs = ["python/tpu/tpu_sharding_test.py"], + additional_deps = [ + ":tpu_sharding", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + ], +) + +tf_py_test( + name = "tpu_infeed_test", + size = "small", + srcs = ["python/tpu/tpu_infeed_test.py"], + additional_deps = [ + ":tpu_feed", + ":tpu_sharding", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + ], +) + +tf_py_test( + name = "tpu_function_test", + size = "small", + srcs = ["python/tpu/tpu_function_test.py"], + additional_deps = [ + ":tpu_function", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), +) diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py new file mode 100644 index 00000000000..bfd7887c516 --- /dev/null +++ b/tensorflow/contrib/tpu/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Ops related to Tensor Processing Units.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import,unused-import +from tensorflow.contrib.tpu.python.ops.tpu_ops import * +from tensorflow.contrib.tpu.python.tpu import * +# pylint: enable=wildcard-import,unused-import + +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__) diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc new file mode 100644 index 00000000000..cbbd19800eb --- /dev/null +++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("CrossReplicaSum") + .Input("input: T") + .Output("output: T") + .Attr("T: {float}") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +An Op to sum inputs across replicated TPU instances. Each +instance supplies its own input, and the output of each is the sum of +all the inputs. + +input: The local input to the sum. +output: The sum of all the distributed inputs. +T: The type of elements to be summed. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/tpu/ops/infeed_ops.cc b/tensorflow/contrib/tpu/ops/infeed_ops.cc new file mode 100644 index 00000000000..be4d4f96493 --- /dev/null +++ b/tensorflow/contrib/tpu/ops/infeed_ops.cc @@ -0,0 +1,107 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +REGISTER_OP("InfeedDequeue") + .Output("output: dtype") + .Attr("dtype: type") + .Attr("shape: shape") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + TensorShapeProto shape_proto; + shape.AsProto(&shape_proto); + ShapeHandle out; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out)); + c->set_output(0, out); + return Status::OK(); + }) + .Doc(R"doc( +A placeholder op for a value that will be fed into the computation. + +output: A tensor that will be provided using the infeed mechanism. +dtype: The type of elements in the tensor. +shape: The shape of the tensor. +)doc"); + +REGISTER_OP("InfeedEnqueue") + .Input("input: dtype") + .Attr("dtype: type") + .Attr("shape: shape = {}") + .Attr("device_ordinal: int = -1") + .SetIsStateful() + .Doc(R"doc( +An op which feeds a single Tensor value into the computation. + +input: A tensor that will be provided using the infeed mechanism. +dtype: The type of elements in the tensor. +shape: The shape of the tensor. +device_ordinal: The TPU device to use. This should be -1 when the Op +is running on a TPU device, and >= 0 when the Op is running on the CPU +device. +)doc"); + +REGISTER_OP("InfeedEnqueueTuple") + .Input("inputs: dtypes") + .Attr("dtypes: list(type)") + .Attr("shapes: list(shape)") + .Attr("device_ordinal: int = -1") + .SetIsStateful() + .Doc(R"doc( +An op which feeds multiple Tensor values into the computation as an XLA tuple. + +inputs: A list of tensors that will be provided using the infeed mechanism. +dtypes: The element types of each element in `inputs`. +shapes: The shapes of each tensor in `inputs`. +device_ordinal: The TPU device to use. This should be -1 when the Op +is running on a TPU device, and >= 0 when the Op is running on the CPU +device. +)doc"); + +REGISTER_OP("InfeedDequeueTuple") + .Output("outputs: dtypes") + .Attr("dtypes: list(type)") + .Attr("shapes: list(shape)") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + std::vector shapes; + TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes)); + for (int i = 0; i < shapes.size(); ++i) { + TensorShapeProto shape_proto; + shapes[i].AsProto(&shape_proto); + ShapeHandle out; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeProto(shape_proto, &out)); + c->set_output(i, out); + } + return Status::OK(); + }) + .Doc(R"doc( +A placeholder op for multiple values that will be fed into the computation +simultaneously as an XLA tuple. + +outputs: A list of tensors that will be provided using the infeed mechanism. +dtypes: The element types of each element in `outputs`. +shapes: The shapes of each tensor in `outputs`. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/tpu/ops/outfeed_ops.cc b/tensorflow/contrib/tpu/ops/outfeed_ops.cc new file mode 100644 index 00000000000..16c57a1c2b2 --- /dev/null +++ b/tensorflow/contrib/tpu/ops/outfeed_ops.cc @@ -0,0 +1,106 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +REGISTER_OP("OutfeedEnqueue") + .Input("input: dtype") + .Attr("dtype: type") + .SetIsStateful() + .Doc(R"doc( +An op which emits a single Tensor value from an XLA computation. + +input: A tensor that will be inserted into the outfeed queue. +)doc"); + +REGISTER_OP("OutfeedEnqueueTuple") + .Input("inputs: dtypes") + .Attr("dtypes: list(type)") + .SetIsStateful() + .Doc(R"doc( +An op which emits multiple Tensor values from an XLA computation. + +inputs: A list of tensors that will be inserted into the outfeed queue as an +XLA tuple. +)doc"); + +REGISTER_OP("OutfeedDequeue") + .Output("output: dtype") + .Attr("dtype: type") + .Attr("shape: shape") + .Attr("device_ordinal: int = -1") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape)); + ShapeHandle out; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out)); + c->set_output(0, out); + return Status::OK(); + }) + .Doc(R"doc( +Retrieves a single tensor from the computation outfeed. This operation will +block indefinitely until data is available. + +output: A tensor that will be read from the device outfeed. +dtype: The type of elements in the tensor. +shape: The shape of the tensor. +device_ordinal: The TPU device to use. This should be -1 when the Op +is running on a TPU device, and >= 0 when the Op is running on the CPU +device. +)doc"); + +REGISTER_OP("OutfeedDequeueTuple") + .Output("outputs: dtypes") + .Attr("dtypes: list(type)") + .Attr("shapes: list(shape)") + .Attr("device_ordinal: int = -1") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + std::vector shapes; + std::vector dtypes; + TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes)); + TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &dtypes)); + if (shapes.size() != dtypes.size()) { + return errors::InvalidArgument( + "Incorrect number of output shapes specified"); + } + for (int i = 0; i < shapes.size(); ++i) { + ShapeHandle out; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shapes[i], &out)); + c->set_output(i, out); + } + return Status::OK(); + }) + .Doc(R"doc( +Retrieve multiple values that will be emitted by the computation as an XLA +tuple. This operations will block indefinitely until data is available. +Output `i` corresponds to XLA tuple element `i`. + +outputs: A list of tensors that will be read from the outfeed. +dtypes: The element types of each element in `outputs`. +shapes: The shapes of each tensor in `outputs`. +device_ordinal: The TPU device to use. This should be -1 when the Op +is running on a TPU device, and >= 0 when the Op is running on the CPU +device. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc new file mode 100644 index 00000000000..282a00b52c6 --- /dev/null +++ b/tensorflow/contrib/tpu/ops/replication_ops.cc @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +REGISTER_OP("TPUReplicatedInput") + .Input("inputs: N * T") + .Output("output: T") + .Attr("N: int >= 1") + .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle cur = c->input(c->num_inputs() - 1); + for (int i = c->num_inputs() - 2; i >= 0; --i) { + TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur), + "From merging shape ", i, + " with other shapes."); + } + c->set_output(0, cur); + return Status::OK(); + }) + .Doc( + "Operator that connects N unreplicated inputs to an N-way " + "replicated TPU computation."); + +REGISTER_OP("TPUReplicatedOutput") + .Input("input: T") + .Output("outputs: num_replicas * T") + .Attr("num_replicas: int >= 1") + .Attr("T: type") + .SetShapeFn([](InferenceContext* c) { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->input(0)); + } + return Status::OK(); + }) + .Doc( + "Operator that connects the output of an N-way replicated TPU " + "computation to N separate outputs."); + +REGISTER_OP("TPUReplicate") + .Attr("computation: func") + .Attr("num_replicas: int >= 1") + .Attr("global_tpu_id: list(int) = []") + .Attr("Tinputs: list(type) >= 0") + .Attr("Tbroadcast_inputs: list(type) >= 0") + .Attr("NumVariables: int >= 0") + .Attr("output_types: list(type) >= 0") + .Input("inputs: Tinputs") + .Input("broadcast_inputs: Tbroadcast_inputs") + .Input("variables: NumVariables * resource") + .Output("outputs: output_types") + .Doc(R"doc( +Runs replicated computations on a distributed TPU system. + +computation: a function containing the computation to run. +num_replicas: the number of replicas of the computation to run. +global_tpu_id: map from device to global tpu id. +Tinputs: the types of the arguments to 'computation'. +inputs: the inputs to 'computation', flattened, in replica-major order. +Tbroadcast_inputs: the types of the additional arguments to broadcast to all + replicas. +broadcast_inputs: additional arguments to broadcast to all replicas. The + broadcast inputs are appended to the per-replica inputs when calling + computation. +output_types: the types of the outputs of 'computation'. +outputs: the outputs of 'computation'. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc b/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc new file mode 100644 index 00000000000..5dc564ed27a --- /dev/null +++ b/tensorflow/contrib/tpu/ops/tpu_configuration_ops.cc @@ -0,0 +1,213 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +// Configuring a distributed TPU system is achieved by running +// the following Ops: +// +// 1 Run _DisconnectHostFromDistributedTPUSystem on the CPU of each +// host. This is needed in case the system had previously been +// configured. It returns, for each host, the number of TPU chips on +// the host. +// +// 2 Run _ConfigureDistributedTPU on TPU_SYSTEM. Takes as input the +// number of chips on each host. Validates that all hosts have the +// same number of chips, and that the chips are consistent with the +// topology set by flags. Has a single output which is a proto +// describing the requested system configuration, which is sent to all +// hosts. +// +// 3 Run _InitializeHostForDistributedTPU on the CPU of each host, +// taking as input the output from ConfigureDistributedTPU. Has a +// single Tensor output which is a vector of int32 indicating, for +// each TPU on the host, what its global TPU system id is. +// +// 4 Run _WaitForDistributedTPU on TPU_SYSTEM, taking as input the +// outputs from all the _InitializeHostForDistributedTPU +// Ops. _WaitForDistributedTPU has an attr host_specs which is a +// vector giving the partial device spec for each host. These +// partial specs are combined in the Op with the outputs from the host +// initialization Ops to construct a mapping from full TPU device +// specs to global TPU ids. Has a single Tensor output which is a +// matrix of int32 indicating, for each host (outer dimension) and for +// each TPU on the host (inner dimension) what that TPU's global id +// is. _WaitForDistributedTPU also waits for the TPU distributed +// system to initialize fully, which may take several minutes for a +// large system. +// +// 5 Run _SetGlobalTPUArray on the CPU of each host, taking as input +// the output from _WaitForDistributedTPU. This Op tells each host the +// global Id of every TPU on every host. +// +// Most user code works by placing the ConfigureDistributedTPU Op on +// the desired TPU_SYSTEM device, and a graph rewrite replaces it by +// the subgraph described above. +// +// +// A distributed TPU system can be cleanly shut down by running +// the following Ops: +// +// 1 Run _DisconnectHostFromDistributedTPUSystem on the CPU of each +// host. +// +// 2 Run _ShutdownDistributedTPU on the TPU_SYSTEM where +// _ConfigureDistributedTPU was run. The Op will return an error if no +// system is configured. +// +// +// Most user code works by placing the ShutdownDistributedTPU Op on +// the desired TPU_SYSTEM device, and a graph rewrite replaces it by +// the subgraph described above. + +REGISTER_OP("_ConfigureDistributedTPU") + .Input("inputs: N * int32") + .Output("output: string") + .Attr("N: int >= 1") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + ShapeHandle input; + // Validate that all the inputs are scalars. + for (int i = 0; i < c->num_inputs(); ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &input)); + } + c->set_output(0, c->Scalar()); + return Status::OK(); + }) + .Doc(R"doc( +An op that sets up the centralized structures for a distributed TPU +system. + +inputs: A scalar tensor for each host indicating how many TPU chips +there are on the host. +output: A tensor containing a TPUHostConfiguration proto serialized to +a string, containing the information necessary to initialize the chips +in a host. +)doc"); + +REGISTER_OP("_WaitForDistributedTPU") + .Input("inputs: N * int32") + .Output("global_tpu_array: int32") + .Attr("host_specs: list(string)") + .Attr("startup_timeout_sec: int = 20") + .Attr("N: int") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + ShapeHandle input; + // Validate that all the inputs have the same vector shape. + for (int i = 0; i < c->num_inputs(); ++i) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &input)); + } + c->set_output(0, c->UnknownShapeOfRank(2)); + return ::tensorflow::Status::OK(); + }) + .Doc(R"doc( +An op that blocks execution until a distributed TPU system has +started up. This Op must be run on the same TPU_SYSTEM device as +_ConfigureDistributedTPU, and takes an inputs the outputs from the +_InitializeHostForDistributedTPU Ops. + +inputs: For each initialized host, a vector giving the global TPU id +of each TPU on the host. +global_tpu_array: A two-dimensional array. For each host (the outer +dimension) the array lists the global ids of the TPUs on that host. +host_specs: For each initialized host, the partial device specification +indicating job, replica, and task. Combining this spec with +'/device:TPU:k' gives the full device name of the k'th TPU on the +host. +startup_timeout_sec: The number of seconds to wait for the TPU system +to stabilize. +)doc"); + +REGISTER_OP("_SetGlobalTPUArray") + .Input("global_tpu_array: int32") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + ShapeHandle input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input)); + return ::tensorflow::Status::OK(); + }) + .Doc(R"doc( +An op that informs a host of the global ids of all the of TPUs in the +system. + +global_tpu_array: A two-dimensional array. For each host (the outer +dimension) the array lists the global ids of the TPUs on that host. +)doc"); + +REGISTER_OP("_ShutdownDistributedTPU").SetIsStateful().Doc(R"doc( +An op that shuts down a running distributed TPU system. The Op returns +an error if no system is running. This Op must be run on the same +TPU_SYSTEM device as the corresponding _ConfigureDistributedTPU was run +to start the system, and must be run only after +_DisconnectHostFromDistributedTPUSystem has completed on every host in +the system. +)doc"); + +REGISTER_OP("_InitializeHostForDistributedTPU") + .Input("input: string") + .Output("tpu_ids: int32") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + ShapeHandle input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &input)); + c->set_output(0, c->Vector(c->UnknownDim())); + return ::tensorflow::Status::OK(); + }) + .Doc(R"doc( +An op that connects each chip on the host to a centralized UberDriver to allow +them to operate as a distributed system with chips in other hosts. + +input: A string containing the address of the UberDriver to connect to. +tpu_ids: A vector containing the global TPU id of each TPU on the host. +)doc"); + +REGISTER_OP("_DisconnectHostFromDistributedTPUSystem") + .Output("number_of_tpu_chips: int32") + .SetIsStateful() + .Doc(R"doc( +An op that disconnects the TPUs on a host from a running distributed +TPU system. + +number_of_tpu_chips: A scalar tensor containing the number of TPU +chips on the host. +)doc"); + +REGISTER_OP("ConfigureDistributedTPU") + .Output("global_tpu_array: int32") + .Attr("embedding_config: string = ''") + .SetIsStateful() + .Doc(R"doc( +An op that sets up the centralized structures for a distributed TPU +system. + +global_tpu_array: A two-dimensional array. For each host (the outer +dimension) the array lists the global ids of the TPUs on that host. +embedding_config: Internal use. +)doc"); + +REGISTER_OP("ShutdownDistributedTPU").SetIsStateful().Doc(R"doc( +An op that shuts down a running distributed TPU system. The Op returns +an error if no system is running. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/tpu/ops/tpu_sendrecv_ops.cc b/tensorflow/contrib/tpu/ops/tpu_sendrecv_ops.cc new file mode 100644 index 00000000000..6d7c11a315a --- /dev/null +++ b/tensorflow/contrib/tpu/ops/tpu_sendrecv_ops.cc @@ -0,0 +1,46 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("_TPUSend") + .Input("tensor: T") + .Attr("T: type") + .Attr("tensor_name: string") + .SetIsStateful() + .Doc(R"doc( +Sends the named tensor over the TPU fabric. + +tensor: The tensor to send. +tensor_name: The name of the tensor to send. +)doc"); + +REGISTER_OP("_TPURecv") + .Output("tensor: T") + .Attr("T: type") + .Attr("tensor_name: string") + .Attr("shape: shape") + .SetIsStateful() + .Doc(R"doc( +Receives the named tensor over the TPU fabric. + +tensor: The tensor to receive. +tensor_name: The name of the tensor to receive. +shape: The shape of the input tensor. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py new file mode 100644 index 00000000000..8d3344fac36 --- /dev/null +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -0,0 +1,38 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Operations for TPUs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import platform + + +if platform.system() != "Windows": + # pylint: disable=wildcard-import,unused-import,g-import-not-at-top + from tensorflow.contrib.tpu.ops.gen_tpu_ops import * + + from tensorflow.contrib.util import loader + from tensorflow.python.platform import resource_loader + # pylint: enable=wildcard-import,unused-import,g-import-not-at-top + + _tpu_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("_tpu_ops.so")) +else: + # We have already built the appropriate libraries into the binary via CMake + # if we have built contrib, so we don't need this + pass diff --git a/tensorflow/tensorboard/defs/defs.bzl b/tensorflow/contrib/tpu/python/tpu/__init__.py similarity index 62% rename from tensorflow/tensorboard/defs/defs.bzl rename to tensorflow/contrib/tpu/python/tpu/__init__.py index 94e2d7c540f..0dffd7064b1 100644 --- a/tensorflow/tensorboard/defs/defs.bzl +++ b/tensorflow/contrib/tpu/python/tpu/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,14 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# ============================================================================= -def tensorboard_webcomponent_library(**kwargs): - """Rules referencing this will be deleted from the codebase soon.""" - pass +"""Ops related to Tensor Processing Units.""" -def _legacy_js_impl(target, ctx): - return struct() - -legacy_js = aspect( - implementation=_legacy_js_impl, - attr_aspects=["exports"]) +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py new file mode 100644 index 00000000000..157b0fc1ac9 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -0,0 +1,583 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== + +"""Library of TPU helper functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import tpu_function + +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import variable_scope + + +def initialize_system(embedding_config=None, job=None): + """Initializes a distributed TPU system for use with TensorFlow. + + Args: + embedding_config: If not None, an EmbeddingLayerConfiguration proto + describing the desired configuration of the hardware embedding lookup + tables. If embedding_config is None, no hardware embeddings can be used. + job: The job (the XXX in TensorFlow device specification /job:XXX) + that contains the TPU devices that will be initialized. If job=None + it is assumed there is only one job in the TensorFlow flock, and an + error will be returned if this assumption does not hold. + Returns: + Op which, when executed, will initialize the system. + """ + if job is None: + device_name = "/replica:0/task:0/device:TPU_SYSTEM:0" + else: + device_name = "/job:%s/replica:0/task:0/device:TPU_SYSTEM:0" % job + config_string = ("" if embedding_config is None else + embedding_config.SerializeToString()) + with ops.device(device_name): + init_distributed_tpu = tpu_ops.configure_distributed_tpu( + embedding_config=config_string) + return init_distributed_tpu + + +def shutdown_system(job=None): + """Shuts down a running a distributed TPU system.""" + if job is None: + device_name = "/replica:0/task:0/device:TPU_SYSTEM:0" + else: + device_name = "/job:%s/replica:0/task:0/device:TPU_SYSTEM:0" % job + with ops.device(device_name): + shutdown_distributed_tpu = tpu_ops.shutdown_distributed_tpu() + return shutdown_distributed_tpu + + +def core(num): + """Returns the device name for a core in a replicated TPU computation. + + Args: + num: the virtual core number within each replica to which operators should + be assigned. + Returns: + A device name, suitable for passing to tf.device(). + """ + return "device:TPU_REPLICATED_CORE:{}".format(num) + + +# Experimental API to 'break out' of a tpu.rewrite() (or shard(), etc.) context. +# In +# +# XXX +# with tpu.rewrite(...): +# YYY +# with tpu.outside_all_rewrites(): +# ZZZ +# +# the Ops in ZZZ are added outside the scope of the rewrite(). +# TODO(phawkins): currently outside_all_rewrites() pops out of all nested +# control flow scopes, for example loops. It would make more sense if it only +# popped out of a single scope. +@contextlib.contextmanager +def outside_all_rewrites(): + """Experimental API to 'break out' of a tpu.rewrite() (or shard(), etc.).""" + with ops.control_dependencies(None): + yield + + +class TPUReplicateContext(control_flow_ops.ControlFlowContext): + """A ControlFlowContext for nodes inside a TPU computation. + + The primary role of TPUReplicateContext is to mark operators inside a + tpu.replicate() computation with attributes: + * _tpu_replicate=XYZ, where XYZ is a unique name, and + * _tpu_num_replicas=k, where k is the number of replicas. + + We use a ControlFlowContext to perform the annotation since it + integrates with Tensorflow constructs like ResourceVariables. For example, + if a ResourceVariable is constructed inside a tpu.replicate() block, the + ResourceVariable implementation can use "with ops.control_dependencies(None)" + to build the variable's definition outside the replicated computation. + """ + + def __init__(self, name, num_replicas, global_tpu_id=None): + control_flow_ops.ControlFlowContext.__init__(self) + self._name = name + self._num_replicas = num_replicas + self._global_tpu_id = [] if global_tpu_id is None else global_tpu_id + + def AddOp(self, op): + self._AddOpInternal(op) + + def _AddOpInternal(self, op): + # pylint: disable=protected-access + if any(x.dtype._is_ref_dtype for x in op.inputs): + raise NotImplementedError( + "Non-resource Variables are not supported inside TPU computations " + "(operator name: %s)" % op.name) + # pylint: enable=protected-access + if "_tpu_replicate" in op.node_def.attr: + raise ValueError("TPU computations cannot be nested") + op.node_def.attr["_tpu_replicate"].s = self._name + op.node_def.attr["_tpu_num_replicas"].i = self._num_replicas + op.node_def.attr["_tpu_global_id"].list.i.extend(self._global_tpu_id) + op.graph.prevent_feeding(op) + op.graph.prevent_fetching(op) + + def AddValue(self, val): + result = val + if self._outer_context: + result = self._outer_context.AddValue(val) + return result + + def AddInnerOp(self, op): + self._AddOpInternal(op) + if self._outer_context: + self._outer_context.AddInnerOp(op) + + +def replicate(computation, + inputs=None, + infeed_queue=None, + global_tpu_id=None, + name=None): + """Builds a graph operator that runs a replicated TPU computation. + + Args: + computation: a Python function that builds the computation to replicate. + inputs: a list of lists of input tensors or None (equivalent to + [[]]), indexed by [replica_num][input_num]. All replicas must + have the same number of inputs. + infeed_queue: if not None, the InfeedQueue from which to append a tuple + of arguments as inputs to computation. + global_tpu_id: if not None, a Numpy 2D array indicating the global + id of each TPU device in the system. The outer dimension of the + array is host task id, and the inner dimension is device ordinal, + so e.g., global_tpu_id[x][y] indicates the global id of device + /task:x/device:TPU_NODE:y. + name: name of the operator. + Returns: + A list of lists of output tensors, indexed by [replica_num][output_num]. + Raises: + ValueError: if all replicas do not have equal numbers of input tensors. + ValueError: if the number of inputs per replica does not match + the number of formal parameters to `computation`. + """ + if name is None: + name = "TPUReplicate" + inputs = [[]] if inputs is None else inputs + + if global_tpu_id is not None: + # Turn the Numpy array into a flattened list. + global_tpu_id = global_tpu_id.flatten().tolist() + + if ((not isinstance(inputs, list)) or + any(not isinstance(inp, (list, tuple)) for inp in inputs)): + raise TypeError("tpu.replicate() inputs must be a list of lists/tuples") + + num_replicas = len(inputs) + + # No replicas? Nothing to do. + if num_replicas == 0: + return [] + + # Converts inputs to Tensors. + inputs = [[ops.convert_to_tensor(x) for x in inp] for inp in inputs] + + # Verifies that all replicas have matching numbers and types of inputs + input_types = [x.dtype for x in inputs[0]] + input_arity = len(input_types) + for i in range(num_replicas): + if len(inputs[i]) != input_arity: + raise ValueError("Replicas must have the same number of inputs. " + "Replica 0 had {} inputs, replica {} had {} " + "inputs.".format(input_arity, i, len(inputs[i]))) + + types = [x.dtype for x in inputs[i]] + if types != input_types: + raise ValueError( + "Replicas must have matching input types. Replica 0 had " + "input types {}, replica {} had input types {}".format( + input_types, i, types)) + + arg_error = tpu_function.check_function_argument_count( + computation, input_arity, infeed_queue) + if arg_error is not None: + if infeed_queue is None: + raise TypeError( + "Supplied computation cannot be called with the specified inputs. " + "You specified %d inputs: %s, but the computation needs %s" % ( + input_arity, str([i.name for i in inputs[0]]), arg_error)) + else: + raise TypeError( + "Supplied computation cannot be called with the specified inputs. " + "You specified %d inputs: %s and %d additional inputs from infeed," + " but the computation needs %s" % (input_arity, str( + [i.name + for i in inputs[0]]), infeed_queue.number_of_tuple_elements, + arg_error)) + + graph = ops.get_default_graph() + + with ops.name_scope(name, "replicate"): + # Fan-in: Builds a TPUReplicatedInput node for each input. + computation_inputs = [] + for i in range(0, input_arity): + replicas = [inputs[replica][i] for replica in xrange(num_replicas)] + computation_inputs.append( + tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) + + context = TPUReplicateContext( + name=graph.unique_name("cluster"), + num_replicas=num_replicas, + global_tpu_id=global_tpu_id) + try: + context.Enter() + + with tpu_function.tpu_shard_context(num_replicas): + + # The EncapsulateTPUComputations rewrite needs to identify the + # replicated arguments inside each computation. Adds identity operators + # tagged with an attribute _tpu_replicated_input to identify the + # replicated inputs. + # pylint: disable=protected-access + with graph._attr_scope({"_tpu_replicated_input": + attr_value_pb2.AttrValue(b=True)}): + computation_inputs = [ + array_ops.identity(x, name="replicated_input_{}".format(i)) + for i, x in enumerate(computation_inputs)] + # pylint: enable=protected-access + + # If there is an infeed queue, adds the dequeued values to the + # computation's inputs. + if infeed_queue is not None: + infeed_queue.set_number_of_shards(num_replicas) + for t in infeed_queue.generate_dequeue_op(): + computation_inputs.append(t) + + # Only resource variables work inside a TPU computation, so turn on + # resource variables for the computation. + # TODO(phawkins): consider removing this code. It will + # be less confusing to clients if they knowingly choose to use resource + # variables. + vscope = variable_scope.get_variable_scope() + saved_use_resource = vscope.use_resource + vscope.set_use_resource(True) + + outputs = computation(*computation_inputs) + + vscope.set_use_resource(saved_use_resource) + + # If the computation only returned one value, makes it a tuple. + if not isinstance(outputs, (list, tuple)): + outputs = (outputs,) + + try: + with ops.device(core(0)): + outputs = [ + o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) + for o in outputs + ] + except Exception as e: + raise ValueError( + "TPU function return values must all either be Operations or " + "convertible to Tensors. Got '%s'" % str(e)) + + # Separates the returned Operations and Tensors. + output_operations = [o for o in outputs if isinstance(o, ops.Operation)] + output_tensors = [o for o in outputs + if not isinstance(o, ops.Operation)] + + if outputs != output_tensors + output_operations: + raise ValueError( + "TPU functions must return zero-or more Tensor values followed by " + "zero or more Operations.") + output_arity = len(output_tensors) + + # Wraps outputs in Identity ops. Otherwise a replicated input copied + # straight to an output would bypass the replicate(). This would be bad + # because the TPUReplicatedInput/TPUReplicatedOutput operator would not + # be rewritten away, leading to a runtime error. + # TODO(phawkins): extend the rewrite to elide these nodes instead. + with ops.device(core(0)): + output_tensors = [array_ops.identity(x) for x in output_tensors] + finally: + context.Exit() + + # Fan-out: Builds a TPUReplicatedOutput node for each output. + outputs = [tpu_ops.tpu_replicated_output(output_tensors[i], num_replicas, + name="output{}".format(i)) + for i in xrange(output_arity)] + + with ops.control_dependencies(output_operations): + if output_arity == 0: + # Returns a list of NoOps dependent on the replication Op, indexed by + # [replica_num]. + return [ + control_flow_ops.no_op(name="%s_shard_%d" % (name, i)) + for i in range(num_replicas) + ] + else: + # Wraps the outputs in identity operators so the names of any possible + # `fetch` nodes are preserved by the replication rewrite. + return [ + [array_ops.identity(outputs[out][replica], + name="output_%d_shard_%d" % (out, replica)) + for out in xrange(output_arity)] + for replica in xrange(num_replicas) + ] + + +def shard(computation, + inputs=None, + num_shards=1, + input_shard_axes=None, + outputs_from_all_shards=True, + output_shard_axes=None, + infeed_queue=None, + global_tpu_id=None, + name=None): + """Shards `computation` for parallel execution. + + `inputs` must be a list of Tensors or None (equivalent to an empty + list), each of which has a corresponding split axis (from + `input_shard_axes`). Each input is split into `num_shards` pieces + along the corresponding axis, and computation is applied to each + shard in parallel. + + Tensors are broadcast to all shards if they are lexically captured by + `computation`. e.g., + + x = tf.constant(7) + def computation(): + return x + 3 + ... = shard(computation, ...) + + TODO(phawkins): consider adding support for broadcasting Tensors passed + as inputs. + + If `outputs_from_all_shards` is true, the outputs from all shards of + `computation` are concatenated back together along their `output_shards_axes`. + Otherwise, each output is taken from an arbitrary shard. + + Inputs and outputs of the computation must be at least rank-1 Tensors. + + Args: + computation: a Python function that builds a computation to apply to each + shard of the input. + inputs: a list of input tensors or None (equivalent to an empty + list). Each input tensor has a corresponding shard axes, given + by `input_shard_axes`, which must have size divisible by + `num_shards`. + num_shards: the number of shards. + input_shard_axes: a list of dimensions along which to shard `inputs`, or + `None`. `None` means "shard all inputs along dimension 0". If not `None`, + there must be one dimension per input. + outputs_from_all_shards: boolean or list of boolean. For each output, if + `True`, outputs from all shards are concatenated along the corresponding + `output_shard_axes` entry. Otherwise, each output is taken + from an arbitrary shard. If the argument is a boolean, the argument's + value is used for each output. + output_shard_axes: a list of dimensions along which to concatenate the + outputs of `computation`, or `None`. `None` means "concatenate all outputs + along dimension 0". If not `None`, there must be one dimension per output. + Ignored if `outputs_from_all_shards` is False. + infeed_queue: if not None, the InfeedQueue to use to augment the inputs of + `computation`. + global_tpu_id: if not None, a Numpy 2D array indicating the global + id of each TPU device in the system. The outer dimension of the + array is host task id, and the inner dimension is device ordinal, + so e.g., global_tpu_id[x][y] indicates the global id of device + /task:x/device:TPU_NODE:y. + name: name of the operator. + Returns: + A list of output tensors. + Raises: + ValueError: if num_shards <= 0 + ValueError: if len(input_shard_axes) != len(inputs) + ValueError: if len(output_shard_axes) != len(outputs from `computation`) + """ + + if num_shards <= 0: + raise ValueError("num_shards must be a positive integer.") + + # Converts inputs to Tensors. + inputs = [] if inputs is None else [ops.convert_to_tensor(x) for x in inputs] + + if input_shard_axes is None: + input_shard_axes = [0] * len(inputs) + if len(inputs) != len(input_shard_axes): + raise ValueError("Length of input_shard_axes must be equal to the number " + "of inputs.") + + if inputs: + # Splits the `inputs` along the corresponding `input_shard_axes`, giving + # lists with layout [input][shard] + split_inputs = [ + array_ops.split(x, num_shards, axis=axis) + for (axis, x) in zip(input_shard_axes, inputs)] + + # Transposes the input lists to have layout [shard][input] + transposed_inputs = [list(i) for i in zip(*split_inputs)] + else: + transposed_inputs = [[]] * num_shards + + outputs = replicate( + computation, + transposed_inputs, + infeed_queue=infeed_queue, + global_tpu_id=global_tpu_id, + name=name) + + # There must be at least one shard since num_shards > 0. + # TODO(b/36647078) remove disable when pylint bug is fixed. + # pylint: disable=indexing-exception + if isinstance(outputs[0], ops.Operation): + # pylint: enable=indexing-exception + # There were no outputs from the computation and replicate returned a list + # of NoOps with control dependencies on the computation. Return the first + # one so it can be used as a control dependency or fetch node. + # TODO(b/36647078) remove disable when pylint bug is fixed. + # pylint: disable=indexing-exception + return [outputs[0]] + # pylint: enable=indexing-exception + + # TODO(b/36647078) remove disable when pylint bug is fixed. + # pylint: disable=indexing-exception + num_outputs = len(outputs[0]) + # pylint: enable=indexing-exception + + if output_shard_axes is None: + output_shard_axes = [0] * num_outputs + if num_outputs != len(output_shard_axes): + raise ValueError("Length of output_shard_axes must be equal to the number " + "of outputs.") + + if isinstance(outputs_from_all_shards, bool): + outputs_from_all_shards = [outputs_from_all_shards] * num_outputs + + if num_outputs != len(outputs_from_all_shards): + raise ValueError("Length of outputs_from_all_shards must be equal to the " + "number of outputs.") + + results = [] + for (axis, all_shards, x) in zip(output_shard_axes, outputs_from_all_shards, + zip(*outputs)): + if all_shards: + # Concatenate all of the outputs together. + results.append(array_ops.concat(list(x), axis=axis)) + else: + # TODO(phawkins): use a smarter policy, e.g., round-robin across shards. + results.append(x[0]) + + return results + + +def batch_parallel(computation, + inputs=None, + num_shards=1, + infeed_queue=None, + global_tpu_id=None, + name=None): + """Shards `computation` along the batch dimension for parallel execution. + + Convenience wrapper around shard(). + + `inputs` must be a list of Tensors or None (equivalent to an empty + list). Each input is split into `num_shards` pieces along the 0-th + dimension, and computation is applied to each shard in parallel. + + Tensors are broadcast to all shards if they are lexically captured by + `computation`. e.g., + + x = tf.constant(7) + def computation(): + return x + 3 + ... = shard(computation, ...) + + The outputs from all shards are concatenated back together along their 0-th + dimension. + + Inputs and outputs of the computation must be at least rank-1 Tensors. + + Args: + computation: a Python function that builds a computation to apply to each + shard of the input. + inputs: a list of input tensors or None (equivalent to an empty + list). The 0-th dimension of each Tensor must have size + divisible by `num_shards`. + num_shards: the number of shards. + infeed_queue: if not None, the InfeedQueue from which to append a tuple + of arguments as inputs to `computation`. + global_tpu_id: if not None, a Numpy 2D array indicating the global + id of each TPU device in the system. The outer dimension of the + array is host task id, and the inner dimension is device ordinal, + so e.g., global_tpu_id[x][y] indicates the global id of device + /task:x/device:TPU_NODE:y. + name: name of the operator. + Returns: + A list of output tensors. + Raises: + ValueError: if num_shards <= 0 + """ + return shard( + computation, + inputs, + num_shards=num_shards, + infeed_queue=infeed_queue, + global_tpu_id=global_tpu_id, + name=name) + + +def rewrite(computation, + inputs=None, + infeed_queue=None, + global_tpu_id=None, + name=None): + """Rewrites `computation` for execution on a TPU system. + + Args: + computation: a Python function that builds a computation to apply + to the input. If the function takes n inputs, 'inputs' should be + a list of n tensors. If the function returns m outputs, rewrite + will return a list of m tensors. + inputs: a list of input tensors or None (equivalent to an empty list). + infeed_queue: if not None, the InfeedQueue from which to append a tuple + of arguments as inputs to `computation`. + global_tpu_id: if not None, a Numpy 2D array indicating the global + id of each TPU device in the system. The outer dimension of the + array is host task id, and the inner dimension is device ordinal, + so e.g., global_tpu_id[x][y] indicates the global id of device + /task:x/device:TPU_NODE:y. + name: name of the operator. + Returns: + A list of output tensors. + """ + if inputs is not None and not isinstance(inputs, (list, tuple)): + raise TypeError("tpu.rewrite() inputs must be a list or tuple") + + # TODO(b/36647078) remove disable when pylint bug is fixed. + # pylint: disable=indexing-exception + return replicate( + computation, + None if inputs is None else [inputs], + infeed_queue=infeed_queue, + global_tpu_id=global_tpu_id, + name=name)[0] + # pylint: enable=indexing-exception diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py new file mode 100644 index 00000000000..a19d1db8312 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -0,0 +1,47 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== + +"""A RunConfig subclass with TPU support.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.contrib.learn.python.learn.estimators import run_config as run_config_lib + + +class TpuConfig(collections.namedtuple( + 'TpuConfig', ['iterations_per_loop', 'num_shards'])): + """TPU related configuration required by `TPUEstimator`.""" + + def __new__(cls, iterations_per_loop=2, num_shards=2): + return super(TpuConfig, cls).__new__( + cls, + iterations_per_loop=iterations_per_loop, + num_shards=num_shards) + + +class RunConfig(run_config_lib.RunConfig): + """RunConfig with TPU support.""" + + def __init__(self, tpu_config=None, **kwargs): + super(RunConfig, self).__init__(**kwargs) + self._tpu_config = tpu_config or TpuConfig() + + @property + def tpu_config(self): + return self._tpu_config diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py new file mode 100644 index 00000000000..b702ab91f6a --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -0,0 +1,361 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== + +"""Tpu Estimator class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading +from six.moves import queue as Queue # pylint: disable=redefined-builtin + +from tensorflow.contrib.tpu.python.tpu import tpu +from tensorflow.contrib.tpu.python.tpu import tpu_config +from tensorflow.contrib.tpu.python.tpu import tpu_feed +from tensorflow.contrib.tpu.python.tpu import training_loop + +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training + + +def _tpu_job(run_config): + # The tpu job is determined by the run_config. Right now, this method is + # required as tpu_config is not part of the RunConfig. + return None if run_config.master in ['', 'local'] else 'tpu_worker' + + +class _SIGNAL(object): + """Signal used to control the input thread of infeed.""" + NEXT_BATCH = 1 + STOP = 2 + + +class InfeedThreadController(object): + """This wraps the infeed thread and stops when Estimator train finishes. + + For model_fn wrapper, it is not possible to know when the `train` API will + stop. It could be the cases that the `max_steps` is reached or some hook + requests the stop in the monitored_session. + + This controller (with coordination with `TpuInfeedSessionHook`) does the + following: + + 1) It pre-infeeds one `batch` data for current TPU iterations. + + 2) When `before_run` of `TpuInfeedSessionHook` is called, one more `batch` + data will be infed. + + 3) When `end` of `TpuInfeedSessionHook` is called, the thread will end + gracefully. + + So, we might need to adjust the algorithrm here if the IO is slower than the + computation. + """ + + def __init__(self, session, enqueue_ops, iterations): + self._signal_queue = Queue.Queue() + self._input_thd = threading.Thread(target=self._input_thread_fn_for_loading, + args=(session, enqueue_ops, iterations)) + self._input_thd.daemon = True + self._input_thd.start() + + def _input_thread_fn_for_loading(self, session, enqueue_ops, iterations): + count = 0 + while True: + signal = self._signal_queue.get() + if signal == _SIGNAL.STOP: + logging.info('Stop Infeed input thread.') + return + + for i in range(iterations): + logging.debug('InfeedEnqueue data for iteration (%d, %d)', count, i) + session.run(enqueue_ops) + count += 1 + + def load_next_batch(self): + self._signal_queue.put(_SIGNAL.NEXT_BATCH) + + def join(self): + logging.info('Waiting for InputThread to exit.') + self._signal_queue.put(_SIGNAL.STOP) + self._input_thd.join() + + +class TpuInfeedSessionHook(session_run_hook.SessionRunHook): + """A Session hook setting up the TPU initialization and infeed. + + This hook does two major things: + 1. initialize and shutdown TPU system (maybe a separated hook) + 2. launch and join the input thread for infeed. + """ + + def __init__(self, run_config, enqueue_fn): + self._iterations = run_config.tpu_config.iterations_per_loop + self._enqueue_fn = enqueue_fn + self._tpu_job = _tpu_job(run_config) + + def begin(self): + self._enqueue_ops = self._enqueue_fn() + logging.info('TPU job name %s', self._tpu_job) + self._init_op = [tpu.initialize_system(job=self._tpu_job)] + self._finalize_op = [tpu.shutdown_system(job=self._tpu_job)] + + def after_create_session(self, session, coord): + logging.info('Init TPU system') + session.run(self._init_op) + + logging.info('Start infeed input thread controller') + self._infeed_thd_controller = InfeedThreadController( + session, self._enqueue_ops, self._iterations) + + def before_run(self, run_context): + logging.info('Load next batch of data to infeed.') + self._infeed_thd_controller.load_next_batch() + + def end(self, session): + logging.info('Stop infeed input thread controller') + self._infeed_thd_controller.join() + + logging.info('Shutdown TPU system.') + session.run(self._finalize_op) + + +class TpuEstimator(estimator_lib.Estimator): + """Estimator with TPU support. + + The only difference is a wrapped model_fn is set in the constructor. + """ + + def __init__(self, + model_fn=None, + model_dir=None, + config=None, + params=None, + use_tpu=True): + if use_tpu: + model_function = wrapped_model_fn(model_fn, config) + else: + model_function = model_fn + + super(TpuEstimator, self).__init__( + model_fn=model_function, + model_dir=model_dir, + config=config, + params=params) + if not isinstance(config, tpu_config.RunConfig): + raise ValueError('`config` must be `tpu_config.RunConfig`') + + def _create_global_step(self, graph): + """Creates a global step suitable for TPUs. + + Args: + graph: The graph in which to create the global step. + + Returns: + A global step `Tensor`. + + Raises: + ValueError: if the global step tensor is already defined. + """ + graph = graph or ops.get_default_graph() + if training.get_global_step(graph) is not None: + raise ValueError('"global_step" already exists.') + # Create in proper graph and base name_scope. + with graph.as_default() as g, g.name_scope(None): + return variable_scope.get_variable( + ops.GraphKeys.GLOBAL_STEP, + shape=[], + dtype=dtypes.int32, + initializer=init_ops.zeros_initializer(), + trainable=False, + use_resource=True, + collections=[ops.GraphKeys.GLOBAL_VARIABLES, + ops.GraphKeys.GLOBAL_STEP]) + + +# TODO(xiejw): Improve the structure of this input_fn to infeed converion. +# The code now looks not like Estimator style. We need to abstract many +# details. +def _create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels): + """Utility to convert input_fn to enqueue and dequeue fns for TPU. + + Mainly, three things need to be done here. + 1. Calls the input_fn many times (`num_shards`) to infeed the data into TPU + 2. Create a dequeue_fn used by the train_step inside TPU execution to + dequeue the tensors. + 3. Sets up the input thread to infeed. + + Args: + run_config: run_config + features: features + labels: labels + + Returns: + A tuple of (dequeue_fn, and thread main function) + """ + infeed_names = None + infeed_tuple = [] + if isinstance(features, dict): + # We need a fixed ordering for enqueueing and dequeueing. + infeed_names = [name for name in features] + infeed_tuple.extend([features[name] for name in infeed_names]) + else: + infeed_tuple.append(features) + # TODO(jhseu): Handle multi-head and None labels + infeed_tuple.append(labels) + # TODO(jhseu): Update when b/36470756 is settled. + infeed_queue = tpu_feed.InfeedQueue( + tuple_types=[t.dtype for t in infeed_tuple], + tuple_shapes=[t.shape for t in infeed_tuple]) + infeed_queue.set_number_of_shards(run_config.tpu_config.num_shards) + + def dequeue_fn(): + """dequeue_fn is used by the train_step in TPU to retrieve the tensors.""" + values = infeed_queue.generate_dequeue_op() + if infeed_names is None: + return values + # Restore the feature dictionary and label. + dequeued_features = {} + for i in range(len(values) - 1): + dequeued_features[infeed_names[i]] = values[i] + label = values[-1] + return dequeued_features, label + + def enqueue_fn(): + """enqueue_fn is used to add ops to the graph to send tensors.""" + job = _tpu_job(run_config) + def placement_function(index): + if job is None: + return '/replica:0/task:0/device:CPU:0' + else: + return '/job:%s/replica:0/task:%d/device:CPU:0' % (job, index / 8) + return infeed_queue.split_inputs_and_generate_enqueue_ops( + infeed_tuple, placement_function=placement_function) + + return (dequeue_fn, enqueue_fn) + + +def wrapped_model_fn(model_fn, run_config): + """Returns a new model_fn, which wraps the TPU support.""" + + # Verifies the model_fn signature according to Estimator framework. + estimator_lib._verify_model_fn_args(model_fn, params=None) # pylint: disable=protected-access + + def _model_fn(features, labels, mode): + """model_fn.""" + # TODO(jhseu): Move to EVAL and PREDICT to TPU. + if mode != model_fn_lib.ModeKeys.TRAIN: + return model_fn(features, labels, mode) + + dequeue_fn, enqueue_fn = ( + _create_infeed_enqueue_ops_and_dequeue_fn(run_config, features, labels)) + + loss = _train_on_tpu_shards( + run_config, + train_step=_convert_model_fn_to_train_step( + model_fn, dequeue_fn, mode, run_config)) + + # Gets the variables back from TPU nodes. This means the variables updated + # by TPU will now be *synced* to host memory. + update_ops = [ + array_ops.check_numerics(v.read_value(), + 'Gradient for %s is NaN' % v.name).op + for v in variables.trainable_variables() + ] + + hooks = [ + TpuInfeedSessionHook(run_config, enqueue_fn), + training.LoggingTensorHook( + {'loss': array_ops.identity(loss), + 'step': training.get_global_step()}, + every_n_secs=30) + ] + + return model_fn_lib.EstimatorSpec( + mode, + loss=array_ops.identity(loss), + training_hooks=hooks, + train_op=control_flow_ops.group(*update_ops)) + return _model_fn + + +def _convert_model_fn_to_train_step(model_fn, dequeue_fn, mode, run_config): + """generates a train step based on the model_fn.""" + + def _call_model_fn(features, labels): + """Calls the model_fn with required parameters.""" + model_fn_args = estimator_lib._model_fn_args(model_fn) # pylint: disable=protected-access + kwargs = {} + if 'mode' in model_fn_args: + kwargs['mode'] = mode + # Uncomment the following lines once `params` is supported. + # if 'params' in model_fn_args: + # kwargs['params'] = params + if 'config' in model_fn_args: + kwargs['config'] = run_config + return model_fn(features=features, labels=labels, **kwargs) + + def _verify_estimator_spec(estimator_spec): + """Validates the estimator_spec.""" + err_msg = '{} returned by EstimatorSpec is not supported in TPUEstimator.' + if estimator_spec.training_chief_hooks: + raise ValueError(err_msg.format('training_chief_hooks')) + if estimator_spec.training_hooks: + raise ValueError(err_msg.format('training_hooks')) + return estimator_spec + + def train_step(loss): + """Training step function for use inside a while loop.""" + del loss # unused; required in function signature. + features, labels = dequeue_fn() + + # TODO(xiejw): how to do we support hook and savers in the original + # model_fn. Realistically, the original + # model_fn will be excuted on TPU chips in a replica way. The hooks + # returned by the model_fn cannot be supported at all. If we have to, + # the graph construction part in the model_fn should be separated from the + # control part (such as hooks and savers). By that the graph construction + # could de defered on TPU chip, while the control logic can stay in host. + estimator_spec = _verify_estimator_spec(_call_model_fn(features, labels)) + loss, train_op = estimator_spec.loss, estimator_spec.train_op + with ops.control_dependencies([train_op]): + return array_ops.identity(loss) + return train_step + + +def _train_on_tpu_shards(run_config, train_step): + """Executes the `train_step` on all shards.""" + def train_shard(): + return training_loop.repeat(run_config.tpu_config.iterations_per_loop, + train_step, + [1e7], # initial_loss + name='loop') + + (loss,) = tpu.shard(train_shard, + inputs=[], + num_shards=run_config.tpu_config.num_shards, + outputs_from_all_shards=False) + return loss diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_feed.py b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py new file mode 100644 index 00000000000..668b5b8b911 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tpu_feed.py @@ -0,0 +1,613 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== + +"""Helper library for handling infeed between hosts and TPUs. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import tpu_sharding + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops + + +class InfeedQueue(object): + """A helper object to build a device infeed queue. + + The InfeedQueue builds the host-side and device-side Ops to enqueue and + dequeue elements, respectively, and ensures that their types and + shapes match. + """ + + def __init__(self, + number_of_tuple_elements=None, + tuple_types=None, + tuple_shapes=None, + shard_dimensions=None, + name=None): + """Creates a new InfeedQueue with the given configuration. + + The configuration need not be fully specified at creation since it + can be modified subsequently by methods that set the values + explicitly or infer them from the shapes of inputs. + + Args: + number_of_tuple_elements: the number of Tensors fed atomically through the + queue, must be present unless it can be inferred from other arguments. + tuple_types: if not None, a list of types of the elements of the queue. + tuple_shapes: if not None, a list of shapes of the elements of the queue. + shard_dimensions: if not None, a list of dimensions on which the + elements of the queue should be sharded during automatic + parallelization. + name: the name of the queue. + + Raises: + ValueError: if number_of_tuple_elements <= 0; or + number_of_tuple_arguments, tuple_types, tuple_shapes, and + shard_dimensions are all None; or the length of tuple_types, + tuple_shapes, or shard_dimensions is not equal to + number_of_tuple_elements; or any element of shard_dimensions + can't be converted to a Dimension. + TypeError: if any element of tuple_types or tuple_shapes can't + be converted to a dtype or TensorShape, respectively. + """ + self._frozen = False + self._generated_enqueue_ops = False + self._generated_dequeue_op = False + self._name = "InfeedQueue" if name is None else name + if number_of_tuple_elements is None: + if tuple_types is not None: + number_of_tuple_elements = len(tuple_types) + elif tuple_shapes is not None: + number_of_tuple_elements = len(tuple_shapes) + elif shard_dimensions is not None: + number_of_tuple_elements = len(shard_dimensions) + else: + raise ValueError( + "number of tuple elements cannot be inferred from InfeedQueue " + "constructor" + ) + if number_of_tuple_elements <= 0: + raise ValueError("number_of_tuple_elements %d must be > 0" % + number_of_tuple_elements) + # Make an empty sharding policy for each tuple element. + self._sharding_policies = [ + tpu_sharding.ShardingPolicy() + for _ in xrange(number_of_tuple_elements) + ] + if tuple_types is not None: + self.set_tuple_types(tuple_types) + else: + self._tuple_types = None + if tuple_shapes is not None: + self.set_tuple_shapes(tuple_shapes) + else: + self._tuple_shapes = None + if shard_dimensions is not None: + self.set_shard_dimensions(shard_dimensions) + self._validate() + + def _validate(self): + """Checks that the configuration is self-consistent. + + Raises: + ValueError: if the shapes and sharding policies don't match. + """ + if self.tuple_shapes is not None: + for (policy, shape) in zip(self._sharding_policies, self._tuple_shapes): + # Raise an error if the policy is incompatible with the shape. + _ = policy.get_sharded_shape(shape) + + @property + def number_of_tuple_elements(self): + """Returns the number of InfeedQueue tuple elements.""" + return len(self._sharding_policies) + + @property + def tuple_types(self): + """Returns the types of the InfeedQueue tuple elements.""" + return self._tuple_types + + def set_tuple_types(self, tuple_types): + """Sets the type of each element of the queue. + + tuple_types must be a list of length + self.number_of_tuple_elements, and each element must be + convertible to a dtype. + + Args: + tuple_types: the types of each queue element. + + Raises: + ValueError: if tuple_types is not of length + self.number_of_tuple_elements. + TypeError: if an element of tuple_types cannot be converted to a + dtype. + """ + if len(tuple_types) != self.number_of_tuple_elements: + raise ValueError("tuple_types is %s, but must be a list of length %d" % + (str(tuple_types), self.number_of_tuple_elements)) + if self._frozen: + for (frozen, updated) in zip(self._tuple_types, tuple_types): + if frozen != updated: + raise ValueError( + "Trying to update InfeedQueue with frozen configuration with an " + "incompatible type. Frozen types are %s, updated types are %s" % ( + str(self._tuple_types), str(tuple_types))) + else: + try: + self._tuple_types = [dtypes.as_dtype(t) for t in tuple_types] + except (TypeError) as e: + raise TypeError( + "tuple_types is %s, but must be a list of elements each " + "convertible to dtype: got error %s" % (str(tuple_types), str(e))) + + @property + def tuple_shapes(self): + """Returns the shapes of the InfeedQueue tuple elements.""" + return self._tuple_shapes + + def set_tuple_shapes(self, tuple_shapes): + """Sets the shape of each element of the queue. + + tuple_shapes must be a list of length + self.number_of_tuple_elements, and each element must be + convertible to a TensorShape. + + Args: + tuple_shapes: the shapes of each queue element. + + Raises: + ValueError: if tuple_shapes is not of length + self.number_of_tuple_elements. + TypeError: if an element of tuple_shapes cannot be converted to + a TensorShape. + """ + if len(tuple_shapes) != self.number_of_tuple_elements: + raise ValueError("tuple_shapes is %s, but must be a list of length %d" % + (str(tuple_shapes), self.number_of_tuple_elements)) + try: + tuple_shapes = [tensor_shape.as_shape(shape) for shape in tuple_shapes] + except (ValueError, TypeError) as e: + raise TypeError( + "tuple_shapes is %s, but must be a list of elements each " + "convertible to TensorShape: got error %s" % (str(tuple_shapes), + str(e))) + if self._frozen: + for (frozen, updated) in zip(self._tuple_shapes, tuple_shapes): + if frozen != updated: + raise ValueError( + "Trying to update InfeedQueue with frozen configuration with an " + "incompatible shape. Frozen shapes are %s, updated shapes are %s" + % (str(self._tuple_shapes), str(tuple_shapes))) + else: + self._tuple_shapes = tuple_shapes + self._validate() + + @property + def sharding_policies(self): + """Returns the sharding policies of the InfeedQueue tuple elements.""" + return self._sharding_policies + + @property + def shard_dimensions(self): + """Gets the shard dimension of each tuple element. + + Returns: + A list of length number_of_tuple_elements, where each list entry + is the shard dimension of that tuple element or None if the + shard dimension has not been set. + """ + # The number of shards is always the same for all the policies. + return [policy.shard_dimension for policy in self._sharding_policies] + + def set_shard_dimensions(self, shard_dimensions): + """Sets the shard_dimension of each element of the queue. + + shard_dimensions must be a list of length + self.number_of_tuple_elements, and each element must be + convertible to a Dimension compatible with self.tuple_shapes. + + Args: + shard_dimensions: the dimensions of each queue element. + + Raises: + ValueError: if shard_dimensions is not of length + self.number_of_tuple_elements; or an element of + shard_dimensions cannot be converted to a Dimension; or an + element of shard_dimensions is a Dimension that is out of + range for the corresponding tuple element shape. + """ + if len(shard_dimensions) != self.number_of_tuple_elements: + raise ValueError("shard_dimensions is %s, but must be a list of length %d" + % (str(shard_dimensions), + self.number_of_tuple_elements)) + for (policy, dimension) in zip(self._sharding_policies, shard_dimensions): + policy.set_shard_dimension(dimension) + self._validate() + + @property + def number_of_shards(self): + """Gets the number of shards to use for the InfeedQueue. + + Returns: + Number of shards or None if the number of shards has not been set. + """ + # The number of shards is always the same for all the policies. + return self._sharding_policies[0].number_of_shards + + def set_number_of_shards(self, number_of_shards): + """Sets the number of shards to use for the InfeedQueue. + + Args: + number_of_shards: number of ways to shard the InfeedQueue. + + Raises: + ValueError: if number_of_shards is not > 0; or the policies have + been frozen and number_of_shards was already set to something + else. + """ + for policy in self._sharding_policies: + policy.set_number_of_shards(number_of_shards) + self._validate() + + def set_configuration_from_input_tensors(self, input_tensors): + """Sets the shapes and types of the queue tuple elements. + + input_tensors is a list of Tensors whose types and shapes are used + to set the queue configuration. + + Args: + input_tensors: list of Tensors of the same types and shapes as + the desired queue Tuple. + + Raises: + ValueError: if input_tensors is not a list of length + self.number_of_tuple_elements + """ + if len(input_tensors) != self.number_of_tuple_elements: + raise ValueError( + "input_tensors is %s, but should be a list of %d Tensors", ( + str(input_tensors), self.number_of_tuple_elements)) + self.set_tuple_shapes([t.shape for t in input_tensors]) + self.set_tuple_types([t.dtype for t in input_tensors]) + + def set_configuration_from_sharded_input_tensors(self, input_tensors): + """Sets the shapes and types of the queue tuple elements. + + input_tensors is a list of lists of Tensors whose types and shapes are used + to set the queue configuration. The length of the outer list is the number + of shards required, and each inner list is the tuple of Tensors to use to + determine the types and shapes of the correponding shard. This method + depends on the shard dimension, and calling it freezes the shard policy. + + Args: + input_tensors: list of lists of Tensors. The outer list length corresponds + to the desired number of shards, and each inner list is the size + and shape of the desired configuration of the corresponding shard. + + Raises: + ValueError: if any inner list is not a list of length + self.number_of_tuple_elements; or the inner lists do not combine to + form a consistent unsharded shape. + TypeError: if the types of the Tensors in the inner lists do not match. + """ + if not self._frozen: + # Unset the tuple shapes in case the configuration becomes + # transiently inconsistent. + self._tuple_shapes = None + number_of_shards = len(input_tensors) + self.set_number_of_shards(number_of_shards) + for t in input_tensors: + if len(t) != self.number_of_tuple_elements: + raise ValueError( + "input_tensors is %s but must be a list of lists, where each inner" + " list has length number_of_tuple_elements=%d" % ( + str(input_tensors), self.number_of_tuple_elements)) + # Transpose the inputs to make a list of shard shapes for each tuple + # element. + sharded_shapes = [[t[i].shape for t in input_tensors] + for i in xrange(self.number_of_tuple_elements)] + # For each tuple, get the unsharded shape using that tuple's policy. + unsharded_shapes = [ + policy.get_unsharded_shape(s) + for (policy, s) in zip(self._sharding_policies, sharded_shapes) + ] + self.set_tuple_shapes(unsharded_shapes) + for i in xrange(1, self.number_of_shards): + for (t1, t2) in zip(input_tensors[0], input_tensors[i]): + if t1.dtype != t2.dtype: + raise TypeError( + "types of the tuple elements of input_tensors %s are not " + "consistent" % str(input_tensors)) + self.set_tuple_types([t.dtype for t in input_tensors[0]]) + + def freeze(self): + """Freezes the InfeedQueue so it can no longer be modified. + + The configuration is implicitly frozen before any host-side or + device-side Ops are generated. The configuration cannot be frozen + until the types and shapes of the tuple elements have been set. + + Raises: + ValueError: if the types or shapes of the tuple elements have not been + set. + """ + self._frozen = True + if self._tuple_types is None: + raise ValueError( + "Can't freeze an InfeedQueue without setting all tuple types.") + if self._tuple_shapes is None: + raise ValueError( + "Can't freeze an InfeedQueue without setting all tuple shapes.") + for shape in self._tuple_shapes: + if shape.dims is None: + raise ValueError( + "Can't freeze an InfeedQueue without setting all tuple shapes.") + for policy in self._sharding_policies: + policy.freeze() + self._validate() + + def generate_dequeue_op(self): + """Generates the device-side Op to dequeue a tuple from the queue. + + Implicitly freezes the queue configuration if it is not already + frozen, which will raise errors if the shapes and types have not + been fully specified. + + Returns: + A list of Outputs corresponding to a shard of infeed dequeued + into XLA, suitable for use within a replicated block. + + Raises: + ValueError: if the types or shapes of the tuple elements have not been + set; or if a dequeue op has already been generated. + """ + self.freeze() + if self._generated_dequeue_op: + raise ValueError("Can't generate two dequeue Ops from the same queue") + self._generated_dequeue_op = True + full_name = "%s/dequeue" % self._name + sharded_shapes = [ + policy.get_sharded_shape(shape) + for (shape, policy) in zip(self._tuple_shapes, self._sharding_policies) + ] + return tpu_ops.infeed_dequeue_tuple( + dtypes=self._tuple_types, shapes=sharded_shapes, name=full_name) + + def _generate_enqueue_op(self, + inputs, + name_prefix, + index, + device=None, + tpu_ordinal=-1): + """Generate a host-side Op to enqueue a tuple to the queue. + + If device is None the inputs are all required to have the same + device specification, and the enqueue Op is colocated with + inputs[0]. Otherwise the enqueue Op is placed on 'device'. + + Args: + inputs: a list of Tensors with the types and shapes of the tuple elements. + name_prefix: the base name for the Op. + index: the shard index, used to uniquify the Op name. + device: device to place the Op on, or None if it should be + colocated with the inputs. + tpu_ordinal: ordinal of the TPU device on the host to use for + infeed if device is a CPU device. Should be set to -1 if device + is a TPU device. + + Returns: + An Op corresponding to a shard of infeed enqueued at the host, + suitable for use within a replicated block. + + Raises: + ValueError: if device is None and inputs do not all have the + same device specification. + """ + full_name = "%s/%d" % (name_prefix, index) + shapes = [t.shape for t in inputs] + if device is None: + devices = [t.device for t in inputs] + for i in xrange(1, self.number_of_tuple_elements): + if devices[0] != devices[i]: + raise ValueError( + "input devices for shard %d are %s, but should all be the same", + index, str(devices)) + with ops.colocate_with(inputs[0]): + return tpu_ops.infeed_enqueue_tuple( + inputs=inputs, + shapes=shapes, + name=full_name, + device_ordinal=tpu_ordinal) + else: + with ops.device(device): + return tpu_ops.infeed_enqueue_tuple( + inputs=inputs, + shapes=shapes, + name=full_name, + device_ordinal=tpu_ordinal) + + def generate_enqueue_ops(self, sharded_inputs): + """Generates the host-side Ops to enqueue the shards of a tuple. + + sharded_inputs is a list, one for each shard, of lists of + Tensors. sharded_inputs[0] is the tuple of Tensors to use to feed + shard 0 if the queue. Returns the host-side Ops that must be run to + enqueue the sharded tuple. The Op for shard i is colocated with the inputs + for shard i. + + Implicitly freezes the queue configuration if it is not already + frozen. If the configuration has already been frozen, and is not + compatible with the types and shapes of sharded_inputs, an error + will be raised. + + Args: + sharded_inputs: a list of lists of Tensors. The length of the outer list + determines the number of shards. Each inner list indicates the types + and shapes of the tuples in the corresponding shard. + + Returns: + A list of host-side Ops, one for each shard, that when executed together + will enqueue a full-size element of infeed. + + Raises: + ValueError: if the queue configuration has previously been frozen and the + shapes of the elements of sharded_inputs are not compatible with the + frozen configuration; or if the shapes of the elements of sharded_inputs + don't form a consistent unsharded tuple; or if the elements of a tuple + have different device constraints. + TypeError: if the queue configuration has previously been frozen and the + types of the elements of sharded_inputs are not compatible with the + frozen configuration; or if the types of the elements of sharded_inputs + don't form a consistent unsharded tuple. + """ + self.set_configuration_from_sharded_input_tensors(sharded_inputs) + self.freeze() + if self._generated_enqueue_ops: + raise ValueError("Can't generate two enqueue Ops from the same queue") + self._generated_enqueue_ops = True + name_prefix = "%s/enqueue" % self._name + return [ + self._generate_enqueue_op(shard, name_prefix, index) + for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) + ] + + # TODO(misard) Generalize this to the case of systems that don't + # have 8 devices per host, and figure out what to do with + # model-parallelism. + def _default_placement_function(self, index): + return "/task:%d/device:CPU:0" % (index / 8) + + def _default_ordinal_function(self, index): + return index % 8 + + # TODO(b/36470756) remove this from tutorials once we have a better story + # for automatic placement of input pipelines. + def split_inputs_and_generate_enqueue_ops(self, + inputs, + global_tpu_id=None, + placement_function=None, + tpu_ordinal_function=None): + """POORLY-PERFORMING ON MULTI-HOST SYSTEMS. + + Generates the host-side Ops to enqueue a tuple. + + This method performs poorly because it takes an entire input on a single + host, splits it, and distributes it to all of the cores. It is present only + to simplify tutorial examples. + + inputs is a list of Tensors to use to feed the queue. Each input is split + into self.number_of_shards shards. Returns an Op for each shard to enqueue + the shard. The Op for shard i is placed on device placement_function(i). + + Implicitly freezes the queue configuration if it is not already + frozen. If the configuration has already been frozen, and is not + compatible with the types and shapes of inputs, an error + will be raised. + + Args: + inputs: a list of Tensors which indicates the types and shapes of the + queue tuple. + global_tpu_id: if not None, a Numpy 2D array indicating the global + id of each TPU device in the system. The outer dimension of the + array is host task id, and the inner dimension is device ordinal, + so e.g., global_tpu_id[x][y] indicates the global id of device + /task:x/device:TPU_NODE:y. If global_tpu_id is not None, but + placement_function and ordinal_function are None, then global_tpu_id + will be used to place infeed on the TPUs with the first k global ids, + where k is the number of shards in the queue. + placement_function: if not None, a function that takes the shard + index as input and returns a device string indicating which + device the shard's infeed should be placed on. If placement_function + and tpu_ordinal_function are None, inputs are sharded round-robin + across the devices in the system. + tpu_ordinal_function: if not None, a function that takes the + shard index as input and returns the ordinal of the TPU device + the shard's infeed should be placed on. If placement_function + and tpu_ordinal_function are None, inputs are sharded round-robin + across the devices in the system. + + Returns: + A list of host-side Ops, one for each shard, that when executed together + will enqueue a full-size element of infeed. + + Raises: + ValueError: if the queue configuration has previously been frozen and the + shapes of the elements of inputs are not compatible with the frozen + configuration. + TypeError: if the queue configuration has previously been frozen and the + types of the elements of inputs are not compatible with the frozen + configuration. + """ + if global_tpu_id is None: + if placement_function is None: + placement_function = self._default_placement_function + if tpu_ordinal_function is None: + tpu_ordinal_function = self._default_ordinal_function + else: + global_id_map = {} + for host, devices in enumerate(global_tpu_id): + for ordinal, global_id in enumerate(devices): + global_id_map[global_id] = (host, ordinal) + + def _placement_function_from_map(index): + return "/task:%d/device:CPU:0" % global_id_map[index][0] + + def _ordinal_function_from_map(index): + return global_id_map[index][1] + + if placement_function is None: + placement_function = _placement_function_from_map + if tpu_ordinal_function is None: + tpu_ordinal_function = _ordinal_function_from_map + self.set_configuration_from_input_tensors(inputs) + self.freeze() + if self._generated_enqueue_ops: + raise ValueError("Can't generate two enqueue Ops from the same queue") + self._generated_enqueue_ops = True + split_name_prefix = "%s/split" % self._name + if self.number_of_shards == 1: + transposed_sharded_inputs = [[inp] for inp in inputs] + else: + transposed_sharded_inputs = [ + array_ops.split( + inp, + self.number_of_shards, + axis=policy.shard_dimension, + name="%s/%d" % (split_name_prefix, index)) + for (inp, policy, index) in zip(inputs, self._sharding_policies, + xrange(self.number_of_tuple_elements)) + ] + sharded_inputs = [[shard[i] for shard in transposed_sharded_inputs] + for i in xrange(self.number_of_shards)] + name_prefix = "%s/enqueue" % self._name + return [ + self._generate_enqueue_op( + shard, + name_prefix, + index, + device=placement_function(index), + tpu_ordinal=tpu_ordinal_function(index)) + for (shard, index) in zip(sharded_inputs, xrange(self.number_of_shards)) + ] diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_function.py b/tensorflow/contrib/tpu/python/tpu/tpu_function.py new file mode 100644 index 00000000000..de16e3b1572 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tpu_function.py @@ -0,0 +1,106 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Helper library for functions used during TPU compilation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib + +from tensorflow.python.util import tf_inspect + + +class TpuContext(object): + """A context object holding state about the TPU computation being built.""" + + def __init__(self): + """Creates a new TpuContext.""" + self._number_of_shards = None + + @property + def number_of_shards(self): + return self._number_of_shards + + def set_number_of_shards(self, number_of_shards): + self._number_of_shards = number_of_shards + + +# The Tpu context holds the number of shards when a sharded computation is +# being built, or None if no computation is being built. +_current_tpu_context = TpuContext() + + +@contextlib.contextmanager +def tpu_shard_context(number_of_shards): + if _current_tpu_context.number_of_shards is not None: + raise NotImplementedError("tpu_shard_context cannot be nested.") + try: + _current_tpu_context.set_number_of_shards(number_of_shards) + yield + finally: + _current_tpu_context.set_number_of_shards(None) + + +def get_tpu_context(): + return _current_tpu_context + + +def check_function_argument_count(func, input_arity, infeed_queue): + """Validate the number of input arguments to a tpu function. + + Args: + func: the Python function that will be called to generate the body + of a TPUFunction. + input_arity: the number of explicit arguments supplied by the + caller. + infeed_queue: if not None, the infeed queue that will supply + additional arguments to the function. + + Returns: + None if function can be called with the supplied number of + arguments, or an error string if it cannot. + """ + def format_error(complaint, quantity): + return "%s %d argument%s" % (complaint, quantity, "" + if quantity == 1 else "s") + + number_of_arguments_needed = input_arity + if infeed_queue is not None: + number_of_arguments_needed += infeed_queue.number_of_tuple_elements + arg_spec = tf_inspect.getargspec(func) + number_of_args = len(arg_spec.args) + if arg_spec.defaults is None: + number_of_defaults = 0 + else: + number_of_defaults = len(arg_spec.defaults) + min_required_arguments = number_of_args - number_of_defaults + if number_of_arguments_needed < min_required_arguments: + # The required number of arguments is not enough to call the function. + if number_of_defaults == 0 and arg_spec.varargs is None: + return format_error("exactly", number_of_args) + else: + return format_error("at least", min_required_arguments) + if arg_spec.varargs is None and number_of_arguments_needed > number_of_args: + # The required number of arguments is too many to call the function. + if number_of_defaults == 0: + return format_error("exactly", number_of_args) + else: + return format_error("at most", number_of_args) + # Since there are varargs, func can accept any number of arguments + # greater than the minimum. + return None + diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_function_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_function_test.py new file mode 100644 index 00000000000..463c249a95c --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tpu_function_test.py @@ -0,0 +1,125 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Tests for tpu_function helpers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tpu.python.tpu import tpu_feed +from tensorflow.contrib.tpu.python.tpu import tpu_function + +from tensorflow.python.platform import test + + +class FunctionArgCheckTest(test.TestCase): + + def testSimple(self): + """Tests that arg checker works for functions with no varargs or defaults. + """ + + def func(x, y, z): + return x + y + z + + self.assertEqual(None, + tpu_function.check_function_argument_count(func, 3, None)) + self.assertEqual("exactly 3 arguments", + tpu_function.check_function_argument_count(func, 2, None)) + queue = tpu_feed.InfeedQueue(2) + self.assertEqual(None, + tpu_function.check_function_argument_count(func, 1, queue)) + self.assertEqual("exactly 3 arguments", + tpu_function.check_function_argument_count(func, 2, queue)) + + def testDefaultArgs(self): + """Tests that arg checker works for a function with no varargs.""" + + def func(x, y, z=17): + return x + y + z + + self.assertEqual(None, + tpu_function.check_function_argument_count(func, 3, None)) + self.assertEqual(None, + tpu_function.check_function_argument_count(func, 2, None)) + self.assertEqual("at least 2 arguments", + tpu_function.check_function_argument_count(func, 1, None)) + self.assertEqual("at most 3 arguments", + tpu_function.check_function_argument_count(func, 4, None)) + queue = tpu_feed.InfeedQueue(1) + self.assertEqual(None, + tpu_function.check_function_argument_count(func, 2, queue)) + self.assertEqual(None, + tpu_function.check_function_argument_count(func, 1, queue)) + self.assertEqual("at least 2 arguments", + tpu_function.check_function_argument_count(func, 0, queue)) + self.assertEqual("at most 3 arguments", + tpu_function.check_function_argument_count(func, 4, queue)) + + def testVarArgs(self): + """Tests that arg checker works for a function with varargs.""" + + def func(x, y, *z): + return x + y + len(z) + + self.assertEqual(None, + tpu_function.check_function_argument_count(func, 2, None)) + self.assertEqual(None, + tpu_function.check_function_argument_count(func, 3, None)) + self.assertEqual(None, + tpu_function.check_function_argument_count(func, 4, None)) + self.assertEqual("at least 2 arguments", + tpu_function.check_function_argument_count(func, 1, None)) + queue = tpu_feed.InfeedQueue(1) + self.assertEqual(None, + tpu_function.check_function_argument_count(func, 1, queue)) + self.assertEqual(None, + tpu_function.check_function_argument_count(func, 2, queue)) + self.assertEqual(None, + tpu_function.check_function_argument_count(func, 3, queue)) + self.assertEqual("at least 2 arguments", + tpu_function.check_function_argument_count(func, 0, queue)) + + def testVarArgsAndDefaults(self): + """Tests that arg checker works for a function with varargs and defaults.""" + + def func(x, y, z=17, *q): + return x + y + z + len(q) + + self.assertEqual(None, + tpu_function.check_function_argument_count(func, 2, None)) + self.assertEqual(None, + tpu_function.check_function_argument_count(func, 3, None)) + self.assertEqual(None, + tpu_function.check_function_argument_count(func, 4, None)) + self.assertEqual(None, + tpu_function.check_function_argument_count(func, 5, None)) + self.assertEqual("at least 2 arguments", + tpu_function.check_function_argument_count(func, 1, None)) + queue = tpu_feed.InfeedQueue(1) + self.assertEqual(None, + tpu_function.check_function_argument_count(func, 1, queue)) + self.assertEqual(None, + tpu_function.check_function_argument_count(func, 2, queue)) + self.assertEqual(None, + tpu_function.check_function_argument_count(func, 3, queue)) + self.assertEqual(None, + tpu_function.check_function_argument_count(func, 4, queue)) + self.assertEqual("at least 2 arguments", + tpu_function.check_function_argument_count(func, 0, queue)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_infeed_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_infeed_test.py new file mode 100644 index 00000000000..a41ff60d0af --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tpu_infeed_test.py @@ -0,0 +1,130 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Tests for TPU InfeedQueue methods.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tpu.python.tpu import tpu_feed + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.platform import test + + +class InfeedTest(test.TestCase): + + def testConstructor(self): + """Tests that the constructor can be called with different arguments.""" + i = tpu_feed.InfeedQueue(number_of_tuple_elements=2) + self.assertEqual(i.number_of_tuple_elements, 2) + self.assertEqual(i.tuple_types, None) + self.assertEqual(i.tuple_shapes, None) + self.assertEqual(i.number_of_shards, None) + i = tpu_feed.InfeedQueue( + tuple_types=[dtypes.float32, dtypes.int32, dtypes.int32]) + self.assertEqual(i.number_of_tuple_elements, 3) + self.assertEqual(i.tuple_types, + [dtypes.float32, dtypes.int32, dtypes.int32]) + self.assertEqual(i.tuple_shapes, None) + self.assertEqual(i.number_of_shards, None) + i = tpu_feed.InfeedQueue(tuple_shapes=[[1], [2, 3]]) + self.assertEqual(i.number_of_tuple_elements, 2) + self.assertEqual(i.tuple_types, None) + self.assertEqual(i.tuple_shapes, [[1], [2, 3]]) + self.assertEqual(i.number_of_shards, None) + i = tpu_feed.InfeedQueue(shard_dimensions=[1, 0, 7]) + self.assertEqual(i.number_of_tuple_elements, 3) + self.assertEqual(i.tuple_types, None) + self.assertEqual(i.tuple_shapes, None) + self.assertEqual([p.shard_dimension + for p in i.sharding_policies], [1, 0, 7]) + with self.assertRaises(ValueError): + i = tpu_feed.InfeedQueue() + with self.assertRaises(ValueError): + i = tpu_feed.InfeedQueue( + number_of_tuple_elements=2, tuple_types=[dtypes.float32]) + with self.assertRaises(ValueError): + i = tpu_feed.InfeedQueue(number_of_tuple_elements=2, tuple_shapes=[[1]]) + with self.assertRaises(ValueError): + i = tpu_feed.InfeedQueue(number_of_tuple_elements=2, shard_dimensions=[1]) + with self.assertRaises(ValueError): + i = tpu_feed.InfeedQueue(tuple_shapes=[[1], [2, 3]], shard_dimensions=[1]) + + def testModification(self): + """Tests modification of the queue post-construction.""" + i = tpu_feed.InfeedQueue(number_of_tuple_elements=2) + i.set_tuple_types([dtypes.float32, dtypes.int32]) + self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32]) + i.set_tuple_types([dtypes.float32, dtypes.float32]) + self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.float32]) + with self.assertRaises(ValueError): + i.set_tuple_types([dtypes.float32]) + i.set_tuple_shapes([[1], [2, 3]]) + self.assertEqual(i.tuple_shapes, [[1], [2, 3]]) + i.set_tuple_shapes([[1, 2], [3, 4]]) + self.assertEqual(i.tuple_shapes, [[1, 2], [3, 4]]) + with self.assertRaises(ValueError): + i.set_tuple_shapes([[1, 2]]) + i.set_number_of_shards(2) + self.assertEqual(i.number_of_shards, 2) + i.set_number_of_shards(3) + self.assertEqual(i.number_of_shards, 3) + t1 = constant_op.constant(1, dtypes.int32, shape=[6]) + t2 = constant_op.constant(2.0, dtypes.float32, shape=[3, 18]) + i.set_configuration_from_input_tensors([t1, t2]) + self.assertEqual(i.tuple_shapes, [[6], [3, 18]]) + self.assertEqual(i.tuple_types, [dtypes.int32, dtypes.float32]) + i.set_configuration_from_sharded_input_tensors([[t2, t1], [t2, t1]]) + self.assertEqual(i.number_of_shards, 2) + self.assertEqual(i.tuple_shapes, [[6, 18], [12]]) + self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32]) + i.set_shard_dimensions([1, 0]) + i.set_number_of_shards(3) + with self.assertRaises(ValueError): + i.set_number_of_shards(4) + + def testFreezing(self): + """Tests freezing the queue.""" + i = tpu_feed.InfeedQueue(number_of_tuple_elements=2) + t1 = constant_op.constant(1, dtypes.int32, shape=[2]) + t2 = constant_op.constant(2.0, dtypes.float32, shape=[2, 4]) + i.set_configuration_from_sharded_input_tensors([[t2, t1], [t2, t1]]) + self.assertEqual(i.number_of_shards, 2) + self.assertEqual(i.tuple_shapes, [[4, 4], [4]]) + self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32]) + self.assertEqual(i.shard_dimensions, [0, 0]) + i.freeze() + i.set_number_of_shards(2) + i.set_tuple_shapes([[4, 4], [4]]) + i.set_tuple_types([dtypes.float32, dtypes.int32]) + i.set_shard_dimensions([0, 0]) + with self.assertRaises(ValueError): + i.set_number_of_shards(1) + with self.assertRaises(ValueError): + i.set_tuple_shapes([[8, 8], [8]]) + with self.assertRaises(ValueError): + i.set_tuple_types([dtypes.int32, dtypes.float32]) + with self.assertRaises(ValueError): + i.set_shard_dimensions([1, 0]) + self.assertEqual(i.number_of_shards, 2) + self.assertEqual(i.tuple_shapes, [[4, 4], [4]]) + self.assertEqual(i.tuple_types, [dtypes.float32, dtypes.int32]) + self.assertEqual(i.shard_dimensions, [0, 0]) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py new file mode 100644 index 00000000000..9d12a364c33 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tpu_optimizer.py @@ -0,0 +1,106 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Optimizer that implements cross-shard gradient reduction for TPU.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.python.training import optimizer + + +class CrossShardOptimizer(optimizer.Optimizer): + """A optimizer sums gradients across TPU shards.""" + + def __init__(self, opt, name="CrossShardOptimizer"): + super(CrossShardOptimizer, self).__init__(False, name) + self._opt = opt + + def compute_gradients(self, *args, **kwargs): + """Compute gradients of "loss" for the variables in "var_list". + + This simply wraps the compute_gradients() from the real optimizer. The + gradients will be aggregated in the apply_gradients() so that user can + modify the gradients like clipping with per replica global norm if needed. + The global norm with aggregated gradients can be bad as one replica's huge + gradients can hurt the gradients from other replicas. + + Args: + *args: Arguments for compute_gradients(). + **kwargs: Keyword arguments for compute_gradients(). + + Returns: + A list of (gradient, variable) pairs. + """ + return self._opt.compute_gradients(*args, **kwargs) + + def apply_gradients(self, grads_and_vars, global_step=None, name=None): + """Apply gradients to variables. + + Calls tpu_ops.cross_replica_sum() to sum gradient contributions across + replicas, and then applies the real optimizer. + + Args: + grads_and_vars: List of (gradient, variable) pairs as returned by + compute_gradients(). + global_step: Optional Variable to increment by one after the + variables have been updated. + name: Optional name for the returned operation. Default to the + name passed to the Optimizer constructor. + + Returns: + An `Operation` that applies the gradients. If `global_step` was not None, + that operation also increments `global_step`. + + Raises: + ValueError: If the grads_and_vars is malformed. + """ + summed_grads_and_vars = [] + for (grad, var) in grads_and_vars: + if grad is None: + summed_grads_and_vars.append((grad, var)) + else: + summed_grads_and_vars.append((tpu_ops.cross_replica_sum(grad), var)) + return self._opt.apply_gradients(summed_grads_and_vars, global_step, name) + + def get_slot(self, *args, **kwargs): + """Return a slot named "name" created for "var" by the Optimizer. + + This simply wraps the get_slot() from the actual optimizer. + + Args: + *args: Arguments for get_slot(). + **kwargs: Keyword arguments for get_slot(). + + Returns: + The `Variable` for the slot if it was created, `None` otherwise. + """ + return self._opt.get_slot(*args, **kwargs) + + def get_slot_names(self, *args, **kwargs): + """Return a list of the names of slots created by the `Optimizer`. + + This simply wraps the get_slot_names() from the actual optimizer. + + Args: + *args: Arguments for get_slot(). + **kwargs: Keyword arguments for get_slot(). + + Returns: + A list of strings. + """ + return self._opt.get_slot_names(*args, **kwargs) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py b/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py new file mode 100644 index 00000000000..d545a94ca6a --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py @@ -0,0 +1,248 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Helper library for sharding during TPU compilation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.python.framework import tensor_shape + +_DEFAULT_NUMBER_OF_SHARDS = 1 +_DEFAULT_SHARD_DIMENSION = 0 + + +# TODO(b/36777903) change other parts of tpu.py to use this class. +class ShardingPolicy(object): + """An object use to hold the sharding policy for a Tensor. + """ + + def __init__(self): + self._number_of_shards = None + self._shard_dimension = None + self._frozen = False + + def __str__(self): + if self.number_of_shards is None or self.shard_dimension is None: + return "ShardingPolicy(unset)" + else: + return ("ShardingPolicy(%d shards dimension %d)" % + (self.number_of_shards, self.shard_dimension)) + + def _fill_default_values(self): + if self._number_of_shards is None: + self._number_of_shards = _DEFAULT_NUMBER_OF_SHARDS + if self._shard_dimension is None: + self._shard_dimension = tensor_shape.as_dimension( + _DEFAULT_SHARD_DIMENSION) + + def freeze(self): + """Prevents further modification to the sharding policy. + + Any values that have not been set when freeze is called are set to + defaults. If the ShardingPolicy is already frozen, this is a NoOp. + """ + if not self._frozen: + self._fill_default_values() + self._frozen = True + + @property + def number_of_shards(self): + """Returns the number of shards in the policy or None if unspecified.""" + return self._number_of_shards + + def set_number_of_shards(self, number_of_shards): + """Sets the number of shards for the current policy. + + If the policy has been frozen then number_of_shards must match the + existing setting. + + Args: + number_of_shards: The number of shards to use in the policy. + + Raises: + ValueError: If the policy has been frozen and number_of_shards + differs from the frozen value; or number_of_shards <= 0. + """ + if self._frozen: + if self._number_of_shards != number_of_shards: + raise ValueError( + "Can't set sharding policy to use %d shards since it has been " + "frozen to use %d." % (number_of_shards, self._number_of_shards)) + else: + if number_of_shards > 0: + self._number_of_shards = number_of_shards + else: + raise ValueError( + "Can't set sharding policy to use %s shards; value must be >0", + str(number_of_shards)) + + @property + def shard_dimension(self): + """Returns the shard dimension of the policy or None if unspecified.""" + return self._shard_dimension + + def set_shard_dimension(self, shard_dimension): + """Sets the shard dimension for the current policy. + + If the policy has been frozen then shard_dimension must match the + existing setting. + + Args: + shard_dimension: The shard dimension to use in the policy. + + Raises: + ValueError: If the policy has been frozen and shard_dimension + differs from the frozen value, or shard_dimension can't be + interpreted as a Dimension. + """ + if self._frozen: + if self._shard_dimension != shard_dimension: + raise ValueError( + "Can't set shard dimension to %d since it has been frozen to " + "use %d." % (shard_dimension, self._shard_dimension)) + else: + self._shard_dimension = tensor_shape.as_dimension(shard_dimension) + + def merge(self, other): + """Merges the policy of another policy into the current policy. + + Args: + other: The policy to merge into this one. + + Raises: + ValueError: If this policy has been frozen and the merge conflicts with + the frozen policy. + """ + if other.number_of_shards is not None: + self.set_number_of_shards(other.number_of_shards) + if other.shard_dimension is not None: + self.set_shard_dimension(other.shard_dimension) + + def get_sharded_shape(self, shape, shard_index=None): + """Returns the shape of a shard of a full Tensor. + + When given the shape of a 'full-size' Tensor, returns the shape of + the sub-Tensor after it has been sharded. Freezes the policy if it + has not yet been frozen. + + Args: + shape: The shape of the full-size Tensor to be sharded. + shard_index: The index of the shard whose shape should be returned. + shard_index can be None for sharding policies that use the same + shape for every shard. + freeze_config: + + Returns: + The shape of the sharded version of the Tensor. + + Raises: + ValueError: If shard_index is None when shards are of different + shapes; or shard_index is not None and + !(0<=shard_index= self.number_of_shards: + raise ValueError("shard_index %d, but must be in [0,%d)." % + (shard_index, self._number_of_shards)) + shape = tensor_shape.as_shape(shape) + if self._number_of_shards == 1: + # Don't do anything when there's only one shard. + return shape + ndims = shape.ndims + if ndims is None: + raise ValueError("shape must be a specified shape not Unknown") + if ndims <= self._shard_dimension: + raise ValueError("shape %s does not contain shard_dimension %d" % + (shape.as_list(), self._shard_dimension)) + dims = shape.as_list() + if (dims[self._shard_dimension] % self._number_of_shards) != 0: + raise ValueError("shape %s cannot be sharded %d ways along dimension %d" % + (shape.as_list(), self._number_of_shards, + self._shard_dimension)) + dims[self._shard_dimension] /= self._number_of_shards + return tensor_shape.as_shape(dims) + + def _unshard_shape(self, shape): + """Return the unsharded shape that would generate a given sharded shape. + + Args: + shape: the sharded shape to unshard + + Returns: + The unsharded shape. + + Raises: + ValueError: if shape is unknown or does not contain + self.shard_dimension + TypeError: if shape is not convertible to a TensorShape + """ + shape = tensor_shape.as_shape(shape) + if self._number_of_shards == 1: + # Don't do anything when there's only one shard. + return shape + ndims = shape.ndims + if ndims is None: + raise ValueError("shape must be a specified shape not Unknown") + if ndims <= self._shard_dimension: + raise ValueError("shape %s does not contain shard_dimension %d" % + (shape.as_list(), self._shard_dimension)) + dims = shape.as_list() + dims[self._shard_dimension] *= self._number_of_shards + return tensor_shape.as_shape(dims) + + def get_unsharded_shape(self, shapes): + """Returns the shape of an unsharded Tensor given a list of shards. + + When given a list of shapes of shards, returns the shape of the + unsharded Tensor that would generate the shards. Sets defaults for the + policy if number_of_shards or shard_dimension is None. + + Args: + shapes: The shapes of the Tensor shards to be combined. + + Returns: + The shape of the unsharded version of the Tensor. + + Raises: + ValueError: if shapes is not a list of length + self.number_of_shards; or any element of shapes is not a valid + shape consistent with the sharding policy; or the list of + shapes is not a valid sharding of a full shape. + TypeError: if an element of shapes is not convertible to a + TensorShape + """ + self._fill_default_values() + if len(shapes) != self.number_of_shards: + raise ValueError( + "shapes is %s but must be a list of length number_of_shards=%d" % ( + str(shapes), self.number_of_shards)) + unsharded_shapes = [self._unshard_shape(s) for s in shapes] + for i in xrange(self.number_of_shards - 1): + if unsharded_shapes[i] != unsharded_shapes[self.number_of_shards - 1]: + raise ValueError( + "sharded shapes %s are not consistent shards of a full shape " + "sharded %d ways along dimension %d" % ( + str(shapes), self.number_of_shards, self.shard_dimension)) + return unsharded_shapes[0] diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_sharding_test.py b/tensorflow/contrib/tpu/python/tpu/tpu_sharding_test.py new file mode 100644 index 00000000000..b0a5511d2d7 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/tpu_sharding_test.py @@ -0,0 +1,138 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Tests for tpu_function helpers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tpu.python.tpu import tpu_sharding + +from tensorflow.python.framework import tensor_shape +from tensorflow.python.platform import test + + +class ShardingTest(test.TestCase): + + def testFreeze(self): + """Tests that freezing a policy applies default values.""" + p1 = tpu_sharding.ShardingPolicy() + p1.freeze() + self.assertEqual(p1.number_of_shards, + tpu_sharding._DEFAULT_NUMBER_OF_SHARDS) + self.assertEqual(p1.shard_dimension, tpu_sharding._DEFAULT_SHARD_DIMENSION) + p2 = tpu_sharding.ShardingPolicy() + p2.set_number_of_shards(17) + p2.set_shard_dimension(23) + p2.freeze() + self.assertEqual(p2.number_of_shards, 17) + self.assertEqual(p2.shard_dimension, 23) + + def testFrozen(self): + """Tests that frozen policies can't be changed.""" + p1 = tpu_sharding.ShardingPolicy() + p1.freeze() + with self.assertRaises(ValueError): + p1.set_number_of_shards(17) + with self.assertRaises(ValueError): + p1.set_shard_dimension(22) + + def testStr(self): + """Tests the string representation.""" + p1 = tpu_sharding.ShardingPolicy() + self.assertEqual(str(p1), "ShardingPolicy(unset)") + p1.set_number_of_shards(17) + self.assertEqual(str(p1), "ShardingPolicy(unset)") + p1.set_shard_dimension(8) + self.assertEqual(str(p1), "ShardingPolicy(17 shards dimension 8)") + + def testMerge(self): + """Tests that merging works.""" + p1 = tpu_sharding.ShardingPolicy() + p1.set_number_of_shards(17) + p1.set_shard_dimension(23) + p2 = tpu_sharding.ShardingPolicy() + p2.merge(p1) + self.assertEqual(p2.number_of_shards, 17) + self.assertEqual(p2.shard_dimension, 23) + p1 = tpu_sharding.ShardingPolicy() + p1.set_shard_dimension(12) + p2.merge(p1) + self.assertEqual(p2.number_of_shards, 17) + self.assertEqual(p2.shard_dimension, 12) + p2.freeze() + p2.merge(p1) + self.assertEqual(p2.number_of_shards, 17) + self.assertEqual(p2.shard_dimension, 12) + p1.set_number_of_shards(1) + with self.assertRaises(ValueError): + p2.merge(p1) + p1 = tpu_sharding.ShardingPolicy() + p1.set_number_of_shards(17) + p2.merge(p1) + p1.set_shard_dimension(2) + with self.assertRaises(ValueError): + p2.merge(p1) + + def testGetShardedShape(self): + """Tests getting a sharded shape.""" + p = tpu_sharding.ShardingPolicy() + p.set_number_of_shards(3) + p.set_shard_dimension(1) + self.assertEqual(p.get_sharded_shape([4, 9]), [4, 3]) + p.freeze() + with self.assertRaises(ValueError): + p.set_shard_dimension(0) + with self.assertRaises(ValueError): + _ = p.get_sharded_shape([4, 9], shard_index=4) + with self.assertRaises(ValueError): + _ = p.get_sharded_shape([4, 9], shard_index=-1) + with self.assertRaises(TypeError): + _ = p.get_sharded_shape("not_a_shape") + with self.assertRaises(ValueError): + _ = p.get_sharded_shape(tensor_shape.TensorShape(None)) + with self.assertRaises(ValueError): + _ = p.get_sharded_shape([4, 10], shard_index=-1) + + def testGetUnshardedShape(self): + """Tests getting an unsharded shape.""" + p = tpu_sharding.ShardingPolicy() + p.set_number_of_shards(2) + p.set_shard_dimension(1) + self.assertEqual(p.get_unsharded_shape([[4, 3], [4, 3]]), [4, 6]) + with self.assertRaises(ValueError): + _ = p.get_unsharded_shape([[4, 3]]) + with self.assertRaises(ValueError): + _ = p.get_unsharded_shape([[4, 3], [4, 3], [4, 3]]) + with self.assertRaises(ValueError): + _ = p.get_unsharded_shape([[4, 3], [4, 2]]) + with self.assertRaises(TypeError): + _ = p.get_unsharded_shape([[4, 3], "not_a_shape"]) + with self.assertRaises(ValueError): + _ = p.get_unsharded_shape([None, [4, 3]]) + with self.assertRaises(ValueError): + _ = p.get_unsharded_shape([[2], [4, 3]]) + + def testScalar(self): + """Tests sharding and unsharding scalars.""" + p = tpu_sharding.ShardingPolicy() + p.freeze() + self.assertEqual(p.get_sharded_shape([]), []) + self.assertEqual(p.get_unsharded_shape([[]]), []) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/tpu/python/tpu/training_loop.py b/tensorflow/contrib/tpu/python/tpu/training_loop.py new file mode 100644 index 00000000000..3d7896127a9 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/training_loop.py @@ -0,0 +1,213 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= + +"""Library for constructing a training loop, suitable for TPUs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tpu.python.tpu import tpu_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops + + +def while_loop(condition, body, inputs=None, infeed_queue=None, name=None): + """Builds a training loop for TPUs. + + The set of loop-carried tensors corresponds to `inputs`. Both + `condition` and `body` take the current value of the loop-carried + tensors. 'body' additionally takes a tuple of infeed from + infeed_queue if infeed_queue is not None. `condition` must return a + single boolean value that determines whether iteration + continues. `body` must return an updated list of values for the + loop-carried tensors. + + Args: + condition: a Python function that builds the loop condition. + body: a Python function that builds the loop body. + inputs: a list of initial values passed into the training loop, or + None (equivalent to an empty list). + infeed_queue: if not None, the infeed queue from which to append a tuple + of arguments as inputs to condition. + name: an optional name for the loop. + + Returns: + The final values of the loop-carried tensors. + + Raises: + TypeError: if body or condition has the wrong signature. + """ + + # Converts inputs to Tensors. + inputs = [] if inputs is None else [ops.convert_to_tensor(x) for + x in inputs] + input_types = [x.dtype for x in inputs] + input_arity = len(inputs) + + body_arg_error = tpu_function.check_function_argument_count( + body, input_arity, infeed_queue) + if body_arg_error is not None: + if infeed_queue is None: + raise TypeError( + "Supplied loop body function cannot be called with the specified " + "inputs. You specified %d inputs: %s, but the loop body needs %s" % ( + input_arity, str([i.name for i in inputs]), body_arg_error)) + else: + raise TypeError( + "Supplied loop body function cannot be called with the specified " + "inputs. You specified %d inputs: %s and %d additional inputs from " + "infeed, but the computation needs %s" % (input_arity, str( + [i.name for i in inputs]), infeed_queue.number_of_tuple_elements, + body_arg_error)) + condition_arg_error = tpu_function.check_function_argument_count( + condition, input_arity, None) + if condition_arg_error is not None: + if infeed_queue is None: + raise TypeError( + "Supplied loop condition function cannot be called with the " + "specified inputs. You specified %d inputs: %s, but the loop " + "condition needs %s" % (input_arity, str([i.name for i in inputs]), + condition_arg_error)) + else: + raise TypeError( + "Supplied loop condition function cannot be called with the " + "specified inputs. You specified %d inputs: %s, but the loop " + "condition needs %s. Note that infeed is not passed to the loop " + "condition." % (input_arity, str([i.name for i in inputs]), + condition_arg_error)) + + def condition_wrapper(*inputs): + # Discards the dummy output added for arity-0 loops. + if input_arity == 0: + inputs = [] + return condition(*inputs) + + def body_wrapper(*inputs): + """Wrapper around `body` that handles infeed queues and control deps.""" + inputs = list(inputs) + + # Discards the dummy output added for arity-0 loops. + if input_arity == 0: + inputs = [] + + # Runs `body` with the dequeue_ops appended. + if infeed_queue: + number_of_shards = tpu_function.get_tpu_context().number_of_shards + if number_of_shards is None: + raise ValueError("Can't build training loop with infeed when there is " + "no tpu_shard_context. Are you building a loop or " + "graph directly rather than from inside tpu.rewrite, " + "tpu.batch_parallel, tpu.shard, or tpu.replicate?") + infeed_queue.set_number_of_shards(number_of_shards) + dequeue_ops = [d for d in infeed_queue.generate_dequeue_op()] + else: + dequeue_ops = [] + outputs = body(*(inputs + dequeue_ops)) + + # If the computation only returned one value, make it a tuple. + if not isinstance(outputs, (list, tuple)): + outputs = (outputs,) + + outputs = [ + o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) + for o in outputs + ] + + # Separates the returned Operations and Tensors. + output_operations = [o for o in outputs if isinstance(o, ops.Operation)] + output_tensors = [o for o in outputs + if not isinstance(o, ops.Operation)] + + if outputs != output_tensors + output_operations: + raise ValueError( + "TPU training loop body must return zero or more Tensor values " + "followed by zero or more Operations.") + + output_types = [op.dtype for op in output_tensors] + if input_types != output_types: + raise TypeError( + "Mismatch between input types and output types for training loop " + "body: {} vs {}".format(input_types, output_types)) + + # Add the dequeue operations to output_operations to ensure they are run + # by the loop, even if the programmer's loop body does not use them. + output_operations += dequeue_ops + + # Add a dummy output, if needed. + if not output_tensors: + output_tensors = array_ops.constant(0) + + if output_operations: + # TODO(phawkins): in principle this is too restrictive since it serializes + # the training loop steps. In practice it does not matter since this loop + # will be compiled by XLA. + return control_flow_ops.tuple(output_tensors, + control_inputs=output_operations) + else: + return output_tensors + + # If the body has arity 0, add a dummy loop-carried value to which we can add + # control dependencies from any side-effecting operations. + if input_arity == 0: + inputs = [array_ops.constant(0)] + return control_flow_ops.while_loop(condition_wrapper, body_wrapper, inputs, + name=name) + + +def repeat(n, body, inputs=None, infeed_queue=None, name=None): + """Builds a training loop that executes a fixed number of interations. + + The set of loop-carried tensors correspond to `inputs`. + `body` must be a function that takes and returns the values of the + loop-carried tensors. + + Args: + n: the number of loop iterations + body: a Python function that builds the loop body. + inputs: a list of initial values passed into the training loop or + None (equivalent to an empty list). + infeed_queue: if not None, the infeed queue from which to append a tuple + of arguments as inputs to condition. + name: an optional name for the loop. + Returns: + The final values of the loop-carried tensors. + Raises: + ValueError: if there is a type error. + """ + def _convert_to_list(xs): + if not isinstance(xs, (list, tuple)): + return [xs] + else: + return list(xs) + + def cond(i, *args): + del args + return i < n + + def body_wrapper(i, *args): + return [i + 1] + _convert_to_list(body(*args)) + + inputs = [0] if inputs is None else [0] + _convert_to_list(inputs) + outputs = while_loop( + cond, body_wrapper, inputs=inputs, infeed_queue=infeed_queue, name=name) + outputs = _convert_to_list(outputs) + if len(outputs) == 1: + # Returns the Op rather than an empty list. + return outputs[0].op + else: + return outputs[1:] diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index 1180ea92994..086372019ca 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -41,24 +41,25 @@ py_library( "//tensorflow/python:control_flow_ops", "//tensorflow/python:data_flow_ops", "//tensorflow/python:errors", - "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:logging_ops", "//tensorflow/python:math_ops", - "//tensorflow/python:ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", "//tensorflow/python:random_ops", "//tensorflow/python:script_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", "//tensorflow/python:string_ops", "//tensorflow/python:summary", "//tensorflow/python:tensor_array_ops", + "//tensorflow/python:tensor_util", "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", - "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/estimator:inputs_queues", "//third_party/py/numpy", "@six_archive//:six", ], @@ -74,7 +75,6 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:training", "//tensorflow/python:variables", ], @@ -91,7 +91,6 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", "//tensorflow/python:string_ops", @@ -107,12 +106,13 @@ py_test( tags = ["manual"], deps = [ ":training_py", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:training", "//tensorflow/python:variables", @@ -128,6 +128,7 @@ py_test( deps = [ ":training_py", "//tensorflow/python:client_testlib", + "@six_archive//:six", ], ) @@ -142,7 +143,6 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", "//tensorflow/python:variables", "//third_party/py/numpy", @@ -160,13 +160,12 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", "//tensorflow/python:errors", - "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:logging_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", "//tensorflow/python:random_ops", + "//tensorflow/python:random_seed", "//tensorflow/python:training", "//tensorflow/python:variables", "//third_party/py/numpy", @@ -186,11 +185,10 @@ py_test( ":training_py", "//tensorflow/python:client_testlib", "//tensorflow/python:data_flow_ops", - "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", + "//tensorflow/python:random_seed", "//tensorflow/python:training", "//tensorflow/python:variables", ], @@ -206,10 +204,11 @@ py_test( ":training_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:data_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", "//third_party/py/numpy", ], @@ -223,12 +222,9 @@ py_test( tags = ["manual"], deps = [ ":training_py", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:data_flow_ops", + "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:training", "//third_party/py/numpy", @@ -255,11 +251,10 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:random_seed", "//tensorflow/python:state_ops", "//tensorflow/python:summary", "//tensorflow/python:training", @@ -278,14 +273,11 @@ py_test( ":training_py", "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/layers:layers_py", - "//tensorflow/contrib/losses:losses_py", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:random_seed", "//tensorflow/python:training", "//tensorflow/python:variables", "//tensorflow/python/ops/losses", diff --git a/tensorflow/contrib/util/BUILD b/tensorflow/contrib/util/BUILD index 5ad8e3dd358..5575fb35702 100644 --- a/tensorflow/contrib/util/BUILD +++ b/tensorflow/contrib/util/BUILD @@ -27,7 +27,6 @@ cc_binary( srcs = ["convert_graphdef_memmapped_format.cc"], deps = [ ":convert_graphdef_memmapped_format_lib", - "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", ], @@ -69,6 +68,7 @@ py_library( deps = [ "//tensorflow/python:framework", "//tensorflow/python:platform", + "//tensorflow/python:tensor_util", "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/verbs/BUILD b/tensorflow/contrib/verbs/BUILD index e747fa4c9e4..5f062cde890 100644 --- a/tensorflow/contrib/verbs/BUILD +++ b/tensorflow/contrib/verbs/BUILD @@ -66,9 +66,8 @@ cc_library( ":grpc_verbs_service_impl", ":rdma_mgr", ":verbs_service_proto_cc", - "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:session_mgr", - "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime/rpc:async_service_interface", "//tensorflow/core/distributed_runtime/rpc:grpc_call", "//tensorflow/core/distributed_runtime/rpc:grpc_util", @@ -107,8 +106,8 @@ cc_library( hdrs = ["rdma_rendezvous_mgr.h"], deps = [ ":rdma_mgr", + ":verbs_util", "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/distributed_runtime:base_rendezvous_mgr", "//tensorflow/core/distributed_runtime:worker_env", @@ -122,10 +121,11 @@ cc_library( deps = [ ":grpc_verbs_client", ":rdma", + ":verbs_service_proto_cc", "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core/distributed_runtime:session_mgr", "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime/rpc:grpc_channel", "//tensorflow/core/distributed_runtime/rpc:grpc_worker_cache", @@ -144,6 +144,7 @@ cc_library( ":verbs_util", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", "//tensorflow/core:gpu_runtime", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -162,6 +163,8 @@ cc_library( ":grpc_verbs_service", ":rdma_mgr", ":rdma_rendezvous_mgr", + "//tensorflow/core:lib", + "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", ], alwayslink = 1, diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 1f0b100bbbd..27544c99610 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -329,6 +329,7 @@ cc_library( deps = [ ":lib", ":lib_internal", + ":protos_all_cc", "//tensorflow/core/platform/default/build_config:gtest", ] + tf_additional_test_deps(), ) @@ -422,6 +423,7 @@ cc_library( hdrs = ["util/overflow.h"], deps = [ ":framework_lite", + ":lib", ], ) @@ -433,6 +435,7 @@ cc_library( deps = [ ":framework", ":lib", + ":protos_all_cc", ], ) @@ -460,6 +463,10 @@ cc_library( name = "session_options", hdrs = ["public/session_options.h"], visibility = ["//visibility:public"], + deps = [ + ":lib", + ":protos_all_cc", + ], ) cc_library( @@ -495,6 +502,7 @@ cc_library( tf_gen_op_libs( op_lib_names = [ "array_ops", + "bitwise_ops", "candidate_sampling_ops", "control_flow_ops", "ctc_ops", @@ -573,6 +581,7 @@ cc_library( deps = [ ":array_ops_op_lib", ":audio_ops_op_lib", + ":bitwise_ops_op_lib", ":candidate_sampling_ops_op_lib", ":control_flow_ops_op_lib", ":ctc_ops_op_lib", @@ -590,6 +599,7 @@ cc_library( ":no_op_op_lib", ":parsing_ops_op_lib", ":random_ops_op_lib", + ":remote_fused_graph_ops_op_lib", ":script_ops_op_lib", ":sdca_ops_op_lib", ":sendrecv_ops_op_lib", @@ -665,6 +675,7 @@ tf_cuda_library( name = "core_cpu", hdrs = [ "common_runtime/device.h", + "common_runtime/optimization_registry.h", "common_runtime/shape_refiner.h", "graph/algorithm.h", "graph/default_device.h", @@ -1130,7 +1141,6 @@ cc_library( ":framework_internal", ":lib", ":lib_internal", - ":proto_text", ":protos_all_cc", ":stream_executor", ], @@ -1348,6 +1358,7 @@ cc_library( }), deps = [ ":lib", + ":lib_internal", "//tensorflow/core/platform/default/build_config:gif", ], ) @@ -1370,6 +1381,7 @@ cc_library( }), deps = [ ":lib", + ":lib_internal", "//tensorflow/core/platform/default/build_config:jpeg", ], ) @@ -1733,6 +1745,7 @@ cc_library( linkstatic = 1, visibility = ["//visibility:public"], deps = [ + ":core_cpu", ":core_cpu_internal", ":framework", ":lib", @@ -1845,7 +1858,7 @@ cc_library( ":framework_internal", ":lib", ":lib_internal", - ":protos_all_cc", + ":proto_text", "//third_party/eigen3", "@local_config_sycl//sycl:sycl", ], @@ -1866,7 +1879,10 @@ cc_library( "lib/random/philox_random_test_utils.h", "platform/snappy.h", ], - deps = [":lib_internal"], + deps = [ + ":lib", + ":lib_internal", + ], ) cc_library( @@ -1892,6 +1908,7 @@ cc_library( ":framework", ":lib", ":lib_internal", + ":protos_all_cc", ], ) @@ -2833,7 +2850,6 @@ cc_test( ":testlib", "//tensorflow/cc:cc_ops", "//tensorflow/core/kernels:example_parsing_ops", - "//tensorflow/core/kernels:ops_util", ], ) diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 4970c2d252a..b0b834b66ae 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -326,13 +326,14 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef, const FunctionBody* fbody = GetFunctionBody(handle); CHECK_NOTNULL(fbody); - // TODO(zhifengc): For now, we assume int32 is always on host memory - // and other types are always on device memory. We should do type - // inference over function body to derive the correct input/output - // memory types. + // TODO(zhifengc): For now, we assume int32 and resources are always on host + // memory and other types are always on device memory. We should do type + // inference over function body to derive the correct input/output memory + // types. MemoryTypeVector input_memory_types; for (const auto& t : fbody->arg_types) { - input_memory_types.push_back(t == DT_INT32 ? HOST_MEMORY : DEVICE_MEMORY); + input_memory_types.push_back( + (t == DT_INT32 || t == DT_RESOURCE) ? HOST_MEMORY : DEVICE_MEMORY); } MemoryTypeVector output_memory_types; for (const auto& t : fbody->ret_types) { @@ -1046,10 +1047,12 @@ void ToGraphDef(const Graph* g, GraphDef* gdef, bool pretty) { // to be unique and stable after optimization rewrites. Therefore, // we use "n" instead. for (const Edge* e : inputs) { - const string srcname = NewName(e->src(), pretty); if (e == nullptr) { ndef->add_input("unknown"); - } else if (!e->src()->IsOp()) { + continue; + } + const string srcname = NewName(e->src(), pretty); + if (!e->src()->IsOp()) { } else if (e->IsControlEdge()) { ndef->add_input(strings::StrCat("^", srcname)); } else if (e->src_output() == 0) { diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator.cc b/tensorflow/core/common_runtime/gpu/pool_allocator.cc index 700ac347163..66fff16e8f7 100644 --- a/tensorflow/core/common_runtime/gpu/pool_allocator.cc +++ b/tensorflow/core/common_runtime/gpu/pool_allocator.cc @@ -239,11 +239,9 @@ void PoolAllocator::EvictOne() { (alloc_request_count == 0) ? 0.0 : allocated_count_ / static_cast(alloc_request_count); - static int log_counter = 0; - // (counter increment not thread safe but it's just for logging, so we - // don't care). - bool should_log = ((log_counter++ % 10) == 0); - if (should_log) { + // Can turn on for debugging purposes. + const bool kShouldLog = false; + if (kShouldLog) { LOG(INFO) << "PoolAllocator: After " << alloc_request_count << " get requests, put_count=" << put_count_ << " evicted_count=" << evicted_count_ @@ -255,7 +253,7 @@ void PoolAllocator::EvictOne() { size_t new_size_limit = (pool_size_limit_ < kMinPoolSize) ? kMinPoolSize : (kIncreaseFactor * pool_size_limit_); - if (should_log) { + if (kShouldLog) { LOG(INFO) << "Raising pool_size_limit_ from " << pool_size_limit_ << " to " << new_size_limit; } diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 5103e852218..0f6c3c1cb19 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -139,7 +139,7 @@ Status ShapeRefiner::SetShape(const Node* node, int output_port, return Status::OK(); } -Status ShapeRefiner::UpdateNode(const Node* node, bool* refined) { +Status ShapeRefiner::UpdateNode(const Node* node, bool relax, bool* refined) { auto it = node_to_context_.find(node); if (it == node_to_context_.end()) { *refined = true; @@ -155,29 +155,55 @@ Status ShapeRefiner::UpdateNode(const Node* node, bool* refined) { for (const Edge* e : node->in_edges()) { if (e->IsControlEdge()) continue; + int dst_input = e->dst_input(); + int src_output = e->src_output(); + Node* input = e->src(); auto iter = node_to_context_.find(input); if (iter == node_to_context_.end()) { return errors::FailedPrecondition( - "Input ", e->dst_input(), " ('", input->name(), "') for '", - node->name(), "' was not previously added to ShapeRefiner."); + "Input ", dst_input, " ('", input->name(), "') for '", node->name(), + "' was not previously added to ShapeRefiner."); } InferenceContext* c = iter->second.get(); - DCHECK_GE(e->dst_input(), 0); - if (node_context->MergeInput(e->dst_input(), c->output(e->src_output()))) { + DCHECK_GE(dst_input, 0); + ShapeHandle existing_input = node_context->input(dst_input); + if (!relax && node_context->MergeInput(dst_input, c->output(src_output))) { *refined = true; + } else if (relax) { + if (node_context->RelaxInput(dst_input, c->output(src_output))) { + if (!SameDefinedShape(node_context, node_context->input(dst_input), + existing_input)) { + *refined = true; + } + } } // Also propagate handle shape and dtype of edges which are carrying // resource handles. - if (e->src()->output_type(e->src_output()) == DT_RESOURCE) { - auto* shapes_and_types = - c->output_handle_shapes_and_types(e->src_output()); - if (shapes_and_types != nullptr && - node_context->MergeInputHandleShapesAndTypes(e->dst_input(), - *shapes_and_types)) { + if (e->src()->output_type(src_output) == DT_RESOURCE) { + auto* outputs = c->output_handle_shapes_and_types(src_output); + if (!outputs) continue; + + if (!relax && + node_context->MergeInputHandleShapesAndTypes(dst_input, *outputs)) { *refined = true; + } else if (relax) { + std::vector existing_inputs; + const std::vector* inputs = + node_context->input_handle_shapes_and_types(dst_input); + if (inputs) { + existing_inputs = *inputs; + } + if (node_context->RelaxInputHandleShapesAndMergeTypes(dst_input, + *outputs)) { + if (IsUpdatedShapesOrTypes( + node_context, existing_inputs, + *node_context->input_handle_shapes_and_types(dst_input))) { + *refined = true; + } + } } } } @@ -638,4 +664,36 @@ Status ShapeRefiner::RunShapeFn(const Node* node, return Status::OK(); } +bool ShapeRefiner::SameDefinedShape(InferenceContext* c, ShapeHandle s0, + ShapeHandle s1) { + if (c->Rank(s0) != c->Rank(s1)) { + return false; + } else if (!c->RankKnown(s0)) { + return true; + } + + for (int i = 0; i < c->Rank(s0); ++i) { + if (c->Value(c->Dim(s0, i)) != c->Value(c->Dim(s1, i))) { + return false; + } + } + + return true; +} + +bool ShapeRefiner::IsUpdatedShapesOrTypes( + InferenceContext* c, const std::vector& existing, + const std::vector& updated) { + if (existing.size() != updated.size()) { + return true; + } + for (int i = 0; i < existing.size(); i++) { + if (!SameDefinedShape(c, existing[i].shape, updated[i].shape) || + existing[i].dtype != updated[i].dtype) { + return true; + } + } + return false; +} + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h index 603659d54e2..1af7835392f 100644 --- a/tensorflow/core/common_runtime/shape_refiner.h +++ b/tensorflow/core/common_runtime/shape_refiner.h @@ -24,6 +24,9 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" namespace tensorflow { +namespace grappler { +class GraphProperties; +} // ShapeRefiner performs shape inference for TensorFlow Graphs. It is // responsible for instantiating InferenceContext objects for each @@ -57,8 +60,13 @@ class ShapeRefiner { // Update the input shapes of node in case the shapes of the fan-ins of 'node' // have themselves been modified (For example, in case of incremental shape - // refinement). Sets refined to true if any of the node shape has changed. - Status UpdateNode(const Node* node, bool* refined); + // refinement). If 'relax' is true, a new shape with the broadest set of + // information will be set as the new input (see InferenceContext::RelaxInput + // for full details and examples). Sets refined to true if any shapes have + // changed (in their string representations). Note that shapes may have been + // updated to newer versions (but with identical string representations) even + // if <*refined> is set to false. + Status UpdateNode(const Node* node, bool relax, bool* refined); // Returns the InferenceContext for 'node', if present. shape_inference::InferenceContext* GetContext(const Node* node) const { @@ -78,6 +86,22 @@ class ShapeRefiner { } private: + friend class ShapeRefinerTest; + friend class ::tensorflow::grappler::GraphProperties; + + // Returns true if the ranks and all dimensions of and are either + // equal in value or both unknown. + static bool SameDefinedShape(shape_inference::InferenceContext* c, + shape_inference::ShapeHandle s0, + shape_inference::ShapeHandle s1); + + // Returns true if the shapes and types stored in <*existing> are identical in + // value to the shapes and types in <*updated>. + static bool IsUpdatedShapesOrTypes( + shape_inference::InferenceContext* c, + const std::vector& existing, + const std::vector& updated); + // Tries to infer tensor output based on the input shapes of the node. In some // cases, the shapes of the inputs are sufficient for inferring the contents // of the output tensor. For example, a Shape op with fully defined input diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc index 55485dc979a..7ffab38ba2a 100644 --- a/tensorflow/core/common_runtime/shape_refiner_test.cc +++ b/tensorflow/core/common_runtime/shape_refiner_test.cc @@ -26,6 +26,35 @@ limitations under the License. #include "tensorflow/core/public/version.h" namespace tensorflow { + +class ShapeRefinerTest : public ::testing::Test { + protected: + // These give access to private functions of DimensionHandle and ShapeHandle. + bool SameHandle(shape_inference::DimensionHandle a, + shape_inference::DimensionHandle b) { + return a.SameHandle(b); + } + + bool SameHandle(shape_inference::ShapeHandle a, + shape_inference::ShapeHandle b) { + return a.SameHandle(b); + } + + // These give access to private functions of ShapeRefiner. + bool SameDefinedShape(shape_inference::InferenceContext* c, + shape_inference::ShapeHandle s0, + shape_inference::ShapeHandle s1) { + return ShapeRefiner::SameDefinedShape(c, s0, s1); + } + + bool IsUpdatedShapesOrTypes( + shape_inference::InferenceContext* c, + const std::vector& existing, + const std::vector& updated) { + return ShapeRefiner::IsUpdatedShapesOrTypes(c, existing, updated); + } +}; + namespace { #define EXPECT_SHAPE(EXPECTED, M, OP, IDX) \ @@ -34,7 +63,7 @@ namespace { EXPECT_EQ(EXPECTED, ctx->DebugString(ctx->output(IDX))); \ } while (0); -TEST(ShapeRefinerTest, Constant) { +TEST_F(ShapeRefinerTest, Constant) { // Create a constant node and validate that adding it is successful // and that its shape is correct. Scope root = Scope::NewRootScope(); @@ -45,7 +74,7 @@ TEST(ShapeRefinerTest, Constant) { EXPECT_SHAPE("[]", m, c, 0); } -TEST(ShapeRefinerTest, MatMul) { +TEST_F(ShapeRefinerTest, MatMul) { ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); Scope root = Scope::NewRootScope(); @@ -62,7 +91,7 @@ TEST(ShapeRefinerTest, MatMul) { EXPECT_SHAPE("[2,2]", m, mm, 0); } -TEST(ShapeRefinerTest, InvalidOrder) { +TEST_F(ShapeRefinerTest, InvalidOrder) { ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); Scope root = Scope::NewRootScope(); auto a = ops::Const(root, {{1.0f}, {2.0f}}); @@ -77,7 +106,7 @@ TEST(ShapeRefinerTest, InvalidOrder) { s.error_message()); } -TEST(ShapeRefinerTest, BadShapes) { +TEST_F(ShapeRefinerTest, BadShapes) { ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); Scope root = Scope::NewRootScope(); auto a = ops::Const(root, {{1.0f}, {2.0f}}); @@ -94,7 +123,7 @@ TEST(ShapeRefinerTest, BadShapes) { .contains("Dimensions must be equal, but are 1 and 2")); } -TEST(ShapeRefinerTest, SetShape) { +TEST_F(ShapeRefinerTest, SetShape) { ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); Scope root = Scope::NewRootScope(); @@ -133,7 +162,7 @@ REGISTER_OP("TestOpWithNoShapeFn").Input("a: int32").Output("o: int32"); } // namespace -TEST(ShapeRefinerTest, MissingShapeInferenceFns) { +TEST_F(ShapeRefinerTest, MissingShapeInferenceFns) { Scope root = Scope::NewRootScope(); auto a = ops::Const(root, 42); Node* b; @@ -147,7 +176,7 @@ TEST(ShapeRefinerTest, MissingShapeInferenceFns) { TF_EXPECT_OK(m.AddNode(b)); } -TEST(ShapeRefinerTest, PropagateConstants) { +TEST_F(ShapeRefinerTest, PropagateConstants) { // Reduction dimension is a variable, so we don't know its value. // So the output shape value is unknown (though its rank is known). { @@ -220,7 +249,7 @@ REGISTER_OP("TestOp") } // namespace -TEST(ShapeRefinerTest, InputTensorDependencies) { +TEST_F(ShapeRefinerTest, InputTensorDependencies) { ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); Graph graph(OpRegistry::Global()); Node* node; @@ -289,7 +318,7 @@ REGISTER_OP("ShapeDataInt64") } // namespace -TEST(ShapeRefinerTest, PropagateShapeAcrossTensorContent) { +TEST_F(ShapeRefinerTest, PropagateShapeAcrossTensorContent) { Scope root = Scope::NewRootScope(); // Create variable 2x4 tensor. @@ -320,7 +349,7 @@ TEST(ShapeRefinerTest, PropagateShapeAcrossTensorContent) { EXPECT_EQ("[4]", ctx->DebugString(ctx->output(0))); } -TEST(ShapeRefinerTest, PropagateShapeAcrossTensorContentInt64) { +TEST_F(ShapeRefinerTest, PropagateShapeAcrossTensorContentInt64) { Scope root = Scope::NewRootScope(); // Create variable 2x4 tensor. @@ -354,7 +383,7 @@ TEST(ShapeRefinerTest, PropagateShapeAcrossTensorContentInt64) { EXPECT_EQ("[4]", ctx->DebugString(ctx->output(0))); } -TEST(ShapeRefinerTest, PropagateShapeAcrossTensorContentInt32Overflow) { +TEST_F(ShapeRefinerTest, PropagateShapeAcrossTensorContentInt32Overflow) { Scope root = Scope::NewRootScope(); // Create variable 2x4 tensor. @@ -386,7 +415,7 @@ TEST(ShapeRefinerTest, PropagateShapeAcrossTensorContentInt32Overflow) { EXPECT_FALSE(m.AddNode(shape_data).ok()); } -TEST(ShapeRefinerTest, PropagateRankAcrossTensorContent) { +TEST_F(ShapeRefinerTest, PropagateRankAcrossTensorContent) { Scope root = Scope::NewRootScope(); // Create variable 2x4x3 tensor. @@ -412,7 +441,7 @@ TEST(ShapeRefinerTest, PropagateRankAcrossTensorContent) { EXPECT_EQ("[3]", ctx->DebugString(ctx->output(0))); } -TEST(ShapeRefinerTest, PropagateSizeAcrossTensorContent) { +TEST_F(ShapeRefinerTest, PropagateSizeAcrossTensorContent) { Scope root = Scope::NewRootScope(); // Create variable. @@ -438,7 +467,7 @@ TEST(ShapeRefinerTest, PropagateSizeAcrossTensorContent) { EXPECT_EQ("[120]", ctx->DebugString(ctx->output(0))); } -TEST(ShapeRefinerTest, PropagateSizeAcrossTensorContentInt64) { +TEST_F(ShapeRefinerTest, PropagateSizeAcrossTensorContentInt64) { Scope root = Scope::NewRootScope(); // Create variable. @@ -469,7 +498,7 @@ TEST(ShapeRefinerTest, PropagateSizeAcrossTensorContentInt64) { EXPECT_EQ("[515396075280]", ctx->DebugString(ctx->output(0))); } -TEST(ShapeRefinerTest, PropagateSizeAcrossTensorContentInt32Overflow) { +TEST_F(ShapeRefinerTest, PropagateSizeAcrossTensorContentInt32Overflow) { Scope root = Scope::NewRootScope(); // Create variable. @@ -496,7 +525,7 @@ TEST(ShapeRefinerTest, PropagateSizeAcrossTensorContentInt32Overflow) { EXPECT_FALSE(m.AddNode(shape_data).ok()); } -TEST(ShapeRefinerTest, PropagateShape) { +TEST_F(ShapeRefinerTest, PropagateShape) { Scope root = Scope::NewRootScope(); // 3x2 input auto input = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}); @@ -518,7 +547,7 @@ TEST(ShapeRefinerTest, PropagateShape) { EXPECT_EQ("[3,2]", ctx->DebugString(ctx->output(0))); } -TEST(ShapeRefinerTest, PropagateSize) { +TEST_F(ShapeRefinerTest, PropagateSize) { Scope root = Scope::NewRootScope(); // 3x2 input auto input = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}); @@ -539,7 +568,7 @@ TEST(ShapeRefinerTest, PropagateSize) { EXPECT_EQ("[6]", ctx->DebugString(ctx->output(0))); } -TEST(ShapeRefinerTest, PropagateRank) { +TEST_F(ShapeRefinerTest, PropagateRank) { Scope root = Scope::NewRootScope(); // 3x2 input auto input = ops::Const(root, {{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}); @@ -560,7 +589,7 @@ TEST(ShapeRefinerTest, PropagateRank) { EXPECT_EQ("[2]", ctx->DebugString(ctx->output(0))); } -TEST(ShapeRefinerTest, PropagateRange) { +TEST_F(ShapeRefinerTest, PropagateRange) { Scope root = Scope::NewRootScope(); auto begin = ops::Const(root, 1); auto limit = ops::Const(root, 11); @@ -583,7 +612,7 @@ TEST(ShapeRefinerTest, PropagateRange) { EXPECT_EQ("[1,4,7,10]", ctx->DebugString(ctx->output(0))); } -TEST(ShapeRefinerTest, ConstantValueTwoInputsToSameNode) { +TEST_F(ShapeRefinerTest, ConstantValueTwoInputsToSameNode) { Scope root = Scope::NewRootScope(); // This node is used as two inputs to 'range'. auto begin_and_delta = ops::Const(root, 1); @@ -607,7 +636,7 @@ TEST(ShapeRefinerTest, ConstantValueTwoInputsToSameNode) { // Creates a graph where 'begin' is attempted to be visited during // constant value evaluation after having been processed once. -TEST(ShapeRefinerTest, ConstantValueVisitNodeTwice) { +TEST_F(ShapeRefinerTest, ConstantValueVisitNodeTwice) { Scope root = Scope::NewRootScope(); auto begin = ops::Const(root, 1); auto limit = ops::Const(root, 8); @@ -716,7 +745,7 @@ REGISTER_OP("WithUnknownShape") } // namespace -TEST(ShapeRefinerTest, ConstantValueAsShape_EmptyVector) { +TEST_F(ShapeRefinerTest, ConstantValueAsShape_EmptyVector) { Scope root = Scope::NewRootScope(); Node* input; TF_ASSERT_OK( @@ -734,7 +763,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_EmptyVector) { EXPECT_EQ("[]", ctx->DebugString(ctx->output(0))); } -TEST(ShapeRefinerTest, ConstantValueAsShape_Shape) { +TEST_F(ShapeRefinerTest, ConstantValueAsShape_Shape) { for (int pass = 0; pass < 2; ++pass) { Scope root = Scope::NewRootScope(); Node* input; @@ -761,7 +790,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_Shape) { } } -TEST(ShapeRefinerTest, ConstantValueAsShape_PackInt32) { +TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt32) { Scope root = Scope::NewRootScope(); Node* scalar_non_const; TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt32") @@ -793,7 +822,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_PackInt32) { EXPECT_EQ("[10,20,?,40]", ctx->DebugString(ctx->output(0))); } -TEST(ShapeRefinerTest, ConstantValueAsShape_PackInt64) { +TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInt64) { Scope root = Scope::NewRootScope(); Node* scalar_non_const; TF_ASSERT_OK(NodeBuilder("in", "NonConstScalarInt64") @@ -825,7 +854,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_PackInt64) { EXPECT_EQ("[10,20,?,1099511627776]", ctx->DebugString(ctx->output(0))); } -TEST(ShapeRefinerTest, ConstantValueAsShape_PackUnknownDim) { +TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackUnknownDim) { Scope root = Scope::NewRootScope(); InputList inputs{ @@ -851,7 +880,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_PackUnknownDim) { EXPECT_EQ("[10,?]", ctx->DebugString(ctx->output(0))); } -TEST(ShapeRefinerTest, ConstantValueAsShape_PackInvalidInput) { +TEST_F(ShapeRefinerTest, ConstantValueAsShape_PackInvalidInput) { Scope root = Scope::NewRootScope(); // Inputs are length 2 vectors instead of scalars. @@ -876,7 +905,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_PackInvalidInput) { StringPiece(m.AddNode(result).error_message()).contains("but is rank 2")); } -TEST(ShapeRefinerTest, ConstantValueAsShape_Concat) { +TEST_F(ShapeRefinerTest, ConstantValueAsShape_Concat) { Scope root = Scope::NewRootScope(); Graph* g = root.graph(); Node* partial_1; @@ -913,7 +942,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_Concat) { EXPECT_EQ("[1,?,3,?,5,6,?,8,9,10,11]", ctx->DebugString(ctx->output(0))); } -TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) { +TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) { Scope root = Scope::NewRootScope(); Graph* g = root.graph(); Node* scalar_non_const; @@ -956,7 +985,7 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatWithUnknown) { EXPECT_EQ("?", ctx->DebugString(ctx->output(0))); } -TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) { +TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) { Scope root = Scope::NewRootScope(); Graph* g = root.graph(); Node* scalar_non_const; @@ -995,7 +1024,78 @@ TEST(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) { m.AddNode(result).error_message()); } -TEST(ShapeRefinerTest, IncrementalUpdates) { +namespace { + +// Dummy op to test ShapeRefiner util functions +REGISTER_OP("Dummy"); + +} // namespace + +TEST_F(ShapeRefinerTest, SameDefinedShape) { + Scope root = Scope::NewRootScope(); + Graph* g = root.graph(); + Node* test; + TF_CHECK_OK(NodeBuilder("test", "Dummy").Finalize(g, &test)); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); + m.set_require_shape_inference_fns(false); + TF_ASSERT_OK(m.AddNode(test)); + shape_inference::InferenceContext* ctx = m.GetContext(test); + + auto unknown = ctx->UnknownShape(); + auto unknown_b = ctx->UnknownShape(); + auto s_1_2 = ctx->MakeShape({1, 2}); + auto s_1_2_b = ctx->MakeShape({1, 2}); + auto s_2_2 = ctx->MakeShape({2, 2}); + auto s_unknown_2 = ctx->MakeShape({-1, 2}); + auto s_unknown_2_b = ctx->MakeShape({-1, 2}); + + EXPECT_TRUE(SameDefinedShape(ctx, unknown, unknown_b)); + EXPECT_FALSE(SameDefinedShape(ctx, unknown, s_1_2)); + EXPECT_TRUE(SameDefinedShape(ctx, s_1_2, s_1_2_b)); + EXPECT_FALSE(SameDefinedShape(ctx, s_1_2, s_2_2)); + EXPECT_TRUE(SameDefinedShape(ctx, s_unknown_2, s_unknown_2_b)); +} + +TEST_F(ShapeRefinerTest, IsUpdatedShapesOrTypes) { + Scope root = Scope::NewRootScope(); + Graph* g = root.graph(); + Node* test; + TF_CHECK_OK(NodeBuilder("test", "Dummy").Finalize(g, &test)); + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); + m.set_require_shape_inference_fns(false); + TF_ASSERT_OK(m.AddNode(test)); + shape_inference::InferenceContext* ctx = m.GetContext(test); + + std::vector t0{ + {ctx->MakeShape({1, 2, 3}), DT_FLOAT}, + {ctx->UnknownShape(), DT_INVALID}, + {ctx->MakeShape({4, 3, 2, 1}), DT_INT32}}; + + std::vector t1{ + {ctx->MakeShape({1, 2, 3}), DT_FLOAT}, + {ctx->UnknownShape(), DT_INVALID}, + {ctx->MakeShape({4, 3, 2, 1}), DT_INT32}}; + + std::vector t2{ + {ctx->MakeShape({1, 2, 4}), DT_FLOAT}, + {ctx->UnknownShape(), DT_INVALID}, + {ctx->MakeShape({4, 3, 2, 1}), DT_INT32}}; + + std::vector t3{ + {ctx->MakeShape({1, 2, 3}), DT_INT32}, + {ctx->UnknownShape(), DT_INVALID}, + {ctx->MakeShape({4, 3, 2, 1}), DT_INT32}}; + + EXPECT_FALSE(IsUpdatedShapesOrTypes(ctx, t0, t1)); + + // A shape has been modified + EXPECT_TRUE(IsUpdatedShapesOrTypes(ctx, t0, t2)); + + // A type has been modified + EXPECT_TRUE(IsUpdatedShapesOrTypes(ctx, t0, t3)); +} + +TEST_F(ShapeRefinerTest, IncrementalUpdates) { Scope root = Scope::NewRootScope(); Graph* g = root.graph(); Node* queue; @@ -1020,12 +1120,34 @@ TEST(ShapeRefinerTest, IncrementalUpdates) { shape_inference::ShapeHandle shp = ctx->MakeShape({3, 7}); ctx->set_output_handle_shapes_and_types( 0, std::vector{{shp, DT_FLOAT}}); - bool refined = false; - TF_ASSERT_OK(m.UpdateNode(dequeue, &refined)); + TF_ASSERT_OK(m.UpdateNode(dequeue, false /* relax */, &refined)); EXPECT_TRUE(refined); ctx = m.GetContext(dequeue); EXPECT_EQ("[3,7]", ctx->DebugString(ctx->output(0))); + + // Inject another shape, but relax instead of merge. + ctx = m.GetContext(queue); + shp = ctx->MakeShape({2, 7}); + ctx->set_output_handle_shapes_and_types( + 0, std::vector{{shp, DT_FLOAT}}); + refined = false; + TF_ASSERT_OK(m.UpdateNode(dequeue, true /* relax */, &refined)); + EXPECT_TRUE(refined); + ctx = m.GetContext(dequeue); + EXPECT_EQ("[?,7]", ctx->DebugString(ctx->output(0))); + + // Inject another partially unknown shape and attempt to relax it. + ctx = m.GetContext(queue); + shp = ctx->MakeShape({shape_inference::InferenceContext::kUnknownDim, 7}); + ctx->set_output_handle_shapes_and_types( + 0, std::vector{{shp, DT_FLOAT}}); + refined = false; + TF_ASSERT_OK(m.UpdateNode(dequeue, true /* relax */, &refined)); + EXPECT_FALSE(refined); + ctx = m.GetContext(dequeue); + EXPECT_EQ("[?,7]", ctx->DebugString(ctx->output(0))); + ASSERT_FALSE(SameHandle(ctx->Dim(ctx->output(0), 0), ctx->Dim(shp, 0))); } } // namespace diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD index 2fc49d4412e..30824a9072f 100644 --- a/tensorflow/core/debug/BUILD +++ b/tensorflow/core/debug/BUILD @@ -123,6 +123,7 @@ tf_cuda_library( linkstatic = 1, deps = [ ":debug_service_proto_cc", + ":debugger_event_metadata_proto_cc", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -144,6 +145,7 @@ tf_cuda_library( ":debug_graph_utils", ":debug_io_utils", ":debug_service_proto_cc", + ":debugger_event_metadata_proto_cc", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -194,6 +196,7 @@ tf_cc_test( deps = [ ":debug_grpc_testlib", ":debug_io_utils", + ":debugger_event_metadata_proto_cc", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -204,6 +207,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/core/platform/default/build_config:platformlib", ], ) @@ -247,6 +251,12 @@ tf_cc_test( ], ) +tf_proto_library_cc( + name = "debugger_event_metadata_proto", + srcs = ["debugger_event_metadata.proto"], + cc_api_version = 2, +) + # TODO(cais): Add the following back in when tfdbg is supported on Android. # filegroup( # name = "android_srcs", diff --git a/tensorflow/core/debug/debug_grpc_testlib.cc b/tensorflow/core/debug/debug_grpc_testlib.cc index c19842a2f6c..7317aa03727 100644 --- a/tensorflow/core/debug/debug_grpc_testlib.cc +++ b/tensorflow/core/debug/debug_grpc_testlib.cc @@ -16,10 +16,12 @@ limitations under the License. #include "tensorflow/core/debug/debug_grpc_testlib.h" #include "tensorflow/core/debug/debug_graph_utils.h" +#include "tensorflow/core/debug/debugger_event_metadata.pb.h" #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/tracing.h" namespace tensorflow { @@ -44,8 +46,6 @@ namespace test { tensorflow::str_util::Split(val.node_name(), ':'); const string node_name = name_items[0]; - int32 output_slot = 0; - tensorflow::strings::safe_strto32(name_items[1], &output_slot); const string debug_op = name_items[2]; const TensorProto& tensor_proto = val.tensor(); @@ -54,9 +54,24 @@ namespace test { return ::grpc::Status::CANCELLED; } - device_names.push_back(val.tag()); + // Obtain the device name, which is encoded in JSON. + third_party::tensorflow::core::debug::DebuggerEventMetadata metadata; + for (int i = 0; i < val.metadata().plugin_data_size(); i++) { + if (val.metadata().plugin_data(i).plugin_name() != "debugger") { + // This plugin data was meant for another plugin. + continue; + } + auto status = tensorflow::protobuf::util::JsonStringToMessage( + val.metadata().plugin_data(i).content(), &metadata); + if (status.ok()) { + // The device name has been determined. + break; + } + } + + device_names.push_back(metadata.device()); node_names.push_back(node_name); - output_slots.push_back(output_slot); + output_slots.push_back(metadata.output_slot()); debug_ops.push_back(debug_op); debug_tensors.push_back(tensor); } diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc index 54366ce2490..69fc3677892 100644 --- a/tensorflow/core/debug/debug_io_utils.cc +++ b/tensorflow/core/debug/debug_io_utils.cc @@ -26,11 +26,13 @@ limitations under the License. #pragma comment(lib,"Ws2_32.lib") #endif +#include "tensorflow/core/debug/debugger_event_metadata.pb.h" #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/util/event.pb.h" #define GRPC_OSS_UNIMPLEMENTED_ERROR \ @@ -55,7 +57,32 @@ Event WrapTensorAsEvent(const DebugNodeKey& debug_node_key, // "DebugIdentity", the debug node_name in the Summary proto will be // "foo/node_a:0:DebugIdentity". summ_val->set_node_name(debug_node_key.debug_node_name); - summ_val->set_tag(debug_node_key.device_name); + + // Tag by the node name. This allows TensorBoard to quickly fetch data per op. + summ_val->set_tag(debug_node_key.node_name); + + // Store data within debugger metadata to be stored for each event. + third_party::tensorflow::core::debug::DebuggerEventMetadata metadata; + metadata.set_device(debug_node_key.device_name); + metadata.set_output_slot(debug_node_key.output_slot); + + // Encode the data in JSON. + string json_output; + tensorflow::protobuf::util::JsonPrintOptions json_options; + json_options.always_print_primitive_fields = true; + auto status = tensorflow::protobuf::util::MessageToJsonString( + metadata, &json_output, json_options); + if (status.ok()) { + // Store summary metadata. Set the plugin to use this data as "debugger". + SummaryMetadata::PluginData* plugin_data = + summ_val->mutable_metadata()->add_plugin_data(); + plugin_data->set_plugin_name("debugger"); + plugin_data->set_content(json_output); + } else { + LOG(WARNING) << "Failed to convert DebuggerEventMetadata proto to JSON. " + << "The debug_node_name is " << debug_node_key.debug_node_name + << "."; + } if (tensor.dtype() == DT_STRING) { // Treat DT_STRING specially, so that tensor_util.MakeNdarray can convert @@ -131,6 +158,9 @@ const char* const DebugIO::kDeviceTag = "device_"; // static const char* const DebugIO::kGraphTag = "graph_"; +// static +const char* const DebugIO::kHashTag = "hash"; + DebugNodeKey::DebugNodeKey(const string& device_name, const string& node_name, const int32 output_slot, const string& debug_op) : device_name(device_name), @@ -351,8 +381,10 @@ Status DebugIO::PublishGraph(const Graph& graph, const string& device_name, const string dump_root_dir = io::JoinPath(debug_url.substr(strlen(kFileURLScheme)), DebugNodeKey::DeviceNameToDevicePath(device_name)); - const string file_name = strings::StrCat(DebugIO::kMetadataFilePrefix, - DebugIO::kGraphTag, now_micros); + const uint64 graph_hash = ::tensorflow::Hash64(buf); + const string file_name = + strings::StrCat(DebugIO::kMetadataFilePrefix, DebugIO::kGraphTag, + DebugIO::kHashTag, graph_hash, "_", now_micros); status.Update( DebugFileIO::DumpEventProtoToFile(event, dump_root_dir, file_name)); diff --git a/tensorflow/core/debug/debug_io_utils.h b/tensorflow/core/debug/debug_io_utils.h index 69d8c7bd4e0..4caa4b5e046 100644 --- a/tensorflow/core/debug/debug_io_utils.h +++ b/tensorflow/core/debug/debug_io_utils.h @@ -143,6 +143,7 @@ class DebugIO { static const char* const kCoreMetadataTag; static const char* const kDeviceTag; static const char* const kGraphTag; + static const char* const kHashTag; static const char* const kFileURLScheme; static const char* const kGrpcURLScheme; diff --git a/tensorflow/core/debug/debug_io_utils_test.cc b/tensorflow/core/debug/debug_io_utils_test.cc index 35c95fb98c4..08ef4001bc1 100644 --- a/tensorflow/core/debug/debug_io_utils_test.cc +++ b/tensorflow/core/debug/debug_io_utils_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/debug/debug_io_utils.h" +#include "tensorflow/core/debug/debugger_event_metadata.pb.h" #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/lib/core/notification.h" @@ -124,10 +125,18 @@ TEST_F(DebugIOUtilsTest, DumpStringTensorToFileSunnyDay) { ASSERT_GE(wall_time, event.wall_time()); ASSERT_EQ(1, event.summary().value().size()); - ASSERT_EQ(kDebugNodeKey.device_name, event.summary().value(0).tag()); + ASSERT_EQ(kDebugNodeKey.node_name, event.summary().value(0).tag()); ASSERT_EQ(kDebugNodeKey.debug_node_name, event.summary().value(0).node_name()); + // Determine and validate some information from the metadata. + third_party::tensorflow::core::debug::DebuggerEventMetadata metadata; + auto status = tensorflow::protobuf::util::JsonStringToMessage( + event.summary().value(0).metadata().plugin_data(0).content(), &metadata); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(kDebugNodeKey.device_name, metadata.device()); + ASSERT_EQ(kDebugNodeKey.output_slot, metadata.output_slot()); + Tensor b_prime(DT_STRING); ASSERT_TRUE(b_prime.FromProto(event.summary().value(0).tensor())); @@ -229,10 +238,19 @@ TEST_F(DebugIOUtilsTest, PublishTensorToMultipleFileURLs) { ASSERT_GE(wall_time, event.wall_time()); ASSERT_EQ(1, event.summary().value().size()); - ASSERT_EQ(kDebugNodeKey.device_name, event.summary().value(0).tag()); + ASSERT_EQ(kDebugNodeKey.node_name, event.summary().value(0).tag()); ASSERT_EQ(kDebugNodeKey.debug_node_name, event.summary().value(0).node_name()); + // Determine and validate some information from the metadata. + third_party::tensorflow::core::debug::DebuggerEventMetadata metadata; + auto status = tensorflow::protobuf::util::JsonStringToMessage( + event.summary().value(0).metadata().plugin_data(0).content(), + &metadata); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(kDebugNodeKey.device_name, metadata.device()); + ASSERT_EQ(kDebugNodeKey.output_slot, metadata.output_slot()); + Tensor a_prime(DT_FLOAT); ASSERT_TRUE(a_prime.FromProto(event.summary().value(0).tensor())); @@ -333,10 +351,19 @@ TEST_F(DebugIOUtilsTest, PublishTensorConcurrentlyToPartiallyOverlappingPaths) { ASSERT_GE(wall_time, event.wall_time()); ASSERT_EQ(1, event.summary().value().size()); - ASSERT_EQ(kDebugNodeKey.device_name, event.summary().value(0).tag()); + ASSERT_EQ(kDebugNodeKey.node_name, event.summary().value(0).tag()); ASSERT_EQ(kDebugNodeKey.debug_node_name, event.summary().value(0).node_name()); + // Determine and validate some information from the metadata. + third_party::tensorflow::core::debug::DebuggerEventMetadata metadata; + auto status = tensorflow::protobuf::util::JsonStringToMessage( + event.summary().value(0).metadata().plugin_data(0).content(), + &metadata); + ASSERT_TRUE(status.ok()); + ASSERT_EQ(kDebugNodeKey.device_name, metadata.device()); + ASSERT_EQ(kDebugNodeKey.output_slot, metadata.output_slot()); + Tensor a_prime(DT_FLOAT); ASSERT_TRUE(a_prime.FromProto(event.summary().value(0).tensor())); diff --git a/tensorflow/core/debug/debugger_event_metadata.proto b/tensorflow/core/debug/debugger_event_metadata.proto new file mode 100644 index 00000000000..44ef305f5a5 --- /dev/null +++ b/tensorflow/core/debug/debugger_event_metadata.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package third_party.tensorflow.core.debug; + +// Encapsulates per-event data related to debugging. +message DebuggerEventMetadata { + string device = 1; + int32 output_slot = 2; +}; diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD index efc08e4c9d0..f59e5f4dc21 100644 --- a/tensorflow/core/distributed_runtime/BUILD +++ b/tensorflow/core/distributed_runtime/BUILD @@ -399,7 +399,6 @@ cc_library( srcs = ["server_lib.cc"], hdrs = ["server_lib.h"], deps = [ - "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index 94fec4f6d00..c79f68a0688 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -1165,7 +1165,6 @@ WorkerCacheInterface* MasterSession::get_worker_cache() const { Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count, ReffedClientGraph** rcg, bool is_partial) { const uint64 hash = HashBuildGraphOptions(opts); - ReffedClientGraph* to_unref = nullptr; { mutex_lock l(mu_); // Keep track of how many times this subgraph has been executed in @@ -1196,7 +1195,6 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count, *rcg = iter->second; (*rcg)->Ref(); } - if (to_unref) to_unref->Unref(); return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 3ebc11614de..c0918ef4452 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -119,6 +119,7 @@ cc_library( hdrs = ["grpc_call.h"], deps = [ "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "@grpc//:grpc++_unsecure", ], ) @@ -381,6 +382,7 @@ cc_library( "//tensorflow/core/distributed_runtime:call_options", "//tensorflow/core/distributed_runtime:local_master", "//tensorflow/core/distributed_runtime:master_interface", + "//tensorflow/core/distributed_runtime:message_wrappers", ], alwayslink = 1, ) diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 1f9e98551f1..2c18ddd48fb 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -314,6 +314,19 @@ Status InferenceContext::WithValue(DimensionHandle dim, int64 value, existing); } +void InferenceContext::Relax(DimensionHandle d0, DimensionHandle d1, + DimensionHandle* out) { + if (d0.SameHandle(d1)) { + *out = d0; + } else if (!ValueKnown(d0) || !ValueKnown(d1)) { + *out = UnknownDim(); + } else if (Value(d0) == Value(d1)) { + *out = d0; + } else { + *out = UnknownDim(); + } +} + Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1, DimensionHandle* out) { if (d0.SameHandle(d1) || !ValueKnown(d1)) { @@ -356,6 +369,48 @@ Status InferenceContext::MergePrefix(ShapeHandle s, ShapeHandle prefix, return Status::OK(); } +void InferenceContext::Relax(ShapeHandle s0, ShapeHandle s1, ShapeHandle* out) { + if (s0.SameHandle(s1)) { + *out = s0; + return; + } else if (!RankKnown(s0) || !RankKnown(s1)) { + *out = UnknownShape(); + return; + } + + const int32 rank = Rank(s0); + if (rank != Rank(s1)) { + *out = UnknownShape(); + return; + } + + bool return_s0 = true; + for (int i = 0; i < rank; ++i) { + auto d0 = Dim(s0, i); + auto d1 = Dim(s1, i); + if (d0.SameHandle(d1)) continue; + + auto v0 = Value(d0); + auto v1 = Value(d1); + if (v0 == kUnknownDim || v1 == kUnknownDim || v0 != v1) { + return_s0 = false; + break; + } + } + if (return_s0) { + *out = s0; + return; + } + + // Relax dims. + std::vector dims(rank); + for (int i = 0; i < rank; ++i) { + // Invariant for relax was checked earlier, so CHECK is ok. + Relax(Dim(s0, i), Dim(s1, i), &dims[i]); + } + *out = MakeShape(dims); +} + Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1, ShapeHandle* out) { if (s0.SameHandle(s1) || !RankKnown(s1)) { @@ -895,9 +950,15 @@ bool InferenceContext::MergeHandleShapesAndTypes( bool refined = false; for (int i = 0; i < shapes_and_types.size(); ++i) { const ShapeAndType& existing = (*to_update)[i]; - new_values[i].dtype = shapes_and_types[i].dtype; - if (new_values[i].dtype != existing.dtype && existing.dtype == DT_INVALID) { - refined = true; + if (shapes_and_types[i].dtype == existing.dtype) { + new_values[i].dtype = existing.dtype; + } else { + if (existing.dtype != DT_INVALID) { + return false; + } else { + new_values[i].dtype = shapes_and_types[i].dtype; + refined = true; + } } if (!Merge(existing.shape, shapes_and_types[i].shape, &new_values[i].shape) .ok()) { @@ -939,6 +1000,62 @@ bool InferenceContext::MergeInputHandleShapesAndTypes( input_handle_shapes_and_types_[idx].get()); } +bool InferenceContext::RelaxHandleShapesAndMergeTypes( + const std::vector& shapes_and_types, + std::vector* to_update) { + if (shapes_and_types.size() != to_update->size()) { + return false; + } + std::vector new_values(shapes_and_types.size()); + bool refined = false; + for (int i = 0; i < shapes_and_types.size(); ++i) { + const ShapeAndType& existing = (*to_update)[i]; + if (shapes_and_types[i].dtype == existing.dtype) { + new_values[i].dtype = existing.dtype; + } else { + if (existing.dtype != DT_INVALID) { + return false; + } else { + new_values[i].dtype = shapes_and_types[i].dtype; + refined = true; + } + } + Relax(existing.shape, shapes_and_types[i].shape, &new_values[i].shape); + if (!existing.shape.SameHandle(new_values[i].shape)) { + refined = true; + } + } + if (!refined) { + return false; + } + for (int i = 0; i < new_values.size(); ++i) { + (*to_update)[i] = new_values[i]; + } + return true; +} + +bool InferenceContext::RelaxOutputHandleShapesAndMergeTypes( + int idx, const std::vector& shapes_and_types) { + if (output_handle_shapes_and_types_[idx] == nullptr) { + output_handle_shapes_and_types_[idx].reset( + new std::vector(shapes_and_types)); + return true; + } + return RelaxHandleShapesAndMergeTypes( + shapes_and_types, output_handle_shapes_and_types_[idx].get()); +} + +bool InferenceContext::RelaxInputHandleShapesAndMergeTypes( + int idx, const std::vector& shapes_and_types) { + if (input_handle_shapes_and_types_[idx] == nullptr) { + input_handle_shapes_and_types_[idx].reset( + new std::vector(shapes_and_types)); + return true; + } + return RelaxHandleShapesAndMergeTypes( + shapes_and_types, input_handle_shapes_and_types_[idx].get()); +} + // ----------------------------------------------------------------------------- // ShapeManager // ----------------------------------------------------------------------------- diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 119bed4071f..56686676596 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -26,6 +26,13 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" namespace tensorflow { + +class ShapeRefinerTest; + +namespace grappler { +class GraphProperties; +} + namespace shape_inference { struct DimensionOrConstant; @@ -62,6 +69,7 @@ class DimensionHandle { friend class InferenceContext; friend class ShapeInferenceTest; friend class ShapeInferenceTestutil; + friend class ::tensorflow::ShapeRefinerTest; friend class ShapeManager; // Intentionally copyable. @@ -98,6 +106,7 @@ class ShapeHandle { friend class InferenceContext; friend class ShapeInferenceTest; friend class ShapeInferenceTestutil; + friend class ::tensorflow::ShapeRefinerTest; friend class ShapeManager; // Intentionally copyable. @@ -201,10 +210,25 @@ class InferenceContext { return s; } - // Merge the stored shape of the input in position idx with the specified - // shape. This requires idx to be in the [0, num_inputs) range. If the merge - // is successful and the new shape differs from the old one, store the new - // shape and return true. Return false otherwise. + // Merge the stored shape of the input in position idx with according + // to the following rules: + // + // - If the ShapeHandles are the same or is unknown, there will be no + // change. Otherwise if the stored shape is unknown, the new shape will be + // . + // - If both shapes are known, then they must have the same rank. + // - For any one dimension, if the values for that dimension in both shapes + // are known, then the values must match. + // - If one shape has equal or more information than the other shape in every + // dimension, the shape with more information will be returned. Otherwise a + // new shape holding the combined information of the input shapes will be + // returned. + // - Example: merging [2,?] and [?,2] results in [2,2] + // - Example: [2,2] cannot be merged with [1,2] + // + // This requires idx to be in the [0, num_inputs) range. If the merge is + // successful and the new shape differs from the old one, store the new shape + // and return true. Return false otherwise. bool MergeInput(int idx, ShapeHandle shape) { ShapeHandle new_shape; if (!Merge(inputs_[idx], shape, &new_shape).ok() || @@ -214,6 +238,41 @@ class InferenceContext { inputs_[idx] = new_shape; return true; } + // Relax the stored shape of the input in position idx with according + // to the following rules: + // + // - If the ShapeHandles are the same then the stored shape will be returned. + // - If either of the ShapeHandles are unknown, then a new UnknownShape will + // be returned. A new shape must be returned because we cannot claim that + // the resulting shape is necessarily the same as either of the input + // shapes. + // - If the shapes both have known ranks but their ranks are different, a new + // UnknownShape will be returned. + // - For any one dimension, if the value for that dimension in either of the + // shapes is unknown, a new shape will be returned with a new UnknownDim in + // that dimension. + // - For any one dimension, if the values for that dimension in both shapes + // are known but do not match, a new shape will be returned with a new + // UnknownDim in that dimension. + // - If both shapes have the same known rank and match in every dimension, + // the stored shape will be returned. + // - Example: relaxing [2,?] and [?,2] results in [?,?] + // - Example: relaxing [2,2] and [3,2] results in [?,2] + // - Example: relaxing [2,2] with [1,2,3] results in ? + // + // This requires idx to be in the [0, num_inputs) range. If the relax is + // successful and the new shape differs from the old one, store the new + // shape and return true. Return false otherwise. + bool RelaxInput(int idx, ShapeHandle shape) { + ShapeHandle new_shape; + Relax(inputs_[idx], shape, &new_shape); + if (inputs_[idx].SameHandle(new_shape)) { + return false; + } + inputs_[idx] = new_shape; + return true; + } + ShapeHandle input(int64 idx) const { return inputs_[idx]; } Status input(StringPiece input_name, std::vector* output) const; int num_inputs() const { return inputs_.size(); } @@ -313,12 +372,9 @@ class InferenceContext { Status WithValue(DimensionHandle dim, int64 value, DimensionHandle* out) TF_MUST_USE_RESULT; - // Merges and and returns the merged shape in <*out>. If and - // are incompatible in rank, or in the value of any dimension, returns - // an error. - // - // Note that <*out> may be set to or . - Status Merge(ShapeHandle in0, ShapeHandle in1, + // Merges and and returns the merged shape in <*out>. See + // 'MergeInput' function for full details and examples. + Status Merge(ShapeHandle s0, ShapeHandle s1, ShapeHandle* out) TF_MUST_USE_RESULT; // Asserts that 's rank >= 's rank, and the first @@ -471,13 +527,34 @@ class InferenceContext { // If the merge is successful and any of the new shapes differs from the old // one, or any of the old dtypes was DT_INVALID, store the new shapes and // return true. Return false otherwise. + // + // See 'MergeInput' function for full details and examples. bool MergeInputHandleShapesAndTypes( int idx, const std::vector& shapes_and_types) TF_MUST_USE_RESULT; // As MergeInputHandleShapesAndTypes, but for an output. bool MergeOutputHandleShapesAndTypes( - int idx, const std::vector& shapes) TF_MUST_USE_RESULT; + int idx, + const std::vector& shapes_and_types) TF_MUST_USE_RESULT; + + // Relaxes the stored shapes and types corresponding to the input handle in + // position idx with the specified shapes and types. This requires idx to be + // in the [0, num_inputs) range. + // + // If the relax is successful and any of the new shapes differs from the old + // one, or any of the old dtypes was DT_INVALID, store the new shapes and + // return true. Return false otherwise. + // + // See 'RelaxInput' function for full details and examples. + bool RelaxInputHandleShapesAndMergeTypes( + int idx, + const std::vector& shapes_and_types) TF_MUST_USE_RESULT; + + // As RelaxInputHandleShapesAndTypes, but for an output. + bool RelaxOutputHandleShapesAndMergeTypes( + int idx, + const std::vector& shapes_and_types) TF_MUST_USE_RESULT; // Returns the output handle shapes and types, for the resource tensor output // at index . Returns NULL if the shape and types were never set. @@ -538,6 +615,8 @@ class InferenceContext { std::vector all_dims_; // values are owned. }; + friend class ::tensorflow::grappler::GraphProperties; + friend class ShapeInferenceTest; // For testing Relax functions. friend class ShapeInferenceTestutil; // For testing shapes. // Shared initialization across the two constructors. Remove @@ -563,11 +642,25 @@ class InferenceContext { // Adds additional context to the given status. Status AttachContext(const Status& status); + // Relaxes and and returns the relaxed dimension in <*out>. If + // and have incompatible values, returns an error. + // + // Note that <*out> may be set to or . + void Relax(DimensionHandle d0, DimensionHandle d1, DimensionHandle* out); + // Relaxes and and returns the relaxed shape in <*out>. See + // 'RelaxInput' function for full details and examples. + void Relax(ShapeHandle s0, ShapeHandle s1, ShapeHandle* out); + // Used to implement MergeInputHandleShapesAndTypes and // MergeOutputHandleShapesAndTypes. bool MergeHandleShapesAndTypes( const std::vector& shapes_and_types, std::vector* to_update) TF_MUST_USE_RESULT; + // Used to implement RelaxInputHandleShapesAndMergeTypes and + // RelaxOutputHandleShapesAndMergeTypes. + bool RelaxHandleShapesAndMergeTypes( + const std::vector& shapes_and_types, + std::vector* to_update) TF_MUST_USE_RESULT; ShapeManager shape_manager_; diff --git a/tensorflow/core/framework/shape_inference_test.cc b/tensorflow/core/framework/shape_inference_test.cc index a9c0303d4cb..66cfbf87475 100644 --- a/tensorflow/core/framework/shape_inference_test.cc +++ b/tensorflow/core/framework/shape_inference_test.cc @@ -61,7 +61,16 @@ class ShapeInferenceTest : public ::testing::Test { bool SameHandle(ShapeHandle a, ShapeHandle b) { return a.SameHandle(b); } bool IsSet(DimensionHandle d) { return d.IsSet(); } bool IsSet(ShapeHandle s) { return s.IsSet(); } + void Relax(InferenceContext* c, DimensionHandle d0, DimensionHandle d1, + DimensionHandle* out) { + c->Relax(d0, d1, out); + } + void Relax(InferenceContext* c, ShapeHandle s0, ShapeHandle s1, + ShapeHandle* out) { + c->Relax(s0, s1, out); + } void TestMergeHandles(bool input_not_output); + void TestRelaxHandles(bool input_not_output); static const int kVersion = 0; // used for graph-def version. }; @@ -495,7 +504,7 @@ TEST_F(ShapeInferenceTest, MergeDim) { EXPECT_TRUE(c.Merge(d2_b, d2, &out).ok()); EXPECT_TRUE(SameHandle(d2_b, out)); - // Merging inequal values is an error. + // Merging unequal values is an error. EXPECT_TRUE( StringPiece(c.Merge(d2, d1, &out).ToString()) .contains( @@ -510,6 +519,122 @@ TEST_F(ShapeInferenceTest, MergeDim) { EXPECT_FALSE(IsSet(out)); } +TEST_F(ShapeInferenceTest, RelaxDim) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(1, 2), + {S({2, InferenceContext::kUnknownDim, 2, 1, + InferenceContext::kUnknownDim})}, + {}, {}, {}); + + auto d2 = c.Dim(c.input(0), 0); + auto d_unknown = c.Dim(c.input(0), 1); + auto d2_b = c.Dim(c.input(0), 2); + auto d1 = c.Dim(c.input(0), 3); + auto d_unknown_b = c.Dim(c.input(0), 4); + DimensionHandle out; + + // Relaxing anything with unknown returns a new unknown. + Relax(&c, d2, d_unknown, &out); + EXPECT_FALSE(SameHandle(d_unknown, out)); + EXPECT_FALSE(SameHandle(d_unknown_b, out)); + EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out)); + Relax(&c, d_unknown, d2, &out); + EXPECT_FALSE(SameHandle(d_unknown, out)); + EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out)); + Relax(&c, d_unknown, d_unknown_b, &out); + EXPECT_FALSE(SameHandle(d_unknown, out)); + EXPECT_FALSE(SameHandle(d_unknown_b, out)); + EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out)); + + // Relaxing with self returns self. + Relax(&c, d2, d2, &out); + EXPECT_TRUE(SameHandle(d2, out)); + Relax(&c, d_unknown, d_unknown, &out); + EXPECT_TRUE(SameHandle(d_unknown, out)); + + // Relaxing equal values returns first one. + Relax(&c, d2, d2_b, &out); + EXPECT_TRUE(SameHandle(d2, out)); + Relax(&c, d2_b, d2, &out); + EXPECT_TRUE(SameHandle(d2_b, out)); + + // Relaxing unequal values returns a new unknown. + Relax(&c, d2, d1, &out); + EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out)); + Relax(&c, d1, d2, &out); + EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out)); +} + +TEST_F(ShapeInferenceTest, RelaxShape) { + NodeDef def; + InferenceContext c( + kVersion, &def, MakeOpDef(7, 2), + {Unknown(), S({1, 2}), S({InferenceContext::kUnknownDim, 2}), + S({1, InferenceContext::kUnknownDim}), S({1, 3}), Unknown(), S({1})}, + {}, {}, {}); + + auto s_unknown = c.input(0); + auto s_1_2 = c.input(1); + auto s_u_2 = c.input(2); + auto s_1_u = c.input(3); + auto s_1_3 = c.input(4); + auto s_unknown_b = c.input(5); + auto s_1 = c.input(6); + ShapeHandle out; + + // Relaxing any shape with unknown returns a new unknown. + Relax(&c, s_unknown, s_1_2, &out); + EXPECT_FALSE(SameHandle(s_u_2, s_unknown)); + EXPECT_EQ("?", c.DebugString(out)); + Relax(&c, s_u_2, s_unknown, &out); + EXPECT_FALSE(SameHandle(s_u_2, out)); + EXPECT_EQ("?", c.DebugString(out)); + Relax(&c, s_unknown, s_unknown_b, &out); + EXPECT_FALSE(SameHandle(s_unknown, out)); + EXPECT_FALSE(SameHandle(s_unknown_b, out)); + EXPECT_EQ("?", c.DebugString(out)); + + // Relaxing with self returns self. + Relax(&c, s_1_2, s_1_2, &out); + EXPECT_TRUE(SameHandle(out, s_1_2)); + + // Relaxing where one of the inputs has less information. + out = ShapeHandle(); + Relax(&c, s_1_2, s_u_2, &out); + EXPECT_FALSE(SameHandle(s_u_2, out)); + EXPECT_EQ("[?,2]", c.DebugString(out)); + out = ShapeHandle(); + Relax(&c, s_u_2, s_1_2, &out); + EXPECT_FALSE(SameHandle(s_u_2, out)); + EXPECT_EQ("[?,2]", c.DebugString(out)); + + // Relaxing where each input has one distinct unknown dimension. + Relax(&c, s_u_2, s_1_u, &out); + EXPECT_EQ("[?,?]", c.DebugString(out)); + EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0))); + EXPECT_FALSE(SameHandle(c.Dim(s_1_u, 1), c.Dim(out, 1))); + auto s_u1 = c.UnknownShapeOfRank(1); + auto s_u2 = c.UnknownShapeOfRank(1); + Relax(&c, s_u1, s_u2, &out); + EXPECT_FALSE(SameHandle(s_u1, out)); + + // Relaxing with mismatched values in a dimension returns a shape with that + // dimension unknown. + out = s_unknown; + Relax(&c, s_u_2, s_1_3, &out); + EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0))); + EXPECT_EQ("[?,?]", c.DebugString(out)); + out = s_unknown; + Relax(&c, s_1_3, s_u_2, &out); + EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0))); + EXPECT_EQ("[?,?]", c.DebugString(out)); + out = s_unknown; + + // Relaxing with mismatched ranks returns a new unknown. + Relax(&c, s_1, s_1_2, &out); + EXPECT_EQ("?", c.DebugString(out)); +} + TEST_F(ShapeInferenceTest, MergeShape) { NodeDef def; InferenceContext c(kVersion, &def, MakeOpDef(7, 2), @@ -1473,8 +1598,8 @@ void ShapeInferenceTest::TestMergeHandles(bool input_not_output) { EXPECT_EQ(t[i].dtype, v[i].dtype); } - // Only difference is in a mismatched dtype. That is ignored, - // and there are no other changes, so nothing is done. + // Only difference is in a mismatched dtype, but that cannot be + // updated unless original dtype is DT_INVALID. t2 = t; t2[2].dtype = DT_FLOAT; ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2)); @@ -1510,11 +1635,111 @@ void ShapeInferenceTest::TestMergeHandles(bool input_not_output) { } TEST_F(ShapeInferenceTest, MergeInputHandleShapesAndTypes) { - TestMergeHandles(true); + TestMergeHandles(true /* input_not_output */); } TEST_F(ShapeInferenceTest, MergeOutputHandleShapesAndTypes) { - TestMergeHandles(false); + TestMergeHandles(false /* input_not_output */); +} + +void ShapeInferenceTest::TestRelaxHandles(bool input_not_output) { + NodeDef def; + InferenceContext c(kVersion, &def, MakeOpDef(2, 2), {S({}), S({})}, {}, {}, + {}); + auto make_shape = [&c](std::initializer_list dim_sizes) { + ShapeHandle s; + TF_CHECK_OK(c.MakeShapeFromShapeProto(S(dim_sizes), &s)); + return s; + }; + auto get_shapes_and_types_from_context = [&](int idx) { + if (input_not_output) { + return c.input_handle_shapes_and_types(idx); + } else { + return c.output_handle_shapes_and_types(idx); + } + }; + auto relax_shapes_and_types_to_context = + [&](int idx, const std::vector& shapes_and_types) { + if (input_not_output) { + return c.RelaxInputHandleShapesAndMergeTypes(idx, shapes_and_types); + } else { + return c.RelaxOutputHandleShapesAndMergeTypes(idx, shapes_and_types); + } + }; + + EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr); + EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr); + + // First relax will take the input completely. + std::vector t{{make_shape({1, 2, 3}), DT_FLOAT}, + {c.UnknownShape(), DT_INVALID}, + {make_shape({4, 3, 2, 1}), DT_INT32}}; + ASSERT_TRUE(relax_shapes_and_types_to_context(0, t)); + ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr); + std::vector v = *get_shapes_and_types_from_context(0); + ASSERT_EQ(3, v.size()); + for (int i = 0; i < v.size(); ++i) { + EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; + EXPECT_EQ(t[i].dtype, v[i].dtype); + } + + // Relax that fails because wrong number of values passed. + // Fails, and no changes made. + ASSERT_FALSE(relax_shapes_and_types_to_context( + 0, std::vector{{make_shape({1, 2, 3}), DT_FLOAT}})); + v = *get_shapes_and_types_from_context(0); + ASSERT_EQ(3, v.size()); + for (int i = 0; i < v.size(); ++i) { + EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i; + EXPECT_EQ(t[i].dtype, v[i].dtype); + } + + // Only difference is in a mismatched shape. This should replace + // the mismatched dimension with an UnknownDim. + auto t2 = t; + t2[2].shape = make_shape({4, 3, 4, 1}); + ASSERT_TRUE(relax_shapes_and_types_to_context(0, t2)); + v = *get_shapes_and_types_from_context(0); + EXPECT_EQ("[4,3,?,1]", c.DebugString(v[2].shape)); + for (int i = 0; i < v.size(); ++i) { + EXPECT_EQ(t[i].dtype, v[i].dtype); + } + + // Only difference is in a mismatched dtype, but that cannot be + // updated unless original dtype is DT_INVALID. + t2 = t; + t2[2].dtype = DT_FLOAT; + ASSERT_FALSE(relax_shapes_and_types_to_context(0, t2)); + v = *get_shapes_and_types_from_context(0); + ASSERT_EQ(3, v.size()); + for (int i = 0; i < v.size(); ++i) { + EXPECT_EQ(t[i].dtype, v[i].dtype); + } + + // Difference is a new shape, which will result in a new UnknownShape. + t[1].shape = make_shape({1, 10}); + ASSERT_TRUE(relax_shapes_and_types_to_context(0, t)); + v = *get_shapes_and_types_from_context(0); + ASSERT_EQ(3, v.size()); + EXPECT_FALSE(SameHandle(t[1].shape, v[1].shape)); + EXPECT_EQ("?", c.DebugString(v[1].shape)); + for (int i = 0; i < v.size(); ++i) { + EXPECT_EQ(t[i].dtype, v[i].dtype); + } + + // Difference is relaxable (new type). + t[1].dtype = DT_DOUBLE; + ASSERT_TRUE(relax_shapes_and_types_to_context(0, t)); + v = *get_shapes_and_types_from_context(0); + EXPECT_EQ(t[1].dtype, v[1].dtype); +} + +TEST_F(ShapeInferenceTest, RelaxInputHandleShapesAndTypes) { + TestRelaxHandles(true /* input_not_output */); +} + +TEST_F(ShapeInferenceTest, RelaxOutputHandleShapesAndTypes) { + TestRelaxHandles(false /* input_not_output */); } } // namespace shape_inference diff --git a/tensorflow/core/framework/shape_inference_testutil.cc b/tensorflow/core/framework/shape_inference_testutil.cc index 7b3cd07429b..b4765ab0b2c 100644 --- a/tensorflow/core/framework/shape_inference_testutil.cc +++ b/tensorflow/core/framework/shape_inference_testutil.cc @@ -43,9 +43,26 @@ Status ShapeInferenceTestutil::InferShapes(ShapeInferenceTestOp op, in_shapes.push_back(shape); } - shape_inference::InferenceContext c(op.graph_def_version, &op.node_def, - op_reg_data->op_def, in_shapes, - op.input_tensors, {}, {}); + std::vector>> + input_resource_handle_shapes_and_types; + for (const auto p : op.input_resource_handle_shapes_and_types) { + if (p == nullptr) { + input_resource_handle_shapes_and_types.push_back(nullptr); + } else { + std::unique_ptr> v( + new std::vector()); + for (const auto& shape_and_type : *p) { + ShapeHandle shape; + TF_RETURN_IF_ERROR( + MakeShapeFromString(&manager, shape_and_type.first, &shape)); + v->emplace_back(shape, shape_and_type.second); + } + input_resource_handle_shapes_and_types.emplace_back(v.release()); + } + } + shape_inference::InferenceContext c( + op.graph_def_version, &op.node_def, op_reg_data->op_def, in_shapes, + op.input_tensors, {}, std::move(input_resource_handle_shapes_and_types)); TF_RETURN_IF_ERROR(c.construction_status()); if (op_reg_data->shape_inference_fn == nullptr) { return errors::InvalidArgument( diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h index 996281e70e6..6bd2cd42915 100644 --- a/tensorflow/core/framework/shape_inference_testutil.h +++ b/tensorflow/core/framework/shape_inference_testutil.h @@ -31,10 +31,13 @@ class NodeDef; class Tensor; struct ShapeInferenceTestOp { + typedef std::pair ShapeAndType; explicit ShapeInferenceTestOp(StringPiece name) : name(name.ToString()) {} string name; NodeDef node_def; std::vector input_tensors; + std::vector*> + input_resource_handle_shapes_and_types; int graph_def_version = TF_GRAPH_DEF_VERSION; }; diff --git a/tensorflow/core/graph/optimizer_cse.cc b/tensorflow/core/graph/optimizer_cse.cc index a22a9b3fa31..54cfd10cdfd 100644 --- a/tensorflow/core/graph/optimizer_cse.cc +++ b/tensorflow/core/graph/optimizer_cse.cc @@ -187,6 +187,12 @@ bool OptimizerCSE::Optimize( for (Node* n : order) { if (!n->IsOp()) continue; + // Don't prune placeholder nodes. + if (n->def().op() == "Placeholder" || n->def().op() == "PlaceholderV2" || + n->def().op() == "PlaceholderWithDefault") { + continue; + } + // See if we should consider this node at all if (consider_fn != nullptr && !consider_fn(n)) continue; @@ -204,6 +210,7 @@ bool OptimizerCSE::Optimize( for (const Edge* e : n->out_edges()) { g_->AddEdge(*candidate, e->src_output(), e->dst(), e->dst_input()); } + g_->RemoveNode(n); changed = true; } diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD index fd2f2b32492..3cdddeb38f1 100644 --- a/tensorflow/core/grappler/clusters/BUILD +++ b/tensorflow/core/grappler/clusters/BUILD @@ -65,7 +65,23 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":cluster", + "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler/costs:op_level_cost_estimator", + "//tensorflow/core/grappler/costs:virtual_scheduler", + ], +) + +cc_test( + name = "virtual_cluster_test", + srcs = ["virtual_cluster_test.cc"], + deps = [ + ":virtual_cluster", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", ], ) diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc index 8690d9f24ad..dec51842e45 100644 --- a/tensorflow/core/grappler/clusters/cluster.cc +++ b/tensorflow/core/grappler/clusters/cluster.cc @@ -14,27 +14,15 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/clusters/cluster.h" -#include namespace tensorflow { namespace grappler { -static std::atomic already_created(false); - Cluster::Cluster(int timeout_s) : timeout_s_(timeout_s) { - // This is really ugly: to avoid leaking variables, we need to reset the tf - // session every time we're done processing a grappler item. However, - // variables are global, and therefore we can't have more than 1 session alive - // at a time. This check detects when more that one cluster is created. - CHECK(!already_created); - already_created = true; - DisableDetailedStats(false); } Cluster::~Cluster() { - CHECK(already_created); - already_created = false; } void Cluster::AllowSoftPlacement(bool soft_placement_state) { diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc index 22ccf5208c1..58b9fc6429a 100644 --- a/tensorflow/core/grappler/clusters/single_machine.cc +++ b/tensorflow/core/grappler/clusters/single_machine.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/grappler/clusters/single_machine.h" +#include #include #include "tensorflow/cc/training/queue_runner.h" @@ -31,11 +32,22 @@ limitations under the License. namespace tensorflow { namespace grappler { +static std::atomic already_created(false); + SingleMachine::SingleMachine(int timeout_s, int num_cpu_cores, int num_gpus) : Cluster(timeout_s), num_gpus_(num_gpus), expected_init_time_s_(0), closing_(false) { + // This is really ugly: to avoid leaking variables, we need to reset the tf + // session every time we're done processing a grappler item. However, + // variables are global, and therefore we can't have more than 1 session alive + // at a time. This check detects when more that one cluster is created. + CHECK(!already_created); + already_created = true; + + VLOG(1) << "Number of CPU cores: " << num_cpu_cores + << " Number of GPUs: " << num_gpus; thread_pool_.reset(new thread::ThreadPool( Env::Default(), SanitizeThreadSuffix("single_machine"), 2)); @@ -62,6 +74,9 @@ SingleMachine::~SingleMachine() { thread_pool_.reset(); Reset(options_, {}).IgnoreError(); + + CHECK(already_created); + already_created = false; } Status SingleMachine::Provision() { @@ -73,9 +88,12 @@ Status SingleMachine::Provision() { DeviceProperties attr = GetLocalCPUInfo(); devices_["/job:localhost/replica:0/task:0/cpu:0"] = GetLocalCPUInfo(); + VLOG(1) << "Number of GPUs: " << num_gpus_; for (int i = 0; i < num_gpus_; ++i) { - devices_[strings::StrCat("/job:localhost/replica:0/task:0/gpu:", i)] = - GetLocalGPUInfo(i); + string device_name = + strings::StrCat("/job:localhost/replica:0/task:0/gpu:", i); + VLOG(1) << "Adding GPU device " << device_name; + devices_[device_name] = GetLocalGPUInfo(i); } return Status::OK(); } diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc index 4ca4c03dbb6..32d9be9aeee 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster.cc +++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc @@ -14,16 +14,27 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/clusters/virtual_cluster.h" +#include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" +#include "tensorflow/core/grappler/costs/virtual_scheduler.h" namespace tensorflow { namespace grappler { VirtualCluster::VirtualCluster( const std::unordered_map& devices) - : Cluster(0) { + : Cluster(0), node_estimator_(new OpLevelCostEstimator()) { devices_ = devices; } +VirtualCluster::VirtualCluster( + const std::unordered_map& devices, + OpLevelCostEstimator* node_estimator) + : Cluster(0), node_estimator_(node_estimator) { + devices_ = devices; +} VirtualCluster::~VirtualCluster() {} Status VirtualCluster::Provision() { return Status::OK(); } @@ -32,12 +43,60 @@ Status VirtualCluster::Initialize(const GrapplerItem& item) { return Status::OK(); } -Status VirtualCluster::Run(const GraphDef& item, +Status VirtualCluster::Run(const GraphDef& graph, const std::vector>& feed, const std::vector& fetch, RunMetadata* metadata) { - return Status::OK(); + // Initialize a virtual scheduler to process the graph. Make sure to use + // static shape inference to prevent the schedulrer from calling the Run + // method on the cluster, and create an infinite loop. + GrapplerItem item; + item.graph = graph; + item.feed = feed; + item.fetch = fetch; + VirtualScheduler scheduler(&item, true, this); + TF_RETURN_IF_ERROR(scheduler.Init()); + if (metadata) { + metadata->clear_step_stats(); + metadata->clear_cost_graph(); + } + + Costs node_costs; + do { + NodeInfo node_info = scheduler.GetCurrNodeInfo(); + const auto& op_info = node_info.op_info; + node_costs = node_estimator_->PredictCosts(op_info); + if (metadata) { + CostGraphDef::Node* cost_node = + metadata->mutable_cost_graph()->add_node(); + const string& op_name = node_info.name; + cost_node->set_name(op_name); + cost_node->set_device(node_info.device_name); + cost_node->set_compute_cost( + node_costs.execution_time.asMicroSeconds().count()); + cost_node->set_compute_time( + node_costs.compute_time.asMicroSeconds().count()); + cost_node->set_memory_time( + node_costs.memory_time.asMicroSeconds().count()); + for (const auto& output : node_info.op_info.outputs()) { + auto output_info = cost_node->add_output_info(); + output_info->set_dtype(output.dtype()); + *output_info->mutable_shape() = output.shape(); + + int64 size = DataTypeSize(output.dtype()); + for (const auto& dim : output.shape().dim()) { + size *= std::max(1, dim.size()); + } + output_info->set_size(size); + } + } + } while (scheduler.MarkCurrNodeExecuted(node_costs)); + + if (metadata) { + scheduler.Summary(metadata->mutable_step_stats()); + } + return Status::OK(); } } // namespace grappler diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.h b/tensorflow/core/grappler/clusters/virtual_cluster.h index cd8436a9870..a74911cb23a 100644 --- a/tensorflow/core/grappler/clusters/virtual_cluster.h +++ b/tensorflow/core/grappler/clusters/virtual_cluster.h @@ -18,18 +18,20 @@ limitations under the License. #include #include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" #include "tensorflow/core/protobuf/device_properties.pb.h" namespace tensorflow { namespace grappler { // Create a simple cluster that lists the devices (and their properties) -// available in a TensorFlow session. This cluster doesn't allow running an -// actual graph. It is useful however when used in conjusction with costs models -// that aren't based on the execution of the graph. +// available in a TensorFlow session. This cluster simulates the execution of +// actual graphs. class VirtualCluster : public Cluster { public: VirtualCluster(const std::unordered_map& devices); + VirtualCluster(const std::unordered_map& devices, + OpLevelCostEstimator* node_estimator); ~VirtualCluster() override; @@ -38,6 +40,9 @@ class VirtualCluster : public Cluster { Status Run(const GraphDef& item, const std::vector>& feed, const std::vector& fetch, RunMetadata* metadata) override; + + private: + std::unique_ptr node_estimator_; }; } // end namespace grappler diff --git a/tensorflow/core/grappler/clusters/virtual_cluster_test.cc b/tensorflow/core/grappler/clusters/virtual_cluster_test.cc new file mode 100644 index 00000000000..6f25e7b0d4d --- /dev/null +++ b/tensorflow/core/grappler/clusters/virtual_cluster_test.cc @@ -0,0 +1,95 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/clusters/virtual_cluster.h" +#include "tensorflow/core/framework/cost_graph.pb.h" +#include "tensorflow/core/framework/step_stats.pb.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class VirtualClusterTest : public ::testing::Test { + public: + void SetUp() override { + // Invent a CPU so that predictions remain the same from machine to machine. + DeviceProperties cpu_device; + cpu_device.set_type("CPU"); + cpu_device.set_frequency(1000); + cpu_device.set_num_cores(4); + cpu_device.set_bandwidth(32); + cpu_device.set_l1_cache_size(32 * 1024); + cpu_device.set_l2_cache_size(256 * 1024); + cpu_device.set_l3_cache_size(4 * 1024 * 1024); + std::unordered_map devices; + devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device; + cluster_.reset(new VirtualCluster(devices)); + TF_CHECK_OK(cluster_->Provision()); + } + + void TearDown() override { cluster_.reset(); } + + protected: + std::unique_ptr cluster_; +}; + +TEST_F(VirtualClusterTest, CostModel) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, + cluster_->GetDeviceNames()); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + TF_CHECK_OK(cluster_->Initialize(item)); + + RunMetadata metadata; + TF_CHECK_OK(cluster_->Run(item.graph, item.feed, item.fetch, &metadata)); + + // There should be at least 4 nodes corresponding to the 4 stages we created + // in the fake input. + EXPECT_LE(4, metadata.cost_graph().node_size()); + for (const auto& node : metadata.cost_graph().node()) { + // Skip the constant node that configures the random number generator. + if (node.name().find("Const/Const") != string::npos) { + continue; + } + EXPECT_EQ(1, node.output_info_size()); + EXPECT_EQ(40, node.output_info(0).size()); + const TensorShapeProto& shape = node.output_info(0).shape(); + EXPECT_EQ(2, shape.dim_size()); + EXPECT_EQ(10, shape.dim(0).size()); + EXPECT_EQ(1, shape.dim(1).size()); + if (node.name() == "x") { + EXPECT_EQ(1500, node.compute_cost()); + } else { + EXPECT_EQ(2500, node.compute_cost()); + } + } + + for (const auto& dev_stat : metadata.step_stats().dev_stats()) { + EXPECT_EQ("/job:localhost/replica:0/task:0/cpu:0", dev_stat.device()); + for (const auto& node : dev_stat.node_stats()) { + if (node.node_name() == "AddN") { + EXPECT_EQ(2500, node.op_end_rel_micros()); + } + } + } +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 2b30facd84d..96d37c9a97e 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -42,6 +42,7 @@ cc_library( ":op_performance_data_cc", ":utils", "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/clusters:cluster", @@ -160,7 +161,7 @@ cc_test( srcs = ["virtual_placer_test.cc"], deps = [ ":virtual_placer", - "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -264,7 +265,6 @@ cc_library( ":virtual_scheduler", "//tensorflow/core:core_cpu_base", "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", ], ) diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 21b73b6618d..7ac35ef271c 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -32,9 +32,108 @@ using shape_inference::ShapeHandle; namespace { -// Merges shapes , determined from an EnqueueV2 node, into -// <*queue_shapes_and_types>. -Status MergeEnqueueShapesAndTypes( +// If a Merge node has a NextIteration node as an input then that input will +// try to forward an UnknownShape at graph construction time. However, the +// Merge shape function will always propagate an UnknownShape if any of its +// inputs are UnknownShapes. So we need to ignore the input from NextIteration +// nodes to propagate any known shape from the Merge node. +Status ShapeOfMergeNode(const Node* node, InferenceContext* c) { + ShapeHandle out = c->input(0); + if (!c->RankKnown(out)) { + out = c->UnknownShape(); + } else { + int32 rank = c->Rank(out); + for (const Edge* e : node->in_edges()) { + if (e->src()->IsNextIteration() || e->dst_input() <= 0) { + continue; + } + ShapeHandle input = c->input(e->dst_input()); + if (!c->RankKnown(input) || c->Rank(input) != rank) { + out = c->UnknownShape(); + break; + } + + for (int d = 0; d < rank; ++d) { + if (c->Value(c->Dim(input, d)) != c->Value(c->Dim(out, d))) { + TF_RETURN_IF_ERROR(c->ReplaceDim(out, d, c->UnknownDim(), &out)); + } + } + } + } + c->set_output(0, out); + c->set_output(1, c->Scalar()); + return Status::OK(); +} + +// Manually propagate the input shape for Enter nodes and update any Merge node +// outputs. +Status UpdateEnter(ShapeRefiner* shape_refiner, const Node* node, bool relax, + std::queue* new_shapes) { + auto enter_ctx = shape_refiner->GetContext(node); + for (int i = 0; i < enter_ctx->num_outputs(); i++) { + TF_RETURN_IF_ERROR(shape_refiner->SetShape(node, i, enter_ctx->input(0))); + } + for (const Edge* e : node->out_edges()) { + Node* dst = e->dst(); + if (dst->IsMerge()) { + bool updated = false; + TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(dst, relax, &updated)); + if (!updated) { + continue; + } + InferenceContext* merge_ctx = shape_refiner->GetContext(dst); + DCHECK_NE(merge_ctx, nullptr); + TF_RETURN_IF_ERROR(ShapeOfMergeNode(dst, merge_ctx)); + new_shapes->push(dst); + } + } + return Status::OK(); +} + +// Propagates the shapes in the transitive fan-out of . +Status PropagateShapes(ShapeRefiner* shape_refiner, bool relax, + std::queue* new_shapes) { + while (!new_shapes->empty()) { + const Node* n = new_shapes->front(); + new_shapes->pop(); + for (const Node* fanout : n->out_nodes()) { + bool updated = false; + TF_RETURN_IF_ERROR(shape_refiner->UpdateNode(fanout, relax, &updated)); + if (fanout->IsEnter()) { + TF_RETURN_IF_ERROR( + UpdateEnter(shape_refiner, fanout, relax, new_shapes)); + } else if (updated) { + // We want to avoid propagating through loops on the merge pass because + // the shapes are not guaranteed to converge. + if (!relax && fanout->IsNextIteration()) { + continue; + } + new_shapes->push(fanout); + } + } + } + return Status::OK(); +} + +} // namespace + +void GraphProperties::Relax(InferenceContext* c, ShapeHandle s0, ShapeHandle s1, + ShapeHandle* out) { + c->Relax(s0, s1, out); +} + +bool GraphProperties::SameDefinedShape(InferenceContext* c, ShapeHandle s0, + ShapeHandle s1) { + return ShapeRefiner::SameDefinedShape(c, s0, s1); +} + +bool GraphProperties::IsUpdatedShapesOrTypes( + InferenceContext* c, const std::vector& existing, + const std::vector& updated) { + return ShapeRefiner::IsUpdatedShapesOrTypes(c, existing, updated); +} + +Status GraphProperties::MergeEnqueueShapesAndTypes( const std::vector& shapes_and_types, InferenceContext* qctx, std::vector* queue_shapes_and_types) { if (shapes_and_types.size() != queue_shapes_and_types->size()) { @@ -56,7 +155,27 @@ Status MergeEnqueueShapesAndTypes( return Status::OK(); } -} // namespace +Status GraphProperties::RelaxEnqueueShapesAndMergeTypes( + const std::vector& shapes_and_types, InferenceContext* qctx, + std::vector* queue_shapes_and_types) { + if (shapes_and_types.size() != queue_shapes_and_types->size()) { + return errors::InvalidArgument( + "Enqueue nodes mixed number of tensors: ", shapes_and_types.size(), + " vs ", queue_shapes_and_types->size()); + } + for (int i = 0; i < shapes_and_types.size(); ++i) { + const ShapeAndType& a = shapes_and_types[i]; + ShapeAndType& b = (*queue_shapes_and_types)[i]; + if (a.dtype != b.dtype) { + return errors::InvalidArgument("Enqueue nodes mixed dtypes for tensor ", + i, ": ", DataTypeString(a.dtype), " vs ", + DataTypeString(b.dtype)); + } + + Relax(qctx, a.shape, b.shape, &b.shape); + } + return Status::OK(); +} Status GraphProperties::InferStatically() { Graph graph(OpRegistry::Global()); @@ -66,8 +185,11 @@ Status GraphProperties::InferStatically() { Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner); TF_RETURN_IF_ERROR(s); - // List the resources and the nodes using them + // List the resources and the nodes using them. Also collect the Enter and + // Merge nodes. std::unordered_map> resources; + std::unordered_set enter_nodes; + std::unordered_set merge_nodes; for (const Node* const node : graph.nodes()) { for (int i = 0; i < node->num_inputs(); ++i) { if (node->input_type(i) == DataType::DT_RESOURCE) { @@ -76,82 +198,142 @@ Status GraphProperties::InferStatically() { resources[resource].insert(node); } } + if (node->IsEnter()) { + enter_nodes.insert(node); + } else if (node->IsMerge()) { + merge_nodes.insert(node); + } } - // If we found a resource, try to propagate the shapes through it. - bool done = true; - do { - std::queue new_shapes; - for (const auto& resource_data : resources) { - const Node* qnode = resource_data.first; - StringPiece type(qnode->type_string()); - if (!type.ends_with("QueueV2")) { - continue; + // Propagate the initial shapes of Enter nodes manually (the Enter shape + // function always forwards an UnknownShape). + std::queue new_shapes; + for (const Node* node : enter_nodes) { + TF_RETURN_IF_ERROR( + UpdateEnter(&shape_refiner, node, false /* relax */, &new_shapes)); + } + TF_RETURN_IF_ERROR( + PropagateShapes(&shape_refiner, false /* relax */, &new_shapes)); + + // We propagate shapes through the graph in two phases. In the first phase, we + // exclusively merge shapes but we do not propagate shapes through loops. Then + // on the second phase, we exclusively relax shapes and propagate shapes + // through loops until reaching fixed point. + for (int relax = 0; relax < 2; relax++) { + // We don't update Merge nodes with the input of NextIteration nodes on the + // merge pass. So we do that at the beginning of the relax pass instead. + if (relax) { + bool updated = false; + for (const Node* node : merge_nodes) { + TF_RETURN_IF_ERROR( + shape_refiner.UpdateNode(node, false /* relax */, &updated)); } - auto qctx = shape_refiner.GetContext(qnode); - if (!qctx) { - continue; + } + + bool done = true; + do { + if (relax) { + // Propagate shapes through any loops in the graph by relaxing. + for (const Node* node : merge_nodes) { + new_shapes.push(node); + } + TF_RETURN_IF_ERROR(PropagateShapes(&shape_refiner, relax, &new_shapes)); } - // Check to see if the shape is fully defined. - auto* queue_handle_data = qctx->output_handle_shapes_and_types(0); - if (queue_handle_data != nullptr) { - bool fully_defined = true; - for (const auto& shape_and_type : *queue_handle_data) { - if (!qctx->FullyDefined(shape_and_type.shape) || - shape_and_type.dtype == DT_INVALID) { - fully_defined = false; - } - } - if (fully_defined) { + // If we found a resource, try to propagate the shapes through it. + new_shapes = std::queue(); + for (const auto& resource_data : resources) { + const Node* qnode = resource_data.first; + StringPiece type(qnode->type_string()); + if (!type.ends_with("QueueV2") && !qnode->IsEnter()) { continue; } - } - - std::vector queue_shapes_and_types; - if (queue_handle_data != nullptr) { - queue_shapes_and_types = *queue_handle_data; - } - for (const auto& node : resource_data.second) { - auto ctx = shape_refiner.GetContext(node); - if (!ctx) { + auto qctx = shape_refiner.GetContext(qnode); + if (!qctx) { continue; } - // TODO(bsteiner): handle EnqueueMany as well. - if (node->type_string().find("Enqueue") != std::string::npos && - node->type_string().find("EnqueueMany") == std::string::npos) { - std::vector shapes_and_types; - for (int i = 1; i < ctx->num_inputs(); ++i) { - shapes_and_types.push_back({ctx->input(i), node->input_type(i)}); - } - if (queue_shapes_and_types.empty()) { - queue_shapes_and_types = shapes_and_types; + // Check to see if the shape is fully defined. + auto* queue_handle_data = qctx->output_handle_shapes_and_types(0); + if (queue_handle_data != nullptr) { + bool fully_defined = true; + for (const auto& shape_and_type : *queue_handle_data) { + if (!qctx->FullyDefined(shape_and_type.shape) || + shape_and_type.dtype == DT_INVALID) { + fully_defined = false; + } + } + // If we are merging, then we are done. If we are relaxing, then we + // could potentially propagate a less specific shape. + if (fully_defined && !relax) { + continue; + } + } + + // Merge all inputs into the enqueue node, regardless of which phase we + // are in. + std::vector queue_shapes_and_types; + for (const auto& node : resource_data.second) { + auto ctx = shape_refiner.GetContext(node); + if (!ctx) { + continue; + } + // TODO(bsteiner): handle EnqueueMany as well. + if (node->type_string().find("Enqueue") != std::string::npos && + node->type_string().find("EnqueueMany") == std::string::npos) { + std::vector shapes_and_types; + for (int i = 1; i < ctx->num_inputs(); ++i) { + shapes_and_types.push_back({ctx->input(i), node->input_type(i)}); + } + + if (queue_shapes_and_types.empty()) { + queue_shapes_and_types = shapes_and_types; + } else { + TF_RETURN_IF_ERROR(MergeEnqueueShapesAndTypes( + shapes_and_types, qctx, &queue_shapes_and_types)); + } + } + } + // Combine the input shapes with the existing output shape. We either + // merge or relax depending on which phase we are in. + if (queue_handle_data != nullptr) { + if (relax) { + TF_RETURN_IF_ERROR(RelaxEnqueueShapesAndMergeTypes( + *queue_handle_data, qctx, &queue_shapes_and_types)); } else { TF_RETURN_IF_ERROR(MergeEnqueueShapesAndTypes( - shapes_and_types, qctx, &queue_shapes_and_types)); + *queue_handle_data, qctx, &queue_shapes_and_types)); + } + } + // Set the output ShapeAndType handles. If we successfully update the + // resource node, add its fan-out to the queue. + const std::vector* outputs = + qctx->output_handle_shapes_and_types(0); + std::vector existing_outputs; + if (outputs) { + existing_outputs = *outputs; + } + if (!queue_shapes_and_types.empty()) { + if (!relax && qctx->MergeOutputHandleShapesAndTypes( + 0, queue_shapes_and_types)) { + new_shapes.push(qnode); + } else if (relax && qctx->RelaxOutputHandleShapesAndMergeTypes( + 0, queue_shapes_and_types)) { + if (IsUpdatedShapesOrTypes( + qctx, existing_outputs, + *qctx->output_handle_shapes_and_types(0))) { + new_shapes.push(qnode); + } } } } - if (!queue_shapes_and_types.empty() && - qctx->MergeOutputHandleShapesAndTypes(0, queue_shapes_and_types)) { - new_shapes.push(qnode); + // Propagate the shapes in the transitive fan-out of the queue. + done = new_shapes.empty(); + if (!done) { + TF_RETURN_IF_ERROR(PropagateShapes(&shape_refiner, relax, &new_shapes)); } - } - // Propagate the shapes in the transitive fan-out of the queue. - done = new_shapes.empty(); - while (!new_shapes.empty()) { - const Node* n = new_shapes.front(); - new_shapes.pop(); - for (const Node* fanout : n->out_nodes()) { - bool updated = false; - TF_RETURN_IF_ERROR(shape_refiner.UpdateNode(fanout, &updated)); - if (updated) { - new_shapes.push(fanout); - } - } - } - } while (!done); + } while (!done); + } for (const Node* const node : graph.nodes()) { VLOG(1) << " " << node->name(); diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h index b849c4b3f04..954aeb3905b 100644 --- a/tensorflow/core/grappler/costs/graph_properties.h +++ b/tensorflow/core/grappler/costs/graph_properties.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/costs/op_performance_data.pb.h" #include "tensorflow/core/grappler/grappler_item.h" @@ -49,6 +50,34 @@ class GraphProperties { GrapplerItem item_; std::map> input_properties_; std::map> output_properties_; + + // Merges shapes , determined from an EnqueueV2 node, into + // <*queue_shapes_and_types>. + Status MergeEnqueueShapesAndTypes( + const std::vector& shapes_and_types, + shape_inference::InferenceContext* qctx, + std::vector* queue_shapes_and_types); + // Relaxes shapes , determined from an EnqueueV2 node, into + // <*queue_shapes_and_types>. + Status RelaxEnqueueShapesAndMergeTypes( + const std::vector& shapes_and_types, + shape_inference::InferenceContext* qctx, + std::vector* queue_shapes_and_types); + + // This gives access to private function of InferenceContext. + static void Relax(shape_inference::InferenceContext* c, + shape_inference::ShapeHandle s0, + shape_inference::ShapeHandle s1, + shape_inference::ShapeHandle* out); + + // These give access to private functions of ShapeRefiner. + static bool SameDefinedShape(shape_inference::InferenceContext* c, + shape_inference::ShapeHandle s0, + shape_inference::ShapeHandle s1); + static bool IsUpdatedShapesOrTypes( + shape_inference::InferenceContext* c, + const std::vector& existing, + const std::vector& updated); }; } // end namespace grappler diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 10a88b59a2f..cc6f097cd04 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -309,17 +309,8 @@ TEST_F(GraphPropertiesTest, Queues) { EXPECT_EQ("float: [1,2,3]", PropToString(props5[2])); } -TEST_F(GraphPropertiesTest, Loops) { - // Test graph produced in python using: - /* - with tf.Graph().as_default(): - i = tf.constant(0) - c = lambda i: tf.less(i, 10) - b = lambda i: tf.add(i, 1) - r = tf.while_loop(c, b, [i]) - with open('/tmp/graph.txt', 'w') as f: - f.write(str(tf.get_default_graph().as_graph_def())) - */ +TEST_F(GraphPropertiesTest, WhileLoop) { + // Python code used to generate the graph is below. const string gdef_ascii = R"EOF( node { name: "Const" @@ -342,6 +333,33 @@ node { } } } +node { + name: "ones" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 2 + } + dim { + size: 2 + } + } + float_val: 1.0 + } + } + } +} node { name: "while/Enter" op: "Enter" @@ -371,6 +389,35 @@ node { } } } +node { + name: "while/Enter_1" + op: "Enter" + input: "ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "frame_name" + value { + s: "while/while/" + } + } + attr { + key: "is_constant" + value { + b: false + } + } + attr { + key: "parallel_iterations" + value { + i: 10 + } + } +} node { name: "while/Merge" op: "Merge" @@ -389,6 +436,24 @@ node { } } } +node { + name: "while/Merge_1" + op: "Merge" + input: "while/Enter_1" + input: "while/NextIteration_1" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} node { name: "while/Less/y" op: "Const" @@ -448,6 +513,26 @@ node { } } } +node { + name: "while/Switch_1" + op: "Switch" + input: "while/Merge_1" + input: "while/LoopCond" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@while/Merge_1" + } + } + } +} node { name: "while/Identity" op: "Identity" @@ -460,7 +545,18 @@ node { } } node { - name: "while/Add/y" + name: "while/Identity_1" + op: "Identity" + input: "while/Switch_1:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/add/y" op: "Const" input: "^while/Identity" attr { @@ -482,10 +578,10 @@ node { } } node { - name: "while/Add" + name: "while/add" op: "Add" input: "while/Identity" - input: "while/Add/y" + input: "while/add/y" attr { key: "T" value { @@ -493,10 +589,57 @@ node { } } } +node { + name: "while/concat/axis" + op: "Const" + input: "^while/Identity" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "while/concat" + op: "ConcatV2" + input: "while/Identity_1" + input: "while/Identity_1" + input: "while/concat/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } +} node { name: "while/NextIteration" op: "NextIteration" - input: "while/Add" + input: "while/add" attr { key: "T" value { @@ -504,6 +647,17 @@ node { } } } +node { + name: "while/NextIteration_1" + op: "NextIteration" + input: "while/concat" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} node { name: "while/Exit" op: "Exit" @@ -515,21 +669,2333 @@ node { } } } +node { + name: "while/Exit_1" + op: "Exit" + input: "while/Switch_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} versions { - producer: 11 + producer: 21 } )EOF"; + // Test graph produced in python using: + /* + with tf.Graph().as_default(): + i0 = tf.constant(0) + m0 = tf.ones([2, 2]) + c = lambda i, m: i < 10 + b = lambda i, m: [i+1, tf.concat([m, m], axis=0)] + r = tf.while_loop( + c, b, loop_vars=[i0, m0], + shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) + with open('/tmp/graph.pbtxt', 'w') as f: + f.write(str(tf.get_default_graph().as_graph_def())) + */ + GrapplerItem item; CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph)); GraphProperties properties(item); TF_CHECK_OK(properties.InferStatically()); - const auto props = properties.GetOutputProperties("while/Exit"); - EXPECT_EQ(1, props.size()); + std::vector nodes{"while/Merge_1", "while/NextIteration_1", + "while/Exit_1"}; + for (const string& node : nodes) { + const auto props = properties.GetOutputProperties(node); + const OpInfo::TensorProperties& prop = props[0]; + EXPECT_EQ(DT_FLOAT, prop.dtype()); + EXPECT_EQ("float: [-1,2]", PropToString(prop)); + } +} + +TEST_F(GraphPropertiesTest, NestedLoop) { + // Python code used to generate the graph is below. + const string gdef_ascii = R"EOF( +node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "ones" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 1 + } + } + float_val: 1.0 + } + } + } +} +node { + name: "while/Enter" + op: "Enter" + input: "Const" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "frame_name" + value { + s: "while/while/" + } + } + attr { + key: "is_constant" + value { + b: false + } + } + attr { + key: "parallel_iterations" + value { + i: 10 + } + } +} +node { + name: "while/Enter_1" + op: "Enter" + input: "ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "frame_name" + value { + s: "while/while/" + } + } + attr { + key: "is_constant" + value { + b: false + } + } + attr { + key: "parallel_iterations" + value { + i: 10 + } + } +} +node { + name: "while/Merge" + op: "Merge" + input: "while/Enter" + input: "while/NextIteration" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/Merge_1" + op: "Merge" + input: "while/Enter_1" + input: "while/NextIteration_1" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/Less/y" + op: "Const" + input: "^while/Merge" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } +} +node { + name: "while/Less" + op: "Less" + input: "while/Merge" + input: "while/Less/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/LoopCond" + op: "LoopCond" + input: "while/Less" +} +node { + name: "while/Switch" + op: "Switch" + input: "while/Merge" + input: "while/LoopCond" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@while/Merge" + } + } + } +} +node { + name: "while/Switch_1" + op: "Switch" + input: "while/Merge_1" + input: "while/LoopCond" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@while/Merge_1" + } + } + } +} +node { + name: "while/Identity" + op: "Identity" + input: "while/Switch:1" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/Identity_1" + op: "Identity" + input: "while/Switch_1:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/while/Const" + op: "Const" + input: "^while/Identity" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "while/while/Enter" + op: "Enter" + input: "while/while/Const" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "frame_name" + value { + s: "while/while/while/while/" + } + } + attr { + key: "is_constant" + value { + b: false + } + } + attr { + key: "parallel_iterations" + value { + i: 10 + } + } +} +node { + name: "while/while/Enter_1" + op: "Enter" + input: "while/Identity_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "frame_name" + value { + s: "while/while/while/while/" + } + } + attr { + key: "is_constant" + value { + b: false + } + } + attr { + key: "parallel_iterations" + value { + i: 10 + } + } +} +node { + name: "while/while/Merge" + op: "Merge" + input: "while/while/Enter" + input: "while/while/NextIteration" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/while/Merge_1" + op: "Merge" + input: "while/while/Enter_1" + input: "while/while/NextIteration_1" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/while/Less/y" + op: "Const" + input: "^while/while/Merge" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } +} +node { + name: "while/while/Less" + op: "Less" + input: "while/while/Merge" + input: "while/while/Less/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/while/LoopCond" + op: "LoopCond" + input: "while/while/Less" +} +node { + name: "while/while/Switch" + op: "Switch" + input: "while/while/Merge" + input: "while/while/LoopCond" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@while/while/Merge" + } + } + } +} +node { + name: "while/while/Switch_1" + op: "Switch" + input: "while/while/Merge_1" + input: "while/while/LoopCond" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@while/while/Merge_1" + } + } + } +} +node { + name: "while/while/Identity" + op: "Identity" + input: "while/while/Switch:1" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/while/Identity_1" + op: "Identity" + input: "while/while/Switch_1:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/while/add/y" + op: "Const" + input: "^while/while/Identity" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "while/while/add" + op: "Add" + input: "while/while/Identity" + input: "while/while/add/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/while/concat/axis" + op: "Const" + input: "^while/while/Identity" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "while/while/concat" + op: "ConcatV2" + input: "while/while/Identity_1" + input: "while/while/Identity_1" + input: "while/while/concat/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } +} +node { + name: "while/while/NextIteration" + op: "NextIteration" + input: "while/while/add" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/while/NextIteration_1" + op: "NextIteration" + input: "while/while/concat" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/while/Exit" + op: "Exit" + input: "while/while/Switch" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/while/Exit_1" + op: "Exit" + input: "while/while/Switch_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/add/y" + op: "Const" + input: "^while/Identity" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "while/add" + op: "Add" + input: "while/Identity" + input: "while/add/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/concat/axis" + op: "Const" + input: "^while/Identity" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "while/concat" + op: "ConcatV2" + input: "while/Identity_1" + input: "while/Identity_1" + input: "while/concat/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } +} +node { + name: "while/NextIteration" + op: "NextIteration" + input: "while/add" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/NextIteration_1" + op: "NextIteration" + input: "while/concat" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/Exit" + op: "Exit" + input: "while/Switch" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/Exit_1" + op: "Exit" + input: "while/Switch_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +versions { + producer: 21 +} + )EOF"; + + // Test graph produced in python using: + /* + with tf.Graph().as_default(): + i0 = tf.constant(0) + + def inner(j, y): + def inner_cond(j, y): + return j < 3 + + def inner_body(j, y): + return j+1, tf.concat([y, y], axis=2) + + return tf.while_loop(inner_cond, inner_body, loop_vars=[j, y], + shape_invariants=[i0.get_shape(), + tf.TensorShape([None, 1, None])]) + + def outer_cond(i, x): + return i < 3 + + def outer_body(i, x): + j, y = inner(0, x) + return i+1, tf.concat([x, x], axis=0) + + r = tf.while_loop(outer_cond, outer_body, + loop_vars=[i0, tf.ones([1, 1, 1])], + shape_invariants=[i0.get_shape(), + tf.TensorShape([None, 1, None])]) + + with open('/tmp/graph.pbtxt', 'w') as f: + f.write(str(tf.get_default_graph().as_graph_def())) + */ + + GrapplerItem item; + CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph)); + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically()); + + std::vector outer_nodes{"while/Merge_1", "while/NextIteration_1", + "while/Exit_1"}; + std::vector inner_nodes{"while/while/Merge_1", + "while/while/NextIteration_1", + "while/while/Exit_1"}; + for (const string& node : outer_nodes) { + const auto props = properties.GetOutputProperties(node); + const OpInfo::TensorProperties& prop = props[0]; + EXPECT_EQ(DT_FLOAT, prop.dtype()); + EXPECT_EQ("float: [-1,1,1]", PropToString(prop)); + } + for (const string& node : inner_nodes) { + const auto props = properties.GetOutputProperties(node); + const OpInfo::TensorProperties& prop = props[0]; + EXPECT_EQ(DT_FLOAT, prop.dtype()); + EXPECT_EQ("float: [-1,1,-1]", PropToString(prop)); + } +} + +TEST_F(GraphPropertiesTest, LoopsAndQueues) { + // Python code used to generate the graph is below. + const string gdef_ascii = R"EOF( +node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "fifo_queue" + op: "FIFOQueueV2" + attr { + key: "capacity" + value { + i: 1 + } + } + attr { + key: "component_types" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "shapes" + value { + list { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "ones" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + dim { + size: 1 + } + } + float_val: 1.0 + } + } + } +} +node { + name: "while/Enter" + op: "Enter" + input: "Const" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "frame_name" + value { + s: "while/while/" + } + } + attr { + key: "is_constant" + value { + b: false + } + } + attr { + key: "parallel_iterations" + value { + i: 10 + } + } +} +node { + name: "while/Enter_1" + op: "Enter" + input: "ones" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "frame_name" + value { + s: "while/while/" + } + } + attr { + key: "is_constant" + value { + b: false + } + } + attr { + key: "parallel_iterations" + value { + i: 10 + } + } +} +node { + name: "while/Merge" + op: "Merge" + input: "while/Enter" + input: "while/NextIteration" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/Merge_1" + op: "Merge" + input: "while/Enter_1" + input: "while/NextIteration_1" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/Less/y" + op: "Const" + input: "^while/Merge" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } +} +node { + name: "while/Less" + op: "Less" + input: "while/Merge" + input: "while/Less/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/LoopCond" + op: "LoopCond" + input: "while/Less" +} +node { + name: "while/Switch" + op: "Switch" + input: "while/Merge" + input: "while/LoopCond" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@while/Merge" + } + } + } +} +node { + name: "while/Switch_1" + op: "Switch" + input: "while/Merge_1" + input: "while/LoopCond" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@while/Merge_1" + } + } + } +} +node { + name: "while/Identity" + op: "Identity" + input: "while/Switch:1" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/Identity_1" + op: "Identity" + input: "while/Switch_1:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/fifo_queue_enqueue/Enter" + op: "Enter" + input: "fifo_queue" + attr { + key: "T" + value { + type: DT_RESOURCE + } + } + attr { + key: "frame_name" + value { + s: "while/while/" + } + } + attr { + key: "is_constant" + value { + b: true + } + } + attr { + key: "parallel_iterations" + value { + i: 10 + } + } +} +node { + name: "while/fifo_queue_enqueue" + op: "QueueEnqueueV2" + input: "while/fifo_queue_enqueue/Enter" + input: "while/Identity_1" + attr { + key: "Tcomponents" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "timeout_ms" + value { + i: -1 + } + } +} +node { + name: "while/concat/axis" + op: "Const" + input: "^while/Identity" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 2 + } + } + } +} +node { + name: "while/concat" + op: "ConcatV2" + input: "while/Identity_1" + input: "while/Identity_1" + input: "while/concat/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } +} +node { + name: "while/fifo_queue_Dequeue" + op: "QueueDequeueV2" + input: "while/fifo_queue_enqueue/Enter" + input: "^while/Identity" + attr { + key: "component_types" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "timeout_ms" + value { + i: -1 + } + } +} +node { + name: "while/while/Const" + op: "Const" + input: "^while/Identity" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "while/while/Enter" + op: "Enter" + input: "while/while/Const" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "frame_name" + value { + s: "while/while/while/while/" + } + } + attr { + key: "is_constant" + value { + b: false + } + } + attr { + key: "parallel_iterations" + value { + i: 10 + } + } +} +node { + name: "while/while/Enter_1" + op: "Enter" + input: "while/fifo_queue_Dequeue" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "frame_name" + value { + s: "while/while/while/while/" + } + } + attr { + key: "is_constant" + value { + b: false + } + } + attr { + key: "parallel_iterations" + value { + i: 10 + } + } +} +node { + name: "while/while/Merge" + op: "Merge" + input: "while/while/Enter" + input: "while/while/NextIteration" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/while/Merge_1" + op: "Merge" + input: "while/while/Enter_1" + input: "while/while/NextIteration_1" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/while/Less/y" + op: "Const" + input: "^while/while/Merge" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 3 + } + } + } +} +node { + name: "while/while/Less" + op: "Less" + input: "while/while/Merge" + input: "while/while/Less/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/while/LoopCond" + op: "LoopCond" + input: "while/while/Less" +} +node { + name: "while/while/Switch" + op: "Switch" + input: "while/while/Merge" + input: "while/while/LoopCond" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@while/while/Merge" + } + } + } +} +node { + name: "while/while/Switch_1" + op: "Switch" + input: "while/while/Merge_1" + input: "while/while/LoopCond" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@while/while/Merge_1" + } + } + } +} +node { + name: "while/while/Identity" + op: "Identity" + input: "while/while/Switch:1" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/while/Identity_1" + op: "Identity" + input: "while/while/Switch_1:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/while/add/y" + op: "Const" + input: "^while/while/Identity" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "while/while/add" + op: "Add" + input: "while/while/Identity" + input: "while/while/add/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/while/concat/axis" + op: "Const" + input: "^while/while/Identity" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "while/while/concat" + op: "ConcatV2" + input: "while/while/Identity_1" + input: "while/while/Identity_1" + input: "while/while/concat/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } +} +node { + name: "while/while/NextIteration" + op: "NextIteration" + input: "while/while/add" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/while/NextIteration_1" + op: "NextIteration" + input: "while/while/concat" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/while/Exit" + op: "Exit" + input: "while/while/Switch" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/while/Exit_1" + op: "Exit" + input: "while/while/Switch_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/add/y" + op: "Const" + input: "^while/Identity" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "while/add" + op: "Add" + input: "while/Identity" + input: "while/add/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/NextIteration" + op: "NextIteration" + input: "while/add" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/NextIteration_1" + op: "NextIteration" + input: "while/concat" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/Exit" + op: "Exit" + input: "while/Switch" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/Exit_1" + op: "Exit" + input: "while/Switch_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +versions { + producer: 21 +} + )EOF"; + + // Test graph produced in python using: + /* + with tf.Graph().as_default(): + i0 = tf.constant(0) + q = tf.FIFOQueue(1, "float") + + def inner(j, y): + def inner_cond(j, y): + return j < 3 + + def inner_body(j, y): + return j+1, tf.concat([y, y], axis=0) + + return tf.while_loop(inner_cond, inner_body, + loop_vars=[j, y], + shape_invariants=[i0.get_shape(), + tf.TensorShape(None)]) + + def outer_cond(i, x): + return i < 3 + + def outer_body(i, x): + q.enqueue(x) + y = tf.concat([x, x], axis=2) + inner(0, q.dequeue()) + return i+1, y + + i, z = tf.while_loop(outer_cond, outer_body, + loop_vars=[i0, tf.ones([1, 1, 1])], + shape_invariants=[i0.get_shape(), + tf.TensorShape([None, 1, None])]) + + with open('/tmp/graph.pbtxt', 'w') as f: + f.write(str(tf.get_default_graph().as_graph_def())) + */ + + GrapplerItem item; + CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph)); + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically()); + + std::vector outer_nodes{"while/Merge_1", "while/NextIteration_1", + "while/Exit_1"}; + std::vector inner_nodes{"while/while/Merge_1", + "while/while/NextIteration_1", + "while/while/Exit_1"}; + for (const string& node : outer_nodes) { + const auto props = properties.GetOutputProperties(node); + const OpInfo::TensorProperties& prop = props[0]; + EXPECT_EQ(DT_FLOAT, prop.dtype()); + EXPECT_EQ("float: [1,1,-1]", PropToString(prop)); + } + for (const string& node : inner_nodes) { + const auto props = properties.GetOutputProperties(node); + const OpInfo::TensorProperties& prop = props[0]; + EXPECT_EQ(DT_FLOAT, prop.dtype()); + EXPECT_EQ("float: [-1,1,-1]", PropToString(prop)); + } +} + +TEST_F(GraphPropertiesTest, QueuesAndLoops) { + // Python code used to generate the graph is below. + const string gdef_ascii = R"EOF( +node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "fifo_queue" + op: "FIFOQueueV2" + attr { + key: "capacity" + value { + i: 1 + } + } + attr { + key: "component_types" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "shapes" + value { + list { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "ones" + op: "Const" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 2 + } + dim { + size: 2 + } + } + float_val: 1.0 + } + } + } +} +node { + name: "fifo_queue_enqueue" + op: "QueueEnqueueV2" + input: "fifo_queue" + input: "ones" + attr { + key: "Tcomponents" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "timeout_ms" + value { + i: -1 + } + } +} +node { + name: "fifo_queue_1" + op: "FIFOQueueV2" + attr { + key: "capacity" + value { + i: 1 + } + } + attr { + key: "component_types" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "container" + value { + s: "" + } + } + attr { + key: "shapes" + value { + list { + } + } + } + attr { + key: "shared_name" + value { + s: "" + } + } +} +node { + name: "fifo_queue_Dequeue" + op: "QueueDequeueV2" + input: "fifo_queue" + attr { + key: "component_types" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "timeout_ms" + value { + i: -1 + } + } +} +node { + name: "while/Enter" + op: "Enter" + input: "Const" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "frame_name" + value { + s: "while/while/" + } + } + attr { + key: "is_constant" + value { + b: false + } + } + attr { + key: "parallel_iterations" + value { + i: 10 + } + } +} +node { + name: "while/Enter_1" + op: "Enter" + input: "fifo_queue_Dequeue" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "frame_name" + value { + s: "while/while/" + } + } + attr { + key: "is_constant" + value { + b: false + } + } + attr { + key: "parallel_iterations" + value { + i: 10 + } + } +} +node { + name: "while/Merge" + op: "Merge" + input: "while/Enter" + input: "while/NextIteration" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/Merge_1" + op: "Merge" + input: "while/Enter_1" + input: "while/NextIteration_1" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/Less/y" + op: "Const" + input: "^while/Merge" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 10 + } + } + } +} +node { + name: "while/Less" + op: "Less" + input: "while/Merge" + input: "while/Less/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/LoopCond" + op: "LoopCond" + input: "while/Less" +} +node { + name: "while/Switch" + op: "Switch" + input: "while/Merge" + input: "while/LoopCond" + attr { + key: "T" + value { + type: DT_INT32 + } + } + attr { + key: "_class" + value { + list { + s: "loc:@while/Merge" + } + } + } +} +node { + name: "while/Switch_1" + op: "Switch" + input: "while/Merge_1" + input: "while/LoopCond" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_class" + value { + list { + s: "loc:@while/Merge_1" + } + } + } +} +node { + name: "while/Identity" + op: "Identity" + input: "while/Switch:1" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/Identity_1" + op: "Identity" + input: "while/Switch_1:1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/add/y" + op: "Const" + input: "^while/Identity" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "while/add" + op: "Add" + input: "while/Identity" + input: "while/add/y" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/concat/axis" + op: "Const" + input: "^while/Identity" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 0 + } + } + } +} +node { + name: "while/concat" + op: "ConcatV2" + input: "while/Identity_1" + input: "while/Identity_1" + input: "while/concat/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } +} +node { + name: "while/NextIteration" + op: "NextIteration" + input: "while/add" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/NextIteration_1" + op: "NextIteration" + input: "while/concat" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "while/Exit" + op: "Exit" + input: "while/Switch" + attr { + key: "T" + value { + type: DT_INT32 + } + } +} +node { + name: "while/Exit_1" + op: "Exit" + input: "while/Switch_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } +} +node { + name: "fifo_queue_1_enqueue" + op: "QueueEnqueueV2" + input: "fifo_queue_1" + input: "while/Exit_1" + attr { + key: "Tcomponents" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "timeout_ms" + value { + i: -1 + } + } +} +node { + name: "fifo_queue_1_Dequeue" + op: "QueueDequeueV2" + input: "fifo_queue_1" + attr { + key: "component_types" + value { + list { + type: DT_FLOAT + } + } + } + attr { + key: "timeout_ms" + value { + i: -1 + } + } +} +node { + name: "concat/axis" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } +} +node { + name: "concat" + op: "ConcatV2" + input: "fifo_queue_1_Dequeue" + input: "fifo_queue_1_Dequeue" + input: "concat/axis" + attr { + key: "N" + value { + i: 2 + } + } + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } +} +versions { + producer: 21 +} + )EOF"; + + // Test graph produced in python using: + /* + with tf.Graph().as_default(): + i0 = tf.constant(0) + q0 = tf.FIFOQueue(1, "float") + q0.enqueue(tf.ones([2, 2])) + q1 = tf.FIFOQueue(1, "float") + + def c(i, m): + return i < 10 + + def b(i, m): + return i+1, tf.concat([m, m], axis=0) + + i, m = tf.while_loop( + c, b, loop_vars=[i0, q0.dequeue()], + shape_invariants=[i0.get_shape(), tf.TensorShape(None)]) + + q1.enqueue(m) + v = q1.dequeue(); + tf.concat([v, v], axis=1) + with open('/tmp/graph.pbtxt', 'w') as f: + f.write(str(tf.get_default_graph().as_graph_def())) + */ + + GrapplerItem item; + CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &item.graph)); + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically()); + + std::vector nodes{"while/Merge_1", "while/NextIteration_1", + "while/Exit_1"}; + + for (const string& node : nodes) { + const auto props = properties.GetOutputProperties(node); + const OpInfo::TensorProperties& prop = props[0]; + EXPECT_EQ(DT_FLOAT, prop.dtype()); + EXPECT_EQ("float: [-1,2]", PropToString(prop)); + } + + const auto props = properties.GetOutputProperties("concat"); const OpInfo::TensorProperties& prop = props[0]; - EXPECT_EQ(DT_INT32, prop.dtype()); - EXPECT_TRUE(prop.shape().unknown_rank()); + EXPECT_EQ(DT_FLOAT, prop.dtype()); + EXPECT_EQ("float: [-1,4]", PropToString(prop)); } } // namespace diff --git a/tensorflow/core/grappler/costs/measuring_cost_estimator.cc b/tensorflow/core/grappler/costs/measuring_cost_estimator.cc index e4a0d6f1b86..8fd1801863a 100644 --- a/tensorflow/core/grappler/costs/measuring_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/measuring_cost_estimator.cc @@ -101,6 +101,7 @@ Status MeasuringCostEstimator::PredictCosts(const GraphDef& optimized_graph, } // Run "measurement_steps_" and measure the time. + VLOG(1) << "Number of measurement steps: " << measurement_steps_; if (measurement_threads_ > 0) { for (int i = 0; i < measurement_steps_; ++i) { thread_pool_->Schedule([i, &measurement_fn]() { measurement_fn(i); }); diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index d8b8a12eb29..ba6686e7df9 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -314,6 +314,8 @@ std::pair OpLevelCostEstimator::GetDeviceInfo( bandwidth = 100; } } + VLOG(1) << "Device: " << device.type() << " GFLOPS: " << gflops + << " Bandwidth: " << bandwidth; return std::make_pair(gflops, bandwidth); } @@ -461,7 +463,7 @@ int64 OpLevelCostEstimator::CountConv2DOperations( ops *= conv_dims.kx * conv_dims.ky; ops *= conv_dims.iz * conv_dims.oz; ops *= kOpsPerMac; - VLOG(1) << "Operations for Conv2D" << ops; + VLOG(1) << "Operations for Conv2D " << ops; if (conv_info != nullptr) { *conv_info = conv_dims; @@ -679,7 +681,7 @@ int64 OpLevelCostEstimator::CountConv2DBackPropInputOperations( ops *= conv_dims.iz * conv_dims.oz; ops *= kOpsPerMac; - VLOG(1) << "Operations for Conv2DBackPropInput" << ops; + VLOG(1) << "Operations for Conv2DBackPropInput " << ops; if (returned_conv_dims != nullptr) { *returned_conv_dims = conv_dims; diff --git a/tensorflow/core/grappler/costs/virtual_placer.cc b/tensorflow/core/grappler/costs/virtual_placer.cc index 0291bd04909..a3dfb278e2a 100644 --- a/tensorflow/core/grappler/costs/virtual_placer.cc +++ b/tensorflow/core/grappler/costs/virtual_placer.cc @@ -36,17 +36,19 @@ VirtualPlacer::VirtualPlacer(const Cluster* cluster) { } else { default_device_ = devices_.begin()->first; + VLOG(1) << "Number of devices: " << devices_.size(); for (const auto& device : devices_) { if (str_util::Lowercase(device.first).find("gpu") != string::npos) { default_device_ = device.first; + break; } - break; } } } const DeviceProperties& VirtualPlacer::get_device(const NodeDef& node) const { string device = get_canonical_device_name(node); + VLOG(3) << "Device name: " << device; auto it = devices_.find(device); DCHECK(it != devices_.end()); return it->second; diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index c68d4e31c46..f6f91e37ac8 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -58,7 +58,9 @@ Costs CombineCosts(const Costs& left, const Costs& right) { VirtualScheduler::VirtualScheduler(const GrapplerItem* grappler_item, const bool use_static_shapes, Cluster* cluster) - : // TODO(dyoon): Use a better way than FIFO. + : // Allow LIFO as well as FIFO. LIFO allows an output node of an node to + // follow it in execution, saving addition memory time from having to + // write and read. For default cases, use FIFO for performance. ready_nodes_(new FIFOManager()), graph_costs_(Costs::ZeroCosts()), graph_properties_(*grappler_item), @@ -244,8 +246,8 @@ string VirtualScheduler::DeviceName(const NodeDef* node) const { string VirtualScheduler::ChannelDeviceName(const NodeDef* from, const NodeDef* to) const { CHECK(!initialized_) << "ChannelDeviceName is called after Init()."; - - return kChannelDevice + ": " + DeviceName(from) + " to " + DeviceName(to); + return kChannelDevice + ": from " + DeviceName(from) + " to " + + DeviceName(to); } std::pair VirtualScheduler::CreateSendRecv( @@ -318,8 +320,8 @@ NodeInfo VirtualScheduler::GetCurrNodeInfo() const { } // Construct NodeInfo. - const auto& node_state = node_map_.at(node); NodeInfo node_info; + const auto& node_state = node_map_.at(node); node_info.name = node->name(); node_info.device_name = node_state.device_name; auto& op_info = node_info.op_info; @@ -577,6 +579,7 @@ Costs VirtualScheduler::Summary() const { << " GB, at the end: " << state.memory_usage << " B"; VLOG(1) << "Per-op execution time (and memory usage at peak memory usage):"; + // Profile non-persistent op memory usage. for (const auto& node_port : state.mem_usage_snapshot_at_peak) { const auto* node = node_port.first; @@ -617,5 +620,30 @@ Costs VirtualScheduler::Summary() const { return critical_path_costs; } +Costs VirtualScheduler::Summary(StepStats* stepstats) { + if (stepstats != nullptr) { + for (const auto& device : device_) { + DeviceStepStats* device_stepstats = stepstats->add_dev_stats(); + device_stepstats->set_device(device.first); + for (const auto& node_def : device.second.nodes_executed) { + const NodeState& nodestate = node_map_.at(node_def); + NodeExecStats* node_stats = device_stepstats->add_node_stats(); + node_stats->set_node_name(node_def->op()); + node_stats->set_timeline_label(node_def->name()); + node_stats->set_op_start_rel_micros(0); + node_stats->set_all_start_micros( + nodestate.time_scheduled.asMicroSeconds().count()); + node_stats->set_op_end_rel_micros( + nodestate.time_finished.asMicroSeconds().count() - + nodestate.time_scheduled.asMicroSeconds().count()); + node_stats->set_all_end_rel_micros( + nodestate.time_finished.asMicroSeconds().count() - + nodestate.time_scheduled.asMicroSeconds().count()); + } + } + } + return Summary(); +} + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.h b/tensorflow/core/grappler/costs/virtual_scheduler.h index 472ba90f7c5..8b9eccbd432 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.h +++ b/tensorflow/core/grappler/costs/virtual_scheduler.h @@ -19,7 +19,10 @@ limitations under the License. #include #include #include +#include +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/step_stats.pb.h" #include "tensorflow/core/grappler/costs/cost_estimator.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/costs/virtual_placer.h" @@ -80,16 +83,27 @@ struct DeviceState { // Nodes executed on this device in execution order. std::vector nodes_executed; + struct NodePairHash { + public: + const std::size_t operator()( + const std::pair& element) const { + return std::hash()(element.first); + } + }; + // Nodes currently allocated in memory: set of NodeDef* and port_num pairs // so that we can track which output of the node is in memory. - std::set> nodes_in_memory; + std::unordered_set, NodePairHash> + nodes_in_memory; // Nodes allocated in memory persistently: e.g., Variables. - std::set> persistent_nodes; + std::unordered_set, NodePairHash> + persistent_nodes; // Snapshot of nodes_in_memory, when memory usage is at peak. // Same to nodes_in_memory, it's a set of NodeDef* and port_num pairs. - std::set> mem_usage_snapshot_at_peak; + std::unordered_set, NodePairHash> + mem_usage_snapshot_at_peak; Costs device_costs; std::map op_to_cost; // Per-op cost. @@ -113,7 +127,7 @@ class ReadyNodeManager { ReadyNodeManager() {} virtual ~ReadyNodeManager() {} virtual void AddNode(const NodeDef* node) = 0; - virtual const NodeDef* GetCurrNode() const = 0; + virtual const NodeDef* GetCurrNode() = 0; virtual void RemoveCurrNode() = 0; virtual bool Empty() const = 0; }; @@ -123,7 +137,7 @@ class FIFOManager : public ReadyNodeManager { FIFOManager() : ReadyNodeManager() {} ~FIFOManager() override {} void AddNode(const NodeDef* node) override { nodes_.push_back(node); } - const NodeDef* GetCurrNode() const override { return nodes_.front(); } + const NodeDef* GetCurrNode() override { return nodes_.front(); } void RemoveCurrNode() override { nodes_.pop_front(); } bool Empty() const override { return nodes_.empty(); } @@ -131,6 +145,40 @@ class FIFOManager : public ReadyNodeManager { std::list nodes_; }; +// The LIFOManager schedules nodes by returning the last one added to the +// scheduler. A node is executed and then its ready outputs are newly added to +// the scheduler, so the LIFOManager will return outputs to a node following +// that node's execution. +class LIFOManager : public ReadyNodeManager { + public: + LIFOManager() : ReadyNodeManager() {} + ~LIFOManager() override {} + void AddNode(const NodeDef* node) override { nodes_.push_back(node); } + const NodeDef* GetCurrNode() override { + curr_pos_ = nodes_.end(); + curr_pos_--; + return nodes_.back(); + } + void RemoveCurrNode() override { + if (curr_pos_ != nodes_.end()) { + nodes_.erase(curr_pos_); + } else if (!nodes_.empty()) { + nodes_.pop_back(); + } + curr_pos_ = nodes_.end(); + curr_pos_--; + } + bool Empty() const override { return nodes_.empty(); } + + private: + std::list nodes_; + // Keep track of the current node being executed by saving its position. + // Necessary because nodes may be added to the end of the list while a node is + // executing, and we want to remove the correct node (the one that is + // executing) rather than the new ones being added. + std::list::iterator curr_pos_ = nodes_.end(); +}; + // A wrapper struct to OpInfo proto. // TODO(dyoon): once we extend OpInfo or implement a better interface, and then // delete this wrapper struct. @@ -158,6 +206,9 @@ class VirtualScheduler { // Prints out summary of execution (timing, memory usage, etc.) Costs Summary() const; + // Like the above, but writes detailed stats to stepstats. + // If stepstats is nullptr, then just calls and return Summary(). + Costs Summary(StepStats* stepstats); protected: // GetDeviceStates and GetNodeStates are currently for testing purpuse only. @@ -216,6 +267,7 @@ class VirtualScheduler { // Auxilliary data structures for constructing NodeState and DeviceState. GraphProperties graph_properties_; Cluster* cluster_; // Not owned. + const GrapplerItem* grappler_item_; // Not owned. bool use_static_shapes_; bool initialized_; diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index 9e48c411dc0..a9174c7417c 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/costs/virtual_scheduler.h" - #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/grappler/clusters/virtual_cluster.h" #include "tensorflow/core/grappler/costs/virtual_placer.h" @@ -39,15 +38,40 @@ class TestVirtualScheduler : public VirtualScheduler { class VirtualSchedulerTest : public ::testing::Test { protected: + NodeDef node1_, node2_, node3_, node4_, node5_, node6_; + const string kCPU0 = "/job:localhost/replica:0/task:0/cpu:0"; - void SetUp() override { - // Initializes cluster_ and placer_. - std::unordered_map devices; + DeviceProperties GetDummyCPUDevice() { + // Create CPU with 2 cores, 4 Ghz freq, 2 GB/s mem bandwidth. + // - 8 Gflops + // - 2 GB/s DeviceProperties cpu_device; cpu_device.set_type("CPU"); - devices[kCPU0] = cpu_device; + cpu_device.set_frequency(4000); + cpu_device.set_num_cores(2); + cpu_device.set_bandwidth(2000000); + return cpu_device; + } + void SetUp() override { + // Initializes nodes for manager + node1_.set_name("Node1"); + node2_.set_name("Node2"); + node3_.set_name("Node3"); + node4_.set_name("Node4"); + node5_.set_name("Node5"); + node6_.set_name("Node6"); + + // Initializes cluster_ and placer_. + std::unordered_map devices; + + // Set some dummy CPU properties + DeviceProperties cpu_device = GetDummyCPUDevice(); + + // IMPORTANT: Device is not actually ever used in the test case since + // force_cpu_type is defaulted to "Haswell" + devices[kCPU0] = cpu_device; cluster_.reset(new VirtualCluster(devices)); placer_.reset(new VirtualPlacer(cluster_.get())); } @@ -102,6 +126,41 @@ class VirtualSchedulerTest : public ::testing::Test { dependency_["y"] = {"x", "f"}; } + void CreateGrapplerItemWithMatmulChain() { + tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0); + // Add control dependencies to ensure tests do not rely on specific + // manager and the order remains consistent for the test. + auto a = tensorflow::ops::RandomUniform(s.WithOpName("a"), {3200, 3200}, + DT_FLOAT); + auto b = tensorflow::ops::RandomUniform( + s.WithOpName("b").WithControlDependencies(a), {3200, 3200}, DT_FLOAT); + auto c = tensorflow::ops::RandomUniform( + s.WithOpName("c").WithControlDependencies(b), {3200, 3200}, DT_FLOAT); + auto d = tensorflow::ops::RandomUniform( + s.WithOpName("d").WithControlDependencies(c), {3200, 3200}, DT_FLOAT); + auto e = tensorflow::ops::RandomUniform( + s.WithOpName("e").WithControlDependencies(d), {3200, 3200}, DT_FLOAT); + + auto ab = tensorflow::ops::MatMul( + s.WithOpName("ab").WithControlDependencies(e), a, b); + auto abc = tensorflow::ops::MatMul(s.WithOpName("abc"), ab, c); + auto abcd = tensorflow::ops::MatMul(s.WithOpName("abcd"), abc, d); + auto abcde = tensorflow::ops::MatMul(s.WithOpName("abcde"), abcd, e); + + GraphDef def; + TF_CHECK_OK(s.ToGraphDef(&def)); + + grappler_item_.reset(new GrapplerItem); + grappler_item_->id = "test_matmul_sequence_graph"; + grappler_item_->graph = def; + grappler_item_->fetch = {"abcde"}; + + dependency_["ab"] = {"a", "b"}; + dependency_["abc"] = {"ab", "c"}; + dependency_["abcd"] = {"abc", "d"}; + dependency_["abcde"] = {"abcd", "e"}; + } + // AddN that takes 4 tensors with 10x10x10x10. void CreateGrapplerItemWithAddN() { tensorflow::Scope s = tensorflow::Scope::NewRootScope().WithDevice(kCPU0); @@ -201,6 +260,20 @@ class VirtualSchedulerTest : public ::testing::Test { TF_CHECK_OK(scheduler_->Init()); } + // Returns cost based on op. + Costs SimplePredictCosts(const NodeInfo& info) const { + Costs c; + int64 exec_cost = 0; + if (info.op_info.op() == "MatMul") { + exec_cost = 2000000000; + } + if (info.op_info.op() == "RandomUniform") { + exec_cost = 1000000000; + } + c.execution_time = Costs::NanoSeconds(exec_cost); + return c; + } + // Call this after init scheduler_. Scheduler stops after executing // target_node. std::unordered_map RunScheduler(const string& target_node) { @@ -211,6 +284,8 @@ class VirtualSchedulerTest : public ::testing::Test { NodeInfo node_info = scheduler_->GetCurrNodeInfo(); ops_executed[node_info.name] = node_info; + Costs node_costs = SimplePredictCosts(node_info); + // Check scheduling order. auto it = dependency_.find(node_info.name); if (it != dependency_.end()) { @@ -218,7 +293,7 @@ class VirtualSchedulerTest : public ::testing::Test { EXPECT_GT(ops_executed.count(preceding_node), 0); } } - more_nodes = scheduler_->MarkCurrNodeExecuted(zero_costs); + more_nodes = scheduler_->MarkCurrNodeExecuted(node_costs); if (node_info.name == target_node) { // Scheduler has the state after executing the target node. @@ -263,7 +338,8 @@ class VirtualSchedulerTest : public ::testing::Test { // Helper method tthat checks name - port pairs. void ValidateMemoryUsageSnapshot( const std::vector& expected_names, const int port_num_expected, - const std::set>& mem_usage_snapshot) { + const std::unordered_set, + DeviceState::NodePairHash>& mem_usage_snapshot) { std::set> nodes_at_peak_mem_usage; std::transform( mem_usage_snapshot.begin(), mem_usage_snapshot.end(), @@ -311,6 +387,218 @@ class VirtualSchedulerTest : public ::testing::Test { const int depth_out_ = 16; }; +// Test that FIFOManager correctly returns the current node with only 1 node. +TEST_F(VirtualSchedulerTest, GetSingleNodeFIFOManager) { + // Init. + FIFOManager manager = FIFOManager(); + + // Add the node to FIFOManager. + manager.AddNode(&node1_); + EXPECT_EQ("Node1", manager.GetCurrNode()->name()); +} + +// Test that FIFOManager removes the only node contained within. +TEST_F(VirtualSchedulerTest, RemoveSingleNodeFIFOManager) { + // Init. + FIFOManager manager = FIFOManager(); + + // Add the node to FIFOManager. + manager.AddNode(&node1_); + + // Remove the only node in FIFOManager. + manager.RemoveCurrNode(); + EXPECT_TRUE(manager.Empty()); +} + +// Test that FIFOManager can remove multiple nodes and returns the current node +// in the right order +TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleFIFOManager) { + // Init. + FIFOManager manager = FIFOManager(); + + // Add the nodes to FIFOManager. + manager.AddNode(&node1_); + manager.AddNode(&node2_); + manager.AddNode(&node3_); + manager.AddNode(&node4_); + + // Keep checking current node while removing nodes from manager. + EXPECT_EQ("Node1", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node2", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node3", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node4", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_TRUE(manager.Empty()); +} + +// Test that FIFOManager can remove multiple nodes and add more nodes, still +// returning the current node in the right order +TEST_F(VirtualSchedulerTest, AddAndRemoveMultipleFIFOManager) { + // Init. + FIFOManager manager = FIFOManager(); + + // Add the nodes to FIFOManager. + manager.AddNode(&node1_); + manager.AddNode(&node2_); + manager.AddNode(&node3_); + manager.AddNode(&node4_); + + // Keep checking current node as nodes are removed and added. + EXPECT_EQ("Node1", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node2", manager.GetCurrNode()->name()); + manager.AddNode(&node5_); + manager.RemoveCurrNode(); + EXPECT_EQ("Node3", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node4", manager.GetCurrNode()->name()); + manager.AddNode(&node6_); + manager.RemoveCurrNode(); + EXPECT_EQ("Node5", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node6", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_TRUE(manager.Empty()); +} + +// Test that LIFOManager correctly returns the current node with only 1 node. +TEST_F(VirtualSchedulerTest, GetSingleNodeLIFOManager) { + // Init. + LIFOManager manager = LIFOManager(); + + // Add the node to LIFOManager. + manager.AddNode(&node1_); + EXPECT_EQ("Node1", manager.GetCurrNode()->name()); +} + +// Test that LIFOManager removes the only node contained within. +TEST_F(VirtualSchedulerTest, RemoveSingleNodeLIFOManager) { + // Init. + LIFOManager manager = LIFOManager(); + + // Add the node to LIFOManager. + manager.AddNode(&node1_); + + // Remove the only node in LIFOManager. + manager.RemoveCurrNode(); + EXPECT_TRUE(manager.Empty()); +} + +// Test that LIFOManager can remove multiple nodes and returns the current node +// in the right order +TEST_F(VirtualSchedulerTest, GetAndRemoveMultipleLIFOManager) { + // Init. + LIFOManager manager = LIFOManager(); + + // Add the nodes to LIFOManager. + manager.AddNode(&node1_); + manager.AddNode(&node2_); + manager.AddNode(&node3_); + manager.AddNode(&node4_); + + // Keep checking current node while removing nodes from manager. + EXPECT_EQ("Node4", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node3", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node2", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node1", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_TRUE(manager.Empty()); +} + +// Test that LIFOManager can remove multiple nodes (must be removing the current +// node) and add more nodes, still returning the current node in the right order +TEST_F(VirtualSchedulerTest, AddAndRemoveMultipleLIFOManager) { + // Init. + LIFOManager manager = LIFOManager(); + + // Add the nodes to LIFOManager. + manager.AddNode(&node1_); + manager.AddNode(&node2_); + manager.AddNode(&node3_); + manager.AddNode(&node4_); + + // Keep checking current node as nodes are removed and added. + EXPECT_EQ("Node4", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node3", manager.GetCurrNode()->name()); + manager.AddNode(&node5_); + manager.RemoveCurrNode(); + EXPECT_EQ("Node5", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node2", manager.GetCurrNode()->name()); + manager.AddNode(&node6_); + manager.RemoveCurrNode(); + EXPECT_EQ("Node6", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_EQ("Node1", manager.GetCurrNode()->name()); + manager.RemoveCurrNode(); + EXPECT_TRUE(manager.Empty()); +} + +// Create small graph, run predict costs on it, make sure the costs from the +// summary match the hand-calculated costs. +TEST_F(VirtualSchedulerTest, SummaryCostTest) { + // Run matmul test. + CreateGrapplerItemWithMatmulChain(); + InitScheduler(); + auto ops_executed = RunScheduler(""); + Costs c = scheduler_->Summary(); + + // RandomUniform - 5 + // Matmuls - 4 * 2 = 8 + // Total: 13 + EXPECT_EQ(13000000, c.execution_time.asMicroSeconds().count()); +} + +// Like the above SummaryCostTest, but makes sure the stepstats timeline is +// correct. +TEST_F(VirtualSchedulerTest, SummaryCostStepStatsTest) { + // Run matmul test. + CreateGrapplerItemWithMatmulChain(); + InitScheduler(); + auto ops_executed = RunScheduler(""); + StepStats stepstats; + Costs c = scheduler_->Summary(&stepstats); + EXPECT_EQ(13000000, c.execution_time.asMicroSeconds().count()); + + // Should only be 1 device! + EXPECT_EQ(1, stepstats.dev_stats().size()); + + // Create a map of op name -> start and end times (micros). + std::map> start_end_times; + for (const auto& device_step_stats : stepstats.dev_stats()) { + for (const auto& stats : device_step_stats.node_stats()) { + // The node name is actually in the timeline_label. + int64 start = stats.all_start_micros(); + int64 end = start + stats.all_end_rel_micros(); + start_end_times[stats.timeline_label()] = + std::pair(start, end); + } + } + + // The base start_time is the time to compute RandomUniforms + int64 cur_time = static_cast(5000000); + // The increment is the execution time of one matmul. See + // CreateGrapplerItemWithMatmulChain for details. + int64 increment = static_cast(2000000); + auto op_names = {"ab", "abc", "abcd", "abcde"}; + for (const auto& op_name : op_names) { + int64 actual_start = start_end_times[op_name].first; + int64 actual_end = start_end_times[op_name].second; + int64 expected_start = cur_time; + int64 expected_end = cur_time + increment; + EXPECT_EQ(expected_start, actual_start); + EXPECT_EQ(expected_end, actual_end); + cur_time += increment; + } +} + TEST_F(VirtualSchedulerTest, InitAndBasicScheduling) { // Init. CreateGrapplerItemWithConv2Ds(); diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc index 312a457abf4..88ddd6c1b3c 100644 --- a/tensorflow/core/grappler/grappler_item.cc +++ b/tensorflow/core/grappler/grappler_item.cc @@ -55,7 +55,7 @@ std::vector ComputeTransitiveFanin( std::vector queue; for (const string& root : terminal_nodes) { const NodeDef* node = name_to_node[NodeName(root)]; - CHECK(node); + CHECK(node) << "Unknown root " << root; queue.push_back(node); } diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index bb36152bd87..969376917be 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -332,6 +332,29 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( return nullptr; } + // Validate feed, fetch and init nodes + std::unordered_set nodes; + for (const auto& node : new_item->graph.node()) { + nodes.insert(node.name()); + } + for (const auto& feed : new_item->feed) { + if (nodes.find(feed.first) == nodes.end()) { + LOG(ERROR) << "Feed node " << feed.first << " doesn't exist in graph"; + return nullptr; + } + } + for (const auto& fetch : new_item->fetch) { + if (nodes.find(fetch) == nodes.end()) { + LOG(ERROR) << "Fetch node " << fetch << " doesn't exist in graph"; + return nullptr; + } + } + for (const auto& init : new_item->init_ops) { + if (nodes.find(init) == nodes.end()) { + LOG(ERROR) << "Init node " << init << " doesn't exist in graph"; + return nullptr; + } + } return new_item; } diff --git a/tensorflow/core/grappler/grappler_item_builder.h b/tensorflow/core/grappler/grappler_item_builder.h index 3aa1d2027f5..7135c83801a 100644 --- a/tensorflow/core/grappler/grappler_item_builder.h +++ b/tensorflow/core/grappler/grappler_item_builder.h @@ -31,7 +31,7 @@ struct ItemConfig { : ignore_user_placement(true), ignore_colocation(true), placeholder_unknown_output_shape_dim(-1), - apply_optimizations(true), + apply_optimizations(false), inline_functions(true) {} // If true, ignore all user specified node placement. diff --git a/tensorflow/core/grappler/grappler_item_builder_test.cc b/tensorflow/core/grappler/grappler_item_builder_test.cc index 92225ffb1b4..048870f9e51 100644 --- a/tensorflow/core/grappler/grappler_item_builder_test.cc +++ b/tensorflow/core/grappler/grappler_item_builder_test.cc @@ -51,7 +51,11 @@ void SampleSumSymbolicGradientGraphdef( auto g0 = SymbolicGradient(scope, std::initializer_list{x, y, z}, {DT_FLOAT, DT_INT32}, fn); - fetches->mutable_node_list()->add_value(g0[0].name()); + // TODO(bsteiner): we should rewrite the feed/fetch nodes to reflect the + // inlining that's done in the item builder + // fetches->mutable_node_list()->add_value(g0[0].name()); + fetches->mutable_node_list()->add_value("SymbolicGradient/dx"); + fetches->mutable_node_list()->add_value("SymbolicGradient/dy_reshaped"); TF_CHECK_OK(scope.ToGraphDef(def)); @@ -109,11 +113,12 @@ TEST_F(GrapplerItemBuilderTest, SymbolicGradientInlining) { std::unique_ptr with_inline = CreateGrapplerItem(def, fetches); // For the inlined graph, there should be 0 symbolic gradient ops. - CHECK_EQ(0, CountSymbolicGradientOps(with_inline)); + EXPECT_EQ(0, CountSymbolicGradientOps(with_inline)); // For the inlined graph, make sure all the required expanded op’s are in the // graph. - CHECK_EQ(ops_of_inline.size(), CountOpsWithNames(with_inline, ops_of_inline)); + EXPECT_EQ(ops_of_inline.size(), + CountOpsWithNames(with_inline, ops_of_inline)); } } // namespace diff --git a/tensorflow/core/grappler/inputs/BUILD b/tensorflow/core/grappler/inputs/BUILD index 176b3e982fb..5c70f409697 100644 --- a/tensorflow/core/grappler/inputs/BUILD +++ b/tensorflow/core/grappler/inputs/BUILD @@ -22,7 +22,6 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ - "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], @@ -37,7 +36,6 @@ cc_test( deps = [ ":utils", "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", ], diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 6bb3d50b76d..3a705b85f7a 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -59,7 +59,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":graph_optimizer", - "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:devices", @@ -78,11 +77,11 @@ cc_test( ":auto_parallel", "//tensorflow/cc:cc_ops", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", - "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", ], ) @@ -200,6 +199,7 @@ cc_library( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/utils:topological_sort", ], ) @@ -229,7 +229,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":graph_optimizer", - "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:devices", @@ -254,7 +253,6 @@ cc_test( "//tensorflow/core:testlib", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", - "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", ], ) diff --git a/tensorflow/core/grappler/optimizers/auto_parallel.cc b/tensorflow/core/grappler/optimizers/auto_parallel.cc index d4326a022f4..42f2f1850f4 100644 --- a/tensorflow/core/grappler/optimizers/auto_parallel.cc +++ b/tensorflow/core/grappler/optimizers/auto_parallel.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/devices.h" #include "tensorflow/core/grappler/grappler_item.h" @@ -248,7 +249,8 @@ void AutoParallel::BuildGraph(GraphDef* graph) { for (const auto& fetch : item_->fetch) { AddNodeControl(fetch, {control->name()}, graph); } - *(graph->mutable_library()) = item_->graph.library(); + *graph->mutable_library() = item_->graph.library(); + *graph->mutable_versions() = item_->graph.versions(); LOG(INFO) << "Parallelized graph size: " << graph->node_size(); } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index ac04be6d331..07d2d1e1f18 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -18,11 +18,13 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/constant_folding.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/grappler_item.h" @@ -283,15 +285,32 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { return false; } + // We can only fold nodes if all their inputs are known statically, except in + // the case of a merge node that propagate the first inputs that becomes + // available, and therefore only requires a single constant input to be + // foldable. + bool has_constant_input = false; + const bool is_merge = IsMerge(node); for (const auto& input : node.input()) { if (IsControlInput(input)) { continue; } - bool is_const = IsConstant(*node_map_->GetNode(input)); - if (!is_const) { + const NodeDef* input_node = node_map_->GetNode(input); + bool is_const = IsConstant(*input_node); + if (!is_const && !is_merge) { return false; } + // Don't fold strings constants for now since this causes problems with + // checkpointing. + if (is_const && input_node->attr().at("dtype").type() == DT_STRING) { + return false; + } + has_constant_input |= is_const; } + if (is_merge) { + return has_constant_input; + } + return true; } @@ -387,6 +406,82 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node, } Status ConstantFolding::FoldNode(const NodeDef& node, GraphDef* output) { + if (IsMerge(node)) { + // Merge nodes are special, in the sense that they execute as soon as one of + // their input is ready. We can therefore fold a merge node iff it has at + // least one constant input without control dependency. + // We still need to ensure that the nodes in the fanin of the merge node are + // scheduled. We'll therefore add a control dependency from the merge node + // to the folded constant. We end up with: + // * the merge node and its inputs are preserved as is + // * a new constant node C1, driven by the merge node through a control + // dependency, initialized to the value of the folded input + // * a new constant node C2, driven by the merge node through a control + // dependency, initialized to the index of the folded input + // * the fanout of the merge nodes is rewired to be driven by either C1 or + // C2. + for (int input_index = 0; input_index < node.input_size(); ++input_index) { + const auto& input = node.input(input_index); + if (IsControlInput(input)) { + // Try the next input. + continue; + } + NodeDef* input_node = node_map_->GetNode(input); + if (!IsConstant(*input_node)) { + continue; + } + bool valid_input = true; + for (const string& fanin_of_input : input_node->input()) { + if (IsControlInput(fanin_of_input)) { + valid_input = false; + break; + } + } + if (!valid_input) { + // Try the next input + continue; + } + NodeDef* const_out = output->add_node(); + *const_out = *input_node; + const_out->set_name( + AddPrefixToNodeName(node.name(), kConstantFoldingConst)); + *const_out->add_input() = AsControlDependency(node); + node_map_->AddNode(const_out->name(), const_out); + + NodeDef* const_index = output->add_node(); + const_index->set_op("Const"); + Tensor index(DT_INT32, TensorShape({})); + index.flat()(0) = input_index; + (*const_index->mutable_attr())["dtype"].set_type(DT_INT32); + index.AsProtoTensorContent( + (*const_index->mutable_attr())["value"].mutable_tensor()); + const_index->set_name(AddPrefixToNodeName( + strings::StrCat(node.name(), "_index"), kConstantFoldingConst)); + *const_index->add_input() = AsControlDependency(node); + node_map_->AddNode(const_index->name(), const_index); + + auto outputs = node_map_->GetOutputs(node.name()); + for (auto& output : outputs) { + for (int i = 0; i < output->input_size(); i++) { + int position; + string node_name = ParseNodeName(output->input(i), &position); + if (node_name == node.name()) { + if (position == 0) { + *output->mutable_input(i) = const_out->name(); + } else if (position == 1) { + *output->mutable_input(i) = const_index->name(); + } else { + // This is a control dependency (or an invalid edge since the + // merge node has only 2 inputs): preserve them. + } + } + } + } + return Status::OK(); + } + return Status::OK(); + } + std::vector const_nodes; TF_RETURN_IF_ERROR(EvaluateOneFoldable(node, &const_nodes)); @@ -549,6 +644,10 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, TF_RETURN_IF_ERROR(FoldGraph(output)); TF_RETURN_IF_ERROR(SimplifyGraph(output)); LOG(INFO) << "Optimized graph size: " << output->node_size(); + + *output->mutable_library() = item.graph.library(); + *output->mutable_versions() = item.graph.versions(); + return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 566d3cd9a39..cb29ccc3410 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -383,6 +383,97 @@ TEST_F(ConstantFoldingTest, SwitchNodes) { } } +TEST_F(ConstantFoldingTest, MergeNodes) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + Output x = + ops::RandomNormal(scope.WithOpName("x"), {3, 5}, DataType::DT_FLOAT); + Output y = + ops::RandomNormal(scope.WithOpName("y"), {3, 5}, DataType::DT_FLOAT); + Output const1 = + ops::Const(scope.WithOpName("const1").WithControlDependencies(x), 2.7f, + TensorShape({3, 5})); + Output const2 = + ops::Const(scope.WithOpName("const2"), 3.14f, TensorShape({3, 5})); + Output const3 = + ops::Const(scope.WithOpName("const3").WithControlDependencies(x), 3.14f, + TensorShape({3, 5})); + + // Create 3 merge nodes: m1 is foldable, m2 and m3 aren't. + ops::Merge m1(scope.WithOpName("m1"), {x, const1, const2}); + ops::Merge m2(scope.WithOpName("m2"), {const1, const3}); + ops::Merge m3(scope.WithOpName("m3"), {x, y}); + + ops::Identity out1(scope.WithOpName("out1"), m1.output); + ops::Identity idx1(scope.WithOpName("idx1"), m1.value_index); + ops::Identity out2(scope.WithOpName("out2"), m2.output); + ops::Identity idx2(scope.WithOpName("idx2"), m2.value_index); + ops::Identity out3(scope.WithOpName("out3"), m3.output); + ops::Identity idx3(scope.WithOpName("idx3"), m3.value_index); + + GrapplerItem item; + item.fetch.push_back("out1, idx1, out2, idx2, out3, idx3"); + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + ConstantFolding fold; + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + int found_nodes = 0; + for (const auto& node : output.node()) { + if (node.name() == "out1") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("ConstantFolding/m1", node.input(0)); + ++found_nodes; + } else if (node.name() == "idx1") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("ConstantFolding/m1_index", node.input(0)); + ++found_nodes; + } else if (node.name() == "ConstantFolding/m1") { + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("^m1", node.input(0)); + ++found_nodes; + } else if (node.name() == "ConstantFolding/m1_index") { + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("^m1", node.input(0)); + ++found_nodes; + } else if (node.name() == "out2") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("m2", node.input(0)); + ++found_nodes; + } else if (node.name() == "idx2") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("m2:1", node.input(0)); + ++found_nodes; + } else if (node.name() == "out3") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("m3", node.input(0)); + ++found_nodes; + } else if (node.name() == "idx3") { + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("m3:1", node.input(0)); + ++found_nodes; + } + } + // Make sure the graph contains all the nodes we're expecting. + EXPECT_EQ(8, found_nodes); + + std::vector fetch = {"out1", "idx1"}; + auto tensors = EvaluateNodes(output, fetch); + EXPECT_EQ(2, tensors.size()); + const Tensor& out_value = tensors[0]; + EXPECT_EQ(3 * 5, out_value.NumElements()); + for (int i = 0; i < 3 * 5; ++i) { + EXPECT_EQ(3.14f, out_value.flat()(i)); + } + const Tensor& out_idx = tensors[1]; + EXPECT_EQ(1, out_idx.NumElements()); + EXPECT_EQ(2, out_idx.flat()(0)); +} + TEST_F(ConstantFoldingTest, NoOpReduction) { // Build a simple graph with a reduction that can be reduced to the identity. tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 5ac42c2abab..5a017cb0e82 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -1158,7 +1158,7 @@ Status LayoutOptimizer::InferOutputShapes(GrapplerItem* item) { for (const auto& tensor_property : tensor_properties) { *attr_output_shape.mutable_list()->add_shape() = tensor_property.shape(); } - node->mutable_attr()->insert({"_output_shapes", attr_output_shape}); + (*node->mutable_attr())["_output_shapes"] = attr_output_shape; } return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index 1ed7cab4abf..bd8ce8c30cf 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -15,6 +15,9 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/memory_optimizer.h" +#include +#include +#include #include #include @@ -25,11 +28,325 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/graph_rewriter.h" #include "tensorflow/core/grappler/optimizers/static_schedule.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/topological_sort.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { namespace grappler { +// Prefix added to nodes which are recomputed. const char* kRecomputedNodePrefix = "Recomputed"; +const char* kRecomputeTriggerNodePrefix = "RecomputeTrigger"; +// Attribute which may be added to nodes to manually allow them to be +// recomputed. +const char* kRecomputeHint = "_recompute_hint"; +const char* kRecomputationTargetNamePrefix = "gradients/"; + +// Ops which we wouldn't mind recomputing to save memory. +// TODO(allenl): Replace this list with a cost model. +std::unordered_set GetCheapToRecomputeOps() { + std::unordered_set cheap_ops = { + "Add", "AddN", "BiasAdd", + "Cast", "Fill", "FloorDiv", + "FloorMod", "FusedBatchNorm", "Mul", + "Neg", "RealDiv", "Reciprocal", + "Relu", "Reshape", "Rsqrt", + "Sqrt", "Square", "SquaredDifference", + "Sub", "Tile", "Transpose"}; + return cheap_ops; +} + +// Nodes whose inputs we may want to recompute (i.e. gradients). +// TODO(allenl): Rather than blindly recomputing gradient inputs, use a static +// schedule (grappler::EstimateEarliestExecutionTimes) to recompute only nodes +// whose outputs will sit around for a while. +bool IsTargetOp(const NodeDef& node) { + return node.name().find(kRecomputationTargetNamePrefix) == 0; +} + +// Find recomputable ops which feed into target nodes. +std::unordered_set FindCandidateRecomputeNodes( + const NodeMap& node_map, const GraphDef* graph, + const std::function& is_candidate) { + std::unordered_set candidate_recompute_nodes; + for (const auto& node : graph->node()) { + if (!is_candidate(node)) { + continue; + } + bool has_target_output = false; + for (const NodeDef* output : node_map.GetOutputs(node.name())) { + // It only makes sense to recompute this if it feeds into a target + // node. We expand this to dependencies in GetOpGroupsToRecompute. + if (IsTargetOp(*output)) { + has_target_output = true; + break; + } + } + if (!has_target_output) { + continue; + } + bool has_target_input = false; + for (const string& input_name : node.input()) { + // Don't recompute nodes which depend on target nodes. + const NodeDef* input_node = node_map.GetNode(input_name); + if (IsTargetOp(*input_node)) { + has_target_input = true; + break; + } + } + if (has_target_input) { + continue; + } + candidate_recompute_nodes.insert(&node); + } + return candidate_recompute_nodes; +} + +void connected_subgraph(const NodeMap& node_map, bool collect_inputs, + bool collect_outputs, + const std::function& is_candidate, + std::unordered_set* expanded_nodes) { + std::queue to_visit; + for (const NodeDef* starting_node : *expanded_nodes) { + to_visit.push(starting_node); + } + expanded_nodes->clear(); + while (!to_visit.empty()) { + const NodeDef* current_node = to_visit.front(); + to_visit.pop(); + if (!expanded_nodes->insert(current_node).second) { + // We already visited this node + continue; + } + if (collect_inputs) { + // Add inputs and outputs to this subgraph if they are candidates + for (const string& input_name_raw : current_node->input()) { + const NodeDef* input_node = node_map.GetNode(input_name_raw); + if (expanded_nodes->count(input_node) == 0 && + is_candidate(*input_node)) { + to_visit.push(input_node); + } + } + } + if (collect_outputs) { + for (const NodeDef* output : node_map.GetOutputs(current_node->name())) { + if (expanded_nodes->count(output) == 0 && is_candidate(*output)) { + to_visit.push(output); + } + } + } + } +} + +struct RecomputedSubGraph { + std::unordered_set recomputed_source_nodes; + std::unordered_set target_nodes; +}; + +// Find groups of ops to recompute together based on `should_recompute`. +std::vector GetOpGroupsToRecompute( + const GraphDef* graph, const NodeMap& node_map, + const std::function& should_recompute) { + std::unordered_set visited_nodes; + std::vector subgraphs_to_recompute; + std::unordered_set candidate_recompute_nodes = + FindCandidateRecomputeNodes(node_map, graph, should_recompute); + for (const NodeDef* recompute_node : candidate_recompute_nodes) { + if (visited_nodes.count(recompute_node) > 0) { + continue; + } + RecomputedSubGraph current_recomputation; + // Build out recomputation groups by expanding to inexpensive-to-recompute + // nodes which do not feed target nodes. The goal is to capture some + // intermediate activations within this graph. + std::unordered_set unpruned_recompute_nodes; + unpruned_recompute_nodes.insert(recompute_node); + connected_subgraph(node_map, + true, // Collect inputs + true, // Collect outputs + should_recompute, &unpruned_recompute_nodes); + visited_nodes.insert(unpruned_recompute_nodes.begin(), + unpruned_recompute_nodes.end()); + for (const NodeDef* recompute_node : unpruned_recompute_nodes) { + bool inserted_feed = false; + for (NodeDef* output : node_map.GetOutputs(recompute_node->name())) { + if (IsTargetOp(*output)) { + current_recomputation.target_nodes.insert(output); + if (!inserted_feed) { + // Keep track of nodes which feed directly into a target node. These + // and nodes which feed into them will define the recomputed + // subgraph. + current_recomputation.recomputed_source_nodes.insert( + recompute_node); + inserted_feed = true; + } + } + } + } + // Recompute only nodes which eventually feed into a target node. + connected_subgraph(node_map, + true, // Collect inputs + false, // Collect outputs + [&unpruned_recompute_nodes](const NodeDef& node) { + return unpruned_recompute_nodes.count(&node) != 0; + }, + ¤t_recomputation.recomputed_source_nodes); + if (current_recomputation.target_nodes.empty()) { + continue; + } + subgraphs_to_recompute.push_back(current_recomputation); + } + return subgraphs_to_recompute; +} + +// Computes the maximum topological numbers of (1) target node components +// (gradient nodes being fed by the recomputation), and (2) child recompute node +// components for each recomputed node. We will not attach any control +// dependencies to a recomputation unless they have component numbers greater +// than this value (to prevent cycles). +std::unordered_map GetMaxDownstreamComponents( + const std::unordered_set& recomputed_source_nodes, + const std::unordered_set& target_nodes, const NodeMap& node_map, + const std::unordered_map& components) { + std::unordered_map recomputed_node_components; + // Start by setting component numbers to the maximum among target nodes. + for (const NodeDef* original_recompute_node : recomputed_source_nodes) { + int max_target_component = -1; + for (NodeDef* output : + node_map.GetOutputs(original_recompute_node->name())) { + if (target_nodes.count(output) != 0) { + int current_target_component = components.find(output)->second; + if (current_target_component > max_target_component) { + max_target_component = current_target_component; + } + } + } + if (max_target_component > -1) { + recomputed_node_components[original_recompute_node] = + max_target_component; + } + } + // Sort recomputed nodes topologically (based on the original graph) so we can + // efficiently assign to each node the maximum of its recomputed child + // components and its own targets. + std::vector recomputed_source_nodes_topological( + recomputed_source_nodes.begin(), recomputed_source_nodes.end()); + std::sort(recomputed_source_nodes_topological.begin(), + recomputed_source_nodes_topological.end(), + [&components](const NodeDef* first, const NodeDef* second) { + return components.find(first)->second < + components.find(second)->second; + }); + for (const NodeDef* original_recompute_node : + recomputed_source_nodes_topological) { + int max_component; + auto recomputed_component_iterator = + recomputed_node_components.find(original_recompute_node); + if (recomputed_component_iterator != recomputed_node_components.end()) { + max_component = recomputed_component_iterator->second; + } else { + max_component = -1; + } + for (NodeDef* output : + node_map.GetOutputs(original_recompute_node->name())) { + if (recomputed_source_nodes.count(output) == 0) { + continue; + } + auto child_component_iterator = recomputed_node_components.find(output); + CHECK(child_component_iterator != recomputed_node_components.end()); + int child_component = child_component_iterator->second; + if (child_component > max_component) { + max_component = child_component; + } + } + CHECK_GE(max_component, 0); + recomputed_node_components[original_recompute_node] = max_component; + } + return recomputed_node_components; +} + +// Modifies `graph`, adding trigger nodes and returning a mapping from +// `recomputed_source_nodes` to trigger nodes which will not create loops in the +// graph (using the component numberings in `components` and +// `recomputed_node_max_feed_components`). The copied nodes (not the nodes in +// recomputed_source_nodes, which are the originals) eventually get these +// control dependencies. +std::unordered_map +AddRecomputeControlDependencyNodes( + const std::unordered_set& recomputed_source_nodes, + const std::unordered_set& target_nodes, const NodeMap& node_map, + const std::unordered_map& components, + const std::unordered_map& + recomputed_node_max_feed_components, + GraphDef* graph) { + // Sort recomputed nodes based on max downstream components. + std::vector recomputed_source_nodes_topological( + recomputed_source_nodes.begin(), recomputed_source_nodes.end()); + std::sort(recomputed_source_nodes_topological.begin(), + recomputed_source_nodes_topological.end(), + [&recomputed_node_max_feed_components](const NodeDef* first, + const NodeDef* second) { + int first_component = + recomputed_node_max_feed_components.find(first)->second; + int second_component = + recomputed_node_max_feed_components.find(second)->second; + return first_component > second_component + // Ensure a consistent ordering. This is necessary because + // we're working not with node component numbers (which are + // unique) but with the maximum across nodes they feed into + // (very much not unique). + || (first_component == second_component && + first->name() > second->name()); + }); + // Create merged control dependency nodes by sorting target inputs + // topologically and zipper merging with the sorted recomputed nodes. + std::vector target_inputs_topological; + for (const NodeDef* target_node : target_nodes) { + for (const string& target_input_name_raw : target_node->input()) { + const NodeDef* target_input = node_map.GetNode(target_input_name_raw); + if (recomputed_source_nodes.count(target_input) != 0 || + components.find(target_node)->second == + components.find(target_input)->second) { + continue; + } + target_inputs_topological.push_back(target_input); + } + } + std::sort(target_inputs_topological.begin(), target_inputs_topological.end(), + [&components](const NodeDef* first, const NodeDef* second) { + return components.find(first)->second > + components.find(second)->second; + }); + auto target_input_iterator = target_inputs_topological.begin(); + NodeDef* current_trigger_node = nullptr; + std::unordered_map triggers; + for (const NodeDef* original_recomputed_node : + recomputed_source_nodes_topological) { + NodeDef* new_trigger_node = graph->add_node(); + new_trigger_node->set_name(AddPrefixToNodeName( + original_recomputed_node->name(), kRecomputeTriggerNodePrefix)); + new_trigger_node->set_op("NoOp"); + new_trigger_node->set_device(original_recomputed_node->device()); + if (current_trigger_node != nullptr) { + *new_trigger_node->add_input() = + strings::StrCat("^", current_trigger_node->name()); + } + current_trigger_node = new_trigger_node; + triggers[original_recomputed_node] = current_trigger_node; + for (; + target_input_iterator != target_inputs_topological.end() && + components.find(*target_input_iterator)->second > + recomputed_node_max_feed_components.find(original_recomputed_node) + ->second; + ++target_input_iterator) { + *current_trigger_node->add_input() = + strings::StrCat("^", (*target_input_iterator)->name()); + VLOG(2) << " Recomputation trigger " << current_trigger_node->name() + << " depends on " << (*target_input_iterator)->name(); + } + } + return triggers; +} string RecomputedOrOriginalNodeName( const std::unordered_set& recomputed_node_names, @@ -42,14 +359,28 @@ string RecomputedOrOriginalNodeName( } } +// Helper function to recompute a sub-graph (recomputed_source_nodes). Edges +// from recomputed_source_nodes to target_nodes are changed to start from the +// recomputed nodes. void RecomputeSubgraph( - const std::vector& recomputed_source_nodes, - const string& recompute_trigger_node_name, - const std::vector& target_nodes, GraphDef* graph) { + const std::unordered_set& recomputed_source_nodes, + const std::unordered_set& target_nodes, const NodeMap& node_map, + const std::unordered_map& components, + GraphDef* graph) { std::unordered_set recomputed_node_names; - for (const NodeDef* to_recompute : recomputed_source_nodes) { - recomputed_node_names.insert(to_recompute->name()); + VLOG(1) << "Recomputing a " << recomputed_source_nodes.size() + << " node subgraph"; + std::unordered_map recomputed_node_components = + GetMaxDownstreamComponents(recomputed_source_nodes, target_nodes, + node_map, components); + for (const NodeDef* original_node : recomputed_source_nodes) { + VLOG(2) << " " << original_node->name(); + recomputed_node_names.insert(original_node->name()); } + std::unordered_map triggers = + AddRecomputeControlDependencyNodes(recomputed_source_nodes, target_nodes, + node_map, components, + recomputed_node_components, graph); // Create the recomputed sub-graph for (const NodeDef* original_node : recomputed_source_nodes) { NodeDef* copied_node = graph->add_node(); @@ -64,10 +395,10 @@ void RecomputeSubgraph( *copied_node->add_input() = RecomputedOrOriginalNodeName( recomputed_node_names, original_input_name); } - // Set control dependencies on the recomputed nodes so that they are not run - // until the specified trigger runs. + // Each recomputed node gets a control dependency to prevent it from being + // recomputed immediately. *copied_node->add_input() = - strings::StrCat("^", recompute_trigger_node_name); + strings::StrCat("^", triggers[original_node]->name()); } // Set the inputs of nodes in the target subgraph to the recomputed nodes // where applicable. @@ -79,6 +410,52 @@ void RecomputeSubgraph( } } +void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level, + GraphDef* graph) { + // The topological numberings and NodeMap will be stale as soon as we start + // modifying the graph in RecomputeSubgraph. However, RecomputeSubgraph only + // looks up nodes which were in the original graph, and preserves the graph + // topology it's interested in. + // We don't use the results of this topological sort until later, but this + // call invalidates all NodeDef pointers, so it needs to be done before we + // start collecting those. + TopologicalSort(graph); + NodeMap node_map(graph); + std::vector recomputed_subgraphs; + if (optimization_level == RewriterConfig::HEURISTICS) { + // TODO(allenl): Handle ResNet-like architectures better. Right now all of + // the cheap forward ops get grouped into a single subgraph which must + // execute before gradients start executing (unless layers are manually + // separated by identity ops). + std::unordered_set cheap_to_recompute_ops = + GetCheapToRecomputeOps(); + recomputed_subgraphs = GetOpGroupsToRecompute( + graph, node_map, [&cheap_to_recompute_ops](const NodeDef& node) { + return !IsTargetOp(node) && + (cheap_to_recompute_ops.count(node.op()) > 0 || + node.attr().count(kRecomputeHint) > 0); + }); + } else { // optimization_level == RewriterConfig::MANUAL + recomputed_subgraphs = + GetOpGroupsToRecompute(graph, node_map, [](const NodeDef& node) { + return !IsTargetOp(node) && node.attr().count(kRecomputeHint) > 0; + }); + } + if (!recomputed_subgraphs.empty()) { + std::unordered_map topological_numbering; + for (int node_number = 0; node_number < graph->node().size(); + ++node_number) { + topological_numbering[graph->mutable_node(node_number)] = + graph->node().size() - node_number - 1; + } + // Duplicate the indicated sub-graphs and set up control dependencies + for (const RecomputedSubGraph& subgraph : recomputed_subgraphs) { + RecomputeSubgraph(subgraph.recomputed_source_nodes, subgraph.target_nodes, + node_map, topological_numbering, graph); + } + } +} + std::pair BuildSwapPair(NodeDef* node, int input_to_swap, GraphDef* graph) { string tensor_to_swap = strings::StrCat(node->name(), "_", input_to_swap); @@ -205,6 +582,8 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { *optimized_graph = item.graph; + RecomputationRewritingPass(optimization_level_, optimized_graph); + // Figure out what needs to be swapped; std::unordered_map nodes_to_swap; for (auto& node : *optimized_graph->mutable_node()) { diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.h b/tensorflow/core/grappler/optimizers/memory_optimizer.h index dfb24c05c99..5b7ba4001f0 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.h +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.h @@ -16,9 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_ #define TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_ -#include - #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { namespace grappler { @@ -26,7 +25,8 @@ namespace grappler { // Swap tensors in and out of device memory. class MemoryOptimizer : public GraphOptimizer { public: - MemoryOptimizer() {} + explicit MemoryOptimizer(RewriterConfig::MemOptType optimization_level) + : optimization_level_(optimization_level) {} ~MemoryOptimizer() override {} string name() const override { return "memory_optimizer"; }; @@ -36,15 +36,10 @@ class MemoryOptimizer : public GraphOptimizer { void Feedback(Cluster* cluster, const GrapplerItem& item, const GraphDef& pruned_graph, double result) override; -}; -// Helper function to recompute a sub-graph (recomputed_source_nodes) on a -// trigger. Edges from recomputed_source_nodes to target_nodes are changed to -// start from the recomputed nodes. -void RecomputeSubgraph( - const std::vector& recomputed_source_nodes, - const string& recompute_trigger_node_name, - const std::vector& target_nodes, GraphDef* graph); + private: + RewriterConfig::MemOptType optimization_level_; +}; } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc index a4f8e22e1d8..27cad07d5bf 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc @@ -35,89 +35,108 @@ TEST_F(RecomputeSubgraphTest, SimpleSubgraph) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const(s.WithOpName("a"), 1.f, {2, 3, 4}); - Output b = ops::AddN(s.WithOpName("b"), {a}); // Recomputed - Output c = ops::AddN(s.WithOpName("c"), {b}); - Output d = ops::AddN(s.WithOpName("d"), {c}); - Output e = ops::AddN(s.WithOpName("e"), {d, b}); - Output f = ops::AddN(s.WithOpName("f"), {e, a}); + Output b = ops::Identity(s.WithOpName("b"), a); // Recomputed + Output c = ops::Identity(s.WithOpName("c"), b); + Output d = ops::AddN(s.WithOpName("gradients/d"), {c}); + Output e = ops::AddN(s.WithOpName("gradients/e"), {d, b}); + Output f = ops::AddN(s.WithOpName("gradients/f"), {e, a}); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); EXPECT_EQ(6, item.graph.node_size()); NodeMap pre_transform_node_map(&item.graph); - std::vector recomputed_source_nodes; - recomputed_source_nodes.push_back(pre_transform_node_map.GetNode(b.name())); - std::vector target_nodes; - target_nodes.push_back(pre_transform_node_map.GetNode(e.name())); - RecomputeSubgraph(recomputed_source_nodes, d.name(), target_nodes, - &item.graph); - NodeMap post_transform_node_map(&item.graph); - EXPECT_EQ(7, item.graph.node_size()); + (*pre_transform_node_map.GetNode("b")->mutable_attr())["_recompute_hint"] + .set_i(0); + + MemoryOptimizer optimizer(RewriterConfig::MANUAL); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + + TF_EXPECT_OK(status); + NodeMap post_transform_node_map(&output); + EXPECT_EQ(8, output.node_size()); NodeDef* transformed_e = post_transform_node_map.GetNode(e.name()); EXPECT_EQ(2, transformed_e->input_size()); - EXPECT_EQ("d", transformed_e->input(0)); + EXPECT_EQ("gradients/d", transformed_e->input(0)); EXPECT_EQ("Recomputed/b", transformed_e->input(1)); NodeDef* recomputed_b = post_transform_node_map.GetNode("Recomputed/b"); EXPECT_EQ(2, recomputed_b->input_size()); EXPECT_EQ("a", recomputed_b->input(0)); - EXPECT_EQ("^d", recomputed_b->input(1).substr(0, 2)); + EXPECT_EQ("^RecomputeTrigger/b", recomputed_b->input(1)); + NodeDef* recompute_trigger = + post_transform_node_map.GetNode("RecomputeTrigger/b"); + EXPECT_EQ(1, recompute_trigger->input_size()); + EXPECT_EQ("^gradients/d", recompute_trigger->input(0)); } TEST_F(RecomputeSubgraphTest, MultiNode) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const(s.WithOpName("Conv"), 1.f, {2, 3, 4}); - Output b = ops::AddN(s.WithOpName("BN"), {a}); // Recomputed - Output c = ops::AddN(s.WithOpName("ReLU"), {b}); // Recomputed - Output d = ops::AddN(s.WithOpName("Conv1"), {c}); + Output b = ops::Identity(s.WithOpName("BN"), a); // Recomputed + Output c = ops::Identity(s.WithOpName("ReLU"), b); // Recomputed + Output d = ops::Identity(s.WithOpName("Conv1"), c); - Output trigger = ops::Const(s.WithOpName("BN1Grad"), 0.f, {2, 3, 4}); - Output e = ops::AddN(s.WithOpName("Conv1Grad"), {trigger, c}); - Output f = ops::AddN(s.WithOpName("ReLUGrad"), {e, c}); - Output g = ops::AddN(s.WithOpName("BNGrad"), {f, a}); - Output h = ops::AddN(s.WithOpName("ConvGrad"), {g}); + // The "gradients/" prefix means the heuristic will pick these up as + // candidates to have their inputs recomputed. + Output trigger = ops::AddN(s.WithOpName("gradients/BN1Grad"), {d}); + Output e = ops::AddN(s.WithOpName("gradients/Conv1Grad"), {trigger, c}); + Output f = ops::AddN(s.WithOpName("gradients/ReLUGrad"), {e, c}); + Output g = ops::AddN(s.WithOpName("gradients/BNGrad"), {f, a}); + Output h = ops::AddN(s.WithOpName("gradients/ConvGrad"), {g}); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); EXPECT_EQ(9, item.graph.node_size()); NodeMap pre_transform_node_map(&item.graph); - std::vector recomputed_source_nodes; - recomputed_source_nodes.push_back(pre_transform_node_map.GetNode(b.name())); - recomputed_source_nodes.push_back(pre_transform_node_map.GetNode(c.name())); - std::vector target_nodes; - target_nodes.push_back(pre_transform_node_map.GetNode(e.name())); - target_nodes.push_back(pre_transform_node_map.GetNode(f.name())); - target_nodes.push_back(pre_transform_node_map.GetNode(g.name())); - RecomputeSubgraph(recomputed_source_nodes, trigger.name(), target_nodes, - &item.graph); - NodeMap post_transform_node_map(&item.graph); - EXPECT_EQ(11, item.graph.node_size()); + // Set op types so that the heuristic will pick these nodes up to be + // recomputed + pre_transform_node_map.GetNode("BN")->set_op("FusedBatchNorm"); + pre_transform_node_map.GetNode("ReLU")->set_op("Relu"); + + MemoryOptimizer optimizer(RewriterConfig::HEURISTICS); + GraphDef first_pass_output; + Status first_pass_status = + optimizer.Optimize(nullptr, item, &first_pass_output); + TF_EXPECT_OK(first_pass_status); + + NodeMap post_transform_node_map(&first_pass_output); + EXPECT_EQ(13, first_pass_output.node_size()); NodeDef* transformed_e = post_transform_node_map.GetNode(e.name()); EXPECT_EQ(2, transformed_e->input_size()); - EXPECT_EQ("BN1Grad", transformed_e->input(0)); + EXPECT_EQ("gradients/BN1Grad", transformed_e->input(0)); EXPECT_EQ("Recomputed/ReLU", transformed_e->input(1)); NodeDef* transformed_f = post_transform_node_map.GetNode(f.name()); EXPECT_EQ(2, transformed_f->input_size()); - EXPECT_EQ("Conv1Grad", transformed_f->input(0)); + EXPECT_EQ("gradients/Conv1Grad", transformed_f->input(0)); EXPECT_EQ("Recomputed/ReLU", transformed_f->input(1)); NodeDef* transformed_g = post_transform_node_map.GetNode(g.name()); EXPECT_EQ(2, transformed_g->input_size()); - EXPECT_EQ("ReLUGrad", transformed_g->input(0)); + EXPECT_EQ("gradients/ReLUGrad", transformed_g->input(0)); EXPECT_EQ("Conv", transformed_g->input(1)); NodeDef* recomputed_b = post_transform_node_map.GetNode("Recomputed/BN"); EXPECT_EQ(2, recomputed_b->input_size()); EXPECT_EQ("Conv", recomputed_b->input(0)); - EXPECT_EQ("^BN1Grad", recomputed_b->input(1).substr(0, 8)); + EXPECT_EQ("^RecomputeTrigger/BN", recomputed_b->input(1)); + NodeDef* recompute_trigger_b = + post_transform_node_map.GetNode("RecomputeTrigger/BN"); + EXPECT_EQ(1, recompute_trigger_b->input_size()); + EXPECT_EQ("^RecomputeTrigger/ReLU", recompute_trigger_b->input(0)); + NodeDef* recomputed_c = post_transform_node_map.GetNode("Recomputed/ReLU"); EXPECT_EQ(2, recomputed_c->input_size()); EXPECT_EQ("Recomputed/BN", recomputed_c->input(0)); - EXPECT_EQ("^BN1Grad", recomputed_c->input(1).substr(0, 8)); + EXPECT_EQ("^RecomputeTrigger/ReLU", recomputed_c->input(1)); + NodeDef* recompute_trigger_c = + post_transform_node_map.GetNode("RecomputeTrigger/ReLU"); + EXPECT_EQ(1, recompute_trigger_c->input_size()); + EXPECT_EQ("^gradients/BN1Grad", recompute_trigger_c->input(0)); } class MemoryOptimizerTest : public ::testing::Test { public: - static VirtualCluster CreateVirtualCluster() { + static std::unique_ptr CreateVirtualCluster() { DeviceProperties cpu_device; cpu_device.set_type("CPU"); cpu_device.set_frequency(1000); @@ -125,7 +144,7 @@ class MemoryOptimizerTest : public ::testing::Test { cpu_device.set_bandwidth(32); std::unordered_map devices; devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device; - return VirtualCluster(devices); + return std::unique_ptr(new VirtualCluster(devices)); } }; @@ -148,11 +167,11 @@ TEST_F(MemoryOptimizerTest, SimpleSwapping) { (*item.graph.mutable_node(4)->mutable_attr())["_swap_to_host"]; val.mutable_list()->add_i(0); - VirtualCluster cluster(CreateVirtualCluster()); + std::unique_ptr cluster(CreateVirtualCluster()); - MemoryOptimizer optimizer; + MemoryOptimizer optimizer(RewriterConfig::MANUAL); GraphDef output; - Status status = optimizer.Optimize(&cluster, item, &output); + Status status = optimizer.Optimize(cluster.get(), item, &output); TF_EXPECT_OK(status); EXPECT_EQ(7, output.node_size()); diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 8bb7800df4e..4007c8802b5 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/meta_optimizer.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/grappler/optimizers/auto_parallel.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" @@ -41,7 +42,7 @@ std::unique_ptr MetaOptimizer::NewOptimizer( graph_optimizer.reset(new LayoutOptimizer()); } if (optimizer == "memory") { - graph_optimizer.reset(new MemoryOptimizer()); + graph_optimizer.reset(new MemoryOptimizer(RewriterConfig::MANUAL)); } if (optimizer == "autoparallel") { graph_optimizer.reset( @@ -66,8 +67,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, std::unique_ptr(new LayoutOptimizer())); } if (cfg_.memory_optimization() > 0) { - optimizers.push_back( - std::unique_ptr(new MemoryOptimizer())); + optimizers.push_back(std::unique_ptr( + new MemoryOptimizer(cfg_.memory_optimization()))); } if (cfg_.auto_parallel().enable()) { optimizers.push_back(std::unique_ptr( @@ -101,8 +102,14 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } } TopologicalSort(optimized_graph); - // Copy the graph version. - *optimized_graph->mutable_versions() = item.graph.versions(); + + // Make sure that the optimizers preserved the graph version and library. + DCHECK_GE(optimized_graph->library().function_size(), + item.graph.library().function_size()); + DCHECK_GE(optimized_graph->library().gradient_size(), + item.graph.library().gradient_size()); + DCHECK_EQ(optimized_graph->versions().producer(), + item.graph.versions().producer()); return Status::OK(); } @@ -114,7 +121,8 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item, bool MetaOptimizerEnabled(const RewriterConfig& cfg) { return cfg.optimize_tensor_layout() || cfg.constant_folding() || - cfg.auto_parallel().enable() || !cfg.optimizers().empty(); + cfg.auto_parallel().enable() || cfg.memory_optimization() > 0 || + !cfg.optimizers().empty(); } Status RunMetaOptimizer(const GrapplerItem& item, const RewriterConfig& cfg, diff --git a/tensorflow/core/grappler/optimizers/model_pruner.cc b/tensorflow/core/grappler/optimizers/model_pruner.cc index efa21638369..34fb5b92c75 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/model_pruner.h" #include +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/optimizers/graph_rewriter.h" #include "tensorflow/core/grappler/utils.h" @@ -73,6 +75,9 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, << " nodes from the graph. The graph now contains " << pruned_graph->node_size() << " nodes."; + *pruned_graph->mutable_library() = item.graph.library(); + *pruned_graph->mutable_versions() = item.graph.versions(); + return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/static_schedule_test.cc b/tensorflow/core/grappler/optimizers/static_schedule_test.cc index c932c90765e..95a745be21a 100644 --- a/tensorflow/core/grappler/optimizers/static_schedule_test.cc +++ b/tensorflow/core/grappler/optimizers/static_schedule_test.cc @@ -29,7 +29,7 @@ namespace { class StaticScheduleTest : public ::testing::Test { public: - VirtualCluster CreateVirtualCluster() const { + std::unique_ptr CreateVirtualCluster() const { // Invent a CPU so that predictions remain the same from machine to machine. DeviceProperties cpu_device; cpu_device.set_type("CPU"); @@ -41,7 +41,7 @@ class StaticScheduleTest : public ::testing::Test { cpu_device.set_l3_cache_size(4 * 1024 * 1024); std::unordered_map devices; devices["/job:localhost/replica:0/task:0/cpu:0"] = cpu_device; - return VirtualCluster(devices); + return std::unique_ptr(new VirtualCluster(devices)); } }; @@ -51,11 +51,11 @@ TEST_F(StaticScheduleTest, BasicGraph) { GrapplerItem item; CHECK(fake_input.NextItem(&item)); - VirtualCluster cluster(CreateVirtualCluster()); + std::unique_ptr cluster(CreateVirtualCluster()); std::unordered_map completion_times; Status status = - EstimateEarliestExecutionTimes(item, &cluster, &completion_times); + EstimateEarliestExecutionTimes(item, cluster.get(), &completion_times); TF_EXPECT_OK(status); EXPECT_EQ(item.graph.node_size(), completion_times.size()); @@ -97,11 +97,11 @@ TEST_F(StaticScheduleTest, BasicGraphWithCtrlDependencies) { EXPECT_EQ("e", item.graph.node(4).name()); *item.graph.mutable_node(4)->add_input() = "^c"; - VirtualCluster cluster(CreateVirtualCluster()); + std::unique_ptr cluster(CreateVirtualCluster()); std::unordered_map completion_times; Status status = - EstimateEarliestExecutionTimes(item, &cluster, &completion_times); + EstimateEarliestExecutionTimes(item, cluster.get(), &completion_times); TF_EXPECT_OK(status); EXPECT_EQ(item.graph.node_size(), completion_times.size()); diff --git a/tensorflow/core/grappler/utils/topological_sort.cc b/tensorflow/core/grappler/utils/topological_sort.cc index 131756fc5c2..fdf7fb0d3d7 100644 --- a/tensorflow/core/grappler/utils/topological_sort.cc +++ b/tensorflow/core/grappler/utils/topological_sort.cc @@ -56,7 +56,7 @@ void TopologicalSort(GraphDef* graph) { ready_nodes.pop_front(); } if (sorted_graph.node_size() == graph->node_size()) { - *graph = sorted_graph; + graph->mutable_node()->Swap(sorted_graph.mutable_node()); } } diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index b8370a96a85..2bfbad5dc4c 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1936,15 +1936,6 @@ tf_kernel_library( deps = IO_DEPS, ) -# TODO(jhseu): Restore after merge. -#tf_kernel_library( -# name = "lmdb_reader_op", -# prefix = "lmdb_reader_op", -# deps = IO_DEPS + [ -# "@lmdb", -# ], -#) - tf_kernel_library( name = "lmdb_reader_op", prefix = "lmdb_reader_op", @@ -2355,7 +2346,7 @@ tf_kernel_library( tf_kernel_library( name = "cwise_op", prefix = "cwise_op", - deps = MATH_DEPS, + deps = MATH_DEPS + ["//tensorflow/core:bitwise_ops_op_lib"], ) tf_kernel_library( @@ -4171,6 +4162,9 @@ filegroup( "cwise_op_abs.cc", "cwise_op_add_1.cc", "cwise_op_add_2.cc", + "cwise_op_bitwise_and.cc", + "cwise_op_bitwise_or.cc", + "cwise_op_bitwise_xor.cc", "cwise_op_div.cc", "cwise_op_equal_to_1.cc", "cwise_op_equal_to_2.cc", @@ -4179,6 +4173,7 @@ filegroup( "cwise_op_floor_div.cc", "cwise_op_greater.cc", "cwise_op_greater_equal.cc", + "cwise_op_invert.cc", "cwise_op_isfinite.cc", "cwise_op_less.cc", "cwise_op_log.cc", @@ -4990,7 +4985,6 @@ tf_kernel_library( name = "remote_fused_graph_ops", prefix = "remote_fused_graph_execute_op", deps = [ - ":remote_fused_graph_execute_op", ":remote_fused_graph_execute_utils", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -5000,18 +4994,6 @@ tf_kernel_library( ], ) -cc_library( - name = "remote_fused_graph_execute_op", - srcs = ["remote_fused_graph_execute_op.cc"], - deps = [ - ":remote_fused_graph_execute_utils", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - ], -) - cc_library( name = "remote_fused_graph_execute_utils", srcs = ["remote_fused_graph_execute_utils.cc"], @@ -5080,9 +5062,9 @@ tf_cc_test( deps = [ ":ops_testutil", ":ops_util", - ":remote_fused_graph_execute_op", ":remote_fused_graph_execute_op_test_utils", ":remote_fused_graph_execute_utils", + ":remote_fused_graph_ops", "//tensorflow/cc:cc_ops", "//tensorflow/cc:ops", "//tensorflow/cc:scope", diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc index 2307c2de0e6..b2bbc5831ca 100644 --- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc @@ -272,6 +272,75 @@ __global__ void SwapDimension1And2InTensor3UsingTiles(const T* input, } } +// Use shared memory tiles to swap dimension-1 and dimension-2 of a 3D tensor +// when only one of the dimension sizes is smaller than 16, +// where dimensions are zero-based: output[i][j][k] = input[i][k][j]. +// +// small_dim = the_smaller_dimension_size +// large_dim = the_larger_dimension_size +// tile_num_per_block = blockDim.x +// kTileLength = small_dim +// +// Each thread block operates on a single rectangle tile, where its width is +// kTileLength (we currently set it to 64) and its height is small_dim, +// We set the thread block's X dimension to be tile_num_per_block, and its Y +// and Z to be one. +template +__global__ void SwapDimension1And2InTensor3SmallDim(const T* input, + Dimension<3> input_dims, + T* output) { + // TODO(yangzihao) avoid share memory bank conflict. + extern __shared__ __align__(sizeof(T)) unsigned char shmem[]; + T* shared_memory_tile = reinterpret_cast(shmem); + + eigen_assert(blockDim.y == 1); + eigen_assert(blockDim.z == 1); + eigen_assert(gridDim.z == 1); + + int x = threadIdx.x; + int tile_height = blockDim.x; + + // Get tile height, width, and thread/block origin indices. + int small_dim = SmallDim2 ? input_dims[2] : input_dims[1]; + int large_dim = SmallDim2 ? input_dims[1] : input_dims[2]; + int block_origin_idx = small_dim * large_dim * blockIdx.y; + int block_offset = blockIdx.x * blockDim.x; + int thread_origin_idx = + block_origin_idx + (SmallDim2 ? block_offset * small_dim : block_offset) + + x; + + if (block_offset + blockDim.x > large_dim) { + tile_height = large_dim - block_offset; + } + + // Load a continous memory region to shared memory tile. + if (x < tile_height) { + for (int y = 0; y < small_dim; y++) { + int shmem_index = SmallDim2 ? (x + y * tile_height) : (x * small_dim + y); + shared_memory_tile[shmem_index] = + ldg(input + thread_origin_idx + + y * (SmallDim2 ? tile_height : large_dim)); + } + } + + __syncthreads(); + + // Get block origin index for output array. + int output_block_offset = block_origin_idx; + int output_block_idx = SmallDim2 ? block_offset : block_offset * small_dim; + int output_block_origin_idx = output_block_offset + output_block_idx; + + // Store the tranposed memory region in shared memory to device. + if (x < tile_height) { + for (int y = 0; y < small_dim; y++) { + int output_idx = output_block_origin_idx + x + + y * (SmallDim2 ? large_dim : tile_height); + int shmem_index = SmallDim2 ? (x * small_dim + y) : (x + y * tile_height); + output[output_idx] = shared_memory_tile[shmem_index]; + } + } +} + // A Cuda custom kernel that convert input to output, given proper padding on // the left and the top. The padded value is zero. template @@ -420,25 +489,60 @@ template void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input, const Dimension<3>& input_dims, T* output) { // If both dimensions are not trivial, use tiles for the actual swapping. + // If one dimension is trivial, use SmallDim kernel for swapping. // Otherwise, the trivial swapping relying on the ldg cache is more efficient. static const int kMinDimensionToUseTiles = 16; bool use_tiles = (input_dims[1] >= kMinDimensionToUseTiles && input_dims[2] >= kMinDimensionToUseTiles); + bool use_small_dim = ((input_dims[1] >= kMinDimensionToUseTiles && + input_dims[2] < kMinDimensionToUseTiles)) || + ((input_dims[1] < kMinDimensionToUseTiles && + input_dims[2] >= kMinDimensionToUseTiles)); + + static const int NumSubTiles = 8; if (use_tiles) { - // We get best performance when TileSize is the number of threads in a warp - // (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256 - // threads. static const int TileSize = 32; - static const int NumSubTiles = 8; Dimension<3> input_dims_in_tiles = { input_dims[0], (input_dims[1] + TileSize - 1) / TileSize, (input_dims[2] + TileSize - 1) / TileSize, }; int total_tiles_count = input_dims_in_tiles[0] * input_dims_in_tiles[1] * input_dims_in_tiles[2]; + // We get best performance when TileSize is the number of threads in a warp + // (32 on our GPUs) and NumSubTiles is 8, so our block size is 8 * 32 = 256 + // threads. SwapDimension1And2InTensor3UsingTiles<<< total_tiles_count, dim3(TileSize, NumSubTiles), 0, d.stream()>>>( input, input_dims, output); + } else if (use_small_dim) { + // When only one of the dimensions is smaller than kMinDimensionToUseTiles, + // we use one block to process a rectangle region with the size of + // kTileLength * small_dim. We found that when set kTileLength to 64 on + // TitanX Maxwell GPU, it achieves the best performance. + // large_dim + // +---------------...--------+ + // | | | | + // small_dim | | ... | | + // | | | | + // +--------------...---------+ + // \----- ------/ \- -/ + // V V + // kTileLength(tile_height) tile_height + static const int kTileLength = 64; + int small_dim = std::min(input_dims[2], input_dims[1]); + int large_dim = std::max(input_dims[2], input_dims[1]); + int tile_num_per_block = (large_dim + kTileLength - 1) / kTileLength; + if (input_dims[2] < input_dims[1]) { + SwapDimension1And2InTensor3SmallDim + <<>>( + input, input_dims, output); + } else { + SwapDimension1And2InTensor3SmallDim + <<>>( + input, input_dims, output); + } } else { int total_element_count = input_dims[0] * input_dims[1] * input_dims[2]; CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d); diff --git a/tensorflow/core/kernels/cwise_op_bitwise_and.cc b/tensorflow/core/kernels/cwise_op_bitwise_and.cc new file mode 100644 index 00000000000..017a2182dcf --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_bitwise_and.cc @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER6(BinaryOp, CPU, "BitwiseAnd", functor::bitwise_and, int8, int16, int32, + int64, uint8, uint16); + +#if TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("BitwiseAnd").Device(DEVICE_SYCL).TypeConstraint("T"), \ + BinaryOp>); +REGISTER_SYCL_KERNEL(int8); +REGISTER_SYCL_KERNEL(int16); +REGISTER_SYCL_KERNEL(int32); +REGISTER_SYCL_KERNEL(int64); +REGISTER_SYCL_KERNEL(uint8); +REGISTER_SYCL_KERNEL(uint16); +#undef REGISTER_SYCL_KERNEL + +#endif // TENSORFLOW_USE_SYCL + +#if GOOGLE_CUDA +REGISTER6(BinaryOp, GPU, "BitwiseAnd", functor::bitwise_and, int8, int16, int32, + int64, uint8, uint16); +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_bitwise_or.cc b/tensorflow/core/kernels/cwise_op_bitwise_or.cc new file mode 100644 index 00000000000..36f45fe92df --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_bitwise_or.cc @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER6(BinaryOp, CPU, "BitwiseOr", functor::bitwise_or, int8, int16, int32, + int64, uint8, uint16); + +#if TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("BitwiseOr").Device(DEVICE_SYCL).TypeConstraint("T"), \ + BinaryOp>); +REGISTER_SYCL_KERNEL(int8); +REGISTER_SYCL_KERNEL(int16); +REGISTER_SYCL_KERNEL(int32); +REGISTER_SYCL_KERNEL(int64); +REGISTER_SYCL_KERNEL(uint8); +REGISTER_SYCL_KERNEL(uint16); +#undef REGISTER_SYCL_KERNEL + +#endif // TENSORFLOW_USE_SYCL + +#if GOOGLE_CUDA +REGISTER6(BinaryOp, GPU, "BitwiseOr", functor::bitwise_or, int8, int16, int32, + int64, uint8, uint16); +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_bitwise_xor.cc b/tensorflow/core/kernels/cwise_op_bitwise_xor.cc new file mode 100644 index 00000000000..36432d851d9 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_bitwise_xor.cc @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER6(BinaryOp, CPU, "BitwiseXor", functor::bitwise_xor, int8, int16, int32, + int64, uint8, uint16); + +#if TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("BitwiseXor").Device(DEVICE_SYCL).TypeConstraint("T"), \ + BinaryOp>); +REGISTER_SYCL_KERNEL(int8); +REGISTER_SYCL_KERNEL(int16); +REGISTER_SYCL_KERNEL(int32); +REGISTER_SYCL_KERNEL(int64); +REGISTER_SYCL_KERNEL(uint8); +REGISTER_SYCL_KERNEL(uint16); +#undef REGISTER_SYCL_KERNEL + +#endif // TENSORFLOW_USE_SYCL + +#if GOOGLE_CUDA +REGISTER6(BinaryOp, GPU, "BitwiseXor", functor::bitwise_xor, int8, int16, int32, + int64, uint8, uint16); +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/tensorboard/components/tf_imports/graphlib.html b/tensorflow/core/kernels/cwise_op_gpu_bitwise_and.cu.cc similarity index 56% rename from tensorflow/tensorboard/components/tf_imports/graphlib.html rename to tensorflow/core/kernels/cwise_op_gpu_bitwise_and.cu.cc index 664b855f17f..27f973c90d7 100644 --- a/tensorflow/tensorboard/components/tf_imports/graphlib.html +++ b/tensorflow/core/kernels/cwise_op_gpu_bitwise_and.cu.cc @@ -1,6 +1,4 @@ - +==============================================================================*/ - +#if GOOGLE_CUDA - +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY6(bitwise_and, int8, int16, int32, int64, uint8, uint16); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tensorboard-color.html b/tensorflow/core/kernels/cwise_op_gpu_bitwise_or.cu.cc similarity index 56% rename from tensorflow/tensorboard/components/tf_dashboard_common/tensorboard-color.html rename to tensorflow/core/kernels/cwise_op_gpu_bitwise_or.cu.cc index 7f9ca646148..a34c3a52cd6 100644 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tensorboard-color.html +++ b/tensorflow/core/kernels/cwise_op_gpu_bitwise_or.cu.cc @@ -1,6 +1,4 @@ - +==============================================================================*/ - +#if GOOGLE_CUDA - +#endif // GOOGLE_CUDA diff --git a/tensorflow/tensorboard/components/tf_storage/tf-storage.html b/tensorflow/core/kernels/cwise_op_gpu_bitwise_xor.cu.cc similarity index 56% rename from tensorflow/tensorboard/components/tf_storage/tf-storage.html rename to tensorflow/core/kernels/cwise_op_gpu_bitwise_xor.cu.cc index ff3f7b0ad4a..a4531ab7c6f 100644 --- a/tensorflow/tensorboard/components/tf_storage/tf-storage.html +++ b/tensorflow/core/kernels/cwise_op_gpu_bitwise_xor.cu.cc @@ -1,6 +1,4 @@ - +==============================================================================*/ - - +#if GOOGLE_CUDA - +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY6(bitwise_xor, int8, int16, int32, int64, uint8, uint16); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/tensorboard/components/vz_projector/test/assert.ts b/tensorflow/core/kernels/cwise_op_gpu_invert.cu.cc similarity index 60% rename from tensorflow/tensorboard/components/vz_projector/test/assert.ts rename to tensorflow/core/kernels/cwise_op_gpu_invert.cu.cc index f489517a7f2..62f33612db0 100644 --- a/tensorflow/tensorboard/components/vz_projector/test/assert.ts +++ b/tensorflow/core/kernels/cwise_op_gpu_invert.cu.cc @@ -1,10 +1,10 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -13,4 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -const assert = chai.assert; +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_UNARY6(invert, int8, int16, int32, int64, uint8, uint16); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_invert.cc b/tensorflow/core/kernels/cwise_op_invert.cc new file mode 100644 index 00000000000..c84ee6894eb --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_invert.cc @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER6(UnaryOp, CPU, "Invert", functor::invert, int8, int16, int32, int64, + uint8, uint16); + +#ifdef TENSORFLOW_USE_SYCL +REGISTER(UnaryOp, SYCL, "Invert", functor::invert, int8, int16, int32, int64, + uint8, uint16); +#endif // TENSORFLOW_USE_SYCL + +#if GOOGLE_CUDA +REGISTER6(UnaryOp, GPU, "Invert", functor::invert, int8, int16, int32, int64, + uint8, uint16); +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index 6d80e4bfc1d..c11d6cfabb3 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -528,6 +528,17 @@ struct atan : base > {}; struct logical_not : base > { }; +// Flip all bits. Named invert to be consistent with numpy. +template +struct invert_op { + EIGEN_EMPTY_STRUCT_CTOR(invert_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a) const { + return ~a; + } +}; + +template +struct invert : base> {}; // NOTE: std::isinf, std::isnan, std::isfinite are plain function. // Therefore we need to wrap them in functors to be used with Eigen's @@ -708,6 +719,42 @@ struct logical_and : base {}; struct logical_or : base {}; +template +struct bitwise_and_op { + EIGEN_EMPTY_STRUCT_CTOR(bitwise_and_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, + const T& y) const { + return x & y; + } +}; + +template +struct bitwise_or_op { + EIGEN_EMPTY_STRUCT_CTOR(bitwise_or_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, + const T& y) const { + return x | y; + } +}; + +template +struct bitwise_xor_op { + EIGEN_EMPTY_STRUCT_CTOR(bitwise_xor_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& x, + const T& y) const { + return x ^ y; + } +}; + +template +struct bitwise_and : base> {}; + +template +struct bitwise_or : base> {}; + +template +struct bitwise_xor : base> {}; + template struct make_complex_func { typedef std::complex result_type; diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc index be9fc5de693..f63a99a7308 100644 --- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc +++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc @@ -59,6 +59,10 @@ EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dBackpropFilterGPUSmall( args.filter_rows * args.filter_cols <= args.in_cols * block_rows; } +// The DepthwiseConv2dGPUKernels perform either forward or backprop input +// convolution depending on a template argument of this enum. +enum DepthwiseConv2dDirection { DIRECTION_FORWARD, DIRECTION_BACKWARD }; + // A Cuda kernel to compute the depthwise convolution forward pass // in NHWC format. template +// Backprop input direction is the same as forward direction with the filter +// rotated by 180°. +template __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall( const DepthwiseArgs args, const T* input, const T* filter, T* output) { assert(CanLaunchDepthwiseConv2dGPUSmall(args)); @@ -217,7 +224,9 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall( const int max_depth = in_depth - thread_depth; const int filter_write_offset = thread_pix < filter_pixels ? tile_size + thread_idx : 0; - const int filter_read_offset = tile_size + thread_depth; + const int filter_read_offset = + tile_size + thread_depth + + (kDirection == DIRECTION_FORWARD ? 0 : filter_pixels * kBlockSlices); const bool skip_second = !kKnownEvenRows && thread_row + (in_rows & 1) == block_rows; @@ -253,12 +262,17 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall( const T* filter_ptr = filter_read_offset + shared_data; UNROLL for (int r = 0; r < filter_rows; ++r) { UNROLL for (int c = 0; c < filter_cols; ++c) { + if (kDirection == DIRECTION_BACKWARD) { + filter_ptr -= kBlockSlices; + } const T filter_value = *filter_ptr; const T* const tile_ptr = shared_offset + shared_data; sum1 += filter_value * tile_ptr[0]; sum2 += filter_value * tile_ptr[tile_offset]; shared_offset += kBlockSlices; - filter_ptr += kBlockSlices; + if (kDirection == DIRECTION_FORWARD) { + filter_ptr += kBlockSlices; + } } shared_offset += in_increment; } @@ -408,8 +422,11 @@ __global__ void __launch_bounds__(1024, 2) // Tiles of the input and filter tensors are loaded into shared memory before // performing the convolution. Each thread handles two elements per iteration, // one each in the lower and upper half of a tile. -template +// Backprop input direction is the same as forward direction with the filter +// rotated by 180°. +template __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( const DepthwiseArgs args, const T* input, const T* filter, T* output) { assert(CanLaunchDepthwiseConv2dGPUSmall(args)); @@ -480,7 +497,9 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( const int max_slice = in_slices - thread_depth; const int filter_write_offset = filter_pix < filter_pixels ? tile_size + thread_idx : 0; - const int filter_read_offset = tile_size + thread_depth; + const int filter_read_offset = + tile_size + thread_depth + + (kDirection == DIRECTION_FORWARD ? 0 : filter_pixels * kBlockSlices); const bool skip_second = !kKnownEvenRows && thread_row + (in_rows & 1) == block_rows; @@ -514,12 +533,17 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( const T* filter_ptr = filter_read_offset + shared_data; UNROLL for (int r = 0; r < filter_rows; ++r) { UNROLL for (int c = 0; c < filter_cols; ++c) { + if (kDirection == DIRECTION_BACKWARD) { + filter_ptr -= kBlockSlices; + } const T filter_value = *filter_ptr; const T* const tile_ptr = shared_offset + shared_data; sum1 += filter_value * tile_ptr[0]; sum2 += filter_value * tile_ptr[tile_offset]; ++shared_offset; - filter_ptr += kBlockSlices; + if (kDirection == DIRECTION_FORWARD) { + filter_ptr += kBlockSlices; + } } shared_offset += in_increment; } @@ -535,83 +559,80 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( } } -template +template void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, const DepthwiseArgs args, const T* input, const T* filter, T* output, TensorFormat data_format) { const int block_rows = (args.in_rows + 1) / 2; + dim3 block_dim; + void (*kernel)(const DepthwiseArgs, const T*, const T*, T*); + if (data_format == FORMAT_NHWC) { + block_dim = dim3(kBlockSlices, args.in_cols, block_rows); + kernel = DepthwiseConv2dGPUKernelNHWCSmall; + } else if (data_format == FORMAT_NCHW) { + block_dim = dim3(args.in_cols, block_rows, kBlockSlices); + kernel = DepthwiseConv2dGPUKernelNCHWSmall; + } else { + assert(false && "Incorrect data format"); + return; + } const int tile_cols = args.in_cols + args.filter_cols - 1; const int tile_rows = block_rows * 2 + args.filter_rows - 1; const int tile_pixels = tile_rows * tile_cols; const int filter_pixels = args.filter_rows * args.filter_cols; - const int shared_memory_size = kBlockSlices * (tile_pixels + filter_pixels) * sizeof(T); const int num_outputs = args.batch * args.out_rows * args.out_cols * args.out_depth; - - if (data_format == FORMAT_NHWC) { - dim3 block_dim = dim3(kBlockSlices, args.in_cols, block_rows); - CudaLaunchConfig config = GetCudaLaunchConfig( - num_outputs, d, - DepthwiseConv2dGPUKernelNHWCSmall, - shared_memory_size, block_dim.x * block_dim.y * block_dim.z); - DepthwiseConv2dGPUKernelNHWCSmall - <<>>( - args, input, filter, output); - } else if (data_format == FORMAT_NCHW) { - dim3 block_dim = dim3(args.in_cols, block_rows, kBlockSlices); - CudaLaunchConfig config = GetCudaLaunchConfig( - num_outputs, d, - DepthwiseConv2dGPUKernelNCHWSmall, - shared_memory_size, block_dim.x * block_dim.y * block_dim.z); - DepthwiseConv2dGPUKernelNCHWSmall - <<>>( - args, input, filter, output); - } else { - assert(false && "Incorrect data format"); - } + CudaLaunchConfig config = + GetCudaLaunchConfig(num_outputs, d, kernel, shared_memory_size, + block_dim.x * block_dim.y * block_dim.z); + kernel<<>>( + args, input, filter, output); } -template +template void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, const DepthwiseArgs args, const T* input, const T* filter, T* output, TensorFormat data_format) { if (args.in_rows & 1) { - LaunchDepthwiseConv2dGPUSmall(d, args, input, filter, - output, data_format); + LaunchDepthwiseConv2dGPUSmall( + d, args, input, filter, output, data_format); } else { - LaunchDepthwiseConv2dGPUSmall(d, args, input, filter, - output, data_format); + LaunchDepthwiseConv2dGPUSmall( + d, args, input, filter, output, data_format); } } -template +template void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, const DepthwiseArgs args, const T* input, const T* filter, T* output, TensorFormat data_format) { // Maximize (power of two) kBlockSlices while keeping a block within 1024 // threads (2 pixels per thread). - const int in_pixels = args.in_rows * args.in_cols; - if (in_pixels > 512) { - LaunchDepthwiseConv2dGPUSmall( - d, args, input, filter, output, data_format); - } else if (in_pixels > 256) { - LaunchDepthwiseConv2dGPUSmall( - d, args, input, filter, output, data_format); + const int block_pixels = (args.in_rows + 1) / 2 * args.in_cols; + if (block_pixels > 256) { + LaunchDepthwiseConv2dGPUSmall(d, args, input, filter, + output, data_format); + } else if (block_pixels > 128) { + LaunchDepthwiseConv2dGPUSmall(d, args, input, filter, + output, data_format); } else { - LaunchDepthwiseConv2dGPUSmall( - d, args, input, filter, output, data_format); + LaunchDepthwiseConv2dGPUSmall(d, args, input, filter, + output, data_format); } } @@ -620,38 +641,30 @@ template ; + } else if (data_format == FORMAT_NCHW) { + kernel = + DepthwiseConv2dGPUKernelNCHW; + } else { + assert(false && "Incorrect data format"); + return; + } const int num_outputs = args.batch * args.out_rows * args.out_cols * args.out_depth; + CudaLaunchConfig config = GetCudaLaunchConfig(num_outputs, d, kernel, 0, 0); // The compile-time constant version runs faster with a single block. const int max_block_count = kKnownFilterWidth < 0 || kKnownFilterHeight < 0 || kKnownDepthMultiplier < 0 ? std::numeric_limits::max() : d.getNumCudaMultiProcessors(); - if (data_format == FORMAT_NHWC) { - CudaLaunchConfig config = GetCudaLaunchConfig( - num_outputs, d, - DepthwiseConv2dGPUKernelNHWC, - 0, 0); - DepthwiseConv2dGPUKernelNHWC - <<>>(args, input, filter, output, num_outputs); - } else if (data_format == FORMAT_NCHW) { - CudaLaunchConfig config = GetCudaLaunchConfig( - num_outputs, d, - DepthwiseConv2dGPUKernelNCHW, - 0, 0); - DepthwiseConv2dGPUKernelNCHW - <<>>(args, input, filter, - output, num_outputs); - } else { - assert(false && "Incorrect data format"); - } } template @@ -660,8 +673,9 @@ void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs args, TensorFormat data_format) { if (args.depth_multiplier == 1) { if (CanLaunchDepthwiseConv2dGPUSmall(args)) { - LaunchDepthwiseConv2dGPUSmall( - d, args, input, filter, output, data_format); + LaunchDepthwiseConv2dGPUSmall(d, args, input, filter, + output, data_format); return; } @@ -756,145 +770,6 @@ __global__ void __launch_bounds__(640, 2) } } -// CUDA kernel to compute the depthwise convolution backward w.r.t. input in -// NHWC format, tailored for small images up to 32x32. Stride and depth -// multiplier must be 1. Padding must be 'SAME', which allows to reuse the index -// computation. Only use this kernel if CanLaunchDepthwiseConv2dGPUSmall(args) -// returns true. -// Implementation is the same as the forward pass, except that the filter is -// rotate by 180°, see filter_read_offset and filter_ptr. -// Tiles of the input and filter tensors are loaded into shared memory before -// performing the convolution. Each thread handles two elements per iteration, -// one each in the lower and upper half of a tile. -template -__global__ -__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropInputGPUKernelNHWCSmall( - const DepthwiseArgs args, const T* input, const T* filter, T* output) { - assert(CanLaunchDepthwiseConv2dGPUSmall(args)); - // Holds block plus halo and filter data for blockDim.x depths. - extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[]; - T* const shared_data = reinterpret_cast(shared_memory); - - const int batches = args.batch; - const int in_rows = args.in_rows; - const int in_cols = args.in_cols; - const int in_depth = args.in_depth; - const int filter_rows = - kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; - const int filter_cols = - kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; - const int pad_rows = args.pad_rows; - const int pad_cols = args.pad_cols; - - // Fixed blockDim.x, corresponding to Pascal's global load granularity of 32B. - const int block_rows = blockDim.z; - - // These values are the same for all threads and could - // be precomputed on the CPU. - const int block_size = block_rows * in_cols * kBlockSlices; - const int in_row_size = in_cols * in_depth; - const int in_size = in_rows * in_row_size; - const int in_increment = (in_cols - 1) * kBlockSlices; - const int filter_pixels = filter_rows * filter_cols; - const int tile_cols = in_cols + filter_cols - 1; - const int even_rows = kKnownEvenRows || (1 & ~in_rows); - const int tile_rows = in_rows + filter_rows - even_rows; - const int tile_row_size = tile_cols * kBlockSlices; - const int tile_size = tile_rows * tile_row_size; - const int tile_offset = block_rows * tile_row_size; - const int pad_offset = pad_rows * tile_cols + pad_cols; - const int batch_blocks = (in_depth + kBlockSlices - 1) / kBlockSlices; - const int in_blocks = batch_blocks * batches; - const int tensor_offset = - kKnownEvenRows ? in_size / 2 : block_rows * in_row_size; - - const int thread_depth = threadIdx.x; - const int thread_col = threadIdx.y; - const int thread_row = threadIdx.z; - - // Position in block. - const int thread_pix = thread_row * in_cols + thread_col; - const int thread_idx = thread_pix * kBlockSlices + thread_depth; - - // Initialize tile, in particular the padding. - for (int i = thread_idx; i < tile_size; i += block_size) { - shared_data[i] = T(0); - } - __syncthreads(); - - // Position in tensors. - const int tensor_idx = thread_pix * in_depth + thread_depth; - - // Position in (padded) shared memory. - const int data_pix = thread_row * tile_cols + thread_col; - const int data_idx = data_pix * kBlockSlices + thread_depth; - - // Position in shared memory, offset by pad_rows / pad_cols. - const int tile_pix = data_pix + pad_offset; - const int tile_idx = tile_pix * kBlockSlices + thread_depth; - - const int max_depth = in_depth - thread_depth; - const int filter_write_offset = - thread_pix < filter_pixels ? tile_size + thread_idx : 0; - const int filter_read_offset = - tile_size + filter_pixels * kBlockSlices + thread_depth; - const bool skip_second = - !kKnownEvenRows && thread_row + (in_rows & 1) == block_rows; - - for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) { - const int batch = b / batch_blocks; - const int stack = b - batch * batch_blocks; - - const int start_depth = stack * kBlockSlices; - const int filter_offset = tensor_idx + start_depth; - const int inout_offset = batch * in_size + filter_offset; - const bool depth_in_range = start_depth < max_depth; - - if (depth_in_range) { - const T* const in_ptr = inout_offset + input; - T* const tile_ptr = tile_idx + shared_data; - tile_ptr[0] = ldg(in_ptr); - if (!skip_second) { - tile_ptr[tile_offset] = ldg(tensor_offset + in_ptr); - } - - if (filter_write_offset != 0) { - shared_data[filter_write_offset] = ldg(filter_offset + filter); - } - } - - // Note: the condition to reach this is uniform across the entire block. - __syncthreads(); - - if (depth_in_range) { - T sum1 = 0; - T sum2 = 0; - int shared_offset = data_idx; - const T* filter_ptr = filter_read_offset + shared_data; - UNROLL for (int r = 0; r < filter_rows; ++r) { - UNROLL for (int c = 0; c < filter_cols; ++c) { - filter_ptr -= kBlockSlices; - const T filter_value = *filter_ptr; - const T* const tile_ptr = shared_offset + shared_data; - sum1 += filter_value * tile_ptr[0]; - sum2 += filter_value * tile_ptr[tile_offset]; - shared_offset += kBlockSlices; - } - shared_offset += in_increment; - } - T* const out_ptr = inout_offset + output; - out_ptr[0] = sum1; - if (!skip_second) { - out_ptr[tensor_offset] = sum2; - } - } - - // Note: the condition to reach this is uniform across the entire block. - __syncthreads(); - } -} - template __global__ void __launch_bounds__(640, 2) @@ -966,234 +841,6 @@ __global__ void __launch_bounds__(640, 2) } } -// CUDA kernel to compute the depthwise convolution backward w.r.t. input in -// NHWC format, tailored for small images up to 32x32. Stride and depth -// multiplier must be 1. Padding must be 'SAME', which allows to reuse the index -// computation. Only use this kernel if CanLaunchDepthwiseConv2dGPUSmall(args) -// returns true. -// Implementation is the same as the forward pass, except that the filter is -// rotate by 180°, see filter_read_offset and filter_ptr. -// Tiles of the input and filter tensors are loaded into shared memory before -// performing the convolution. Each thread handles two elements per iteration, -// one each in the lower and upper half of a tile. -template -__global__ -__launch_bounds__(1024, 2) void DepthwiseConv2dBackpropInputGPUKernelNCHWSmall( - const DepthwiseArgs args, const T* input, const T* filter, T* output) { - assert(CanLaunchDepthwiseConv2dGPUSmall(args)); - // Holds block plus halo and filter data for blockDim.z depths. - extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[]; - T* const shared_data = reinterpret_cast(shared_memory); - - const int batches = args.batch; - const int in_rows = args.in_rows; - const int in_cols = args.in_cols; - const int in_depth = args.in_depth; - const int filter_rows = - kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; - const int filter_cols = - kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; - const int pad_rows = args.pad_rows; - const int pad_cols = args.pad_cols; - - // Fixed blockDim.z, tailored for maximum grid size for images of size 16x16. - const int block_rows = blockDim.y; - - // These values are the same for all threads and could - // be precomputed on the CPU. - const int block_pixels = in_cols * block_rows; - const int block_size = block_pixels * kBlockSlices; - const int in_pixels = in_cols * in_rows; - const int in_increment = in_cols - 1; - const int filter_pixels = filter_rows * filter_cols; - const int tile_cols = in_cols + filter_cols - 1; - const int even_rows = kKnownEvenRows || (1 & ~in_rows); - const int tile_rows = in_rows + filter_rows - even_rows; - const int tile_pixels = tile_cols * tile_rows; - const int tile_size = tile_pixels * kBlockSlices; - const int tile_offset = block_rows * tile_cols; - const int pad_offset = pad_rows * tile_cols + pad_cols; - const int in_slices = in_depth * batches; - const int in_blocks = (in_slices + kBlockSlices - 1) / kBlockSlices; - - const int thread_col = threadIdx.x; - const int thread_row = threadIdx.y; - const int thread_depth = threadIdx.z; - - // Position in block. - const int thread_pix = thread_row * in_cols + thread_col; - const int thread_idx = thread_depth * block_pixels + thread_pix; - - // Initialize tile, in particular the padding. - for (int i = thread_idx; i < tile_size; i += block_size) { - shared_data[i] = T(0); - } - __syncthreads(); - - // Position in tensors. - const int tensor_idx = thread_depth * in_pixels + thread_pix; - - // Position in (padded) shared memory. - const int data_pix = thread_row * tile_cols + thread_col; - const int data_idx = thread_depth * tile_pixels + data_pix; - - // Position in shared memory, offset by pad_rows / pad_cols. - const int tile_idx = data_idx + pad_offset; - - // Filter is always in HWCK format, irrespective of the input/output format. - const int filter_pix = thread_idx / kBlockSlices; - const int filter_depth = thread_idx % kBlockSlices; - const int filter_idx = filter_pix * in_depth; - - const int max_slice = in_slices - thread_depth; - const int filter_write_offset = - filter_pix < filter_pixels ? tile_size + thread_idx : 0; - const int filter_read_offset = - tile_size + filter_pixels * kBlockSlices + thread_depth; - const bool skip_second = - !kKnownEvenRows && thread_row + (in_rows & 1) == block_rows; - - for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) { - const int slice = b * kBlockSlices; - - const int inout_offset = slice * in_pixels + tensor_idx; - const bool slice_in_range = slice < max_slice; - - if (slice_in_range) { - const T* const in_ptr = inout_offset + input; - T* const tile_ptr = tile_idx + shared_data; - tile_ptr[0] = ldg(in_ptr); - if (!skip_second) { - tile_ptr[tile_offset] = ldg(block_pixels + in_ptr); - } - } - - if (filter_write_offset != 0) { - const int filter_offset = filter_idx + (slice + filter_depth) % in_depth; - shared_data[filter_write_offset] = ldg(filter_offset + filter); - } - - // Note: the condition to reach this is uniform across the entire block. - __syncthreads(); - - if (slice_in_range) { - T sum1 = 0; - T sum2 = 0; - int shared_offset = data_idx; - const T* filter_ptr = filter_read_offset + shared_data; - UNROLL for (int r = 0; r < filter_rows; ++r) { - UNROLL for (int c = 0; c < filter_cols; ++c) { - filter_ptr -= kBlockSlices; - const T filter_value = *filter_ptr; - const T* const tile_ptr = shared_offset + shared_data; - sum1 += filter_value * tile_ptr[0]; - sum2 += filter_value * tile_ptr[tile_offset]; - ++shared_offset; - } - shared_offset += in_increment; - } - T* const out_ptr = inout_offset + output; - out_ptr[0] = sum1; - if (!skip_second) { - out_ptr[block_pixels] = sum2; - } - } - - // Note: the condition to reach this is uniform across the entire block. - __syncthreads(); - } -} - -template -void LaunchDepthwiseConv2dBackpropInputGPUSmall(const GpuDevice& d, - const DepthwiseArgs args, - const T* out_backprop, - const T* filter, T* in_backprop, - TensorFormat data_format) { - const int block_rows = (args.in_rows + 1) / 2; - const int tile_cols = args.in_cols + args.filter_cols - 1; - const int tile_rows = block_rows * 2 + args.filter_rows - 1; - const int tile_pixels = tile_rows * tile_cols; - const int filter_pixels = args.filter_rows * args.filter_cols; - - const int shared_memory_size = - kBlockSlices * (tile_pixels + filter_pixels) * sizeof(T); - const int num_outputs = - args.batch * args.out_rows * args.out_cols * args.out_depth; - - if (data_format == FORMAT_NHWC) { - dim3 block_dim = dim3(kBlockSlices, args.in_cols, block_rows); - CudaLaunchConfig config = GetCudaLaunchConfig( - num_outputs, d, - DepthwiseConv2dBackpropInputGPUKernelNHWCSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, - kKnownEvenRows>, - shared_memory_size, block_dim.x * block_dim.y * block_dim.z); - DepthwiseConv2dBackpropInputGPUKernelNHWCSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, kKnownEvenRows> - <<>>( - args, out_backprop, filter, in_backprop); - } else if (data_format == FORMAT_NCHW) { - dim3 block_dim = dim3(args.in_cols, block_rows, kBlockSlices); - CudaLaunchConfig config = GetCudaLaunchConfig( - num_outputs, d, - DepthwiseConv2dBackpropInputGPUKernelNCHWSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, - kKnownEvenRows>, - shared_memory_size, block_dim.x * block_dim.y * block_dim.z); - DepthwiseConv2dBackpropInputGPUKernelNCHWSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, kKnownEvenRows> - <<>>( - args, out_backprop, filter, in_backprop); - } else { - assert(false && "Incorrect data format"); - } -} - -template -void LaunchDepthwiseConv2dBackpropInputGPUSmall(const GpuDevice& d, - const DepthwiseArgs args, - const T* out_backprop, - const T* filter, T* in_backprop, - TensorFormat data_format) { - if (args.in_rows & 1) { - LaunchDepthwiseConv2dBackpropInputGPUSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, false>( - d, args, out_backprop, filter, in_backprop, data_format); - } else { - LaunchDepthwiseConv2dBackpropInputGPUSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, true>( - d, args, out_backprop, filter, in_backprop, data_format); - } -} - -template -void LaunchDepthwiseConv2dBackpropInputGPUSmall(const GpuDevice& d, - const DepthwiseArgs args, - const T* input, const T* filter, - T* output, - TensorFormat data_format) { - // Maximize (power of two) kBlockSlices while keeping a block within 1024 - // threads (2 pixels per thread). - const int in_pixels = args.in_rows * args.in_cols; - if (in_pixels > 512) { - LaunchDepthwiseConv2dBackpropInputGPUSmall( - d, args, input, filter, output, data_format); - } else if (in_pixels > 256) { - LaunchDepthwiseConv2dBackpropInputGPUSmall( - d, args, input, filter, output, data_format); - } else { - LaunchDepthwiseConv2dBackpropInputGPUSmall( - d, args, input, filter, output, data_format); - } -} - template void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d, @@ -1201,31 +848,23 @@ void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d, const T* out_backprop, const T* filter, T* in_backprop, TensorFormat data_format) { - const int num_in_backprop = - args.batch * args.in_rows * args.in_cols * args.in_depth; + void (*kernel)(const DepthwiseArgs, const T*, const T*, T*, int); if (data_format == FORMAT_NHWC) { - CudaLaunchConfig config = GetCudaLaunchConfig( - num_in_backprop, d, - DepthwiseConv2dBackpropInputGPUKernelNHWC< - T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>, - 0, 0); - DepthwiseConv2dBackpropInputGPUKernelNHWC< - T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier> - <<>>( - args, out_backprop, filter, in_backprop, num_in_backprop); + kernel = DepthwiseConv2dBackpropInputGPUKernelNHWC< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; } else if (data_format == FORMAT_NCHW) { - CudaLaunchConfig config = GetCudaLaunchConfig( - num_in_backprop, d, - DepthwiseConv2dBackpropInputGPUKernelNCHW< - T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>, - 0, 0); - DepthwiseConv2dBackpropInputGPUKernelNCHW< - T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier> - <<>>( - args, out_backprop, filter, in_backprop, num_in_backprop); + kernel = DepthwiseConv2dBackpropInputGPUKernelNCHW< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; } else { assert(false && "Incorrect data format"); + return; } + const int num_in_backprop = + args.batch * args.in_rows * args.in_cols * args.in_depth; + CudaLaunchConfig config = + GetCudaLaunchConfig(num_in_backprop, d, kernel, 0, 0); + kernel<<>>( + args, out_backprop, filter, in_backprop, num_in_backprop); } template @@ -1236,8 +875,8 @@ void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d, TensorFormat data_format) { if (args.depth_multiplier == 1) { if (CanLaunchDepthwiseConv2dGPUSmall(args)) { - LaunchDepthwiseConv2dBackpropInputGPUSmall( + LaunchDepthwiseConv2dGPUSmall( d, args, out_backprop, filter, in_backprop, data_format); return; } @@ -1783,17 +1422,9 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( template bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( - const GpuDevice& d, const DepthwiseArgs args, const T* out_backprop, - const T* input, T* filter_backprop, TensorFormat data_format) { - int block_rows = (args.in_rows + 1) / 2; - // args.in_cols * block_rows * kBlockSlices must be multiple of 32. - for (int round_mask = 1; args.in_cols * block_rows * kBlockSlices & 31; - round_mask = round_mask * 2 + 1) { - block_rows = block_rows + round_mask & ~round_mask; - } - if (!CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, block_rows)) { - return false; - } + const GpuDevice& d, const DepthwiseArgs args, const int block_rows, + const T* out_backprop, const T* input, T* filter_backprop, + TensorFormat data_format) { const int tile_cols = args.in_cols + args.filter_cols - 1; const int tile_rows = block_rows * 2 + args.filter_rows - 1; const int tile_pixels = tile_rows * tile_cols; @@ -1804,58 +1435,51 @@ bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( return false; } - const int num_out_backprop = - args.batch * args.out_rows * args.out_cols * args.out_depth; + dim3 block_dim; + void (*kernel)(const DepthwiseArgs, const T*, const T*, T*); if (data_format == FORMAT_NHWC) { - dim3 block_dim = dim3(kBlockSlices, args.in_cols, block_rows); - CudaLaunchConfig config = GetCudaLaunchConfig( - num_out_backprop, d, - DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, - kAccumPixels>, - shared_memory_size, block_dim.x * block_dim.y * block_dim.z); - DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, kAccumPixels> - <<>>( - args, out_backprop, input, filter_backprop); + block_dim = dim3(kBlockSlices, args.in_cols, block_rows); + kernel = DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall< + T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, kAccumPixels>; } else if (data_format == FORMAT_NCHW) { - dim3 block_dim = dim3(args.in_cols, block_rows, kBlockSlices); - CudaLaunchConfig config = GetCudaLaunchConfig( - num_out_backprop, d, - DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, - kAccumPixels>, - shared_memory_size, block_dim.x * block_dim.y * block_dim.z); - DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, kAccumPixels> - <<>>( - args, out_backprop, input, filter_backprop); + block_dim = dim3(args.in_cols, block_rows, kBlockSlices); + kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall< + T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, kAccumPixels>; } else { assert(false && "Incorrect data format"); + return false; } + const int num_out_backprop = + args.batch * args.out_rows * args.out_cols * args.out_depth; + CudaLaunchConfig config = + GetCudaLaunchConfig(num_out_backprop, d, kernel, shared_memory_size, + block_dim.x * block_dim.y * block_dim.z); + kernel<<>>( + args, out_backprop, input, filter_backprop); return true; } template bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( - const GpuDevice& d, const DepthwiseArgs args, const T* out_backprop, - const T* input, T* filter_backprop, TensorFormat data_format) { + const GpuDevice& d, const DepthwiseArgs args, const int block_rows, + const T* out_backprop, const T* input, T* filter_backprop, + TensorFormat data_format) { // Minimize (power of two) kAccumPixels, while satisfying - // kAccumPixels * 64 >= in_pixels * kBlockSlices. - const int block_pixels = args.in_rows * args.in_cols * kBlockSlices; - if (block_pixels > 1024) { + // kAccumPixels * 32 >= block_rows * in_cols * kBlockSlices. + const int block_pixels = block_rows * args.in_cols * kBlockSlices; + if (block_pixels > 512) { return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, 32>( - d, args, out_backprop, input, filter_backprop, data_format); - } else if (block_pixels > 512) { + d, args, block_rows, out_backprop, input, filter_backprop, data_format); + } else if (block_pixels > 256) { return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, 16>( - d, args, out_backprop, input, filter_backprop, data_format); + d, args, block_rows, out_backprop, input, filter_backprop, data_format); } else { return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, 8>( - d, args, out_backprop, input, filter_backprop, data_format); + d, args, block_rows, out_backprop, input, filter_backprop, data_format); } } @@ -1865,19 +1489,43 @@ bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( const T* input, T* filter_backprop, TensorFormat data_format) { // Maximize (power of two) kBlockSlices while keeping a block within 1024 // threads (2 pixels per thread). - const int in_pixels = args.in_rows * args.in_cols; - if (in_pixels > 512) { - return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< - T, kKnownFilterWidth, kKnownFilterHeight, 2>( - d, args, out_backprop, input, filter_backprop, data_format); - } else if (in_pixels > 256) { - return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< - T, kKnownFilterWidth, kKnownFilterHeight, 4>( - d, args, out_backprop, input, filter_backprop, data_format); - } else { - return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< - T, kKnownFilterWidth, kKnownFilterHeight, 8>( - d, args, out_backprop, input, filter_backprop, data_format); + int block_slices = 8; + int block_rows = (args.in_rows + 1) / 2; + int round_mask = 1; + for (; block_slices > 1; block_slices /= 2) { + // args.in_cols * block_rows * kBlockSlices must be multiple of 32. + for (; block_rows * args.in_cols * block_slices & 31; + round_mask = round_mask * 2 + 1) { + block_rows = block_rows + round_mask & ~round_mask; + } + int block_size = block_rows * args.in_cols * block_slices; + if (block_size <= 1024) { + break; + } + } + + if (!CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, block_rows)) { + return false; + } + + switch (block_slices) { + case 8: + return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< + T, kKnownFilterWidth, kKnownFilterHeight, 8>( + d, args, block_rows, out_backprop, input, filter_backprop, + data_format); + case 4: + return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< + T, kKnownFilterWidth, kKnownFilterHeight, 4>( + d, args, block_rows, out_backprop, input, filter_backprop, + data_format); + case 2: + return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< + T, kKnownFilterWidth, kKnownFilterHeight, 2>( + d, args, block_rows, out_backprop, input, filter_backprop, + data_format); + default: + return false; } } @@ -1888,31 +1536,23 @@ void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { - const int num_out_backprop = - args.batch * args.out_rows * args.out_cols * args.out_depth; + void (*kernel)(const DepthwiseArgs, const T*, const T*, T*, int); if (data_format == FORMAT_NHWC) { - CudaLaunchConfig config = GetCudaLaunchConfig( - num_out_backprop, d, - DepthwiseConv2dBackpropFilterGPUKernelNHWC< - T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>, - 0, 0); - DepthwiseConv2dBackpropFilterGPUKernelNHWC< - T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier> - <<>>( - args, out_backprop, input, filter_backprop, num_out_backprop); + kernel = DepthwiseConv2dBackpropFilterGPUKernelNHWC< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; } else if (data_format == FORMAT_NCHW) { - CudaLaunchConfig config = GetCudaLaunchConfig( - num_out_backprop, d, - DepthwiseConv2dBackpropFilterGPUKernelNCHW< - T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>, - 0, 0); - DepthwiseConv2dBackpropFilterGPUKernelNCHW< - T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier> - <<>>( - args, out_backprop, input, filter_backprop, num_out_backprop); + kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHW< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; } else { assert(false && "Incorrect data format"); + return; } + const int num_out_backprop = + args.batch * args.out_rows * args.out_cols * args.out_depth; + CudaLaunchConfig config = + GetCudaLaunchConfig(num_out_backprop, d, kernel, 0, 0); + kernel<<>>( + args, out_backprop, input, filter_backprop, num_out_backprop); } template diff --git a/tensorflow/core/kernels/encode_jpeg_op.cc b/tensorflow/core/kernels/encode_jpeg_op.cc index 8e021b92563..4fcae25aa6e 100644 --- a/tensorflow/core/kernels/encode_jpeg_op.cc +++ b/tensorflow/core/kernels/encode_jpeg_op.cc @@ -55,8 +55,6 @@ class EncodeJpegOp : public OpKernel { context, context->GetAttr("optimize_size", &flags_.optimize_jpeg_size)); OP_REQUIRES_OK(context, context->GetAttr("chroma_downsampling", &flags_.chroma_downsampling)); - OP_REQUIRES_OK(context, context->GetAttr("chroma_downsampling", - &flags_.chroma_downsampling)); string density_unit; OP_REQUIRES_OK(context, context->GetAttr("density_unit", &density_unit)); diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc index 32936d65c8e..35a3f7b189c 100644 --- a/tensorflow/core/kernels/fft_ops.cc +++ b/tensorflow/core/kernels/fft_ops.cc @@ -128,7 +128,7 @@ class FFTCPU : public FFTBase { auto device = ctx->eigen_device(); if (!IsReal()) { - auto input = (Tensor(in)).flat_inner_dims(); + auto input = Tensor(in).flat_inner_dims(); // Compute the FFT using eigen. auto output = out->flat_inner_dims(); constexpr auto direction = @@ -137,7 +137,7 @@ class FFTCPU : public FFTBase { input.template fft(axes); } else { if (IsForward()) { - auto input = (Tensor(in)).flat_inner_dims(); + auto input = Tensor(in).flat_inner_dims(); const auto input_dims = input.dimensions(); // Slice input to fft_shape on its inner-most dimensions. @@ -166,7 +166,7 @@ class FFTCPU : public FFTBase { full_fft.slice(zero_start_indices, output.dimensions()); } else { // Reconstruct the full FFT and take the inverse. - auto input = ((Tensor)in).flat_inner_dims(); + auto input = Tensor(in).flat_inner_dims(); auto output = out->flat_inner_dims(); const auto input_dims = input.dimensions(); diff --git a/tensorflow/core/kernels/filter_dataset_op.cc b/tensorflow/core/kernels/filter_dataset_op.cc index 3503c45f9af..c0e909d73ac 100644 --- a/tensorflow/core/kernels/filter_dataset_op.cc +++ b/tensorflow/core/kernels/filter_dataset_op.cc @@ -124,6 +124,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { "Filter predicate `f` must return a scalar bool."); } matched = result[0].scalar()(); + if (!matched) { + // Clear the output tensor list since it didn't match. + out_tensors->clear(); + } } while (!matched); *end_of_sequence = false; return Status::OK(); diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index 8c3137ece9f..58c4ed37c4f 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -122,6 +122,12 @@ TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg") ArgOp); #undef REGISTER +REGISTER_KERNEL_BUILDER(Name("_Arg") + .Device(DEVICE_GPU) + .HostMemory("output") + .TypeConstraint("T"), + ArgOp); + #define REGISTER(type) \ REGISTER_KERNEL_BUILDER( \ Name("_Retval").Device(DEVICE_GPU).TypeConstraint("T"), RetvalOp); diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.cc b/tensorflow/core/kernels/hexagon/graph_transferer.cc index d927ef3efa0..a82ae61ad9d 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer.cc +++ b/tensorflow/core/kernels/hexagon/graph_transferer.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session.h" @@ -43,10 +44,14 @@ const char INPUTS_NODE_PREFIX[] = "inputs_for_"; const char OUTPUTS_NODE_PREFIX[] = "outputs_for_"; const char DATA_NODE_PREFIX[] = "data_for_op_"; const char CONST_SHAPE_PREFIX[] = "const_shape_"; +const char CONST_VAL_PREFIX[] = "const_val_"; +const char CONST_TENSOR_PREFIX[] = "const_tensor_"; const char PADDING_ATTR_NAME[] = "padding"; const char STRIDES_ATTR_NAME[] = "strides"; +const char KEEP_DIMS_ATTR_NAME[] = "keep_dims"; const char KSIZE_ATTR_NAME[] = "ksize"; const char NULL_OUTPUT_NAME[] = "NULL"; +const char AGGREGATED_INPUT_NODE_NAME[] = "graph_transfer_aggregated_input"; const int PADDING_NA_ID = 0; // VALID = 1, SAME = 2 // This is a temporary workaround to support android build @@ -58,6 +63,16 @@ static string ToString(T val) { return stream.str(); } +static Node* FindMutableNodeByName(const string& name, Graph* graph) { + const TensorId tid = ParseTensorName(name); + for (Node* node : graph->nodes()) { + if (node != nullptr && node->name() == tid.first) { + return node; + } + } + return nullptr; +} + /** * graph loading functions * - LoadGraphFromProto @@ -86,13 +101,22 @@ Status GraphTransferer::LoadGraphFromProto( } } + TF_RETURN_IF_ERROR(TransformGraphToAddAggregatedInputNode( + input_node_info_list, &graph, &shape_refiner)); + std::unordered_multimap op_name_to_node_multimap( graph.num_nodes()); for (const Node* const node : graph.nodes()) { + if (node == nullptr) { + continue; + } CacheNode(*node); } for (const Node* const node : graph.nodes()) { + if (node == nullptr) { + continue; + } VLOG(1) << " " << node->name(); for (const Node* const input_node : node->in_nodes()) { const string& name = input_node->name(); @@ -102,6 +126,9 @@ Status GraphTransferer::LoadGraphFromProto( } for (const Node* const node : graph.nodes()) { + if (node == nullptr) { + continue; + } status = RegisterNodeIfAllInputsAreCached( ops_definitions, shape_refiner, *node, false, input_node_info_list, output_node_names); @@ -265,19 +292,16 @@ GraphTransferInfo& GraphTransferer::GetMutableGraphTransferInfo() { return graph_transfer_info_; } -int GraphTransferer::CacheNode(const Node& node) { +void GraphTransferer::CacheNode(const Node& node) { if (node_name_to_id_cache_map_.count(node.name()) > 0) { - VLOG(1) << "Emplace node to cache failed"; - // TODO(satok): check here? - return -1; + return; } - VLOG(1) << "Cache node: " << node.name() << ", " << node.op_def().name(); node_name_cache_list_.emplace_back(&node); + const int node_id = node_name_cache_list_.size() - 1; bool emplace_succeeded = false; - std::tie(std::ignore, emplace_succeeded) = node_name_to_id_cache_map_.emplace( - node.name(), node_name_cache_list_.size() - 1); + std::tie(std::ignore, emplace_succeeded) = + node_name_to_id_cache_map_.emplace(node.name(), node_id); CHECK(emplace_succeeded); - return node_name_cache_list_.size() - 1; } bool GraphTransferer::AreAllInputsCached(const Node& node) const { @@ -291,22 +315,126 @@ bool GraphTransferer::AreAllInputsCached(const Node& node) const { return true; } +Status GraphTransferer::TransformGraphToAddAggregatedInputNode( + const std::vector>& input_node_info_list, + Graph* graph, ShapeRefiner* shape_refiner) { + // Transform a remote fused graph to add an aggregated input node which takes + // all inputs of the remote graph. + DataTypeVector input_data_types; + std::vector data_types; + std::vector shapes; + std::vector input_nodes; + for (int i = 0; i < input_node_info_list.size(); ++i) { + Node* node = FindMutableNodeByName(input_node_info_list.at(i).first, graph); + CHECK_NOTNULL(node); + input_nodes.emplace_back(node->name()); + input_data_types.emplace_back(input_node_info_list.at(i).second.dtype()); + data_types.emplace_back(input_node_info_list.at(i).second.dtype()); + shapes.emplace_back(input_node_info_list.at(i).second.shape()); + } + + NodeDef input_node_def; + auto builder = + NodeBuilder(AGGREGATED_INPUT_NODE_NAME, "RemoteFusedGraphExecute") + .Input(std::vector{}) + .Attr("Tinputs", DataTypeVector{}) + .Attr("Toutputs", input_data_types) + .Attr("serialized_remote_fused_graph_execute_info", "") + .Attr(RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES, + data_types) + .Attr(RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES, shapes); + + Node* input_node; + TF_RETURN_IF_ERROR(builder.Finalize(graph, &input_node)); + CHECK_NOTNULL(input_node); + + bool refined; + TF_RETURN_IF_ERROR( + shape_refiner->UpdateNode(input_node, false /* relax */, &refined)); + + shape_inference::InferenceContext* context = + shape_refiner->GetContext(input_node); + for (int i = 0; i < input_node_info_list.size(); ++i) { + shape_inference::ShapeHandle handle; + TF_RETURN_IF_ERROR(context->MakeShapeFromTensorShape( + input_node_info_list.at(i).second.shape(), &handle)); + TF_RETURN_IF_ERROR(shape_refiner->SetShape(input_node, i, handle)); + } + + // Cache the aggregate input node first as it's consumed first. + CacheNode(*input_node); + + std::vector original_input_nodes(input_nodes.size()); + + for (int i = 0; i < input_nodes.size(); ++i) { + const string& node_name = input_nodes.at(i); + Node* original_input_node = FindMutableNodeByName(node_name, graph); + CHECK_NOTNULL(original_input_node); + CHECK_EQ(1, original_input_node->num_outputs()); // replaced by identity. + Node* created_node; + TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::BuildIdentityOpNode( + node_name, AGGREGATED_INPUT_NODE_NAME, i, data_types.at(i), graph, + &created_node)); + CHECK_NOTNULL(created_node); + std::vector data_types; + std::vector shapes; + Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( + original_input_node->def(), &data_types, &shapes); + if (status.ok()) { + created_node->AddAttr( + RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_DATA_TYPES, data_types); + created_node->AddAttr(RemoteFusedGraphExecuteUtils::ATTR_OUTPUT_SHAPES, + shapes); + } + for (const Edge* out_edge : original_input_node->out_edges()) { + Node* dst = out_edge->dst(); + int dst_port = out_edge->dst_input(); + // Unused edge will be removed when removing node. + graph->AddEdge(created_node, 0, dst, dst_port); + } + original_input_nodes[i] = original_input_node; + + TF_RETURN_IF_ERROR( + shape_refiner->UpdateNode(created_node, false /* relax */, &refined)); + + shape_inference::InferenceContext* context = + shape_refiner->GetContext(created_node); + CHECK_NOTNULL(context); + + // Cache replaced input node next to the aggregated input node. + CacheNode(*created_node); + } + + // Remove original input nodes after adding new input nodes to avoid + // reusing same pointer in Graph. + for (Node* original_input_node : original_input_nodes) { + graph->RemoveNode(original_input_node); + } + + return Status::OK(); +} + Status GraphTransferer::RegisterNode( const IGraphTransferOpsDefinitions& ops_definitions, const ShapeRefiner& shape_refiner, const Node& node, const std::vector>& input_node_info_list, const std::vector& output_node_names) { - VLOG(1) << "Register node: " << node.name(); + VLOG(1) << "Register node: " << node.name() << ", " << std::hex + << node_name_to_id_cache_map_.at(node.name()); if (node.name() == SOURCE_NODE_NAME || node.name() == SINK_NODE_NAME) { // Just ignore sink and source - return Status(); - } else if (RemoteFusedGraphExecuteUtils::IsInputNode(input_node_info_list, - node.name())) { + return Status::OK(); + } else if (node.name() == AGGREGATED_INPUT_NODE_NAME) { RegisterInputNode(ops_definitions, shape_refiner, node); + return Status::OK(); } else if (node.IsConstant()) { RegisterConstantNode(shape_refiner, node); + } else if (IsPadNode(node)) { + RegisterPadNode(ops_definitions, shape_refiner, node); } else if (HasPaddingAndStrides(node)) { RegisterNodeWithPaddingAndStrides(ops_definitions, shape_refiner, node); + } else if (NeedsToAddRank(node)) { + RegisterNodeWithRank(ops_definitions, shape_refiner, node); } else if (IsNodeFlattenReshape(node, shape_refiner)) { RegisterFlattenNode(ops_definitions, shape_refiner, node); } else if (ops_definitions.GetOpIdFor(node.type_string(), {}) != @@ -318,7 +446,7 @@ Status GraphTransferer::RegisterNode( " has not been implemented yet."); } - return Status(); + return Status::OK(); } void GraphTransferer::RegisterConstantNode(const ShapeRefiner& shape_refiner, @@ -361,8 +489,7 @@ void GraphTransferer::RegisterConstantNode(const ShapeRefiner& shape_refiner, const TensorProto* proto = nullptr; TF_CHECK_OK(GetNodeAttr(node.attrs(), "value", &proto)); Tensor const_tensor; - // TODO(b/32704451): Don't just ignore this status! - MakeTensorFromProto(*proto, &const_tensor).IgnoreError(); + TF_CHECK_OK(MakeTensorFromProto(*proto, &const_tensor)); const_node_info.set_dtype(const_tensor.dtype()); if (data_size > 0) { @@ -394,12 +521,82 @@ int GraphTransferer::RegisterConstantShape(const std::vector& shape) { return node_name_to_id_cache_map_[shape_name]; } +int GraphTransferer::RegisterConstTensor(const Tensor& tensor, + const string& suffix) { + VLOG(1) << "Cache const tensor."; + const int dims = tensor.shape().dims(); + CHECK(dims <= 4); + const string node_name = strings::StrCat(CONST_TENSOR_PREFIX, "_", suffix); + if (node_name_to_id_cache_map_.count(node_name) <= 0) { + node_name_cache_list_.emplace_back(nullptr); + const int id = node_name_cache_list_.size() - 1; + node_name_to_id_cache_map_.emplace(node_name, id); + GraphTransferInfo::ConstNodeInfo& const_node_info = + *graph_transfer_info_.add_const_node_info(); + const_node_info.set_name(node_name); + const_node_info.set_node_id(id); + CHECK_EQ(4, SHAPE_ARRAY_SIZE); + for (int i = 0; i < SHAPE_ARRAY_SIZE; ++i) { + if (i < SHAPE_ARRAY_SIZE - dims) { + const_node_info.add_shape(1); + } else { + const_node_info.add_shape( + tensor.shape().dim_size(i - (SHAPE_ARRAY_SIZE - dims))); + } + } + const_node_info.set_dtype(tensor.dtype()); + const_node_info.set_data(tensor.tensor_data().data(), + tensor.tensor_data().size()); + } + return node_name_to_id_cache_map_[node_name]; +} + +int GraphTransferer::RegisterConstScalar(const DataType dt, const int val, + const int dst_id, + const int dst_input_count) { + VLOG(1) << "Cache const."; + const string val_name = + CONST_VAL_PREFIX + ToString(dst_id) + '_' + ToString(dst_input_count); + if (node_name_to_id_cache_map_.count(val_name) <= 0) { + node_name_cache_list_.emplace_back(nullptr); + const int id = node_name_cache_list_.size() - 1; + node_name_to_id_cache_map_.emplace(val_name, id); + GraphTransferInfo::ConstNodeInfo& const_node_info = + *graph_transfer_info_.add_const_node_info(); + const_node_info.set_name(val_name); + const_node_info.set_node_id(id); + // TODO(satok): Do not assume rank is 4 here. + const_node_info.add_shape(static_cast(1)); + const_node_info.add_shape(static_cast(1)); + const_node_info.add_shape(static_cast(1)); + const_node_info.add_shape(static_cast(1)); + const_node_info.set_data(&val, DataTypeSize(dt)); + } + return node_name_to_id_cache_map_[val_name]; +} + bool GraphTransferer::HasPaddingAndStrides(const Node& node) { auto attrs = node.attrs(); return attrs.Find(PADDING_ATTR_NAME) != nullptr && attrs.Find(STRIDES_ATTR_NAME) != nullptr; } +bool GraphTransferer::NeedsToAddRank(const Node& node) { + const string& op_type = node.def().op(); + if (op_type == "Transpose" || op_type == "ExpandDims") { + return true; + } + return false; +} + +bool GraphTransferer::IsPadNode(const Node& node) { + const string& op_type = node.def().op(); + if (op_type == "Pad") { + return true; + } + return false; +} + bool GraphTransferer::IsNodeFlattenReshape(const Node& node, const ShapeRefiner& shape_refiner) { // Check if node is reshape op @@ -473,15 +670,123 @@ void GraphTransferer::RegisterNodeWithPaddingAndStrides( node.num_outputs(), true /* append_input */, true /* append_output */); } +void GraphTransferer::RegisterNodeWithRank( + const IGraphTransferOpsDefinitions& ops_definitions, + const ShapeRefiner& shape_refiner, const Node& node) { + CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1); + const int id = node_name_to_id_cache_map_[node.name()]; + shape_inference::InferenceContext* context = shape_refiner.GetContext(&node); + const Node* input0_node; + TF_CHECK_OK(node.input_node(0, &input0_node)); + CHECK_NOTNULL(input0_node); + std::vector shapes; + Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( + input0_node->def(), nullptr, &shapes); + CHECK_EQ(1, shapes.size()) << "Output size should be 1."; + const int const_val_id = + RegisterConstScalar(DT_INT32, shapes.at(0).dims(), id, node.num_inputs()); + std::vector extra_inputs{const_val_id}; + // TODO(satok): Set correct data type if it's given. + const int op_type_id = ops_definitions.GetOpIdFor(node.type_string(), {}); + CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount()) + << "Op " << node.type_string() << " not found in map(id = " << op_type_id + << ")"; + bool keep_dims = false; + int padding_id = PADDING_NA_ID; + if (context->GetAttr(KEEP_DIMS_ATTR_NAME, &keep_dims).ok()) { + padding_id = keep_dims ? Padding::SAME : Padding::VALID; + } + + AppendNodeParamsWithIoParams( + shape_refiner, node, node.name(), id, node.type_string(), op_type_id, + padding_id, node.num_inputs(), extra_inputs, node.num_outputs(), + true /* append_input */, true /* append_output */); +} + +void GraphTransferer::RegisterPadNode( + const IGraphTransferOpsDefinitions& ops_definitions, + const ShapeRefiner& shape_refiner, const Node& node) { + static constexpr int PAD_WIDTH = 4; + static constexpr int PAD_HEIGHT = 2; + VLOG(1) << "Register generic node: " << node.name(); + CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1); + const int id = node_name_to_id_cache_map_[node.name()]; + + // TODO(satok): Set correct data type if it's given. + const int op_type_id = ops_definitions.GetOpIdFor(node.type_string(), {}); + CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount()); + + CHECK_EQ(2, node.num_inputs()); + + GraphTransferInfo::NodeInputInfo& node_input_info = + *graph_transfer_info_.add_node_input_info(); + node_input_info.set_node_id(id); + + AddNodeInputByInputIndex(node, 0, &node_input_info); + + const Edge* edge = nullptr; + TF_CHECK_OK(node.input_edge(1, &edge)); + const Node* input_node = edge->src(); + CHECK_NOTNULL(input_node); + CHECK(input_node->IsConstant()); + + const TensorProto* tensor_proto = nullptr; + TF_CHECK_OK(GetNodeAttr(input_node->def(), "value", &tensor_proto)); + CHECK_NOTNULL(tensor_proto); + Tensor const_tensor; + TF_CHECK_OK(MakeTensorFromProto(*tensor_proto, &const_tensor)); + CHECK_EQ(2, const_tensor.shape().dims()); + CHECK_EQ(PAD_HEIGHT, const_tensor.shape().dim_size(1)); + if (const_tensor.shape().dim_size(0) == PAD_WIDTH) { + AddNodeInputByInputIndex(node, 1, &node_input_info); + } else if (const_tensor.shape().dim_size(0) < PAD_WIDTH) { + const int width = const_tensor.shape().dim_size(0); + const TensorProto* proto = nullptr; + TF_CHECK_OK(GetNodeAttr(input_node->def(), "value", &proto)); + Tensor const_tensor; + TF_CHECK_OK(MakeTensorFromProto(*proto, &const_tensor)); + CHECK_EQ(DT_INT32, const_tensor.dtype()); + // reshape tensor input to be rank 4. + // TODO(satok): Never assume rank is 4. + Tensor new_const_tensor(const_tensor.dtype(), TensorShape{4, 2}); + for (int i = 0; i < PAD_HEIGHT; ++i) { + for (int j = 0; j < PAD_WIDTH; ++j) { + if (j < PAD_WIDTH - width) { + new_const_tensor.matrix()(j, i) = 0; + } else { + new_const_tensor.matrix()(j, i) = + const_tensor.matrix()(j - (PAD_WIDTH - width), i); + } + } + } + + const int id = RegisterConstTensor( + new_const_tensor, + strings::StrCat(input_node->name(), "_", node.name(), "_1")); + + GraphTransferInfo::NodeInput& node_input = + *node_input_info.add_node_input(); + node_input.set_node_id(id); + node_input.set_output_port(0); + } else { + CHECK(false); + } + + AppendNodeParamsWithIoParams( + shape_refiner, node, node.name(), id, node.type_string(), op_type_id, + PADDING_NA_ID, node.num_inputs(), {}, node.num_outputs(), + false /* append_input */, true /* append_output */); +} + void GraphTransferer::RegisterInputNode( const IGraphTransferOpsDefinitions& ops_definitions, const ShapeRefiner& shape_refiner, const Node& node) { - VLOG(1) << "Register input node: " << node.name(); + const string op_type = node.type_string(); + VLOG(1) << "Register input node: " << node.name() << ", " << op_type; CHECK_EQ(node_name_to_id_cache_map_.count(node.name()), 1); const int id = node_name_to_id_cache_map_[node.name()]; - const string op_type = node.type_string(); // TODO(satok): Set correct data type if it's given. - const int op_type_id = ops_definitions.GetOpIdFor(op_type, {}); + const int op_type_id = ops_definitions.GetOpIdFor("INPUT", {}); CHECK(op_type_id >= 0 && op_type_id < ops_definitions.GetTotalOpsCount()) << "Op" << node.name() << ", " << op_type << " is not supported," << op_type_id; @@ -546,7 +851,6 @@ void GraphTransferer::AppendNodeParams(const string& name, const int id, const int padding, const int inputs_size, const std::vector& extra_inputs, const int outputs_size) { - VLOG(1) << "Append node params: " << name; GraphTransferInfo::NodeInfo& node_info = *graph_transfer_info_.add_node_info(); node_info.set_name(name); @@ -559,6 +863,23 @@ void GraphTransferer::AppendNodeParams(const string& name, const int id, node_info.set_output_count(static_cast(outputs_size)); } +void GraphTransferer::AddNodeInputByInputIndex( + const Node& node, const int idx, + GraphTransferInfo::NodeInputInfo* node_input_info) { + const Edge* edge = nullptr; + TF_CHECK_OK(node.input_edge(idx, &edge)); + const Node* input_node = edge->src(); + CHECK_NOTNULL(input_node); + const int port = edge->src_output(); + + const std::string& op_name = input_node->name(); + CHECK_GT(node_name_to_id_cache_map_.count(op_name), 0) << op_name; + const int src_id = node_name_to_id_cache_map_[op_name]; + GraphTransferInfo::NodeInput& node_input = *node_input_info->add_node_input(); + node_input.set_node_id(src_id); + node_input.set_output_port(port); +} + void GraphTransferer::AppendNodeInputParams( const int id, const Node& node, const std::vector& extra_inputs) { VLOG(1) << "Append input params: " << node.name() << ", " << node.num_inputs() @@ -567,18 +888,7 @@ void GraphTransferer::AppendNodeInputParams( *graph_transfer_info_.add_node_input_info(); node_input_info.set_node_id(id); for (int i = 0; i < node.num_inputs(); ++i) { - const Edge* edge = nullptr; - TF_CHECK_OK(node.input_edge(i, &edge)); - const Node* input_node = edge->src(); - const int port = edge->src_output(); - - const std::string& op_name = input_node->name(); - CHECK_GT(node_name_to_id_cache_map_.count(op_name), 0) << op_name; - const int src_id = node_name_to_id_cache_map_[op_name]; - GraphTransferInfo::NodeInput& node_input = - *node_input_info.add_node_input(); - node_input.set_node_id(src_id); - node_input.set_output_port(port); + AddNodeInputByInputIndex(node, i, &node_input_info); } for (const int extra_input : extra_inputs) { GraphTransferInfo::NodeInput& node_input = @@ -596,9 +906,10 @@ void GraphTransferer::AppendNodeOutputParams(const ShapeRefiner& shape_refiner, *graph_transfer_info_.add_node_output_info(); node_output_info.set_node_id(id); + std::vector data_types; std::vector shapes; Status status = RemoteFusedGraphExecuteUtils::GetOutputTensorShapeType( - node.attrs(), nullptr, &shapes); + node.attrs(), &data_types, &shapes); for (int i = 0; i < node.num_outputs(); ++i) { int data_size = -1; @@ -608,16 +919,20 @@ void GraphTransferer::AppendNodeOutputParams(const ShapeRefiner& shape_refiner, shape_inference::InferenceContext* context = shape_refiner.GetContext(&node); - shape_inference::ShapeHandle shape_handle = context->output(output_index); - const shape_inference::DimensionHandle num_elements_dim = - context->NumElements(shape_handle); - if (context->ValueKnown(num_elements_dim)) { + + if (context != nullptr && context->ValueKnown(context->NumElements( + context->output(output_index)))) { + const shape_inference::DimensionHandle num_elements_dim = + context->NumElements(context->output(output_index)); const int64 num_output_elements = context->Value(num_elements_dim); data_size = max_bytes_per_data * num_output_elements; + if (status.ok()) { + TF_CHECK_OK(status); + CHECK_EQ(shapes.at(i).num_elements(), num_output_elements); + } } else { TF_CHECK_OK(status); // Use attribute attached to node - CHECK_EQ(node.num_outputs(), shapes.size()) << node.name(); data_size = max_bytes_per_data * shapes.at(i).num_elements(); } CHECK_GE(data_size, 0); @@ -722,11 +1037,11 @@ bool GraphTransferer::TransferParamsComparator::operator()( const int node_id0 = obj0.node_id(); const int node_id1 = obj1.node_id(); bool obj0_uses_obj1 = false; - if (dependency_map_.count(node_id0)) { + if (dependency_map_.count(node_id0) > 0) { obj0_uses_obj1 = dependency_map_.at(node_id0).count(node_id1) > 0; } bool obj1_uses_obj0 = false; - if (dependency_map_.count(node_id1)) { + if (dependency_map_.count(node_id1) > 0) { obj1_uses_obj0 = dependency_map_.at(node_id1).count(node_id0) > 0; } CHECK(!obj0_uses_obj1 || !obj1_uses_obj0); @@ -735,7 +1050,9 @@ bool GraphTransferer::TransferParamsComparator::operator()( } else if (obj1_uses_obj0) { return true; } - return node_id0 > node_id1; + // If there is no dependency between two nodes, it expects that + // the execution order follows node id order. + return node_id0 < node_id1; } /* static */ void GraphTransferer::FillDependencyRec( diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.h b/tensorflow/core/kernels/hexagon/graph_transferer.h index fa12b22d75d..64c60b87c66 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer.h +++ b/tensorflow/core/kernels/hexagon/graph_transferer.h @@ -88,6 +88,9 @@ class GraphTransferer { // Dump verification string of parameters to verify with offline tools void DumpVerificationStringOfNodeTransferParams() const; + static std::array ToTensorShapeArray( + const TensorShape& shape); + private: class TransferParamsComparator { public: @@ -98,10 +101,16 @@ class GraphTransferer { const std::unordered_map>& dependency_map_; }; - int CacheNode(const Node& node); + void CacheNode(const Node& node); bool AreAllInputsCached(const Node& node) const; + // Transform a remote fused graph to add an aggregated input node which takes + // all inputs of the remote graph. + Status TransformGraphToAddAggregatedInputNode( + const std::vector>& input_node_info_list, + Graph* graph, ShapeRefiner* shape_refiner); + Status RegisterNode( const IGraphTransferOpsDefinitions& ops_definitions, const ShapeRefiner& shape_refiner, const Node& node, @@ -113,8 +122,17 @@ class GraphTransferer { int RegisterConstantShape(const std::vector& shape); + int RegisterConstTensor(const Tensor& tensor, const string& suffix); + + int RegisterConstScalar(const DataType dt, const int val, const int dst_id, + const int dst_input_count); + bool HasPaddingAndStrides(const Node& node); + bool NeedsToAddRank(const Node& node); + + bool IsPadNode(const Node& node); + // Return true if the node is a reshape op which just flattens input // TODO(satok): Remove this method once generic reshape op is implemented in // SOC @@ -125,6 +143,13 @@ class GraphTransferer { const IGraphTransferOpsDefinitions& ops_definitions, const ShapeRefiner& shape_refiner, const Node& node); + void RegisterNodeWithRank(const IGraphTransferOpsDefinitions& ops_definitions, + const ShapeRefiner& shape_refiner, + const Node& node); + + void RegisterPadNode(const IGraphTransferOpsDefinitions& ops_definitions, + const ShapeRefiner& shape_refiner, const Node& node); + void RegisterInputNode(const IGraphTransferOpsDefinitions& ops_definitions, const ShapeRefiner& shape_refiner, const Node& node); @@ -150,6 +175,10 @@ class GraphTransferer { const std::vector& extra_inputs, const int outputs_size); + void AddNodeInputByInputIndex( + const Node& node, const int idx, + GraphTransferInfo::NodeInputInfo* node_input_info); + void AppendNodeInputParams(const int id, const Node& node, const std::vector& extra_inputs); @@ -167,9 +196,6 @@ class GraphTransferer { const int outputs_size, const bool append_input_params, const bool append_output_params); - static std::array ToTensorShapeArray( - const TensorShape& shape); - static string ToPaddingDebugString(int padding); // Create dependency map diff --git a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc index ebd4a903301..74ffc026f74 100644 --- a/tensorflow/core/kernels/hexagon/graph_transferer_test.cc +++ b/tensorflow/core/kernels/hexagon/graph_transferer_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/kernels/hexagon/i_graph_transfer_ops_definitions.h" #include "tensorflow/core/kernels/i_remote_fused_graph_executor.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/public/session.h" @@ -47,21 +48,19 @@ class GraphTransfererTest : public ::testing::Test { GraphTransferer gt_; }; -static const std::vector OP_TYPES{ - "INPUT", "OUTPUT", "Conv2D", "MaxPool", "NoOp", "Add", "Const", "Softmax"}; const RemoteFusedGraphExecuteUtils::TensorShapeMap EMPTY_OUTPUT_TENSOR_MAP; class TestGraphTransferOpsDefinitions : public IGraphTransferOpsDefinitions { public: - int GetTotalOpsCount() const final { return OP_TYPES.size(); } + int GetTotalOpsCount() const final { return op_types_.size(); } -int GetOpIdFor(const string& op_type, const DataTypeVector&) const final { - for (int i = 0; i < OP_TYPES.size(); ++i) { - if (OP_TYPES[i] == op_type) { - return i; + int GetOpIdFor(const string& op_type, const DataTypeVector&) const final { + for (int i = 0; i < op_types_.size(); ++i) { + if (op_types_[i] == op_type) { + return i; + } } - } - return -1; + return -1; } GraphTransferInfo::Destination GetTransferDestination() const final { @@ -69,6 +68,9 @@ GraphTransferInfo::Destination GetTransferDestination() const final { } private: + const std::vector op_types_{"INPUT", "OUTPUT", "Conv2D", + "MaxPool", "NoOp", "Add", + "Const", "Softmax", "Identity"}; } TEST_GRAPH_TRANSFER_OPS_DEFINITIONS; static Output BuildAddOps(const Scope& scope, const Input& x, const Input& y) { @@ -312,7 +314,7 @@ TEST_F(GraphTransfererTest, LoadAddGraphWithOutputTensorMap) { const std::vector output_node_names = {NAME_A_PLUS_B}; status = gt_.LoadGraphFromProto(TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, def, inputs, output_node_names, false); - ASSERT_TRUE(status.ok()); + TF_ASSERT_OK(status); } TEST_F(GraphTransfererTest, LoadConvGraph) { @@ -330,7 +332,7 @@ TEST_F(GraphTransfererTest, LoadConvGraph) { gt_.GetGraphTransferInfo().const_node_info_size(); ASSERT_EQ(2, const_node_count); const int op_node_count = gt_.GetGraphTransferInfo().node_info_size(); - ASSERT_EQ(3, op_node_count); + ASSERT_EQ(4, op_node_count); const GraphTransferInfo::NodeInfo* params_conv = FindNodeInfo(gt_, "conv"); ASSERT_TRUE(params_conv != nullptr); const int id = params_conv->node_id(); @@ -356,7 +358,7 @@ TEST_F(GraphTransfererTest, LoadMaxPoolGraph) { gt_.GetGraphTransferInfo().const_node_info_size(); ASSERT_EQ(2, const_node_count); const int op_node_count = gt_.GetGraphTransferInfo().node_info_size(); - ASSERT_EQ(3, op_node_count); + ASSERT_EQ(4, op_node_count); const GraphTransferInfo::NodeInfo* params_max_pool = FindNodeInfo(gt_, "maxpool"); ASSERT_TRUE(params_max_pool != nullptr); diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc index 518b399c374..660ffd268df 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc @@ -27,6 +27,8 @@ namespace tensorflow { constexpr const char* const INPUT_OP_NAME = "INPUT"; constexpr const char* const OUTPUT_OP_NAME = "OUTPUT"; +constexpr int ALIGNMENT_BYTES = 16; + const bool DBG_DUMP_VERIFICATION_STRING = false; const int DBG_LEVEL = 0; // -2: verbose, -1: debug, 0: info const bool DBG_USE_DUMMY_INPUT = false; @@ -34,6 +36,22 @@ const bool DBG_USE_SAMPLE_INPUT = false; const int64 FLAG_ENABLE_PANDA_BINARY_INPUT = 0x01; const bool DBG_DUMP_INPUT_TENSOR_AS_FLOAT_DATA = false; +static string AddPort(const string& node_name) { + if (node_name.find(':') != string::npos) { + return node_name; + } else { + return strings::StrCat(node_name, ":", 0); + } +} + +static uint8* FindAlignedPointer(uint8* ptr) { + const uintptr_t data_ptr_int = reinterpret_cast(ptr); + const int shift_count = + (ALIGNMENT_BYTES - data_ptr_int % ALIGNMENT_BYTES) % ALIGNMENT_BYTES; + uint8* data_ptr = ptr + shift_count; + return data_ptr; +} + /* static */ GraphTransferInfo::NodeInfo* HexagonControlWrapper::FindNodeInfo( const string& name, GraphTransferInfo* graph_transfer_info) { for (GraphTransferInfo::NodeInfo& node_info : @@ -60,18 +78,57 @@ bool HexagonControlWrapper::Init(const RemoteFusedGraphExecuteInfo& info) { std::vector outputs; RemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto( info, &inputs, &outputs); - graph_transferer_.LoadGraphFromProto( + Status status = graph_transferer_.LoadGraphFromProto( HexagonOpsDefinitions::getInstance(), info.remote_graph(), inputs, outputs, false // shape_inference_for_unknown_shape - ); + ); + TF_CHECK_OK(status) << status; } else { // If graph transfer info is attached, just import it. graph_transferer_.SetSerializedGraphTransferInfo( info.serialized_executor_parameters()); } execute_info_ = &info; - return soc_interface_Init(); + bool success = soc_interface_Init(); + if (!success) { + LOG(ERROR) << "Hexagon initialization was failed. See log output."; + return false; + } + const GraphTransferInfo& gt_info = graph_transferer_.GetGraphTransferInfo(); + std::vector input_sizes; + std::vector output_sizes; + CHECK_NOTNULL(execute_info_); + for (int i = 0; i < execute_info_->graph_input_node_name_size(); ++i) { + const string& input = execute_info_->graph_input_node_name(i); + LOG(INFO) << "Add input: " << input << ", " << i; + CHECK(input_port_map_.emplace(AddPort(input), i).second); + const RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_type = + execute_info_->default_graph_input_tensor_shape(i); + int64 buf_size = DataTypeSize(shape_type.dtype()); + for (const TensorShapeProto::Dim& dim : shape_type.shape().dim()) { + buf_size *= dim.size(); + } + input_sizes.emplace_back(static_cast(buf_size)); + } + for (int i = 0; i < execute_info_->graph_output_node_name_size(); ++i) { + const string& output = execute_info_->graph_output_node_name(i); + CHECK(output_port_map_.emplace(AddPort(output), i).second); + const RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_type = + execute_info_->default_graph_output_tensor_shape(i); + + int64 buf_size = DataTypeSize(shape_type.dtype()); + for (const TensorShapeProto::Dim& dim : shape_type.shape().dim()) { + buf_size *= dim.size(); + } + output_sizes.emplace_back(static_cast(buf_size)); + } + + LOG(INFO) << "Allocate inout buffer"; + success &= soc_interface_AllocateInOutNodeBuffers( + input_sizes.size(), input_sizes.data(), output_sizes.size(), + output_sizes.data()); + return success; } bool HexagonControlWrapper::Finalize() { return soc_interface_Finalize(); } @@ -86,9 +143,6 @@ bool HexagonControlWrapper::SetupGraph() { GraphTransferInfo::NodeInfo* node_info = FindNodeInfo(graph_input.name(), &graph_transfer_info); CHECK_NE(node_info, nullptr); - node_info->set_type_name(INPUT_OP_NAME); - node_info->set_soc_op_id( - HexagonOpsDefinitions::getInstance().GetOpIdFor(INPUT_OP_NAME, {})); } // Generate a new output node which is connected to graph output node @@ -202,12 +256,8 @@ bool HexagonControlWrapper::SetupGraph() { auto data = dummy_const_data_.emplace( std::piecewise_construct, std::make_tuple(node_id), std::make_tuple()); CHECK(data.second); - const int additional_bytes_for_alignment = 16; - data.first->second.resize(data_size + additional_bytes_for_alignment - 1); - const uintptr_t data_ptr_int = - reinterpret_cast(data.first->second.data()); - const int shift_count = (16 - data_ptr_int % 16) % 16; - uint8* data_ptr = data.first->second.data() + shift_count; + data.first->second.resize(data_size + ALIGNMENT_BYTES - 1); + uint8* data_ptr = FindAlignedPointer(data.first->second.data()); std::memcpy(data_ptr, params.data().data(), data_size); soc_interface_AppendConstNode(params.name().c_str(), node_id + NODE_ID_OFFSET, shape_0, shape_1, @@ -267,27 +317,37 @@ bool HexagonControlWrapper::TeardownGraph() { return soc_interface_TeardownGraph(); } -bool HexagonControlWrapper::FillInputNode(const string& node_name, - const ConstByteArray bytes) { - uint64 byte_size; - const int x = 1; - const int y = 299; - const int z = 299; - const int d = 3; - if (DBG_USE_DUMMY_INPUT) { - const int array_length = x * y * z * d; - byte_size = array_length * sizeof(float); - dummy_input_float_.resize(array_length); - std::memset(dummy_input_float_.data(), 0, byte_size); - } else { - CHECK(std::get<2>(bytes) == DT_FLOAT); - byte_size = std::get<1>(bytes); - dummy_input_float_.resize(byte_size / sizeof(float)); - std::memcpy(dummy_input_float_.data(), std::get<0>(bytes), byte_size); +bool HexagonControlWrapper::FillInputNode( + const string& node_name, + const std::array& shape, + const ConstByteArray bytes) { + const string tensor_name = AddPort(node_name); + CHECK(input_port_map_.count(tensor_name) > 0); + const int port = input_port_map_.at(tensor_name); + if (input_tensor_data_.count(port) <= 0) { + input_tensor_data_.emplace(port, std::vector{}); } - return soc_interface_FillInputNodeFloat( - x, y, z, d, reinterpret_cast(dummy_input_float_.data()), - byte_size); + std::vector& input_tensor_data = input_tensor_data_.at(port); + + // hexagon only supports 32bit dimension + const int x = static_cast(shape[0]); + const int y = static_cast(shape[1]); + const int z = static_cast(shape[2]); + const int d = static_cast(shape[3]); + + const uint64 byte_size = x * y * z * d * DataTypeSize(std::get<2>(bytes)); + CHECK_EQ(byte_size, std::get<1>(bytes)); + input_tensor_data.resize(byte_size + ALIGNMENT_BYTES); + uint8* data_ptr = FindAlignedPointer(input_tensor_data.data()); + + if (DBG_USE_DUMMY_INPUT) { + std::memset(data_ptr, 0, byte_size); + } else { + std::memcpy(data_ptr, std::get<0>(bytes), byte_size); + } + + return soc_interface_FillInputNodeWithPort(port, x, y, z, d, data_ptr, + byte_size); } bool HexagonControlWrapper::ReadOutputNode( @@ -304,26 +364,28 @@ bool HexagonControlWrapper::ReadOutputNode( break; } } - std::vector outputs; + std::vector outputs; ReadOutputNode(node_name, &outputs); CHECK_EQ(1, outputs.size()); - IRemoteFusedGraphExecutor::ByteArray& output = outputs[0]; + ByteArray& output = outputs[0]; Tensor* output_tensor = tensor_allocator(output_shape); CHECK(output_tensor->TotalBytes() >= std::get<1>(output)) << output_tensor->TotalBytes() << ", " << std::get<1>(output); - // TODO(satok): Avoid specifying float - std::memcpy(output_tensor->flat().data(), std::get<0>(output), - std::get<1>(output)); + TF_CHECK_OK(RemoteFusedGraphExecuteUtils::CopyByteArrayToTensor( + std::get<0>(output), std::get<1>(output), output_tensor)); } bool HexagonControlWrapper::ReadOutputNode( const string& node_name, std::vector* const outputs) { CHECK(outputs != nullptr); ByteArray output; - soc_interface_ReadOutputNodeFloat(node_name.c_str(), &std::get<0>(output), - &std::get<1>(output)); + const string tensor_name = AddPort(node_name); + CHECK(output_port_map_.count(tensor_name) > 0); + const int port = output_port_map_.at(tensor_name); + soc_interface_ReadOutputNodeWithPort(port, &std::get<0>(output), + &std::get<1>(output)); // TODO: Accept all results - std::get<2>(output) = DT_FLOAT; + // std::get<2>(output) = DT_FLOAT; outputs->emplace_back(output); return true; } @@ -347,7 +409,9 @@ bool HexagonControlWrapper::FillInputNode(const string& node_name, } } } - FillInputNode(node_name, ba); + const std::array shape = + GraphTransferer::ToTensorShapeArray(tensor.shape()); + FillInputNode(node_name, shape, ba); return true; } @@ -360,7 +424,9 @@ bool HexagonControlWrapper::Finalize() { return false; } bool HexagonControlWrapper::SetupGraph() { return false; } bool HexagonControlWrapper::ExecuteGraph() { return false; } bool HexagonControlWrapper::TeardownGraph() { return false; } -bool HexagonControlWrapper::FillInputNode(const string&, const ConstByteArray) { +bool HexagonControlWrapper::FillInputNode( + const string&, const std::array&, + const ConstByteArray) { return false; } bool HexagonControlWrapper::FillInputNode(const string&, const Tensor&) { diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h index 97448884e1d..209ac9dbf4a 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h +++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_ #define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_ +#include #include #include "tensorflow/core/framework/types.h" @@ -32,6 +33,9 @@ namespace tensorflow { */ class HexagonControlWrapper final : public IRemoteFusedGraphExecutor { public: + using ByteArray = + std::tuple; + HexagonControlWrapper() = default; int GetVersion() final; bool Init(const RemoteFusedGraphExecuteInfo& info) final; @@ -45,7 +49,13 @@ class HexagonControlWrapper final : public IRemoteFusedGraphExecutor { bool ReadOutputNode(const string& node_name, std::vector* outputs); private: - bool FillInputNode(const string& node_name, const ConstByteArray bytes); + using ConstByteArray = std::tuple; + + bool FillInputNode( + const string& node_name, + const std::array& shape, + const ConstByteArray bytes); // CAVEAT: Need offset as HVX library reserves some ids static constexpr int NODE_ID_OFFSET = 0x10000; @@ -57,11 +67,15 @@ class HexagonControlWrapper final : public IRemoteFusedGraphExecutor { GraphTransferer graph_transferer_{}; // Dummy float array for input node. // TODO(satok): Use actual data passed by FillInputNode and remove - std::vector dummy_input_float_{}; + // std::vector dummy_input_float_{}; + std::unordered_map> input_tensor_data_{}; // Dummy byte array for cosnt node. // TODO(satok): Remove std::unordered_map> dummy_const_data_{}; + std::unordered_map input_port_map_{}; + std::unordered_map output_port_map_{}; + TF_DISALLOW_COPY_AND_ASSIGN(HexagonControlWrapper); }; diff --git a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc index 54ba101501f..cb9091e29f8 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc @@ -46,8 +46,7 @@ adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp namespace tensorflow { -using ByteArray = IRemoteFusedGraphExecutor::ByteArray; -using ConstByteArray = IRemoteFusedGraphExecutor::ConstByteArray; +using ByteArray = HexagonControlWrapper::ByteArray; constexpr const char* const IMAGE_FILENAME = "/data/local/tmp/img_299x299.bmp"; constexpr const char* const MODEL_FILENAME = @@ -87,8 +86,7 @@ static void DumpTop10Results(const int byte_size, 10 /* show top_n results */); } -static void DumpTop10Results( - const std::vector& outputs) { +static void DumpTop10Results(const std::vector& outputs) { CHECK(outputs.size() == 1); const int byte_size = std::get<1>(outputs.at(0)); const float* float_array = @@ -96,9 +94,8 @@ static void DumpTop10Results( DumpTop10Results(byte_size, float_array); } -static void CheckFirstResult( - const std::vector& outputs, - const int expected_first_id) { +static void CheckFirstResult(const std::vector& outputs, + const int expected_first_id) { EXPECT_GE(outputs.size(), 1); const int byte_size = std::get<1>(outputs.at(0)); const int element_count = byte_size / sizeof(float); @@ -240,7 +237,7 @@ static void RunInferenceByHexagonControlWrapper( } // 5-1. Read output node's outputs - std::vector outputs; + std::vector outputs; hexagon_control_wrapper.ReadOutputNode("softmax", &outputs); // 5-2. Dump results diff --git a/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc b/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc index a4b79e6ec4f..2b7585aed1f 100644 --- a/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc +++ b/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.cc @@ -350,6 +350,8 @@ HexagonOpsDefinitions::BuildOpNameToSocOpTypeMap() { #ifdef ENABLE_EXPERIMENTAL_HEXNN_OPS EmplaceOpType("QuantizedMul", {}, SupportedOpType::QUANTIZED_MUL_8x8to32, &op_map); + EmplaceOpType("QuantizedAdd", {}, SupportedOpType::QUANTIZED_ADD_8p8to32, + &op_map); EmplaceOpType("Pad", {}, SupportedOpType::PAD_F, &op_map); EmplaceOpType("SpaceToBatchND", {}, SupportedOpType::SPACE_TO_BATCH_ND_F, &op_map), @@ -359,6 +361,11 @@ HexagonOpsDefinitions::BuildOpNameToSocOpTypeMap() { &op_map); EmplaceOpType("ConcatV2", {}, SupportedOpType::CONCAT_V2_F, &op_map); EmplaceOpType("Conv2DBackpropInput", {}, SupportedOpType::DECONV_F, &op_map); + + EmplaceOpType("Tanh", {}, SupportedOpType::TANH_F, &op_map); + EmplaceOpType("Split", {}, SupportedOpType::SPLIT_F, &op_map); + EmplaceOpType("Transpose", {}, SupportedOpType::TRANSPOSE_F, &op_map); + EmplaceOpType("Concat", {}, SupportedOpType::CONCAT_F, &op_map); #endif return op_map; }; diff --git a/tensorflow/core/kernels/i_remote_fused_graph_executor.h b/tensorflow/core/kernels/i_remote_fused_graph_executor.h index fe62a259de8..09d1f43ff11 100644 --- a/tensorflow/core/kernels/i_remote_fused_graph_executor.h +++ b/tensorflow/core/kernels/i_remote_fused_graph_executor.h @@ -25,10 +25,6 @@ namespace tensorflow { class IRemoteFusedGraphExecutor { public: - using ByteArray = - std::tuple; - using ConstByteArray = std::tuple; using TensorAllocatorFunc = std::function; IRemoteFusedGraphExecutor() = default; diff --git a/tensorflow/core/kernels/iterator_ops.cc b/tensorflow/core/kernels/iterator_ops.cc index ed350d98331..b6825b4e959 100644 --- a/tensorflow/core/kernels/iterator_ops.cc +++ b/tensorflow/core/kernels/iterator_ops.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -160,16 +161,25 @@ class MakeIteratorOp : public OpKernel { } }; -class OneShotIteratorOp : public OpKernel { +class OneShotIteratorOp : public AsyncOpKernel { public: - explicit OneShotIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + explicit OneShotIteratorOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx), + thread_pool_(new thread::ThreadPool( + ctx->env(), ThreadOptions(), + strings::StrCat("one_shot_iterator_initialization_thread_", + SanitizeThreadSuffix(def().name())), + 1 /* num_threads */, false /* low_latency_hint */)) + + { string shared_name; OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &shared_name)); OP_REQUIRES(ctx, shared_name.empty(), errors::InvalidArgument("OneShotIteratorOp does not currently " "support the 'shared_name' attr.")); - OP_REQUIRES_OK(ctx, - ctx->GetAttr("dataset_factory", &dataset_factory_func_)); + const NameAttrList* dataset_factory_func; + OP_REQUIRES_OK(ctx, ctx->GetAttr("dataset_factory", &dataset_factory_func)); + dataset_factory_func_ = *dataset_factory_func; OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); } @@ -187,102 +197,159 @@ class OneShotIteratorOp : public OpKernel { // NOTE(mrry): This is based on `ResourceOpKernel::Compute()`, // but due to the fact that `ResourceOpKernel::CreateResource()` - // does not provide access to the `OpKernelContext*` and we need this - // to invoke the factory function, it's not possible to implement - // this kernel by implementing `CreateResource()`. - void Compute(OpKernelContext* ctx) override { - mutex_lock l(mu_); - if (iterator_resource_ == nullptr) { - ResourceMgr* mgr = ctx->resource_manager(); - OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); - - // Create an IteratorResource that will hold the iterator for this op. - IteratorResource* resource; - OP_REQUIRES_OK( - ctx, - mgr->LookupOrCreate( - cinfo_.container(), cinfo_.name(), &resource, - [this](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - *ret = new IteratorResource(output_dtypes_, output_shapes_); - return Status::OK(); - })); - Status s = VerifyTypesMatch(output_dtypes_, resource->output_dtypes()); - s.Update( - VerifyShapesCompatible(output_shapes_, resource->output_shapes())); - if (TF_PREDICT_FALSE(!s.ok())) { - resource->Unref(); - ctx->SetStatus(s); + // does not provide access to the `OpKernelContext*` and we need + // this to invoke the factory function, it's not possible to + // implement this kernel by implementing `CreateResource()`. + // Furthermore, due to the fact that this kernel might block when + // running the initialization function, we must implement this + // kernel as an async kernel. + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + { + mutex_lock l(mu_); + if (iterator_resource_ == nullptr && initialization_status_.ok()) { + // The initialization thread will call `done`. + if (!initialization_started_) { + // TODO(mrry): Convert the initialization code to use + // callbacks instead of wasting a thread. + thread_pool_->Schedule([this, ctx, done]() { Init(ctx, done); }); + initialization_started_ = true; + } else { + done_callbacks_.emplace_back(ctx, std::move(done)); + } return; } - iterator_resource_ = resource; - - // Call the dataset_factory_func_ to create a new dataset, - // over which this op will iterate. - FunctionLibraryRuntime::Handle f_handle; - OP_REQUIRES_OK(ctx, - ctx->function_library()->Instantiate( - dataset_factory_func_->name(), - AttrSlice(&dataset_factory_func_->attr()), &f_handle)); - FunctionLibraryRuntime::Options opts; - opts.cancellation_manager = ctx->cancellation_manager(); - // Choose a step ID that is guaranteed not to clash with any - // Session-generated step ID. DirectSession only generates - // non-negative step IDs (contiguous, starting from 0), and - // MasterSession generates 56-bit random step IDs whose MSB is - // always 0, so a negative random step ID should suffice. - opts.step_id = -std::abs(static_cast(random::New64())); - ScopedStepContainer step_container( - opts.step_id, [ctx](const string& name) { - ctx->resource_manager()->Cleanup(name).IgnoreError(); - }); - opts.step_container = &step_container; - opts.runner = ctx->runner(); - Notification n; - Status factory_status; - std::vector return_values; - ctx->function_library()->Run(opts, f_handle, {}, &return_values, - [&n, &factory_status](Status s) { - factory_status.Update(s); - n.Notify(); - }); - n.WaitForNotification(); - OP_REQUIRES_OK(ctx, factory_status); - OP_REQUIRES( - ctx, - return_values.size() == 1 && - return_values[0].dtype() == DT_RESOURCE && - TensorShapeUtils::IsScalar(return_values[0].shape()), - errors::InvalidArgument("The `dataset_factory` function must return " - "a single scalar of dtype DT_RESOURCE.")); - - // Retrieve the dataset that was created in the factory function. - DatasetBase* dataset; - const ResourceHandle& dataset_resource = - return_values[0].flat()(0); - OP_REQUIRES_OK(ctx, LookupResource(ctx, dataset_resource, &dataset)); - core::ScopedUnref unref_dataset(dataset); - - // Create an iterator for the dataset that was created in the - // factory function. This transfers ownership of the dataset to - // the iterator, so we can delete it from the resource manager. - OP_REQUIRES_OK(ctx, - iterator_resource_->set_iterator(dataset->MakeIterator())); - OP_REQUIRES_OK(ctx, DeleteResource(ctx, dataset_resource)); } - Tensor* handle; - OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &handle)); - handle->scalar()() = MakeResourceHandle( - ctx, cinfo_.container(), cinfo_.name()); + ProduceOutput(ctx, std::move(done)); } private: - const NameAttrList* dataset_factory_func_; + void Init(OpKernelContext* ctx, DoneCallback done) { + IteratorResource* iterator = nullptr; + ContainerInfo cinfo; + Status s = TryInit(ctx, &iterator, &cinfo); + + std::vector> callbacks_to_run; + { + mutex_lock l(mu_); + if (s.ok()) { + iterator_resource_ = iterator; + cinfo_ = cinfo; + } + initialization_status_ = s; + std::swap(done_callbacks_, callbacks_to_run); + } + + for (auto&& ctx_done : callbacks_to_run) { + ProduceOutput(ctx_done.first, std::move(ctx_done.second)); + } + ProduceOutput(ctx, std::move(done)); + } + + Status TryInit(OpKernelContext* ctx, IteratorResource** iterator, + ContainerInfo* cinfo) { + TF_RETURN_IF_ERROR(cinfo->Init(ctx->resource_manager(), def())); + + // Create an IteratorResource that will hold the iterator for this op. + TF_RETURN_IF_ERROR( + ctx->resource_manager()->LookupOrCreate( + cinfo->container(), cinfo->name(), iterator, + [this](IteratorResource** ret) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = new IteratorResource(output_dtypes_, output_shapes_); + return Status::OK(); + })); + + core::ScopedUnref unref_iterator(*iterator); + + TF_RETURN_IF_ERROR( + VerifyTypesMatch(output_dtypes_, (*iterator)->output_dtypes())); + TF_RETURN_IF_ERROR( + VerifyShapesCompatible(output_shapes_, (*iterator)->output_shapes())); + + // Call the dataset_factory_func_ to create a new dataset, + // over which this op will iterate. + FunctionLibraryRuntime::Handle f_handle; + TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate( + dataset_factory_func_.name(), AttrSlice(&dataset_factory_func_.attr()), + &f_handle)); + FunctionLibraryRuntime::Options opts; + opts.cancellation_manager = ctx->cancellation_manager(); + // Choose a step ID that is guaranteed not to clash with any + // Session-generated step ID. DirectSession only generates + // non-negative step IDs (contiguous, starting from 0), and + // MasterSession generates 56-bit random step IDs whose MSB is + // always 0, so a negative random step ID should suffice. + opts.step_id = -std::abs(static_cast(random::New64())); + ScopedStepContainer step_container(opts.step_id, [ctx](const string& name) { + ctx->resource_manager()->Cleanup(name).IgnoreError(); + }); + opts.step_container = &step_container; + opts.runner = ctx->runner(); + Notification n; + Status factory_status; + std::vector return_values; + ctx->function_library()->Run(opts, f_handle, {}, &return_values, + [&n, &factory_status](Status s) { + factory_status.Update(s); + n.Notify(); + }); + n.WaitForNotification(); + TF_RETURN_IF_ERROR(factory_status); + if (return_values.size() != 1 || return_values[0].dtype() != DT_RESOURCE || + !TensorShapeUtils::IsScalar(return_values[0].shape())) { + return errors::InvalidArgument( + "The `dataset_factory` function must return " + "a single scalar of dtype DT_RESOURCE."); + } + + // Retrieve the dataset that was created in the factory function. + DatasetBase* dataset; + const ResourceHandle& dataset_resource = + return_values[0].flat()(0); + TF_RETURN_IF_ERROR(LookupResource(ctx, dataset_resource, &dataset)); + core::ScopedUnref unref_dataset(dataset); + + // Create an iterator for the dataset that was created in the + // factory function. This transfers ownership of the dataset to + // the iterator, so we can delete it from the resource manager. + TF_RETURN_IF_ERROR((*iterator)->set_iterator(dataset->MakeIterator())); + TF_RETURN_IF_ERROR(DeleteResource(ctx, dataset_resource)); + + (*iterator)->Ref(); + return Status::OK(); + } + + void ProduceOutput(OpKernelContext* ctx, DoneCallback done) { + Tensor* handle; + OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, TensorShape({}), &handle), + done); + Status s; + { + mutex_lock l(mu_); + s = initialization_status_; + if (s.ok()) { + handle->scalar()() = + MakeResourceHandle(ctx, cinfo_.container(), + cinfo_.name()); + } + } + OP_REQUIRES_OK_ASYNC(ctx, s, done); + done(); + } + + NameAttrList dataset_factory_func_; DataTypeVector output_dtypes_; std::vector output_shapes_; + std::unique_ptr thread_pool_; + mutex mu_; ContainerInfo cinfo_ GUARDED_BY(mu_); - IteratorResource* iterator_resource_ = nullptr; + IteratorResource* iterator_resource_ GUARDED_BY(mu_) = nullptr; + + bool initialization_started_ GUARDED_BY(mu_) = false; + Status initialization_status_ GUARDED_BY(mu_); + std::vector> done_callbacks_ + GUARDED_BY(mu_); }; class IteratorGetNextOp : public AsyncOpKernel { diff --git a/tensorflow/core/kernels/lmdb_reader_op.cc b/tensorflow/core/kernels/lmdb_reader_op.cc index 23cabe7b547..3bb07301b5a 100755 --- a/tensorflow/core/kernels/lmdb_reader_op.cc +++ b/tensorflow/core/kernels/lmdb_reader_op.cc @@ -13,18 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "lmdb.h" #include "tensorflow/core/framework/reader_op_kernel.h" #include "tensorflow/core/framework/reader_base.h" #include "tensorflow/core/lib/core/errors.h" #include +#include "lmdb.h" namespace tensorflow { -inline void MDB_CHECK(int mdb_status) { - CHECK_EQ(mdb_status, MDB_SUCCESS) << mdb_strerror(mdb_status); -} +#define MDB_CHECK(val) CHECK_EQ(val, MDB_SUCCESS) << mdb_strerror(val) class LMDBReader : public ReaderBase { public: @@ -131,4 +129,4 @@ class LMDBReaderOp : public ReaderOpKernel { REGISTER_KERNEL_BUILDER(Name("LMDBReader").Device(DEVICE_CPU), LMDBReaderOp); -} +} // namespace tensorflow diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_op.cc b/tensorflow/core/kernels/remote_fused_graph_execute_op.cc index aa3835ecc56..f0b2aa6e1f7 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_op.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_op.cc @@ -109,6 +109,12 @@ class RemoteFusedGraphExecuteOp : public OpKernel { TF_CHECK_OK(ctx->allocate_output(i, shape, &output)); return output; }); + } else { + // For compatibility purpose, returns an empty tensor with specified + // data type as output if no executor is used. + Tensor* output = nullptr; + TensorShape ts({}); + TF_CHECK_OK(ctx->allocate_output(i, ts, &output)); } } } diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc index 103b2be6914..dd9839d2453 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc @@ -1280,6 +1280,69 @@ RemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments( return true; } +/* static */ Status RemoteFusedGraphExecuteUtils::CopyByteArrayToTensor( + const void* src_ptr, const int src_size, Tensor* tensor) { + CHECK(tensor->TotalBytes() >= src_size) + << tensor->TotalBytes() << ", " << src_size; + void* dst_ptr; + switch (tensor->dtype()) { + case DT_FLOAT: + dst_ptr = tensor->flat().data(); + break; + case DT_DOUBLE: + dst_ptr = tensor->flat().data(); + break; + case DT_INT32: + dst_ptr = tensor->flat().data(); + break; + case DT_UINT8: + dst_ptr = tensor->flat().data(); + break; + case DT_INT16: + dst_ptr = tensor->flat().data(); + break; + case DT_INT8: + dst_ptr = tensor->flat().data(); + break; + case DT_STRING: + dst_ptr = tensor->flat().data(); + break; + case DT_INT64: + dst_ptr = tensor->flat().data(); + break; + case DT_BOOL: + dst_ptr = tensor->flat().data(); + break; + case DT_QINT8: + dst_ptr = tensor->flat().data(); + break; + case DT_QUINT8: + dst_ptr = tensor->flat().data(); + break; + case DT_QINT32: + dst_ptr = tensor->flat().data(); + break; + case DT_BFLOAT16: + dst_ptr = tensor->flat().data(); + break; + case DT_QINT16: + dst_ptr = tensor->flat().data(); + break; + case DT_QUINT16: + dst_ptr = tensor->flat().data(); + break; + case DT_UINT16: + dst_ptr = tensor->flat().data(); + break; + default: + CHECK(false) << "type " << tensor->dtype() << " is not supported."; + break; + } + CHECK_NOTNULL(dst_ptr); + std::memcpy(dst_ptr, src_ptr, src_size); + return Status::OK(); +} + /* static */ Status RemoteFusedGraphExecuteUtils::ReplaceInputNodeByPlaceHolder( const string& input, const DataType type, const TensorShape& shape, GraphDef* graph_def) { diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h index a80fc797841..1d4423ed46e 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h @@ -157,7 +157,7 @@ class RemoteFusedGraphExecuteUtils { const std::vector>& input_tensors, const bool dry_run_inference, GraphDef* graph_def); - // Build remote fused graph execute info + // Build remote fused graph execute info. static Status BuildRemoteFusedGraphExecuteInfo( const string& executor_name, const GraphDef& subgraph_def, const std::vector& inputs, const std::vector& outputs, @@ -165,31 +165,31 @@ class RemoteFusedGraphExecuteUtils { DataTypeVector* input_types, DataTypeVector* output_types); // Build remote fused graph execute op node by fusing specified subgraph - // as remote fused graph execute info + // as remote fused graph execute info. static Status BuildRemoteFusedGraphExecuteOpNode( const string& node_name, const string& executor_name, const GraphDef& subgraph_def, const std::vector& inputs, const std::vector& outputs, const bool require_shape_type, Graph* graph, Node** created_node); - // Build Identity node to forward remote graph node output + // Build Identity node to forward remote graph node output. static Status BuildIdentityOpNode(const string& node_name, const string& input_node_name, const int input_node_port, const DataType dt, Graph* graph, Node** created_node); - // Create clusters of given nodes + // Create clusters of given nodes. static Status ClusterizeNodes(const std::unordered_set& node_names, const GraphDef& graph_def, std::vector* cluster_infos); - // Build GraphDef of a given cluster + // Build GraphDef of a given cluster. static Status BuildClusterSubgraphDef(const ClusterInfo& cluster, const GraphDef& graph_def, GraphDef* subgraph_def); - // Build a cluster by given border + // Build a cluster by given border. // CAVEAT: The border must be consistent for one cluster. static Status BuildClusterByBorder(const std::vector& border_inputs, const std::vector& border_outputs, @@ -211,7 +211,7 @@ class RemoteFusedGraphExecuteUtils { const bool require_shape_type, GraphDef* output_graph_def); - // Fuse subgraph of specified nodes + // Fuse subgraph of specified nodes. static Status FuseRemoteGraphByNodeNames( const GraphDef& input_graph_def, const std::vector& inputs, const std::vector& outputs, @@ -220,7 +220,7 @@ class RemoteFusedGraphExecuteUtils { const string& remote_fused_graph_executor_name, const bool require_shape_type, GraphDef* output_graph_def); - // Fuse subgraph of specified border + // Fuse subgraph of specified border. static Status FuseRemoteGraphByBorder( const GraphDef& input_graph_def, const std::vector& inputs, const std::vector& outputs, @@ -230,7 +230,7 @@ class RemoteFusedGraphExecuteUtils { const string& remote_graph_executor_name, const bool require_shape_type, GraphDef* output_graph_def); - // Place arguments to fuse remote graph + // Place arguments to fuse remote graph. static Status PlaceRemoteGraphArguments( const std::vector& inputs, const std::vector& outputs, const std::unordered_set& fused_node_names, @@ -239,7 +239,7 @@ class RemoteFusedGraphExecuteUtils { const string& remote_fused_graph_node_name, const string& remote_graph_executor_name, GraphDef* graph_def); - // Fuse remote graph by placed arguments + // Fuse remote graph by placed arguments. static Status FuseRemoteGraphByPlacedArguments( const GraphDef& input_graph_def, const std::vector>& input_tensors, @@ -249,6 +249,15 @@ class RemoteFusedGraphExecuteUtils { const GraphDef& input_graph_def, const std::vector>& input_tensors); + // Copy a byte array to a tensor data. Though tensor data must be + // updated with typed information in general, we can't guarantee that + // returned values from a remote processor has typed information because + // a logic running in the remote processor possibly be in a separate binary + // which may not link tensorflow libraries. To deal with this situation, + // remote fused graph needs to overwrite the tensor data by a byte array. + static Status CopyByteArrayToTensor(const void* src_ptr, const int src_size, + Tensor* tensor); + private: static void EmplaceTensorShapeType(const string& name, const Tensor& tensor, TensorShapeMap* tensor_shape_map); diff --git a/tensorflow/core/kernels/resize_area_op.cc b/tensorflow/core/kernels/resize_area_op.cc index ad94de89dba..ada50dfb70d 100644 --- a/tensorflow/core/kernels/resize_area_op.cc +++ b/tensorflow/core/kernels/resize_area_op.cc @@ -33,7 +33,6 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; namespace { - struct CachedInterpolation { int64 start; int64 end; @@ -41,7 +40,7 @@ struct CachedInterpolation { float end_minus_one_scale; bool needs_bounding; }; -}; +} // namespace template class ResizeAreaOp : public OpKernel { @@ -170,7 +169,7 @@ class ResizeAreaOp : public OpKernel { : (v + 1 > in_x1 ? in_x1 - v : 1.0); v = ceil(in_x1); - x_interp.end = ceil(in_x1); + x_interp.end = v; v = x_interp.end - 1; x_interp.end_minus_one_scale = v < in_x ? (v + 1 > in_x1 ? st.width_scale : v + 1 - in_x) diff --git a/tensorflow/core/kernels/summary_tensor_op_test.cc b/tensorflow/core/kernels/summary_tensor_op_test.cc index 0006a71bd7b..010ff443fa7 100644 --- a/tensorflow/core/kernels/summary_tensor_op_test.cc +++ b/tensorflow/core/kernels/summary_tensor_op_test.cc @@ -85,8 +85,15 @@ TEST_F(SummaryTensorOpV2Test, BasicPluginData) { ASSERT_EQ(0, out_tensor->dims()); Summary summary; ParseProtoUnlimited(&summary, out_tensor->scalar()()); - ASSERT_EQ(1, summary.value_size()); + + // Check the content of the tensor stored in the summary. + Tensor string_content_tensor; + CHECK(string_content_tensor.FromProto(summary.value(0).tensor())); + ASSERT_EQ("some string tensor content", + string_content_tensor.scalar()()); + + // Check plugin-related data. ASSERT_EQ("tag_foo", summary.value(0).tag()); ASSERT_EQ(2, summary.value(0).metadata().plugin_data_size()); ASSERT_EQ("foo", summary.value(0).metadata().plugin_data(0).plugin_name()); diff --git a/tensorflow/core/lib/strings/strcat.cc b/tensorflow/core/lib/strings/strcat.cc index 3e864c4f282..46a45a66783 100644 --- a/tensorflow/core/lib/strings/strcat.cc +++ b/tensorflow/core/lib/strings/strcat.cc @@ -38,7 +38,7 @@ AlphaNum::AlphaNum(Hex hex) { // We accomplish minimum width by OR'ing in 0x10000 to the user's value, // where 0x10000 is the smallest hex number that is as wide as the user // asked for. - uint64 mask = ((static_cast(1) << (width - 1) * 4)) | value; + uint64 mask = (static_cast(1) << (width - 1) * 4) | value; static const char hexdigits[] = "0123456789abcdef"; do { *--writer = hexdigits[value & 0xF]; diff --git a/tensorflow/core/ops/bitwise_ops.cc b/tensorflow/core/ops/bitwise_ops.cc new file mode 100644 index 00000000000..2005d5e1028 --- /dev/null +++ b/tensorflow/core/ops/bitwise_ops.cc @@ -0,0 +1,64 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("Invert") + .Input("x: T") + .Output("y: T") + .Attr("T: {int8, int16, int32, int64, uint8, uint16}") + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Flips all bits elementwise. + +The result will have exactly those bits set, that are not set in `x`. The +computation is performed on the underlying representation of x. +)doc"); + +#define BINARY_BITWISE() \ + Input("x: T") \ + .Input("y: T") \ + .Output("z: T") \ + .SetIsCommutative() \ + .Attr("T: {int8, int16, int32, int64, uint8, uint16}") \ + .SetShapeFn(shape_inference::UnchangedShape) + +REGISTER_OP("BitwiseAnd").BINARY_BITWISE().Doc(R"doc( +Elementwise computes the bitwise AND of `x` and `y`. + +The result will have those bits set, that are set in both `x` and `y`. The +computation is performed on the underlying representations of `x` and `y`. +)doc"); + +REGISTER_OP("BitwiseOr").BINARY_BITWISE().Doc(R"doc( +Elementwise computes the bitwise OR of `x` and `y`. + +The result will have those bits set, that are set in `x`, `y` or both. The +computation is performed on the underlying representations of `x` and `y`. +)doc"); + +REGISTER_OP("BitwiseXor").BINARY_BITWISE().Doc(R"doc( +Elementwise computes the bitwise XOR of `x` and `y`. + +The result will have those bits set, that are different in `x` and `y`. The +computation is performed on the underlying representations of `x` and `y`. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 4b10e5b79e3..edc21b60291 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -3725,6 +3725,96 @@ op { } } } +op { + name: "BitwiseAnd" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_UINT16 + } + } + } + is_commutative: true +} +op { + name: "BitwiseOr" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_UINT16 + } + } + } + is_commutative: true +} +op { + name: "BitwiseXor" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_UINT16 + } + } + } + is_commutative: true +} op { name: "BroadcastArgs" input_arg { @@ -10428,6 +10518,31 @@ op { version: 17 } } +op { + name: "Invert" + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_UINT16 + } + } + } +} op { name: "InvertPermutation" input_arg { @@ -18798,6 +18913,31 @@ op { } } } +op { + name: "RemoteFusedGraphExecute" + input_arg { + name: "inputs" + type_list_attr: "Tinputs" + } + output_arg { + name: "outputs" + type_list_attr: "Toutputs" + } + attr { + name: "Tinputs" + type: "list(type)" + has_minimum: true + } + attr { + name: "Toutputs" + type: "list(type)" + has_minimum: true + } + attr { + name: "serialized_remote_fused_graph_execute_info" + type: "string" + } +} op { name: "RepeatDataset" input_arg { diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index f0fcd028350..4bf3d24324d 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -24,6 +24,25 @@ using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; +namespace { + +Status DequeueManyV2Shape(InferenceContext* c, ShapeHandle n_shape) { + auto* t = c->input_handle_shapes_and_types(0); + if (t != nullptr && t->size() == c->num_outputs()) { + for (int i = 0; i < c->num_outputs(); ++i) { + ShapeHandle combined_shape; + TF_RETURN_IF_ERROR( + c->Concatenate(n_shape, (*t)[i].shape, &combined_shape)); + c->set_output(i, combined_shape); + } + return Status::OK(); + } else { + return shape_inference::UnknownShape(c); + } +} + +} // namespace + // -------------------------------------------------------------------------- REGISTER_OP("DynamicPartition") @@ -711,7 +730,19 @@ REGISTER_OP("QueueDequeueManyV2") .Output("components: component_types") .Attr("component_types: list(type) >= 1") .Attr("timeout_ms: int = -1") - .SetShapeFn(shape_inference::UnknownShape) + .SetShapeFn([](InferenceContext* c) { + ShapeHandle n_shape; + if (c->input_tensor(1) == nullptr) { + n_shape = c->Vector(InferenceContext::kUnknownDim); + } else { + const int32 n = c->input_tensor(1)->scalar()(); + if (n < 0) { + return errors::InvalidArgument("Input 'n' must be >= 0, but is ", n); + } + n_shape = c->Vector(n); + } + return DequeueManyV2Shape(c, n_shape); + }) .Doc(R"doc( Dequeues `n` tuples of one or more tensors from the given queue. @@ -781,7 +812,9 @@ REGISTER_OP("QueueDequeueUpToV2") .Output("components: component_types") .Attr("component_types: list(type) >= 1") .Attr("timeout_ms: int = -1") - .SetShapeFn(shape_inference::UnknownShape) + .SetShapeFn([](InferenceContext* c) { + return DequeueManyV2Shape(c, c->Vector(InferenceContext::kUnknownDim)); + }) .Doc(R"doc( Dequeues `n` tuples of one or more tensors from the given queue. diff --git a/tensorflow/core/ops/data_flow_ops_test.cc b/tensorflow/core/ops/data_flow_ops_test.cc index 53c843eb60b..6f59db3a1b4 100644 --- a/tensorflow/core/ops/data_flow_ops_test.cc +++ b/tensorflow/core/ops/data_flow_ops_test.cc @@ -149,4 +149,99 @@ TEST(DataFlowOpsTest, TensorArrayV3) { INFER_ERROR("Shape must be rank 0 but is rank 1", op, "[2]"); } +TEST(DataFlowOpsTest, QueueDequeueV2ShapeFn) { + ShapeInferenceTestOp op("QueueDequeueV2"); + TF_ASSERT_OK(NodeDefBuilder("test", op.name) + .Input("handle", 0, DT_RESOURCE) + .Attr("component_types", {DT_FLOAT, DT_INT32}) + .Finalize(&op.node_def)); + + INFER_OK(op, "?", "?;?"); + + std::vector shapes_and_types; + op.input_resource_handle_shapes_and_types.push_back(&shapes_and_types); + INFER_OK(op, "?", "?;?"); + + // Wrong number of shapes provided by handle. + shapes_and_types.emplace_back("[1,?,3]", DT_FLOAT); + INFER_OK(op, "?", "?;?"); + + // Correct number of shapes provided by handle. + shapes_and_types.emplace_back("[?,2]", DT_FLOAT); + INFER_OK(op, "?", "[1,?,3];[?,2]"); +} + +TEST(DataFlowOpsTest, QueueDequeueManyV2ShapeFn) { + ShapeInferenceTestOp op("QueueDequeueManyV2"); + TF_ASSERT_OK(NodeDefBuilder("test", op.name) + .Input("handle", 0, DT_RESOURCE) + .Input("n", 0, DT_INT32) + .Attr("component_types", {DT_FLOAT, DT_INT32}) + .Finalize(&op.node_def)); + + //////////////////////////// + // Input n is not a constant. + INFER_OK(op, "?;?", "?;?"); + std::vector shapes_and_types; + op.input_resource_handle_shapes_and_types.push_back(&shapes_and_types); + op.input_resource_handle_shapes_and_types.push_back(nullptr); + // Wrong number of shapes provided by handle. + shapes_and_types.emplace_back("[1,?,3]", DT_FLOAT); + INFER_OK(op, "?;?", "?;?"); + // Correct number of shapes provided by handle. + shapes_and_types.emplace_back("[?,2]", DT_FLOAT); + INFER_OK(op, "?;?", "[?,1,?,3];[?,?,2]"); + + //////////////////////////// + // Input n is a constant. (set up test and repeat the cases from above). + Tensor n_tensor = test::AsScalar(12); + op.input_tensors.push_back(nullptr); + op.input_tensors.push_back(&n_tensor); + op.input_resource_handle_shapes_and_types.clear(); + shapes_and_types.clear(); + + INFER_OK(op, "?;?", "?;?"); + op.input_resource_handle_shapes_and_types.push_back(&shapes_and_types); + op.input_resource_handle_shapes_and_types.push_back(nullptr); + // Wrong number of shapes provided by handle. + shapes_and_types.emplace_back("[1,?,3]", DT_FLOAT); + INFER_OK(op, "?;?", "?;?"); + // Correct number of shapes provided by handle. + shapes_and_types.emplace_back("[?,2]", DT_FLOAT); + INFER_OK(op, "?;?", "[12,1,?,3];[12,?,2]"); + + n_tensor = test::AsScalar(-1); // invalid value of n. + INFER_ERROR("must be >= 0", op, "?;?"); +} + +TEST(DataFlowOpsTest, QueueDequeueUpToV2ShapeFn) { + // Results are the same regardless of what value is passed for n. + for (int pass = 0; pass < 2; ++pass) { + ShapeInferenceTestOp op("QueueDequeueUpToV2"); + TF_ASSERT_OK(NodeDefBuilder("test", op.name) + .Input("handle", 0, DT_RESOURCE) + .Input("n", 0, DT_INT32) + .Attr("component_types", {DT_FLOAT, DT_INT32}) + .Finalize(&op.node_def)); + + Tensor n_tensor = test::AsScalar(12); + if (pass == 1) { + // Second pass, pass value of as a constant. + op.input_tensors.push_back(nullptr); + op.input_tensors.push_back(&n_tensor); + } + + INFER_OK(op, "?;?", "?;?"); + std::vector shapes_and_types; + op.input_resource_handle_shapes_and_types.push_back(&shapes_and_types); + op.input_resource_handle_shapes_and_types.push_back(nullptr); + // Wrong number of shapes provided by handle. + shapes_and_types.emplace_back("[1,?,3]", DT_FLOAT); + INFER_OK(op, "?;?", "?;?"); + // Correct number of shapes provided by handle. + shapes_and_types.emplace_back("[?,2]", DT_FLOAT); + INFER_OK(op, "?;?", "[?,1,?,3];[?,?,2]"); + } +} + } // end namespace tensorflow diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 37d4379d48d..57615c3fda6 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -3697,6 +3697,102 @@ op { summary: "Bitcasts a tensor from one type to another without copying data." description: "Given a tensor `input`, this operation returns a tensor that has the same buffer\ndata as `input` with datatype `type`.\n\nIf the input datatype `T` is larger than the output datatype `type` then the\nshape changes from [...] to [..., sizeof(`T`)/sizeof(`type`)].\n\nIf `T` is smaller than `type`, the operator requires that the rightmost\ndimension be equal to sizeof(`type`)/sizeof(`T`). The shape then goes from\n[..., sizeof(`type`)/sizeof(`T`)] to [...].\n\n*NOTE*: Bitcast is implemented as a low-level cast, so machines with different\nendian orderings will give different results." } +op { + name: "BitwiseAnd" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_UINT16 + } + } + } + summary: "Elementwise computes the bitwise AND of `x` and `y`." + description: "The result will have those bits set, that are set in both `x` and `y`. The\ncomputation is performed on the underlying representations of `x` and `y`." + is_commutative: true +} +op { + name: "BitwiseOr" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_UINT16 + } + } + } + summary: "Elementwise computes the bitwise OR of `x` and `y`." + description: "The result will have those bits set, that are set in `x`, `y` or both. The\ncomputation is performed on the underlying representations of `x` and `y`." + is_commutative: true +} +op { + name: "BitwiseXor" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_UINT16 + } + } + } + summary: "Elementwise computes the bitwise XOR of `x` and `y`." + description: "The result will have those bits set, that are different in `x` and `y`. The\ncomputation is performed on the underlying representations of `x` and `y`." + is_commutative: true +} op { name: "BroadcastArgs" input_arg { @@ -9933,6 +10029,33 @@ op { explanation: "Use ReciprocalGrad" } } +op { + name: "Invert" + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_UINT16 + } + } + } + summary: "Flips all bits elementwise." + description: "The result will have exactly those bits set, that are not set in `x`. The\ncomputation is performed on the underlying representation of x." +} op { name: "InvertPermutation" input_arg { @@ -18591,6 +18714,36 @@ op { } summary: "Computes rectified linear gradients for a Relu operation." } +op { + name: "RemoteFusedGraphExecute" + input_arg { + name: "inputs" + description: "Arbitrary number of tensors with arbitrary data types" + type_list_attr: "Tinputs" + } + output_arg { + name: "outputs" + description: "Arbitrary number of tensors with arbitrary data types" + type_list_attr: "Toutputs" + } + attr { + name: "Tinputs" + type: "list(type)" + has_minimum: true + } + attr { + name: "Toutputs" + type: "list(type)" + has_minimum: true + } + attr { + name: "serialized_remote_fused_graph_execute_info" + type: "string" + description: "Serialized protocol buffer\nof RemoteFusedGraphExecuteInfo which contains graph specifications." + } + summary: "Execute a sub graph on a remote processor." + description: "The graph specifications(such as graph itself, input tensors and output names)\nare stored as a serialized protocol buffer of RemoteFusedGraphExecuteInfo\nas serialized_remote_fused_graph_execute_info.\nThe specifications will be passed to a dedicated registered\nremote fused graph executor. The executor will send the graph specifications\nto a remote processor and execute that graph. The execution results\nwill be passed to consumer nodes as outputs of this node." +} op { name: "RepeatDataset" input_arg { diff --git a/tensorflow/core/ops/remote_fused_graph_ops.cc b/tensorflow/core/ops/remote_fused_graph_ops.cc index 6e9f37a6152..85370e648c4 100644 --- a/tensorflow/core/ops/remote_fused_graph_ops.cc +++ b/tensorflow/core/ops/remote_fused_graph_ops.cc @@ -19,19 +19,40 @@ limitations under the License. namespace tensorflow { -// TODO(satok): Implement shape_inference +namespace { +using shape_inference::InferenceContext; + +Status RemoteFusedGraphExecuteShapeFn(InferenceContext* c) { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->UnknownShape()); + } + return Status::OK(); +} +} // namespace + REGISTER_OP("RemoteFusedGraphExecute") .Input("inputs: Tinputs") .Output("outputs: Toutputs") .Attr("Tinputs: list(type) >= 0") .Attr("Toutputs: list(type) >= 0") .Attr("serialized_remote_fused_graph_execute_info: string") - .SetShapeFn(shape_inference::UnknownShape) + .SetShapeFn(RemoteFusedGraphExecuteShapeFn) .Doc(R"doc( -Execute a sub graph on a remote processor transferred by GraphTransferer. -The graph specifications are serialized by protobuf as graph_transfer_info. -The implementation / limitations may differ for each platform -and each available peripheral. +Execute a sub graph on a remote processor. + +The graph specifications(such as graph itself, input tensors and output names) +are stored as a serialized protocol buffer of RemoteFusedGraphExecuteInfo +as serialized_remote_fused_graph_execute_info. +The specifications will be passed to a dedicated registered +remote fused graph executor. The executor will send the graph specifications +to a remote processor and execute that graph. The execution results +will be passed to consumer nodes as outputs of this node. + +inputs: Arbitrary number of tensors with arbitrary data types +outputs: Arbitrary number of tensors with arbitrary data types +serialized_remote_fused_graph_execute_info: Serialized protocol buffer +of RemoteFusedGraphExecuteInfo which contains graph specifications. + )doc"); } // namespace tensorflow diff --git a/tensorflow/core/platform/cpu_feature_guard.cc b/tensorflow/core/platform/cpu_feature_guard.cc index 1cfeb2580fa..4941bc12393 100644 --- a/tensorflow/core/platform/cpu_feature_guard.cc +++ b/tensorflow/core/platform/cpu_feature_guard.cc @@ -96,6 +96,14 @@ std::once_flag g_cpu_feature_guard_warn_once_flag; void WarnAboutUnusedCPUFeatures() { std::call_once(g_cpu_feature_guard_warn_once_flag, [] { +#ifdef PLATFORM_WINDOWS +#ifndef __AVX__ + WarnIfFeatureUnused(CPUFeature::AVX, "AVX"); +#endif // __AVX__ +#ifndef __AVX2__ + WarnIfFeatureUnused(CPUFeature::AVX2, "AVX2"); +#endif // __AVX2__ +#else // ifdef platform windows #ifndef __SSE__ WarnIfFeatureUnused(CPUFeature::SSE, "SSE"); #endif // __SSE__ @@ -123,6 +131,7 @@ void WarnAboutUnusedCPUFeatures() { #ifndef __FMA__ WarnIfFeatureUnused(CPUFeature::FMA, "FMA"); #endif // __FMA__ +#endif // else of ifdef platform windows }); } diff --git a/tensorflow/core/platform/hexagon/soc_interface.h b/tensorflow/core/platform/hexagon/soc_interface.h index f4a3cdf4bda..ca37b63e2bc 100644 --- a/tensorflow/core/platform/hexagon/soc_interface.h +++ b/tensorflow/core/platform/hexagon/soc_interface.h @@ -22,6 +22,8 @@ limitations under the License. // naming conflicts. #ifdef __cplusplus extern "C" { +#else +#include #endif // __cplusplus // Returns the version of loaded hexagon wrapper shared library. // You should assert that the version matches the expected version before @@ -39,13 +41,30 @@ bool soc_interface_Finalize(); bool soc_interface_ExecuteGraph(); // Teardown graph setup bool soc_interface_TeardownGraph(); + +// Allocate buffers for input node and output node +bool soc_interface_AllocateInOutNodeBuffers(int input_count, int* input_sizes, + int output_count, + int* output_sizes); + +// Send input data to SOC with port +bool soc_interface_FillInputNodeWithPort(int port, int x, int y, int z, int d, + const uint8_t* const buf, + uint64_t buf_byte_size); + // Send input data to SOC bool soc_interface_FillInputNodeFloat(int x, int y, int z, int d, const uint8_t* const buf, - uint64_t buf_size); + uint64_t buf_byte_size); + +// Load output data from SOC with port +bool soc_interface_ReadOutputNodeWithPort(int port, uint8_t** buf, + uint64_t* buf_byte_size); + // Load output data from SOC bool soc_interface_ReadOutputNodeFloat(const char* const node_name, - uint8_t** buf, uint64_t* buf_size); + uint8_t** buf, uint64_t* buf_byte_size); + // Setup graph // TODO(satok): Remove and use runtime version bool soc_interface_setupDummyGraph(int version); diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index 630f47633f8..c57b110847d 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -160,6 +160,8 @@ message GraphOptions { int32 timeline_step = 8; // Options that control the type and amount of graph rewriting. + // Not currently configurable via the public Python API (i.e. there is no API + // stability guarantee if you import RewriterConfig explicitly). RewriterConfig rewrite_options = 10; }; diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index 4c61e577d21..5480bdaad8a 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -12,24 +12,43 @@ message AutoParallelOptions { } message RewriterConfig { - // Graph rewriting is experimental and subject to change, not subject to any - // API guarantees. + // Graph rewriting is experimental and subject to change, not covered by any + // API stability guarantees. + + // Configuration options for the meta-optimizer. Unless otherwise noted, these + // configuration options do not apply to explicitly triggered optimization + // passes in the optimizers field. bool optimize_tensor_layout = 1; bool disable_model_pruning = 2; bool constant_folding = 3; enum MemOptType { - // Fully disabled + // Disabled in the meta-optimizer. NO_MEM_OPT = 0; - // Driven by manual annotations + // Driven by manual op-level annotations. MANUAL = 1; + // Driven by heuristics. The behavior of these heuristics is subject to + // change. Currently includes an experimental recomputation heuristic. + HEURISTICS = 2; } + // Configures memory optimization passes through the meta-optimizer. Has no + // effect on manually requested memory optimization passes in the optimizers + // field. MemOptType memory_optimization = 4; + // Configures AutoParallel optimization passes either through the + // meta-optimizer or when manually specified through the optimizers field. AutoParallelOptions auto_parallel = 5; // If non-empty, will use this as an alternative way to specify a list of - // optimizations to turn on and the order of the optimizations. + // optimizations to turn on and the order of the optimizations (replacing the + // meta-optimizer). + // + // Of the RewriterConfig options, only the AutoParallel configuration options + // (the auto_parallel field) apply to manually requested optimization passes + // ("autoparallel"). Memory optimization passes ("memory") invoked here are + // not configurable (in contrast to memory optimization passes through the + // meta-optimizer) and act only on manual op annotations. repeated string optimizers = 100; } diff --git a/tensorflow/core/util/ctc/BUILD b/tensorflow/core/util/ctc/BUILD index 357b2535515..c955b280146 100644 --- a/tensorflow/core/util/ctc/BUILD +++ b/tensorflow/core/util/ctc/BUILD @@ -102,5 +102,4 @@ cc_library( hdrs = [ "ctc_loss_util.h", ], - deps = ["//tensorflow/core:lib"], ) diff --git a/tensorflow/core/util/tensor_bundle/BUILD b/tensorflow/core/util/tensor_bundle/BUILD index bae00f74003..ca92713a621 100644 --- a/tensorflow/core/util/tensor_bundle/BUILD +++ b/tensorflow/core/util/tensor_bundle/BUILD @@ -39,7 +39,6 @@ cc_library( ":naming", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", - "//tensorflow/core:framework_headers_lib", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", diff --git a/tensorflow/docs_src/extend/estimators.md b/tensorflow/docs_src/extend/estimators.md index 6bd21be0193..fbde4634151 100644 --- a/tensorflow/docs_src/extend/estimators.md +++ b/tensorflow/docs_src/extend/estimators.md @@ -1,52 +1,55 @@ -# Creating Estimators in tf.contrib.learn +# Creating Estimators in tf.estimator -The tf.contrib.learn framework makes it easy to construct and train machine -learning models via its high-level -@{$python/contrib.learn#estimators$Estimator} API. `Estimator` +The tf.estimator framework makes it easy to construct and train machine +learning models via its high-level Estimator API. `Estimator` offers classes you can instantiate to quickly configure common model types such as regressors and classifiers: -* @{tf.contrib.learn.LinearClassifier}: +* @{tf.estimator.LinearClassifier}: Constructs a linear classification model. -* @{tf.contrib.learn.LinearRegressor}: +* @{tf.estimator.LinearRegressor}: Constructs a linear regression model. -* @{tf.contrib.learn.DNNClassifier}: +* @{tf.estimator.DNNClassifier}: Construct a neural network classification model. -* @{tf.contrib.learn.DNNRegressor}: - Construct a neural network regressions model. +* @{tf.estimator.DNNRegressor}: + Construct a neural network regression model. +* @{tf.estimator.DNNLinearCombinedClassifier}: + Construct a neural network and linear combined classification model. +* @{tf.estimator.DNNRegressor}: + Construct a neural network and linear combined regression model. -But what if none of `tf.contrib.learn`'s predefined model types meets your -needs? Perhaps you need more granular control over model configuration, such as +But what if none of `tf.estimator`'s predefined model types meets your needs? +Perhaps you need more granular control over model configuration, such as the ability to customize the loss function used for optimization, or specify different activation functions for each neural network layer. Or maybe you're implementing a ranking or recommendation system, and neither a classifier nor a regressor is appropriate for generating predictions. This tutorial covers how to create your own `Estimator` using the building -blocks provided in `tf.contrib.learn`, which will predict the ages of +blocks provided in `tf.estimator`, which will predict the ages of [abalones](https://en.wikipedia.org/wiki/Abalone) based on their physical measurements. You'll learn how to do the following: * Instantiate an `Estimator` * Construct a custom model function -* Configure a neural network using `tf.contrib.layers` +* Configure a neural network using `tf.feature_column` and `tf.layers` * Choose an appropriate loss function from `tf.losses` * Define a training op for your model * Generate and return predictions ## Prerequisites -This tutorial assumes you already know tf.contrib.learn API basics, such as -feature columns, input functions, and `fit()`/`evaluate()`/`predict()` -operations. If you've never used tf.contrib.learn before, or need a refresher, +This tutorial assumes you already know tf.estimator API basics, such as +feature columns, input functions, and `train()`/`evaluate()`/`predict()` +operations. If you've never used tf.estimator before, or need a refresher, you should first review the following tutorials: -* @{$tflearn$tf.contrib.learn Quickstart}: Quick introduction to - training a neural network using tf.contrib.learn. +* @{$estimator$tf.estimator Quickstart}: Quick introduction to + training a neural network using tf.estimator. * @{$wide$TensorFlow Linear Model Tutorial}: Introduction to feature columns, and an overview on building a linear classifier in - tf.contrib.learn. -* @{$input_fn$Building Input Functions with tf.contrib.learn}: Overview of how + tf.estimator. +* @{$input_fn$Building Input Functions with tf.estimator}: Overview of how to construct an input_fn to preprocess and feed data into your models. ## An Abalone Age Predictor {#abalone-predictor} @@ -113,7 +116,6 @@ from six.moves import urllib import numpy as np import tensorflow as tf -from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib FLAGS = None ``` @@ -207,17 +209,17 @@ if __name__ == "__main__": ## Instantiating an Estimator -When defining a model using one of tf.contrib.learn's provided classes, such as +When defining a model using one of tf.estimator's provided classes, such as `DNNClassifier`, you supply all the configuration parameters right in the constructor, e.g.: ```python -my_nn = tf.contrib.learn.DNNClassifier(feature_columns=[age, height, weight], - hidden_units=[10, 10, 10], - activation_fn=tf.nn.relu, - dropout=0.2, - n_classes=3, - optimizer="Adam") +my_nn = tf.estimator.DNNClassifier(feature_columns=[age, height, weight], + hidden_units=[10, 10, 10], + activation_fn=tf.nn.relu, + dropout=0.2, + n_classes=3, + optimizer="Adam") ``` You don't need to write any further code to instruct TensorFlow how to train the @@ -229,8 +231,7 @@ constructor accepts just two high-level parameters for model configuration, `model_fn` and `params`: ```python -nn = tf.contrib.learn.Estimator( - model_fn=model_fn, params=model_params) +nn = tf.estimator.Estimator(model_fn=model_fn, params=model_params) ``` * `model_fn`: A function object that contains all the aforementioned logic to @@ -242,7 +243,7 @@ nn = tf.contrib.learn.Estimator( * `params`: An optional dict of hyperparameters (e.g., learning rate, dropout) that will be passed into the `model_fn`. -Note: Just like `tf.contrib.learn`'s predefined regressors and classifiers, the +Note: Just like `tf.estimator`'s predefined regressors and classifiers, the `Estimator` initializer also accepts the general configuration arguments `model_dir` and `config`. @@ -266,7 +267,7 @@ containing the learning rate and instantiates the `Estimator`: model_params = {"learning_rate": LEARNING_RATE} # Instantiate Estimator -nn = tf.contrib.learn.Estimator(model_fn=model_fn, params=model_params) +nn = tf.estimator.Estimator(model_fn=model_fn, params=model_params) ``` ## Constructing the `model_fn` {#constructing-modelfn} @@ -274,31 +275,31 @@ nn = tf.contrib.learn.Estimator(model_fn=model_fn, params=model_params) The basic skeleton for an `Estimator` API model function looks like this: ```python -def model_fn(features, targets, mode, params): +def model_fn(features, labels, mode, params): # Logic to do the following: # 1. Configure the model via TensorFlow operations # 2. Define the loss function for training/evaluation # 3. Define the training operation/optimizer # 4. Generate predictions - # 5. Return predictions/loss/train_op/eval_metric_ops in ModelFnOps object - return ModelFnOps(mode, predictions, loss, train_op, eval_metric_ops) + # 5. Return predictions/loss/train_op/eval_metric_ops in EstimatorSpec object + return EstimatorSpec(mode, predictions, loss, train_op, eval_metric_ops) ``` The `model_fn` must accept three arguments: -* `features`: A dict containing the features passed to the model via `fit()`, - `evaluate()`, or `predict()`. -* `targets`: A `Tensor` containing the labels passed to the model via `fit()`, - `evaluate()`, or `predict()`. Will be empty for `predict()` calls, as these - are the values the model will infer. -* `mode`: One of the following @{tf.contrib.learn.ModeKeys} string values +* `features`: A dict containing the features passed to the model via + `input_fn`. +* `labels`: A `Tensor` containing the labels passed to the model via + `input_fn`. Will be empty for `predict()` calls, as these are the values the + model will infer. +* `mode`: One of the following @{tf.estimator.ModeKeys} string values indicating the context in which the model_fn was invoked: - * `tf.contrib.learn.ModeKeys.TRAIN` The `model_fn` was invoked in training - mode—e.g., via a `fit()` call. - * `tf.contrib.learn.ModeKeys.EVAL`. The `model_fn` was invoked in - evaluation mode—e.g., via an `evaluate()` call. - * `tf.contrib.learn.ModeKeys.INFER`. The `model_fn` was invoked in - inference mode—e.g., via a `predict()` call. + * `tf.estimator.ModeKeys.TRAIN` The `model_fn` was invoked in training + mode, namely via a `train()` call. + * `tf.estimator.ModeKeys.EVAL`. The `model_fn` was invoked in + evaluation mode, namely via an `evaluate()` call. + * `tf.estimator.ModeKeys.PREDICT`. The `model_fn` was invoked in + predict mode, namely via a `predict()` call. `model_fn` may also accept a `params` argument containing a dict of hyperparameters used for training (as shown in the skeleton above). @@ -313,28 +314,23 @@ sections that follow): * Defining the training operation that specifies the `optimizer` algorithm to minimize the loss values calculated by the loss function. -The `model_fn` must return a @{tf.contrib.learn.ModelFnOps} +The `model_fn` must return a @{tf.estimator.EstimatorSpec} object, which contains the following values: * `mode` (required). The mode in which the model was run. Typically, you will return the `mode` argument of the `model_fn` here. -* `predictions` (required in `INFER` and `EVAL` modes). A dict that maps key - names of your choice to `Tensor`s containing the predictions from the model, - e.g.: +* `predictions` (required in `PREDICT` mode). A dict that maps key names of + your choice to `Tensor`s containing the predictions from the model, e.g.: ```python predictions = {"results": tensor_of_predictions} ``` - In `INFER` mode, the dict that you return in `ModelFnOps` will then be + In `PREDICT` mode, the dict that you return in `EstimatorSpec` will then be returned by `predict()`, so you can construct it in the format in which you'd like to consume it. - In `EVAL` mode, the dict is used by - @{$python/contrib.metrics#Metric_Ops_$metric functions} - to compute metrics. - * `loss` (required in `EVAL` and `TRAIN` mode). A `Tensor` containing a scalar loss value: the output of the model's loss function (discussed in more depth @@ -362,7 +358,7 @@ object, which contains the following values: If you do not specify `eval_metric_ops`, only `loss` will be calculated during evaluation. -### Configuring a neural network with `tf.contrib.layers` +### Configuring a neural network with `tf.feature_column` and `tf.layers` Constructing a [neural network](https://en.wikipedia.org/wiki/Artificial_neural_network) entails @@ -372,23 +368,21 @@ layer. The input layer is a series of nodes (one for each feature in the model) that will accept the feature data that is passed to the `model_fn` in the `features` argument. If `features` contains an n-dimensional `Tensor` with all your feature -data (which is the case if `x` and `y` `Dataset`s are passed to `fit()`, -`evaluate()`, and `predict()` directly), then it can serve as the input layer. +data, then it can serve as the input layer. If `features` contains a dict of @{$linear#feature-columns-and-transformations$feature columns} passed to the model via an input function, you can convert it to an input-layer `Tensor` -with the @{tf.contrib.layers.input_from_feature_columns} function in -@{tf.contrib.layers}. +with the @{tf.feature_column.input_layer} function. ```python -input_layer = tf.contrib.layers.input_from_feature_columns( - columns_to_tensors=features, feature_columns=[age, height, weight]) +input_layer = tf.feature_column.input_layer( + features=features, feature_columns=[age, height, weight]) ``` -As shown above, `input_from_feature_columns()` takes two required arguments: +As shown above, `input_layer()` takes two required arguments: -* `columns_to_tensors`. A mapping of the model's `FeatureColumns` to the - `Tensors` containing the corresponding feature data. This is exactly what is - passed to the `model_fn` in the `features` argument. +* `features`. A mapping from string keys to the `Tensors` containing the + corresponding feature data. This is exactly what is passed to the `model_fn` + in the `features` argument. * `feature_columns`. A list of all the `FeatureColumns` in the model—`age`, `height`, and `weight` in the above example. @@ -397,44 +391,44 @@ hidden layers via an [activation function](https://en.wikipedia.org/wiki/Activation_function) that performs a nonlinear transformation on the data from the previous layer. The last hidden layer is then connected to the output layer, the final layer in the model. -tf.contrib.layers provides the following convenience functions for constructing -fully connected layers: +`tf.layers` provides the `tf.layers.dense` function for constructing fully +connected layers. The activation is controlled by the `activation` argument. +Some options to pass to the `activation` argument are: -* `relu(inputs, num_outputs)`. Create a layer of `num_outputs` nodes fully - connected to the previous layer `inputs` with a [ReLU activation - function](https://en.wikipedia.org/wiki/Rectifier_\(neural_networks\)) +* `tf.nn.relu`. The following code creates a layer of `units` nodes fully + connected to the previous layer `input_layer` with a + [ReLU activation function](https://en.wikipedia.org/wiki/Rectifier_\(neural_networks\)) (@{tf.nn.relu}): ```python - hidden_layer = tf.contrib.layers.relu(inputs=input_layer, num_outputs=10) + hidden_layer = tf.layers.dense( + inputs=input_layer, units=10, activation=tf.nn.relu) ``` -* `relu6(inputs, num_outputs)`. Create a layer of `num_outputs` nodes fully +* `tf.nn.relu6`. The following code creates a layer of `units` nodes fully connected to the previous layer `hidden_layer` with a ReLU 6 activation function (@{tf.nn.relu6}): ```python - second_hidden_layer = tf.contrib.layers.relu6(inputs=hidden_layer, num_outputs=20) + second_hidden_layer = tf.layers.dense( + inputs=hidden_layer, units=20, activation=tf.nn.relu) ``` -* `linear(inputs, num_outputs)`. Create a layer of `num_outputs` nodes fully - connected to the previous layer `second_hidden_layer` with *no* activation - function, just a linear transformation: +* `None`. The following code creates a layer of `units` nodes fully connected + to the previous layer `second_hidden_layer` with *no* activation function, + just a linear transformation: ```python - output_layer = tf.contrib.layers.linear(inputs=second_hidden_layer, num_outputs=3) + output_layer = tf.layers.dense( + inputs=second_hidden_layer, units=3, activation=None) ``` -All these functions are -[partials](https://docs.python.org/2/library/functools.html#functools.partial) -of the more general @{tf.contrib.layers.fully_connected} -function, which can be used to add fully connected layers with other activation -functions, e.g.: +Other activation functions are possible, e.g.: ```python -output_layer = tf.contrib.layers.fully_connected(inputs=second_hidden_layer, - num_outputs=10, - activation_fn=tf.sigmoid) +output_layer = tf.layers.dense(inputs=second_hidden_layer, + units=10, + activation_fn=tf.sigmoid) ``` The above code creates the neural network layer `output_layer`, which is fully @@ -446,18 +440,19 @@ Putting it all together, the following code constructs a full neural network for the abalone predictor, and captures its predictions: ```python -def model_fn(features, targets, mode, params): +def model_fn(features, labels, mode, params): """Model function for Estimator.""" # Connect the first hidden layer to input layer - # (features) with relu activation - first_hidden_layer = tf.contrib.layers.relu(features, 10) + # (features["x"]) with relu activation + first_hidden_layer = tf.layers.dense(features["x"], 10, activation=tf.nn.relu) # Connect the second hidden layer to first hidden layer with relu - second_hidden_layer = tf.contrib.layers.relu(first_hidden_layer, 10) + second_hidden_layer = tf.layers.dense( + first_hidden_layer, 10, activation=tf.nn.relu) # Connect the output layer to second hidden layer (no activation fn) - output_layer = tf.contrib.layers.linear(second_hidden_layer, 1) + output_layer = tf.layers.dense(second_hidden_layer, 1) # Reshape output layer to 1-dim Tensor to return predictions predictions = tf.reshape(output_layer, [-1]) @@ -465,9 +460,9 @@ def model_fn(features, targets, mode, params): ... ``` -Here, because you'll be passing the abalone `Datasets` directly to `fit()`, -`evaluate()`, and `predict()` via `x` and `y` arguments, the input layer is the -`features` `Tensor` passed to the `model_fn`. The network contains two hidden +Here, because you'll be passing the abalone `Datasets` using `numpy_input_fn` +as shown below, `features` is a dict `{"x": data_tensor}`, so +`features["x"]` is the input layer. The network contains two hidden layers, each with 10 nodes and a ReLU activation function. The output layer contains no activation function, and is @{tf.reshape} to a one-dimensional @@ -476,47 +471,49 @@ tensor to capture the model's predictions, which are stored in ### Defining loss for the model {#defining-loss} -The `ModelFnOps` returned by the `model_fn` must contain `loss`: a `Tensor` +The `EstimatorSpec` returned by the `model_fn` must contain `loss`: a `Tensor` representing the loss value, which quantifies how well the model's predictions -reflect the target values during training and evaluation runs. The @{tf.losses} +reflect the label values during training and evaluation runs. The @{tf.losses} module provides convenience functions for calculating loss using a variety of metrics, including: -* `absolute_difference(predictions, targets)`. Calculates loss using the +* `absolute_difference(labels, predictions)`. Calculates loss using the [absolute-difference formula](https://en.wikipedia.org/wiki/Deviation_\(statistics\)#Unsigned_or_absolute_deviation) (also known as L1 loss). -* `log_loss(predictions, targets)`. Calculates loss using the [logistic loss +* `log_loss(labels, predictions)`. Calculates loss using the [logistic loss forumula](https://en.wikipedia.org/wiki/Loss_functions_for_classification#Logistic_loss) (typically used in logistic regression). -* `mean_squared_error(predictions, targets)`. Calculates loss using the [mean +* `mean_squared_error(labels, predictions)`. Calculates loss using the [mean squared error](https://en.wikipedia.org/wiki/Mean_squared_error) (MSE; also known as L2 loss). The following example adds a definition for `loss` to the abalone `model_fn` using `mean_squared_error()` (in bold): -
def model_fn(features, targets, mode, params):
+
def model_fn(features, labels, mode, params):
   """Model function for Estimator."""
 
   # Connect the first hidden layer to input layer
-  # (features) with relu activation
-  first_hidden_layer = tf.contrib.layers.relu(features, 10)
+  # (features["x"]) with relu activation
+  first_hidden_layer = tf.layers.dense(features["x"], 10, activation=tf.nn.relu)
 
   # Connect the second hidden layer to first hidden layer with relu
-  second_hidden_layer = tf.contrib.layers.relu(first_hidden_layer, 10)
+  second_hidden_layer = tf.layers.dense(
+      first_hidden_layer, 10, activation=tf.nn.relu)
 
   # Connect the output layer to second hidden layer (no activation fn)
-  output_layer = tf.contrib.layers.linear(second_hidden_layer, 1)
+  output_layer = tf.layers.dense(second_hidden_layer, 1)
 
   # Reshape output layer to 1-dim Tensor to return predictions
   predictions = tf.reshape(output_layer, [-1])
   predictions_dict = {"ages": predictions}
 
+
   # Calculate loss using mean squared error
-  loss = tf.losses.mean_squared_error(targets, predictions)
+  loss = tf.losses.mean_squared_error(labels, predictions)
   ...
See the @{$python/contrib.losses$API guide} for a @@ -524,14 +521,14 @@ full list of loss functions and more details on supported arguments and usage. Supplementary metrics for evaluation can be added to an `eval_metric_ops` dict. The following code defines an `rmse` metric, which calculates the root mean -squared error for the model predictions. Note that the `targets` tensor is cast +squared error for the model predictions. Note that the `labels` tensor is cast to a `float64` type to match the data type of the `predictions` tensor, which will contain real values: ```python eval_metric_ops = { "rmse": tf.metrics.root_mean_squared_error( - tf.cast(targets, tf.float64), predictions) + tf.cast(labels, tf.float64), predictions) } ``` @@ -539,65 +536,25 @@ eval_metric_ops = { The training op defines the optimization algorithm TensorFlow will use when fitting the model to the training data. Typically when training, the goal is to -minimize loss. The tf.contrib.layers API provides the function `optimize_loss`, -which returns a training op that will do just that. `optimize_loss` has four -required arguments: - -* `loss`. The loss value calculated by the `model_fn` (see [Defining Loss for - the Model](#defining-loss)). -* `global_step`. An integer - @{tf.Variable} representing the - step counter to increment for each model training run. Can easily be - created/incremented in TensorFlow via the - @{tf.train.get_global_step} - function. -* `learning_rate`. The [learning - rate](https://en.wikipedia.org/wiki/Stochastic_gradient_descent#Background) - (also known as _step size_) hyperparameter that the optimization algorithm - uses when training. -* `optimizer`. The optimization algorithm to use during training. `optimizer` - can accept any of the following string values, representing an optimization - algorithm predefined in `tf.contrib.layers.optimizers`: - * `SGD`. Implementation of [gradient - descent](https://en.wikipedia.org/wiki/Gradient_descent) - (@{tf.train.GradientDescentOptimizer}) - * `Adagrad`. Implementation of the [AdaGrad optimization - algorithm](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) - (@{tf.train.AdagradOptimizer}) - * `Adam`. Implementation of the [Adam optimization - algorithm](http://arxiv.org/pdf/1412.6980.pdf) - (@{tf.train.AdamOptimizer}) - * `Ftrl`. Implementation of the - [FTRL-Proximal](https://www.eecs.tufts.edu/~dsculley/papers/ad-click-prediction.pdf) - ("Follow The (Proximally) Regularized Leader") algorithm - (@{tf.train.FtrlOptimizer}) - * `Momentum`. Implementation of stochastic gradient descent with - [momentum](https://en.wikipedia.org/wiki/Stochastic_gradient_descent#Momentum) - (@{tf.train.MomentumOptimizer}) - * `RMSProp`. Implementation of the - [RMSprop](http://sebastianruder.com/optimizing-gradient-descent/index.html#rmsprop) - algorithm - (@{tf.train.RMSPropOptimizer}) - -Note: The `optimize_loss` function supports additional optional arguments to -further configure the optimizer, such as for implementing decay. See the -@{tf.contrib.layers.optimize_loss$API docs} for more info. +minimize loss. A simple way to create the training op is to instantiate a +`tf.train.Optimizer` subclass and call the `minimize` method. The following code defines a training op for the abalone `model_fn` using the loss value calculated in [Defining Loss for the Model](#defining-loss), the -learning rate passed to the function in `params`, and the SGD optimizer. For -`global_step`, the convenience function -@{tf.train.get_global_step} -in tf.contrib.framework takes care of generating an integer variable: +learning rate passed to the function in `params`, and the gradient descent +optimizer. For `global_step`, the convenience function +@{tf.train.get_global_step} takes care of generating an integer variable: ```python -train_op = tf.contrib.layers.optimize_loss( - loss=loss, - global_step=tf.contrib.framework.get_global_step(), - learning_rate=params["learning_rate"], - optimizer="SGD") +optimizer = tf.train.GradientDescentOptimizer( + learning_rate=params["learning_rate"]) +train_op = optimizer.minimize( + loss=loss, global_step=tf.train.get_global_step()) ``` +For a full list of optimizers, and other details, see the +@{$python/train#optimizers$API guide}. + ### The complete abalone `model_fn` Here's the final, complete `model_fn` for the abalone age predictor. The @@ -606,40 +563,39 @@ and returns a `ModelFnOps` object containing `mode`, `predictions_dict`, `loss`, and `train_op`: ```python -def model_fn(features, targets, mode, params): +def model_fn(features, labels, mode, params): """Model function for Estimator.""" # Connect the first hidden layer to input layer - # (features) with relu activation - first_hidden_layer = tf.contrib.layers.relu(features, 10) + # (features["x"]) with relu activation + first_hidden_layer = tf.layers.dense(features["x"], 10, activation=tf.nn.relu) # Connect the second hidden layer to first hidden layer with relu - second_hidden_layer = tf.contrib.layers.relu(first_hidden_layer, 10) + second_hidden_layer = tf.layers.dense( + first_hidden_layer, 10, activation=tf.nn.relu) # Connect the output layer to second hidden layer (no activation fn) - output_layer = tf.contrib.layers.linear(second_hidden_layer, 1) + output_layer = tf.layers.dense(second_hidden_layer, 1) # Reshape output layer to 1-dim Tensor to return predictions predictions = tf.reshape(output_layer, [-1]) predictions_dict = {"ages": predictions} # Calculate loss using mean squared error - loss = tf.losses.mean_squared_error(targets, predictions) + loss = tf.losses.mean_squared_error(labels, predictions) # Calculate root mean squared error as additional eval metric eval_metric_ops = { - "rmse": - tf.metrics.root_mean_squared_error( - tf.cast(targets, tf.float64), predictions) + "rmse": tf.metrics.root_mean_squared_error( + tf.cast(labels, tf.float64), predictions) } - train_op = tf.contrib.layers.optimize_loss( - loss=loss, - global_step=tf.contrib.framework.get_global_step(), - learning_rate=params["learning_rate"], - optimizer="SGD") + optimizer = tf.train.GradientDescentOptimizer( + learning_rate=params["learning_rate"]) + train_op = optimizer.minimize( + loss=loss, global_step=tf.train.get_global_step()) - return model_fn_lib.ModelFnOps( + return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions_dict, loss=loss, @@ -657,29 +613,31 @@ Add the following code to the end of `main()` to fit the neural network to the training data and evaluate accuracy: ```python -def get_train_inputs(): - x = tf.constant(training_set.data) - y = tf.constant(training_set.target) - return x, y +train_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": np.array(training_set.data)}, + y=np.array(training_set.target), + num_epochs=None, + shuffle=True) -# Fit -nn.fit(input_fn=get_train_inputs, steps=5000) - -def get_test_inputs(): - x = tf.constant(test_set.data) - y = tf.constant(test_set.target) - return x, y +# Train +nn.train(input_fn=train_input_fn, steps=5000) # Score accuracy -ev = nn.evaluate(input_fn=get_test_inputs, steps=1) +test_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": np.array(test_set.data)}, + y=np.array(test_set.target), + num_epochs=1, + shuffle=False) + +ev = nn.evaluate(input_fn=test_input_fn) print("Loss: %s" % ev["loss"]) print("Root Mean Squared Error: %s" % ev["rmse"]) ``` Note: The above code uses input functions to feed feature (`x`) and label (`y`) -`Tensor`s into the model for both training (`get_train_inputs()`) and evaluation -(`get_test_inputs()`). To learn more about input functions, see the tutorial -@{$input_fn$Building Input Functions with tf.contrib.learn}. +`Tensor`s into the model for both training (`train_input_fn`) and evaluation +(`test_input_fn`). To learn more about input functions, see the tutorial +@{$input_fn$Building Input Functions with tf.estimator}. Then run the code. You should see output like the following: @@ -701,7 +659,11 @@ To predict ages for the `ABALONE_PREDICT` data set, add the following to ```python # Print out predictions -predictions = nn.predict(x=prediction_set.data, as_iterable=True) +predict_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": prediction_set.data}, + num_epochs=1, + shuffle=False) +predictions = nn.predict(input_fn=predict_input_fn) for i, p in enumerate(predictions): print("Prediction %s: %s" % (i + 1, p["ages"])) ``` @@ -723,11 +685,10 @@ Prediction 7: 11.1289 ## Additional Resources -Congrats! You've successfully built a tf.contrib.learn `Estimator` from scratch. +Congrats! You've successfully built a tf.estimator `Estimator` from scratch. For additional reference materials on building `Estimator`s, see the following sections of the API guides: -* @{$python/contrib.learn#Estimators$Estimators} * @{$python/contrib.layers$Layers} * @{$python/contrib.losses$Losses} * @{$python/contrib.layers#optimization$Optimization} diff --git a/tensorflow/docs_src/get_started/tflearn.md b/tensorflow/docs_src/get_started/estimator.md similarity index 70% rename from tensorflow/docs_src/get_started/tflearn.md rename to tensorflow/docs_src/get_started/estimator.md index 002118073ce..a2c2a1ece5d 100644 --- a/tensorflow/docs_src/get_started/tflearn.md +++ b/tensorflow/docs_src/get_started/estimator.md @@ -1,8 +1,8 @@ -# tf.contrib.learn Quickstart +# tf.estimator Quickstart -TensorFlow’s high-level machine learning API (tf.contrib.learn) makes it easy to +TensorFlow’s high-level machine learning API (tf.estimator) makes it easy to configure, train, and evaluate a variety of machine learning models. In this -tutorial, you’ll use tf.contrib.learn to construct a +tutorial, you’ll use tf.estimator to construct a [neural network](https://en.wikipedia.org/wiki/Artificial_neural_network) classifier and train it on the [Iris data set](https://en.wikipedia.org/wiki/Iris_flower_data_set) to @@ -10,8 +10,8 @@ predict flower species based on sepal/petal geometry. You'll write code to perform the following five steps: 1. Load CSVs containing Iris training/test data into a TensorFlow `Dataset` -2. Construct a @{tf.contrib.learn.DNNClassifier$neural network classifier} -3. Fit the model using the training data +2. Construct a @{tf.estimator.DNNClassifier$neural network classifier} +3. Train the model using the training data 4. Evaluate the accuracy of the model 5. Classify new samples @@ -64,47 +64,50 @@ def main(): features_dtype=np.float32) # Specify that all features have real-value data - feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] + feature_columns = [tf.feature_column.numeric_column("x", shape=[4])] # Build 3 layer DNN with 10, 20, 10 units respectively. - classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns, - hidden_units=[10, 20, 10], - n_classes=3, - model_dir="/tmp/iris_model") + classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns, + hidden_units=[10, 20, 10], + n_classes=3, + model_dir="/tmp/iris_model") # Define the training inputs - def get_train_inputs(): - x = tf.constant(training_set.data) - y = tf.constant(training_set.target) + train_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": np.array(training_set.data)}, + y=np.array(training_set.target), + num_epochs=None, + shuffle=True) - return x, y - - # Fit model. - classifier.fit(input_fn=get_train_inputs, steps=2000) + # Train model. + classifier.train(input_fn=train_input_fn, steps=2000) # Define the test inputs - def get_test_inputs(): - x = tf.constant(test_set.data) - y = tf.constant(test_set.target) - - return x, y + test_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": np.array(test_set.data)}, + y=np.array(test_set.target), + num_epochs=1, + shuffle=False) # Evaluate accuracy. - accuracy_score = classifier.evaluate(input_fn=get_test_inputs, - steps=1)["accuracy"] + accuracy_score = classifier.evaluate(input_fn=test_input_fn)["accuracy"] print("\nTest Accuracy: {0:f}\n".format(accuracy_score)) # Classify two new flower samples. - def new_samples(): - return tf.constant( + new_samples = np.array( [[6.4, 3.2, 4.5, 1.5], - [5.8, 3.1, 5.0, 1.7]], dtype=tf.float32) + [5.8, 3.1, 5.0, 1.7]], dtype=np.float32) + predict_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": new_samples}, + num_epochs=1, + shuffle=False) - predictions = list(classifier.predict(input_fn=new_samples)) + predictions = list(classifier.predict(input_fn=predict_input_fn)) + predicted_classes = [p["classes"] for p in predictions] print( "New Samples, Class Predictions: {}\n" - .format(predictions)) + .format(predicted_classes)) if __name__ == "__main__": main() @@ -237,31 +240,30 @@ you'll use `training_set.data` and ## Construct a Deep Neural Network Classifier -tf.contrib.learn offers a variety of predefined models, called -@{$python/contrib.learn#estimators$`Estimator`s}, which you can -use "out of the box" to run training and evaluation operations on your data. +tf.estimator offers a variety of predefined models, called `Estimator`s, which +you can use "out of the box" to run training and evaluation operations on your +data. Here, you'll configure a Deep Neural Network Classifier model to fit the Iris -data. Using tf.contrib.learn, you can instantiate your -@{tf.contrib.learn.DNNClassifier} with -just a couple lines of code: +data. Using tf.estimator, you can instantiate your +@{tf.estimator.DNNClassifier} with just a couple lines of code: ```python # Specify that all features have real-value data -feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] +feature_columns = [tf.feature_column.numeric_column("x", shape=[4])] # Build 3 layer DNN with 10, 20, 10 units respectively. -classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns, - hidden_units=[10, 20, 10], - n_classes=3, - model_dir="/tmp/iris_model") +classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns, + hidden_units=[10, 20, 10], + n_classes=3, + model_dir="/tmp/iris_model") ``` The code above first defines the model's feature columns, which specify the data type for the features in the data set. All the feature data is continuous, so -`tf.contrib.layers.real_valued_column` is the appropriate function to use to +`tf.feature_column.numeric_column` is the appropriate function to use to construct the feature columns. There are four features in the data set (sepal -width, sepal height, petal width, and petal height), so accordingly `dimension` -must be set to `4` to hold all the data. +width, sepal height, petal width, and petal height), so accordingly `shape` +must be set to `[4]` to hold all the data. Then, the code creates a `DNNClassifier` model using the following arguments: @@ -272,34 +274,34 @@ Then, the code creates a `DNNClassifier` model using the following arguments: * `n_classes=3`. Three target classes, representing the three Iris species. * `model_dir=/tmp/iris_model`. The directory in which TensorFlow will save checkpoint data during model training. For more on logging and monitoring - with TensorFlow, see @{$monitors$Logging and Monitoring Basics with tf.contrib.learn}. + with TensorFlow, see + @{$monitors$Logging and Monitoring Basics with tf.estimator}. ## Describe the training input pipeline {#train-input} -The `tf.contrib.learn` API uses input functions, which create the TensorFlow -operations that generate data for the model. In this case, the data is small -enough that it can be stored in @{tf.constant$TensorFlow constants}. The -following code produces the simplest possible input pipeline: +The `tf.estimator` API uses input functions, which create the TensorFlow +operations that generate data for the model. +We can use `tf.estimator.inputs.numpy_input_fn` to produce the input pipeline: ```python # Define the training inputs -def get_train_inputs(): - x = tf.constant(training_set.data) - y = tf.constant(training_set.target) - - return x, y +train_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": np.array(training_set.data)}, + y=np.array(training_set.target), + num_epochs=None, + shuffle=True) ``` ## Fit the DNNClassifier to the Iris Training Data {#fit-dnnclassifier} Now that you've configured your DNN `classifier` model, you can fit it to the -Iris training data using the @{tf.contrib.learn.BaseEstimator.fit$`fit`} method. -Pass `get_train_inputs` as the `input_fn`, and the number of steps to train +Iris training data using the @{tf.estimator.Estimator.train$`train`} method. +Pass `train_input_fn` as the `input_fn`, and the number of steps to train (here, 2000): ```python -# Fit model. -classifier.fit(input_fn=get_train_inputs, steps=2000) +# Train model. +classifier.train(input_fn=train_input_fn, steps=2000) ``` The state of the model is preserved in the `classifier`, which means you can @@ -307,46 +309,44 @@ train iteratively if you like. For example, the above is equivalent to the following: ```python -classifier.fit(x=training_set.data, y=training_set.target, steps=1000) -classifier.fit(x=training_set.data, y=training_set.target, steps=1000) +classifier.train(input_fn=train_input_fn, steps=1000) +classifier.train(input_fn=train_input_fn, steps=1000) ``` However, if you're looking to track the model while it trains, you'll likely -want to instead use a TensorFlow @{tf.contrib.learn.monitors$`monitor`} +want to instead use a TensorFlow @{tf.SessionRunHook$`SessionRunHook`} to perform logging operations. See the tutorial -@{$monitors$“Logging and Monitoring Basics with tf.contrib.learn”} +@{$monitors$Logging and Monitoring Basics with tf.estimator} for more on this topic. ## Evaluate Model Accuracy {#evaluate-accuracy} -You've fit your `DNNClassifier` model on the Iris training data; now, you can -check its accuracy on the Iris test data using the -@{tf.contrib.learn.BaseEstimator.evaluate$`evaluate`} method. Like `fit`, +You've trained your `DNNClassifier` model on the Iris training data; now, you +can check its accuracy on the Iris test data using the +@{tf.estimator.Estimator.evaluate$`evaluate`} method. Like `train`, `evaluate` takes an input function that builds its input pipeline. `evaluate` -returns a `dict` with the evaluation results. The following code passes the Iris -test data—`test_set.data` and `test_set.target`—to `evaluate` and -prints the `accuracy` from the results: +returns a `dict`s with the evaluation results. The following code passes the +Iris test data—`test_set.data` and `test_set.target`—to `evaluate` +and prints the `accuracy` from the results: ```python # Define the test inputs -def get_test_inputs(): - x = tf.constant(test_set.data) - y = tf.constant(test_set.target) - - return x, y +test_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": np.array(test_set.data)}, + y=np.array(test_set.target), + num_epochs=1, + shuffle=False) # Evaluate accuracy. -accuracy_score = classifier.evaluate(input_fn=get_test_inputs, - steps=1)["accuracy"] +accuracy_score = classifier.evaluate(input_fn=test_input_fn)["accuracy"] print("\nTest Accuracy: {0:f}\n".format(accuracy_score)) ``` -Note: The `steps` argument to `evaluate` is important here. -@{tf.contrib.learn.Evaluable.evaluate$`evaluate`} normally runs until it reaches -the end of the input. This is perfect for evaluating over a set of files, but -the constants being used here will never throw the `OutOfRangeError` or -`StopIteration` that it is expecting. +Note: The `num_epochs=1` argument to `numpy_input_fn` is important here. +`test_input_fn` will iterate over the data once, and then raise +`OutOfRangeError`. This error signals the classifier to stop evaluating, so it +will evaluate over the input once. When you run the full script, it will print something close to: @@ -368,21 +368,25 @@ Sepal Length | Sepal Width | Petal Length | Petal Width 5.8 | 3.1 | 5.0 | 1.7 You can predict their species using the `predict()` method. `predict` returns a -generator, which can easily be converted to a list. The following code retrieves -and prints the class predictions: +generator of dicts, which can easily be converted to a list. The following code +retrieves and prints the class predictions: ```python # Classify two new flower samples. -def new_samples(): - return np.array( +new_samples = np.array( [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=np.float32) +predict_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": new_samples}, + num_epochs=1, + shuffle=False) -predictions = list(classifier.predict(input_fn=new_samples)) +predictions = list(classifier.predict(input_fn=predict_input_fn)) +predicted_classes = [p["classes"] for p in predictions] print( "New Samples, Class Predictions: {}\n" - .format(predictions)) + .format(predicted_classes)) ``` Your results should look as follows: @@ -396,14 +400,11 @@ second sample is *Iris virginica*. ## Additional Resources -* For further reference materials on tf.contrib.learn, see the official - @{$python/contrib.learn$API docs}. - -* To learn more about using tf.contrib.learn to create linear models, see +* To learn more about using tf.estimator to create linear models, see @{$linear$Large-scale Linear Models with TensorFlow}. -* To build your own Estimator using tf.contrib.learn APIs, check out - @{$estimators$Creating Estimators in tf.contrib.learn}. +* To build your own Estimator using tf.estimator APIs, check out + @{$estimators$Creating Estimators in tf.estimator}. * To experiment with neural network modeling and visualization in the browser, check out [Deep Playground](http://playground.tensorflow.org/). diff --git a/tensorflow/docs_src/get_started/index.md b/tensorflow/docs_src/get_started/index.md index 241263c72c0..0d302ec3830 100644 --- a/tensorflow/docs_src/get_started/index.md +++ b/tensorflow/docs_src/get_started/index.md @@ -11,16 +11,16 @@ to training an MNIST model on TensorFlow: * @{$mnist/beginners$MNIST for ML Beginners}, which introduces MNIST through the high-level API. - * @{$mnist/pros$Deep MNIST for Experts}, which is more-in depth than - "MNIST for ML Beginners," and assumes some familiarity with machine + * @{$mnist/pros$Deep MNIST for Experts}, which is more-in depth than + "MNIST for ML Beginners," and assumes some familiarity with machine learning concepts. * @{$mnist/mechanics$TensorFlow Mechanics 101}, which introduces MNIST through the low-level API. -For developers new to TensorFlow, the high-level API is a good place to start. +For developers new to TensorFlow, the high-level API is a good place to start. To learn about the high-level API, read the following guides: - * @{$get_started/tflearn$tf.contrib.learn Quickstart}, which introduces this + * @{$get_started/estimator$tf.estimator Quickstart}, which introduces this API. * @{$get_started/input_fn$Building Input Functions with tf.contrib.learn}, which takes you into a somewhat more sophisticated use of this API. diff --git a/tensorflow/docs_src/get_started/input_fn.md b/tensorflow/docs_src/get_started/input_fn.md index a053617b589..2ef050d9ad0 100644 --- a/tensorflow/docs_src/get_started/input_fn.md +++ b/tensorflow/docs_src/get_started/input_fn.md @@ -1,6 +1,6 @@ # Building Input Functions with tf.contrib.learn -This tutorial introduces you to creating input functions in tf.contrib.learn. +This tutorial introduces you to creating input functions in tf.estimator. You'll get an overview of how to construct an `input_fn` to preprocess and feed data into your models. Then, you'll implement an `input_fn` that feeds training, evaluation, and prediction data into a neural network regressor for predicting @@ -8,27 +8,26 @@ median house values. ## Custom Input Pipelines with input_fn -When training a neural network using tf.contrib.learn, it's possible to pass -your feature and target data directly into your `fit`, `evaluate`, or `predict` -operations. Here's an example taken from the @{$tflearn$tf.contrib.learn quickstart tutorial}: +The `input_fn` is used to pass feature and target data to the `train`, +`evaluate`, and `predict` methods of the `Estimator`. +The user can do feature engineering or pre-processing inside the `input_fn`. +Here's an example taken from the @{$estimator$tf.estimator Quickstart tutorial}: ```python +import numpy as np + training_set = tf.contrib.learn.datasets.base.load_csv_with_header( filename=IRIS_TRAINING, target_dtype=np.int, features_dtype=np.float32) -test_set = tf.contrib.learn.datasets.base.load_csv_with_header( - filename=IRIS_TEST, target_dtype=np.int, features_dtype=np.float32) -... -classifier.fit(x=training_set.data, - y=training_set.target, - steps=2000) +train_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": np.array(training_set.data)}, + y=np.array(training_set.target), + num_epochs=None, + shuffle=True) + +classifier.train(input_fn=train_input_fn, steps=2000) ``` -This approach works well when little to no manipulation of source data is -required. But in cases where more feature engineering is needed, -`tf.contrib.learn` supports using a custom input function (`input_fn`) to -encapsulate the logic for preprocessing and piping data into your models. - ### Anatomy of an input_fn The following code illustrates the basic skeleton for an input function: @@ -43,8 +42,9 @@ def my_input_fn(): return feature_cols, labels ``` -The body of the input function contains the specific logic for preprocessing your -input data, such as scrubbing out bad examples or [feature scaling](https://en.wikipedia.org/wiki/Feature_scaling). +The body of the input function contains the specific logic for preprocessing +your input data, such as scrubbing out bad examples or +[feature scaling](https://en.wikipedia.org/wiki/Feature_scaling). Input functions must return the following two values containing the final feature and label data to be fed into your model (as shown in the above code @@ -61,15 +61,27 @@ data. ### Converting Feature Data to Tensors -If your feature/label data is stored in [_pandas_](http://pandas.pydata.org/) -dataframes or [numpy](http://www.numpy.org/) arrays, you'll need to convert it -to `Tensor`s before returning it from your `input_fn`. - -For continuous data, you can create and populate a `Tensor` using `tf.constant`: +If your feature/label data is a python array or stored in +[_pandas_](http://pandas.pydata.org/) dataframes or +[numpy](http://www.numpy.org/) arrays, you can use the following methods to +construct `input_fn`: ```python -feature_column_data = [1, 2.4, 0, 9.9, 3, 120] -feature_tensor = tf.constant(feature_column_data) +import numpy as np +# numpy input_fn. +my_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": np.array(x_data)}, + y=np.array(y_data), + ...) +``` + +```python +import pandas as pd +# pandas input_fn. +my_input_fn = tf.estimator.inputs.pandas_input_fn( + x=pd.DataFrame({"x": x_data}), + y=pd.Series(y_data), + ...) ``` For [sparse, categorical data](https://en.wikipedia.org/wiki/Sparse_matrix) @@ -103,33 +115,26 @@ This corresponds to the following dense tensor: [0, 0, 0, 0, 0.5]] ``` -For more on `SparseTensor`, see the -@{tf.SparseTensor}. +For more on `SparseTensor`, see @{tf.SparseTensor}. ### Passing input_fn Data to Your Model To feed data to your model for training, you simply pass the input function -you've created to your `fit` operation as the value of the `input_fn` parameter, -e.g.: +you've created to your `train` operation as the value of the `input_fn` +parameter, e.g.: ```python -classifier.fit(input_fn=my_input_fn, steps=2000) +classifier.train(input_fn=my_input_fn, steps=2000) ``` -Note that the `input_fn` is responsible for supplying both feature and label -data to the model, and replaces both the `x` and `y` parameters in `fit`. If you -supply an `input_fn` value to `fit` that is not `None` in conjunction with -either an `x` or `y` parameter that is not `None`, it will result in a -`ValueError`. - -Also note that the `input_fn` parameter must receive a function object (i.e., +Note that the `input_fn` parameter must receive a function object (i.e., `input_fn=my_input_fn`), not the return value of a function call -(`input_fn=my_input_fn()`). This means that if you try to pass parameters to the input -function in your `fit` call, as in the following code, it will result in a +(`input_fn=my_input_fn()`). This means that if you try to pass parameters to the +`input_fn` in your `train` call, as in the following code, it will result in a `TypeError`: ```python -classifier.fit(input_fn=my_input_fn(training_set), steps=2000) +classifier.train(input_fn=my_input_fn(training_set), steps=2000) ``` However, if you'd like to be able to parameterize your input function, there are @@ -138,29 +143,33 @@ arguments as your `input_fn` and use it to invoke your input function with the desired parameters. For example: ```python -def my_input_function_training_set(): - return my_input_function(training_set) +def my_input_fn(data_set): + ... -classifier.fit(input_fn=my_input_fn_training_set, steps=2000) +def my_input_fn_training_set(): + return my_input_fn(training_set) + +classifier.train(input_fn=my_input_fn_training_set, steps=2000) ``` Alternatively, you can use Python's [`functools.partial`](https://docs.python.org/2/library/functools.html#functools.partial) function to construct a new function object with all parameter values fixed: ```python -classifier.fit(input_fn=functools.partial(my_input_function, - data_set=training_set), steps=2000) +classifier.train( + input_fn=functools.partial(my_input_fn, data_set=training_set), + steps=2000) ``` -A third option is to wrap your input_fn invocation in a +A third option is to wrap your `input_fn` invocation in a [`lambda`](https://docs.python.org/3/tutorial/controlflow.html#lambda-expressions) and pass it to the `input_fn` parameter: ```python -classifier.fit(input_fn=lambda: my_input_fn(training_set), steps=2000) +classifier.train(input_fn=lambda: my_input_fn(training_set), steps=2000) ``` -One big advantage of architecting your input pipeline as shown above—to accept a +One big advantage of designing your input pipeline as shown above—to accept a parameter for data set—is that you can pass the same `input_fn` to `evaluate` and `predict` operations by just changing the data set argument, e.g.: @@ -168,9 +177,36 @@ and `predict` operations by just changing the data set argument, e.g.: classifier.evaluate(input_fn=lambda: my_input_fn(test_set), steps=2000) ``` -This approach enhances code maintainability: no need to capture `x` and `y` -values in separate variables (e.g., `x_train`, `x_test`, `y_train`, `y_test`) -for each type of operation. +This approach enhances code maintainability: no need to define multiple +`input_fn` (e.g. `input_fn_train`, `input_fn_test`, `input_fn_predict`) for each +type of operation. + +Finally, you can use the methods in `tf.estimator.inputs` to create `input_fn` +from numpy or pandas data sets. The additional benefit is that you can use +more arguments, such as `num_epochs` and `shuffle` to control how the `input_fn` +iterates over the data: + +```python +import pandas as pd + +def get_input_fn_from_pandas(data_set, num_epochs=None, shuffle=True): + return tf.estimator.inputs.pandas_input_fn( + x=pdDataFrame(...), + y=pd.Series(...), + num_epochs=num_epochs, + shuffle=shuffle) +``` + +```python +import numpy as np + +def get_input_fn_from_numpy(data_set, num_epochs=None, shuffle=True): + return tf.estimator.inputs.numpy_input_fn( + x={...}, + y=np.array(...), + num_epochs=num_epochs, + shuffle=shuffle) +``` ### A Neural Network Model for Boston House Values @@ -259,8 +295,7 @@ housing data set contain continuous values, you can create their `FeatureColumn`s using the `tf.contrib.layers.real_valued_column()` function: ```python -feature_cols = [tf.contrib.layers.real_valued_column(k) - for k in FEATURES] +feature_cols = [tf.feature_column.numeric_column(k) for k in FEATURES] ``` NOTE: For a more in-depth overview of feature columns, see @@ -275,36 +310,47 @@ with 10 nodes each), and `feature_columns`, containing the list of `FeatureColumns` you just defined: ```python -regressor = tf.contrib.learn.DNNRegressor(feature_columns=feature_cols, - hidden_units=[10, 10], - model_dir="/tmp/boston_model") +regressor = tf.estimator.DNNRegressor(feature_columns=feature_cols, + hidden_units=[10, 10], + model_dir="/tmp/boston_model") ``` ### Building the input_fn -To pass input data into the `regressor`, create an input function, which will -accept a _pandas_ `Dataframe` and return feature column and label values as -`Tensor`s: +To pass input data into the `regressor`, write a factory method that accepts a +_pandas_ `Dataframe` and returns an `input_fn`: ```python -def input_fn(data_set): - feature_cols = {k: tf.constant(data_set[k].values) - for k in FEATURES} - labels = tf.constant(data_set[LABEL].values) - return feature_cols, labels +def get_input_fn(data_set, num_epochs=None, shuffle=True): + return tf.estimator.inputs.pandas_input_fn( + x=pd.DataFrame({k: data_set[k].values for k in FEATURES}), + y = pd.Series(data_set[LABEL].values), + num_epochs=num_epochs, + shuffle=shuffle) ``` Note that the input data is passed into `input_fn` in the `data_set` argument, which means the function can process any of the `DataFrame`s you've imported: `training_set`, `test_set`, and `prediction_set`. +Two additional arguments are provided: +* `num_epochs`: controls the number of + epochs to iterate over data. For training, set this to `None`, so the + `input_fn` keeps returning data until the required number of train steps is + reached. For evaluate and predict, set this to 1, so the `input_fn` will + iterate over the data once and then raise `OutOfRangeError`. That error will + signal the `Estimator` to stop evaluate or predict. +* `shuffle`: Whether to shuffle the data. For evaluate and predict, set this to + `False`, so the `input_fn` iterates over the data sequentially. For train, + set this to `True`. + ### Training the Regressor -To train the neural network regressor, run `fit` with the `training_set` passed -to the `input_fn` as follows: +To train the neural network regressor, run `train` with the `training_set` +passed to the `input_fn` as follows: ```python -regressor.fit(input_fn=lambda: input_fn(training_set), steps=5000) +regressor.train(input_fn=get_input_fn(training_set), steps=5000) ``` You should see log output similar to the following, which reports training loss @@ -330,7 +376,8 @@ Next, see how the trained model performs against the test data set. Run `evaluate`, and this time pass the `test_set` to the `input_fn`: ```python -ev = regressor.evaluate(input_fn=lambda: input_fn(test_set), steps=1) +ev = regressor.evaluate( + input_fn=get_input_fn(test_set, num_epochs=1, shuffle=False)) ``` Retrieve the loss from the `ev` results and print it to output: @@ -354,10 +401,12 @@ Finally, you can use the model to predict median house values for the `prediction_set`, which contains feature data but no labels for six examples: ```python -y = regressor.predict(input_fn=lambda: input_fn(prediction_set)) -# .predict() returns an iterator; convert to a list and print predictions -predictions = list(itertools.islice(y, 6)) -print ("Predictions: {}".format(str(predictions))) +y = regressor.predict( + input_fn=get_input_fn(prediction_set, num_epochs=1, shuffle=False)) +# .predict() returns an iterator of dicts; convert to a list and print +# predictions +predictions = list(p["predictions"] for p in itertools.islice(y, 6)) +print("Predictions: {}".format(str(predictions))) ``` Your results should contain six house-value predictions in thousands of dollars, diff --git a/tensorflow/docs_src/get_started/leftnav_files b/tensorflow/docs_src/get_started/leftnav_files index 812f248d3eb..656727fbfe0 100644 --- a/tensorflow/docs_src/get_started/leftnav_files +++ b/tensorflow/docs_src/get_started/leftnav_files @@ -3,7 +3,7 @@ get_started.md mnist/beginners.md mnist/pros.md mnist/mechanics.md -tflearn.md +estimator.md input_fn.md monitors.md summaries_and_tensorboard.md diff --git a/tensorflow/docs_src/get_started/monitors.md b/tensorflow/docs_src/get_started/monitors.md index d9c605b013c..5606e953658 100644 --- a/tensorflow/docs_src/get_started/monitors.md +++ b/tensorflow/docs_src/get_started/monitors.md @@ -4,14 +4,14 @@ When training a model, it’s often valuable to track and evaluate progress in real time. In this tutorial, you’ll learn how to use TensorFlow’s logging capabilities and the `Monitor` API to audit the in-progress training of a neural network classifier for categorizing irises. This tutorial builds on the code -developed in @{$tflearn$tf.contrib.learn Quickstart} so if you +developed in @{$estimator$tf.estimator Quickstart} so if you haven't yet completed that tutorial, you may want to explore it first, especially if you're looking for an intro/refresher on tf.contrib.learn basics. ## Setup {#setup} For this tutorial, you'll be building upon the following code from -@{$tflearn$tf.contrib.learn Quickstart}: +@{$estimator$tf.estimator Quickstart}: ```python from __future__ import absolute_import @@ -75,7 +75,7 @@ here](https://www.tensorflow.org/code/tensorflow/examples/tutorials/monitors/iri ## Overview -The @{$tflearn$tf.contrib.learn Quickstart tutorial} walked through +The @{$estimator$tf.estimator Quickstart tutorial} walked through how to implement a neural net classifier to categorize iris examples into one of three species. diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index b970aa5f5fe..120260448d5 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -61,6 +61,42 @@ Invokes a computation with the given arguments. The arity and types of the `args` must match the parameters of the `computation`. It is allowed to have no `args`. +## Clamp + +See also +[`ComputationBuilder::Clamp`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h). + +Clamps an operand to within the range between a minimum and maximum value. + + `Clamp(computation, args...)` + +| Arguments | Type | Semantics | +| ------------- | ----------------------- | -------------------------------- | +| `computation` | `Computation` | computation of type `T_0, T_1, | +: : : ..., T_N -> S` with N parameters : +: : : of arbitrary type : +| `operand` | `ComputationDataHandle` | array of type T | +| `min` | `ComputationDataHandle` | array of type T | +| `max` | `ComputationDataHandle` | array of type T | + +Given an operand and minimum and maximum values, returns the operand if it is in +the range between the minimum and maximum, else returns the minimum value if the +operand is below this range or the maximum value if the operand is above this +range. That is, `clamp(x, a, b) = max(min(x, a), b)`. + +All three arrays must be the same shape. Alternately, as a restricted form of +[broadcasting](broadcasting.md), `min` and/or `max` can be a scalar of type `T`. + +Example with scalar `min` and `max`: + +``` +let operand: s32[3] = {-1, 5, 9}; +let min: s32 = 0; +let max: s32 = 6; +==> +Clamp(operand, min, max) = s32[3]{0, 5, 6}; +``` + ## Collapse See also @@ -547,6 +583,8 @@ ComputationBuilder supports these element-wise unary functions: `Ceil(operand)` Element-wise ceil `x -> ⌈x⌉`. +`Cos(operand)` Element-wise cosine `x -> cos(x)`. + `Exp(operand)` Element-wise natural exponential `x -> e^x`. `Floor(operand)` Element-wise floor `x -> ⌊x⌋`. diff --git a/tensorflow/docs_src/tutorials/kernel_methods.md b/tensorflow/docs_src/tutorials/kernel_methods.md index fbf1afc4ab4..8506b5228e7 100644 --- a/tensorflow/docs_src/tutorials/kernel_methods.md +++ b/tensorflow/docs_src/tutorials/kernel_methods.md @@ -22,7 +22,7 @@ TensorFlow will provide support for sparse features at a later release. This tutorial uses [tf.contrib.learn](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn) (TensorFlow's high-level Machine Learning API) Estimators for our ML models. -If you are not familiar with this API, [tf.contrib.learn Quickstart](https://www.tensorflow.org/get_started/tflearn) +If you are not familiar with this API, [tf.estimator Quickstart](https://www.tensorflow.org/get_started/estimator) is a good place to start. We will use the MNIST dataset. The tutorial consists of the following steps: diff --git a/tensorflow/docs_src/tutorials/linear.md b/tensorflow/docs_src/tutorials/linear.md index de87c164ae0..45856173fa4 100644 --- a/tensorflow/docs_src/tutorials/linear.md +++ b/tensorflow/docs_src/tutorials/linear.md @@ -16,8 +16,7 @@ give it a try. This overview uses code samples from the tutorial, but the tutorial walks through the code in greater detail. To understand this overview it will help to have some familiarity -with basic machine learning concepts, and also with -@{$tflearn$tf.contrib.learn}. +with basic machine learning concepts, and also with @{$estimator$tf.estimator}. [TOC] diff --git a/tensorflow/examples/ios/README.md b/tensorflow/examples/ios/README.md index 9832399d721..a412381196b 100644 --- a/tensorflow/examples/ios/README.md +++ b/tensorflow/examples/ios/README.md @@ -20,9 +20,9 @@ mkdir -p ~/graphs curl -o ~/graphs/inception5h.zip \ https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip \ && unzip ~/graphs/inception5h.zip -d ~/graphs/inception5h -cp ~/graphs/inception5h/* tensorflow/contrib/ios_examples/benchmark/data/ -cp ~/graphs/inception5h/* tensorflow/contrib/ios_examples/camera/data/ -cp ~/graphs/inception5h/* tensorflow/contrib/ios_examples/simple/data/ +cp ~/graphs/inception5h/* tensorflow/examples/ios/benchmark/data/ +cp ~/graphs/inception5h/* tensorflow/examples/ios/camera/data/ +cp ~/graphs/inception5h/* tensorflow/examples/ios/simple/data/ ``` - Change directory to one of the samples, download the TensorFlow-experimental @@ -30,7 +30,7 @@ cp ~/graphs/inception5h/* tensorflow/contrib/ios_examples/simple/data/ long time since it is big (~450MB). For example, if you want to run the simple example, then: ```bash -cd tensorflow/contrib/ios_examples/simple +cd tensorflow/ios/simple pod install open tf_simple_example.xcworkspace # obs, not the .xcodeproj directory ``` @@ -51,9 +51,10 @@ open tf_simple_example.xcworkspace # obs, not the .xcodeproj directory - The TensorFlow-experimental pod is current about ~450MB. The reason it is so big is because we are bundling multiple platforms, and the pod includes - all TensorFlow functionality (e.g. operations). This is convenient during - development, but see below section on how you can build your own custom - TensorFlow library to reduce the size. + all TensorFlow functionality (e.g. operations). The final app size after + build is substantially smaller though (~25MB). Working with the complete + pod is convenient during development, but see below section on how you can + build your own custom TensorFlow library to reduce the size. ### Creating Your own App @@ -145,10 +146,10 @@ rundown: in your project settings. - Remove any use of the `-all_load` flag in your project. The protocol buffers - libraries (full and lite versions) contain duplicate symbols, and the `-all_load` - flag will cause these duplicates to become link errors. If you were using - `-all_load` to avoid issues with Objective-C categories in static libraries, - you may be able to replace it with the `-ObjC` flag. + libraries (full and lite versions) contain duplicate symbols, and the + `-all_load` flag will cause these duplicates to become link errors. If you + were using `-all_load` to avoid issues with Objective-C categories in static + libraries, you may be able to replace it with the `-ObjC` flag. ### Reducing the binary size @@ -159,7 +160,7 @@ It can be tricky to set up the right configuration in your own app to keep the size minimized, so if you do run into this issue we recommend you start by looking at the simple example to examine its size. Here's how you do that: - - Open the Xcode project in tensorflow/contrib/ios_examples/simple. + - Open the Xcode project in tensorflow/examples/ios/simple. - Make sure you've followed the steps above to get the data files. @@ -181,7 +182,7 @@ looking at the simple example to examine its size. Here's how you do that: - Running this command will show the size of the executable as the `tf_simple_example` line. -Right now you'll see a size of around 23 MB, since it's including two +Right now you'll see a size of around 25 MB, since it's including two architectures (armv7 and arm64). As a first step, you should make sure the size increase you see in your own app is similar, and if it's larger, look at the "Other Linker Flags" used in the Simple Xcode project settings to strip the diff --git a/tensorflow/examples/ios/benchmark/BenchmarkViewController.mm b/tensorflow/examples/ios/benchmark/BenchmarkViewController.mm index cab7b36f177..9fc5f6ded24 100644 --- a/tensorflow/examples/ios/benchmark/BenchmarkViewController.mm +++ b/tensorflow/examples/ios/benchmark/BenchmarkViewController.mm @@ -22,17 +22,7 @@ #include #include -//#include "google/protobuf/io/coded_stream.h" -//#include "google/protobuf/io/zero_copy_stream_impl.h" -//#include "google/protobuf/io/zero_copy_stream_impl_lite.h" -//#include "google/protobuf/message_lite.h" #include "tensorflow/core/framework/op_kernel.h" -//#include "tensorflow/core/framework/tensor.h" -//#include "tensorflow/core/framework/types.pb.h" -//#include "tensorflow/core/platform/env.h" -//#include "tensorflow/core/platform/logging.h" -//#include "tensorflow/core/platform/mutex.h" -//#include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session.h" #include "tensorflow/core/util/stat_summarizer.h" diff --git a/tensorflow/examples/learn/README.md b/tensorflow/examples/learn/README.md index 37157fc2967..3166f37e784 100644 --- a/tensorflow/examples/learn/README.md +++ b/tensorflow/examples/learn/README.md @@ -1,33 +1,36 @@ -# TF Learn Examples +# Estimator Examples -Learn is a high-level API for TensorFlow that allows you to create, -train, and use deep learning models easily. See the [Quickstart tutorial](https://www.tensorflow.org/get_started/tflearn) +TensorFlow Estimators are a high-level API for TensorFlow that allows you to +create, train, and use deep learning models easily. + +See the [Quickstart tutorial](https://www.tensorflow.org/get_started/estimator) for an introduction to the API. -To run most of these examples, you need to install the `scikit learn` library (`sudo pip install sklearn`). -Some examples use the `pandas` library for data processing (`sudo pip install pandas`). +To run most of these examples, you need to install the `scikit learn` library +(`sudo pip install sklearn`). Some examples use the `pandas` library for data +processing (`sudo pip install pandas`). ## Basics -* [Deep Neural Network Regression with Boston Data](boston.py) -* [Deep Neural Network Classification with Iris Data](iris.py) -* [Building a Custom Model](iris_custom_model.py) -* [Building a Model Using Different GPU Configurations](iris_run_config.py) +* [Deep Neural Network Regression with Boston Data]( https://www.tensorflow.org/code/tensorflow/examples/learn/boston.py) +* [Deep Neural Network Classification with Iris Data]( https://www.tensorflow.org/code/tensorflow/examples/learn/iris.py) +* [Building a Custom Model]( https://www.tensorflow.org/code/tensorflow/examples/learn/iris_custom_model.py) +* [Building a Model Using Different GPU Configurations]( https://www.tensorflow.org/code/tensorflow/examples/learn/iris_run_config.py) ## Techniques -* [Improving Performance Using Early Stopping with Iris Data](iris_val_based_early_stopping.py) -* [Using skflow with Pipeline](iris_with_pipeline.py) -* [Deep Neural Network with Customized Decay Function](iris_custom_decay_dnn.py) +* [Improving Performance Using Early Stopping with Iris Data]( https://www.tensorflow.org/code/tensorflow/examples/learn/iris_val_based_early_stopping.py) +* [Using skflow with Pipeline]( https://www.tensorflow.org/code/tensorflow/examples/learn/iris_with_pipeline.py) +* [Deep Neural Network with Customized Decay Function]( https://www.tensorflow.org/code/tensorflow/examples/learn/iris_custom_decay_dnn.py) ## Specialized Models -* [Building a Random Forest Model](random_forest_mnist.py) -* [Building a Wide & Deep Model](wide_n_deep_tutorial.py) -* [Building a Residual Network Model](resnet.py) +* [Building a Random Forest Model]( https://www.tensorflow.org/code/tensorflow/examples/learn/random_forest_mnist.py) +* [Building a Wide & Deep Model]( https://www.tensorflow.org/code/tensorflow/examples/learn/wide_n_deep_tutorial.py) +* [Building a Residual Network Model]( https://www.tensorflow.org/code/tensorflow/examples/learn/resnet.py) ## Text classification -* [Text Classification Using Recurrent Neural Networks on Words](text_classification.py) -* [Text Classification Using Convolutional Neural Networks on Words](text_classification_cnn.py) -* [Text Classification Using Recurrent Neural Networks on Characters](text_classification_character_rnn.py) -* [Text Classification Using Convolutional Neural Networks on Characters](text_classification_character_cnn.py) +* [Text Classification Using Recurrent Neural Networks on Words]( https://www.tensorflow.org/code/tensorflow/examples/learn/text_classification.py) +* [Text Classification Using Convolutional Neural Networks on Words]( https://www.tensorflow.org/code/tensorflow/examples/learn/text_classification_cnn.py) +* [Text Classification Using Recurrent Neural Networks on Characters]( https://www.tensorflow.org/code/tensorflow/examples/learn/text_classification_character_rnn.py) +* [Text Classification Using Convolutional Neural Networks on Characters]( https://www.tensorflow.org/code/tensorflow/examples/learn/text_classification_character_cnn.py) diff --git a/tensorflow/examples/tutorials/estimators/abalone.py b/tensorflow/examples/tutorials/estimators/abalone.py index 3c0ea2e4090..4765d5dabf4 100644 --- a/tensorflow/examples/tutorials/estimators/abalone.py +++ b/tensorflow/examples/tutorials/estimators/abalone.py @@ -25,7 +25,6 @@ from six.moves import urllib import numpy as np import tensorflow as tf -from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib FLAGS = None @@ -72,39 +71,39 @@ def maybe_download(train_data, test_data, predict_data): return train_file_name, test_file_name, predict_file_name -def model_fn(features, targets, mode, params): +def model_fn(features, labels, mode, params): """Model function for Estimator.""" # Connect the first hidden layer to input layer - # (features) with relu activation - first_hidden_layer = tf.contrib.layers.relu(features, 10) + # (features["x"]) with relu activation + first_hidden_layer = tf.layers.dense(features["x"], 10, activation=tf.nn.relu) # Connect the second hidden layer to first hidden layer with relu - second_hidden_layer = tf.contrib.layers.relu(first_hidden_layer, 10) + second_hidden_layer = tf.layers.dense( + first_hidden_layer, 10, activation=tf.nn.relu) # Connect the output layer to second hidden layer (no activation fn) - output_layer = tf.contrib.layers.linear(second_hidden_layer, 1) + output_layer = tf.layers.dense(second_hidden_layer, 1) # Reshape output layer to 1-dim Tensor to return predictions predictions = tf.reshape(output_layer, [-1]) predictions_dict = {"ages": predictions} # Calculate loss using mean squared error - loss = tf.losses.mean_squared_error(targets, predictions) + loss = tf.losses.mean_squared_error(labels, predictions) # Calculate root mean squared error as additional eval metric eval_metric_ops = { "rmse": tf.metrics.root_mean_squared_error( - tf.cast(targets, tf.float64), predictions) + tf.cast(labels, tf.float64), predictions) } - train_op = tf.contrib.layers.optimize_loss( - loss=loss, - global_step=tf.contrib.framework.get_global_step(), - learning_rate=params["learning_rate"], - optimizer="SGD") + optimizer = tf.train.GradientDescentOptimizer( + learning_rate=params["learning_rate"]) + train_op = optimizer.minimize( + loss=loss, global_step=tf.train.get_global_step()) - return model_fn_lib.ModelFnOps( + return tf.estimator.EstimatorSpec( mode=mode, predictions=predictions_dict, loss=loss, @@ -133,28 +132,34 @@ def main(unused_argv): model_params = {"learning_rate": LEARNING_RATE} # Instantiate Estimator - nn = tf.contrib.learn.Estimator(model_fn=model_fn, params=model_params) - - def get_train_inputs(): - x = tf.constant(training_set.data) - y = tf.constant(training_set.target) - return x, y - - # Fit - nn.fit(input_fn=get_train_inputs, steps=5000) + nn = tf.estimator.Estimator(model_fn=model_fn, params=model_params) + + train_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": np.array(training_set.data)}, + y=np.array(training_set.target), + num_epochs=None, + shuffle=True) + + # Train + nn.train(input_fn=train_input_fn, steps=5000) # Score accuracy - def get_test_inputs(): - x = tf.constant(test_set.data) - y = tf.constant(test_set.target) - return x, y - - ev = nn.evaluate(input_fn=get_test_inputs, steps=1) + test_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": np.array(test_set.data)}, + y=np.array(test_set.target), + num_epochs=1, + shuffle=False) + + ev = nn.evaluate(input_fn=test_input_fn) print("Loss: %s" % ev["loss"]) print("Root Mean Squared Error: %s" % ev["rmse"]) # Print out predictions - predictions = nn.predict(x=prediction_set.data, as_iterable=True) + predict_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": prediction_set.data}, + num_epochs=1, + shuffle=False) + predictions = nn.predict(input_fn=predict_input_fn) for i, p in enumerate(predictions): print("Prediction %s: %s" % (i + 1, p["ages"])) diff --git a/tensorflow/examples/tutorials/input_fn/boston.py b/tensorflow/examples/tutorials/input_fn/boston.py index c7fb7e23165..34f350e9acd 100644 --- a/tensorflow/examples/tutorials/input_fn/boston.py +++ b/tensorflow/examples/tutorials/input_fn/boston.py @@ -31,10 +31,12 @@ FEATURES = ["crim", "zn", "indus", "nox", "rm", LABEL = "medv" -def input_fn(data_set): - feature_cols = {k: tf.constant(data_set[k].values) for k in FEATURES} - labels = tf.constant(data_set[LABEL].values) - return feature_cols, labels +def get_input_fn(data_set, num_epochs=None, shuffle=True): + return tf.estimator.inputs.pandas_input_fn( + x=pd.DataFrame({k: data_set[k].values for k in FEATURES}), + y=pd.Series(data_set[LABEL].values), + num_epochs=num_epochs, + shuffle=shuffle) def main(unused_argv): @@ -49,26 +51,28 @@ def main(unused_argv): skiprows=1, names=COLUMNS) # Feature cols - feature_cols = [tf.contrib.layers.real_valued_column(k) - for k in FEATURES] + feature_cols = [tf.feature_column.numeric_column(k) for k in FEATURES] # Build 2 layer fully connected DNN with 10, 10 units respectively. - regressor = tf.contrib.learn.DNNRegressor(feature_columns=feature_cols, - hidden_units=[10, 10], - model_dir="/tmp/boston_model") + regressor = tf.estimator.DNNRegressor(feature_columns=feature_cols, + hidden_units=[10, 10], + model_dir="/tmp/boston_model") - # Fit - regressor.fit(input_fn=lambda: input_fn(training_set), steps=5000) + # Train + regressor.train(input_fn=get_input_fn(training_set), steps=5000) - # Score accuracy - ev = regressor.evaluate(input_fn=lambda: input_fn(test_set), steps=1) + # Evaluate loss over one epoch of test_set. + ev = regressor.evaluate( + input_fn=get_input_fn(test_set, num_epochs=1, shuffle=False)) loss_score = ev["loss"] print("Loss: {0:f}".format(loss_score)) - # Print out predictions - y = regressor.predict(input_fn=lambda: input_fn(prediction_set)) - # .predict() returns an iterator; convert to a list and print predictions - predictions = list(itertools.islice(y, 6)) + # Print out predictions over a slice of prediction_set. + y = regressor.predict( + input_fn=get_input_fn(prediction_set, num_epochs=1, shuffle=False)) + # .predict() returns an iterator of dicts; convert to a list and print + # predictions + predictions = list(p["predictions"] for p in itertools.islice(y, 6)) print("Predictions: {}".format(str(predictions))) if __name__ == "__main__": diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index c72dfdd17e2..664c690e347 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -2320,6 +2320,42 @@ func DecodeWav(scope *Scope, contents tf.Output, optional ...DecodeWavAttr) (aud return op.Output(0), op.Output(1) } +// Elementwise computes the bitwise XOR of `x` and `y`. +// +// The result will have those bits set, that are different in `x` and `y`. The +// computation is performed on the underlying representations of `x` and `y`. +func BitwiseXor(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BitwiseXor", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Elementwise computes the bitwise AND of `x` and `y`. +// +// The result will have those bits set, that are set in both `x` and `y`. The +// computation is performed on the underlying representations of `x` and `y`. +func BitwiseAnd(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BitwiseAnd", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // AllCandidateSamplerAttr is an optional argument to AllCandidateSampler. type AllCandidateSamplerAttr func(optionalAttr) @@ -7495,85 +7531,6 @@ func RsqrtGrad(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } -// Execute a sub graph on a remote processor transferred by GraphTransferer. -// -// The graph specifications are serialized by protobuf as graph_transfer_info. -// The implementation / limitations may differ for each platform -// and each available peripheral. -func RemoteFusedGraphExecute(scope *Scope, inputs []tf.Output, Toutputs []tf.DataType, serialized_remote_fused_graph_execute_info string) (outputs []tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"Toutputs": Toutputs, "serialized_remote_fused_graph_execute_info": serialized_remote_fused_graph_execute_info} - opspec := tf.OpSpec{ - Type: "RemoteFusedGraphExecute", - Input: []tf.Input{ - tf.OutputList(inputs), - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { - scope.UpdateErr("RemoteFusedGraphExecute", err) - return - } - return outputs -} - -// Conv3DBackpropFilterV2Attr is an optional argument to Conv3DBackpropFilterV2. -type Conv3DBackpropFilterV2Attr func(optionalAttr) - -// Conv3DBackpropFilterV2DataFormat sets the optional data_format attribute to value. -// -// value: The data format of the input and output data. With the -// default format "NDHWC", the data is stored in the order of: -// [batch, in_depth, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCDHW", the data storage order is: -// [batch, in_channels, in_depth, in_height, in_width]. -// If not specified, defaults to "NDHWC" -func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes the gradients of 3-D convolution with respect to the filter. -// -// Arguments: -// input: Shape `[batch, depth, rows, cols, in_channels]`. -// filter_sizes: An integer vector representing the tensor shape of `filter`, -// where `filter` is a 5-D -// `[filter_depth, filter_height, filter_width, in_channels, out_channels]` -// tensor. -// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, -// out_channels]`. -// strides: 1-D tensor of length 5. The stride of the sliding window for each -// dimension of `input`. Must have `strides[0] = strides[4] = 1`. -// padding: The type of padding algorithm to use. -func Conv3DBackpropFilterV2(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropFilterV2Attr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "Conv3DBackpropFilterV2", - Input: []tf.Input{ - input, filter_sizes, out_backprop, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // MapClearAttr is an optional argument to MapClear. type MapClearAttr func(optionalAttr) @@ -10482,206 +10439,6 @@ func AddSparseToTensorsMap(scope *Scope, sparse_indices tf.Output, sparse_values return op.Output(0) } -// DecodeBmpAttr is an optional argument to DecodeBmp. -type DecodeBmpAttr func(optionalAttr) - -// DecodeBmpChannels sets the optional channels attribute to value. -// If not specified, defaults to 0 -func DecodeBmpChannels(value int64) DecodeBmpAttr { - return func(m optionalAttr) { - m["channels"] = value - } -} - -// Decode the first frame of a BMP-encoded image to a uint8 tensor. -// -// The attr `channels` indicates the desired number of color channels for the -// decoded image. -// -// Accepted values are: -// -// * 0: Use the number of channels in the BMP-encoded image. -// * 3: output an RGB image. -// * 4: output an RGBA image. -// -// Arguments: -// contents: 0-D. The BMP-encoded image. -// -// Returns 3-D with shape `[height, width, channels]`. RGB order -func DecodeBmp(scope *Scope, contents tf.Output, optional ...DecodeBmpAttr) (image tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "DecodeBmp", - Input: []tf.Input{ - contents, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Computes softmax activations. -// -// For each batch `i` and class `j` we have -// -// softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j])) -// -// Arguments: -// logits: 2-D with shape `[batch_size, num_classes]`. -// -// Returns Same shape as `logits`. -func Softmax(scope *Scope, logits tf.Output) (softmax tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Softmax", - Input: []tf.Input{ - logits, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// RandomShuffleQueueV2Attr is an optional argument to RandomShuffleQueueV2. -type RandomShuffleQueueV2Attr func(optionalAttr) - -// RandomShuffleQueueV2Shapes sets the optional shapes attribute to value. -// -// value: The shape of each component in a value. The length of this attr must -// be either 0 or the same as the length of component_types. If the length of -// this attr is 0, the shapes of queue elements are not constrained, and -// only one element may be dequeued at a time. -// If not specified, defaults to <> -// -// REQUIRES: len(value) >= 0 -func RandomShuffleQueueV2Shapes(value []tf.Shape) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["shapes"] = value - } -} - -// RandomShuffleQueueV2Capacity sets the optional capacity attribute to value. -// -// value: The upper bound on the number of elements in this queue. -// Negative numbers mean no limit. -// If not specified, defaults to -1 -func RandomShuffleQueueV2Capacity(value int64) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["capacity"] = value - } -} - -// RandomShuffleQueueV2MinAfterDequeue sets the optional min_after_dequeue attribute to value. -// -// value: Dequeue will block unless there would be this -// many elements after the dequeue or the queue is closed. This -// ensures a minimum level of mixing of elements. -// If not specified, defaults to 0 -func RandomShuffleQueueV2MinAfterDequeue(value int64) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["min_after_dequeue"] = value - } -} - -// RandomShuffleQueueV2Seed sets the optional seed attribute to value. -// -// value: If either seed or seed2 is set to be non-zero, the random number -// generator is seeded by the given seed. Otherwise, a random seed is used. -// If not specified, defaults to 0 -func RandomShuffleQueueV2Seed(value int64) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["seed"] = value - } -} - -// RandomShuffleQueueV2Seed2 sets the optional seed2 attribute to value. -// -// value: A second seed to avoid seed collision. -// If not specified, defaults to 0 -func RandomShuffleQueueV2Seed2(value int64) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["seed2"] = value - } -} - -// RandomShuffleQueueV2Container sets the optional container attribute to value. -// -// value: If non-empty, this queue is placed in the given container. -// Otherwise, a default container is used. -// If not specified, defaults to "" -func RandomShuffleQueueV2Container(value string) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// RandomShuffleQueueV2SharedName sets the optional shared_name attribute to value. -// -// value: If non-empty, this queue will be shared under the given name -// across multiple sessions. -// If not specified, defaults to "" -func RandomShuffleQueueV2SharedName(value string) RandomShuffleQueueV2Attr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// A queue that randomizes the order of elements. -// -// Arguments: -// component_types: The type of each component in a value. -// -// Returns The handle to the queue. -func RandomShuffleQueueV2(scope *Scope, component_types []tf.DataType, optional ...RandomShuffleQueueV2Attr) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"component_types": component_types} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "RandomShuffleQueueV2", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Outputs a `Summary` protocol buffer with scalar values. -// -// The input `tags` and `values` must have the same shape. The generated summary -// has a summary value for each tag-value pair in `tags` and `values`. -// -// Arguments: -// tags: Tags for the summary. -// values: Same shape as `tags. Values for the summary. -// -// Returns Scalar. Serialized `Summary` protocol buffer. -func ScalarSummary(scope *Scope, tags tf.Output, values tf.Output) (summary tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ScalarSummary", - Input: []tf.Input{ - tags, values, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Constructs a tensor by tiling a given tensor. // // This operation creates a new tensor by replicating `input` `multiples` times. @@ -10770,6 +10527,114 @@ func ShardedFilename(scope *Scope, basename tf.Output, shard tf.Output, num_shar return op.Output(0) } +// Conv3DBackpropFilterV2Attr is an optional argument to Conv3DBackpropFilterV2. +type Conv3DBackpropFilterV2Attr func(optionalAttr) + +// Conv3DBackpropFilterV2DataFormat sets the optional data_format attribute to value. +// +// value: The data format of the input and output data. With the +// default format "NDHWC", the data is stored in the order of: +// [batch, in_depth, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCDHW", the data storage order is: +// [batch, in_channels, in_depth, in_height, in_width]. +// If not specified, defaults to "NDHWC" +func Conv3DBackpropFilterV2DataFormat(value string) Conv3DBackpropFilterV2Attr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Computes the gradients of 3-D convolution with respect to the filter. +// +// Arguments: +// input: Shape `[batch, depth, rows, cols, in_channels]`. +// filter_sizes: An integer vector representing the tensor shape of `filter`, +// where `filter` is a 5-D +// `[filter_depth, filter_height, filter_width, in_channels, out_channels]` +// tensor. +// out_backprop: Backprop signal of shape `[batch, out_depth, out_rows, out_cols, +// out_channels]`. +// strides: 1-D tensor of length 5. The stride of the sliding window for each +// dimension of `input`. Must have `strides[0] = strides[4] = 1`. +// padding: The type of padding algorithm to use. +func Conv3DBackpropFilterV2(scope *Scope, input tf.Output, filter_sizes tf.Output, out_backprop tf.Output, strides []int64, padding string, optional ...Conv3DBackpropFilterV2Attr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "Conv3DBackpropFilterV2", + Input: []tf.Input{ + input, filter_sizes, out_backprop, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Execute a sub graph on a remote processor. +// +// The graph specifications(such as graph itself, input tensors and output names) +// are stored as a serialized protocol buffer of RemoteFusedGraphExecuteInfo +// as serialized_remote_fused_graph_execute_info. +// The specifications will be passed to a dedicated registered +// remote fused graph executor. The executor will send the graph specifications +// to a remote processor and execute that graph. The execution results +// will be passed to consumer nodes as outputs of this node. +// +// Arguments: +// inputs: Arbitrary number of tensors with arbitrary data types +// +// serialized_remote_fused_graph_execute_info: Serialized protocol buffer +// of RemoteFusedGraphExecuteInfo which contains graph specifications. +// +// Returns Arbitrary number of tensors with arbitrary data types +func RemoteFusedGraphExecute(scope *Scope, inputs []tf.Output, Toutputs []tf.DataType, serialized_remote_fused_graph_execute_info string) (outputs []tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"Toutputs": Toutputs, "serialized_remote_fused_graph_execute_info": serialized_remote_fused_graph_execute_info} + opspec := tf.OpSpec{ + Type: "RemoteFusedGraphExecute", + Input: []tf.Input{ + tf.OutputList(inputs), + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + if scope.Err() != nil { + return + } + var idx int + var err error + if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { + scope.UpdateErr("RemoteFusedGraphExecute", err) + return + } + return outputs +} + +// Computes numerical negative value element-wise. +// +// I.e., \\(y = -x\\). +func Neg(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Neg", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // SparseToSparseSetOperationAttr is an optional argument to SparseToSparseSetOperation. type SparseToSparseSetOperationAttr func(optionalAttr) @@ -10846,6 +10711,24 @@ func SparseToSparseSetOperation(scope *Scope, set1_indices tf.Output, set1_value return op.Output(0), op.Output(1), op.Output(2) } +// Elementwise computes the bitwise OR of `x` and `y`. +// +// The result will have those bits set, that are set in `x`, `y` or both. The +// computation is performed on the underlying representations of `x` and `y`. +func BitwiseOr(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "BitwiseOr", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Adds up a `SparseTensor` and a dense `Tensor`, producing a dense `Tensor`. // // This Op does not require `a_indices` be sorted in standard lexicographic order. @@ -11888,23 +11771,6 @@ func StringToHashBucketStrong(scope *Scope, input tf.Output, num_buckets int64, return op.Output(0) } -// Computes numerical negative value element-wise. -// -// I.e., \\(y = -x\\). -func Neg(scope *Scope, x tf.Output) (y tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "Neg", - Input: []tf.Input{ - x, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // DestroyResourceOpAttr is an optional argument to DestroyResourceOp. type DestroyResourceOpAttr func(optionalAttr) @@ -13378,6 +13244,24 @@ func ResourceSparseApplyCenteredRMSProp(scope *Scope, var_ tf.Output, mg tf.Outp return scope.AddOperation(opspec) } +// Flips all bits elementwise. +// +// The result will have exactly those bits set, that are not set in `x`. The +// computation is performed on the underlying representation of x. +func Invert(scope *Scope, x tf.Output) (y tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Invert", + Input: []tf.Input{ + x, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Computes the mean along segments of a tensor. // // Read @{$math_ops#segmentation$the section on segmentation} for an explanation of @@ -15283,97 +15167,6 @@ func ResourceApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Out return scope.AddOperation(opspec) } -// ResourceApplyAdamAttr is an optional argument to ResourceApplyAdam. -type ResourceApplyAdamAttr func(optionalAttr) - -// ResourceApplyAdamUseLocking sets the optional use_locking attribute to value. -// -// value: If `True`, updating of the var, m, and v tensors will be protected -// by a lock; otherwise the behavior is undefined, but may exhibit less -// contention. -// If not specified, defaults to false -func ResourceApplyAdamUseLocking(value bool) ResourceApplyAdamAttr { - return func(m optionalAttr) { - m["use_locking"] = value - } -} - -// ResourceApplyAdamUseNesterov sets the optional use_nesterov attribute to value. -// -// value: If `True`, uses the nesterov update. -// If not specified, defaults to false -func ResourceApplyAdamUseNesterov(value bool) ResourceApplyAdamAttr { - return func(m optionalAttr) { - m["use_nesterov"] = value - } -} - -// Update '*var' according to the Adam algorithm. -// -// lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) -// m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t -// v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t -// variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon) -// -// Arguments: -// var_: Should be from a Variable(). -// m: Should be from a Variable(). -// v: Should be from a Variable(). -// beta1_power: Must be a scalar. -// beta2_power: Must be a scalar. -// lr: Scaling factor. Must be a scalar. -// beta1: Momentum factor. Must be a scalar. -// beta2: Momentum factor. Must be a scalar. -// epsilon: Ridge term. Must be a scalar. -// grad: The gradient. -// -// Returns the created operation. -func ResourceApplyAdam(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, beta2_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdamAttr) (o *tf.Operation) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "ResourceApplyAdam", - Input: []tf.Input{ - var_, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, - }, - Attrs: attrs, - } - return scope.AddOperation(opspec) -} - -// 3D fast Fourier transform. -// -// Computes the 3-dimensional discrete Fourier transform over the inner-most 3 -// dimensions of `input`. -// -// Arguments: -// input: A complex64 tensor. -// -// Returns A complex64 tensor of the same shape as `input`. The inner-most 3 -// dimensions of `input` are replaced with their 3D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.fftn with 3 dimensions. -// @end_compatibility -func FFT3D(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "FFT3D", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Output a fact about factorials. func Fact(scope *Scope) (fact tf.Output) { if scope.Err() != nil { @@ -18741,6 +18534,297 @@ func NotEqual(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { return op.Output(0) } +// DecodeBmpAttr is an optional argument to DecodeBmp. +type DecodeBmpAttr func(optionalAttr) + +// DecodeBmpChannels sets the optional channels attribute to value. +// If not specified, defaults to 0 +func DecodeBmpChannels(value int64) DecodeBmpAttr { + return func(m optionalAttr) { + m["channels"] = value + } +} + +// Decode the first frame of a BMP-encoded image to a uint8 tensor. +// +// The attr `channels` indicates the desired number of color channels for the +// decoded image. +// +// Accepted values are: +// +// * 0: Use the number of channels in the BMP-encoded image. +// * 3: output an RGB image. +// * 4: output an RGBA image. +// +// Arguments: +// contents: 0-D. The BMP-encoded image. +// +// Returns 3-D with shape `[height, width, channels]`. RGB order +func DecodeBmp(scope *Scope, contents tf.Output, optional ...DecodeBmpAttr) (image tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "DecodeBmp", + Input: []tf.Input{ + contents, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Computes softmax activations. +// +// For each batch `i` and class `j` we have +// +// softmax[i, j] = exp(logits[i, j]) / sum_j(exp(logits[i, j])) +// +// Arguments: +// logits: 2-D with shape `[batch_size, num_classes]`. +// +// Returns Same shape as `logits`. +func Softmax(scope *Scope, logits tf.Output) (softmax tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "Softmax", + Input: []tf.Input{ + logits, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// RandomShuffleQueueV2Attr is an optional argument to RandomShuffleQueueV2. +type RandomShuffleQueueV2Attr func(optionalAttr) + +// RandomShuffleQueueV2Shapes sets the optional shapes attribute to value. +// +// value: The shape of each component in a value. The length of this attr must +// be either 0 or the same as the length of component_types. If the length of +// this attr is 0, the shapes of queue elements are not constrained, and +// only one element may be dequeued at a time. +// If not specified, defaults to <> +// +// REQUIRES: len(value) >= 0 +func RandomShuffleQueueV2Shapes(value []tf.Shape) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["shapes"] = value + } +} + +// RandomShuffleQueueV2Capacity sets the optional capacity attribute to value. +// +// value: The upper bound on the number of elements in this queue. +// Negative numbers mean no limit. +// If not specified, defaults to -1 +func RandomShuffleQueueV2Capacity(value int64) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["capacity"] = value + } +} + +// RandomShuffleQueueV2MinAfterDequeue sets the optional min_after_dequeue attribute to value. +// +// value: Dequeue will block unless there would be this +// many elements after the dequeue or the queue is closed. This +// ensures a minimum level of mixing of elements. +// If not specified, defaults to 0 +func RandomShuffleQueueV2MinAfterDequeue(value int64) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["min_after_dequeue"] = value + } +} + +// RandomShuffleQueueV2Seed sets the optional seed attribute to value. +// +// value: If either seed or seed2 is set to be non-zero, the random number +// generator is seeded by the given seed. Otherwise, a random seed is used. +// If not specified, defaults to 0 +func RandomShuffleQueueV2Seed(value int64) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["seed"] = value + } +} + +// RandomShuffleQueueV2Seed2 sets the optional seed2 attribute to value. +// +// value: A second seed to avoid seed collision. +// If not specified, defaults to 0 +func RandomShuffleQueueV2Seed2(value int64) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["seed2"] = value + } +} + +// RandomShuffleQueueV2Container sets the optional container attribute to value. +// +// value: If non-empty, this queue is placed in the given container. +// Otherwise, a default container is used. +// If not specified, defaults to "" +func RandomShuffleQueueV2Container(value string) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["container"] = value + } +} + +// RandomShuffleQueueV2SharedName sets the optional shared_name attribute to value. +// +// value: If non-empty, this queue will be shared under the given name +// across multiple sessions. +// If not specified, defaults to "" +func RandomShuffleQueueV2SharedName(value string) RandomShuffleQueueV2Attr { + return func(m optionalAttr) { + m["shared_name"] = value + } +} + +// A queue that randomizes the order of elements. +// +// Arguments: +// component_types: The type of each component in a value. +// +// Returns The handle to the queue. +func RandomShuffleQueueV2(scope *Scope, component_types []tf.DataType, optional ...RandomShuffleQueueV2Attr) (handle tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"component_types": component_types} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "RandomShuffleQueueV2", + + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Outputs a `Summary` protocol buffer with scalar values. +// +// The input `tags` and `values` must have the same shape. The generated summary +// has a summary value for each tag-value pair in `tags` and `values`. +// +// Arguments: +// tags: Tags for the summary. +// values: Same shape as `tags. Values for the summary. +// +// Returns Scalar. Serialized `Summary` protocol buffer. +func ScalarSummary(scope *Scope, tags tf.Output, values tf.Output) (summary tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ScalarSummary", + Input: []tf.Input{ + tags, values, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// ResourceApplyAdamAttr is an optional argument to ResourceApplyAdam. +type ResourceApplyAdamAttr func(optionalAttr) + +// ResourceApplyAdamUseLocking sets the optional use_locking attribute to value. +// +// value: If `True`, updating of the var, m, and v tensors will be protected +// by a lock; otherwise the behavior is undefined, but may exhibit less +// contention. +// If not specified, defaults to false +func ResourceApplyAdamUseLocking(value bool) ResourceApplyAdamAttr { + return func(m optionalAttr) { + m["use_locking"] = value + } +} + +// ResourceApplyAdamUseNesterov sets the optional use_nesterov attribute to value. +// +// value: If `True`, uses the nesterov update. +// If not specified, defaults to false +func ResourceApplyAdamUseNesterov(value bool) ResourceApplyAdamAttr { + return func(m optionalAttr) { + m["use_nesterov"] = value + } +} + +// Update '*var' according to the Adam algorithm. +// +// lr_t <- learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t) +// m_t <- beta1 * m_{t-1} + (1 - beta1) * g_t +// v_t <- beta2 * v_{t-1} + (1 - beta2) * g_t * g_t +// variable <- variable - lr_t * m_t / (sqrt(v_t) + epsilon) +// +// Arguments: +// var_: Should be from a Variable(). +// m: Should be from a Variable(). +// v: Should be from a Variable(). +// beta1_power: Must be a scalar. +// beta2_power: Must be a scalar. +// lr: Scaling factor. Must be a scalar. +// beta1: Momentum factor. Must be a scalar. +// beta2: Momentum factor. Must be a scalar. +// epsilon: Ridge term. Must be a scalar. +// grad: The gradient. +// +// Returns the created operation. +func ResourceApplyAdam(scope *Scope, var_ tf.Output, m tf.Output, v tf.Output, beta1_power tf.Output, beta2_power tf.Output, lr tf.Output, beta1 tf.Output, beta2 tf.Output, epsilon tf.Output, grad tf.Output, optional ...ResourceApplyAdamAttr) (o *tf.Operation) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "ResourceApplyAdam", + Input: []tf.Input{ + var_, m, v, beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, + }, + Attrs: attrs, + } + return scope.AddOperation(opspec) +} + +// 3D fast Fourier transform. +// +// Computes the 3-dimensional discrete Fourier transform over the inner-most 3 +// dimensions of `input`. +// +// Arguments: +// input: A complex64 tensor. +// +// Returns A complex64 tensor of the same shape as `input`. The inner-most 3 +// dimensions of `input` are replaced with their 3D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.fftn with 3 dimensions. +// @end_compatibility +func FFT3D(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "FFT3D", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Given a quantized tensor described by (input, input_min, input_max), outputs a // // range that covers the actual values present in that tensor. This op is diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml index 6d4c10f4caa..462e38d3324 100644 --- a/tensorflow/java/maven/libtensorflow/pom.xml +++ b/tensorflow/java/maven/libtensorflow/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.2.0-rc2 + 1.2.0 ../ libtensorflow diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml index 89b7c6528f7..4002de97bde 100644 --- a/tensorflow/java/maven/libtensorflow_jni/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.2.0-rc2 + 1.2.0 ../ libtensorflow_jni diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml index e8d8fe63781..f1d8aa1715f 100644 --- a/tensorflow/java/maven/pom.xml +++ b/tensorflow/java/maven/pom.xml @@ -6,7 +6,7 @@ 4.0.0 org.tensorflow parentpom - 1.2.0-rc2 + 1.2.0 pom https://www.tensorflow.org diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml index 1192cfe1c37..f2974220d95 100644 --- a/tensorflow/java/maven/proto/pom.xml +++ b/tensorflow/java/maven/proto/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.2.0-rc2 + 1.2.0 ../ proto diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml index 31fb0151099..21c22074048 100644 --- a/tensorflow/java/maven/tensorflow/pom.xml +++ b/tensorflow/java/maven/tensorflow/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.2.0-rc2 + 1.2.0 ../ tensorflow diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index 88d5980835c..ad556c8e253 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -47,6 +47,7 @@ py_library( deps = [ ":tf_optimizer", ":array_ops", + ":bitwise_ops", ":check_ops", ":client", ":client_testlib", @@ -1056,6 +1057,17 @@ tf_gen_op_wrapper_private_py( ], ) +tf_gen_op_wrapper_private_py( + name = "bitwise_ops_gen", + require_shape_functions = True, + visibility = [ + "//learning/brain/python/ops:__pkg__", + "//tensorflow/compiler/tests:__pkg__", + "//tensorflow/contrib/quantization:__pkg__", + "//tensorflow/python/kernel_tests:__pkg__", + ], +) + tf_gen_op_wrapper_private_py( name = "candidate_sampling_ops_gen", require_shape_functions = True, @@ -1264,6 +1276,17 @@ py_library( ], ) +py_library( + name = "bitwise_ops", + srcs = ["ops/bitwise_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":bitwise_ops_gen", + ":framework", + ":util", + ], +) + py_library( name = "sets", srcs = [ @@ -1424,6 +1447,7 @@ py_library( deps = [ ":array_grad", ":array_ops", + ":bitwise_ops", ":control_flow_grad", ":control_flow_ops", ":framework", @@ -2128,6 +2152,18 @@ py_library( ], ) +cuda_py_test( + name = "bitwise_ops_test", + size = "small", + srcs = ["ops/bitwise_ops_test.py"], + additional_deps = [ + ":bitwise_ops", + ":constant_op", + ":dtypes", + ":framework_test_lib", + ], +) + cuda_py_test( name = "control_flow_ops_test", size = "small", @@ -2465,6 +2501,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//third_party/py/numpy", + "@org_python_pypi_backports_weakref", "@protobuf//:protobuf_python", "@six_archive//:six", ], @@ -3148,7 +3185,6 @@ cuda_py_tests( "training/adagrad_da_test.py", "training/adagrad_test.py", "training/basic_loops_test.py", - "training/checkpoint_utils_test.py", "training/coordinator_test.py", "training/device_setter_test.py", "training/ftrl_test.py", @@ -3352,6 +3388,27 @@ py_test( ], ) +py_test( + name = "checkpoint_utils_test", + size = "small", + srcs = ["training/checkpoint_utils_test.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + ":client", + ":client_testlib", + ":framework_for_generated_wrappers", + ":io_ops", + ":partitioned_variables", + ":platform", + ":pywrap_tensorflow", + ":state_ops", + ":training", + ":variable_scope", + ":variables", + ], +) + py_test( name = "monitored_session_test", size = "small", @@ -3712,6 +3769,26 @@ cuda_py_test( main = "ops/concat_benchmark.py", ) +cuda_py_test( + name = "conv2d_benchmark", + size = "large", + srcs = ["ops/conv2d_benchmark.py"], + additional_deps = [ + ":client", + ":client_testlib", + ":control_flow_ops", + ":framework_for_generated_wrappers", + ":nn_ops", + ":platform", + ":platform_benchmark", + ":random_ops", + ":variables", + "//third_party/py/numpy", + "//tensorflow/core:protos_all_py", + ], + main = "ops/conv2d_benchmark.py", +) + cuda_py_test( name = "split_benchmark", srcs = ["ops/split_benchmark.py"], @@ -3801,7 +3878,13 @@ py_test( ":client_testlib", ":framework_for_generated_wrappers", ":math_ops", + ":nn", + ":random_seed", + ":session", ":tf_optimizer", + ":training", + ":variable_scope", + ":variables", "//tensorflow/core:protos_all_py", "//third_party/py/numpy", ], @@ -3833,6 +3916,20 @@ py_library( deps = [":pywrap_tensorflow_internal"], ) +py_binary( + name = "cost_analyzer_tool", + srcs = [ + "grappler/cost_analyzer_tool.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":cost_analyzer", + ":framework_for_generated_wrappers", + ":tf_optimizer", + "//tensorflow/core:protos_all_py", + ], +) + py_test( name = "cost_analyzer_test", size = "small", diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 5357d36330a..db2c5b1fd99 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -77,6 +77,7 @@ from tensorflow.python.ops.standard_ops import * from tensorflow.python.estimator import estimator_lib as estimator from tensorflow.python.feature_column import feature_column_lib as feature_column from tensorflow.python.layers import layers +from tensorflow.python.ops import bitwise_ops as bitwise from tensorflow.python.ops import image_ops as image from tensorflow.python.ops import metrics from tensorflow.python.ops import nn @@ -131,7 +132,6 @@ from tensorflow.python.ops import tensor_array_ops # documentation, or remove. _allowed_symbols = [ 'AttrValue', - 'AutoParallelOptions', 'ConfigProto', 'ClusterDef', 'DeviceSpec', @@ -148,7 +148,6 @@ _allowed_symbols = [ 'NameAttrList', 'NodeDef', 'OptimizerOptions', - 'RewriterConfig', 'RunOptions', 'RunMetadata', 'SessionLog', @@ -211,6 +210,7 @@ _allowed_symbols.extend([ # Export modules and constants. _allowed_symbols.extend([ 'app', + 'bitwise', 'compat', 'errors', 'estimator', diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index ad2ee13db58..849f3cc5e15 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -971,7 +971,6 @@ class BaseSession(SessionInterface): TypeError: If `fetches` or `feed_dict` keys are of an inappropriate type. tf.errors.OpError: Or one of its subclasses if a TensorFlow error happens. """ - assert not self._created_with_new_api, 'Partial runs don\'t work with C API' def _feed_fn(feed): for tensor_type, _, _, feed_fn in _REGISTERED_EXPANSIONS: @@ -999,7 +998,12 @@ class BaseSession(SessionInterface): try: subfeed_t = self.graph.as_graph_element(subfeed, allow_tensor=True, allow_operation=False) - feed_list.append(compat.as_bytes(subfeed_t.name)) + if self._created_with_new_api: + # pylint: disable=protected-access + feed_list.append(subfeed_t._as_tf_output()) + # pylint: enable=protected-access + else: + feed_list.append(compat.as_bytes(subfeed_t.name)) except Exception as e: e.message = ('Cannot interpret feed_list key as Tensor: ' + e.message) @@ -1014,12 +1018,24 @@ class BaseSession(SessionInterface): def _setup_fn(session, feed_list, fetch_list, target_list): self._extend_graph() with errors.raise_exception_on_not_ok_status() as status: - return tf_session.TF_PRunSetup(session, feed_list, fetch_list, - target_list, status) + if self._created_with_new_api: + return tf_session.TF_SessionPRunSetup_wrapper( + session, feed_list, fetch_list, target_list, status) + else: + return tf_session.TF_PRunSetup(session, feed_list, fetch_list, + target_list, status) - return self._do_call(_setup_fn, self._session, feed_list, - _name_list(fetch_handler.fetches()), - _name_list(fetch_handler.targets())) + if self._created_with_new_api: + # pylint: disable=protected-access + final_fetches = [t._as_tf_output() for t in fetch_handler.fetches()] + final_targets = [op._c_op for op in fetch_handler.targets()] + # pylint: enable=protected-access + else: + final_fetches = _name_list(fetch_handler.fetches()) + final_targets = _name_list(fetch_handler.targets()) + + return self._do_call(_setup_fn, self._session, feed_list, final_fetches, + final_targets) def _run(self, handle, fetches, feed_dict, options, run_metadata): """Perform either run or partial_run, depending the presence of `handle`.""" @@ -1248,13 +1264,15 @@ class BaseSession(SessionInterface): status, run_metadata) def _prun_fn(session, handle, feed_dict, fetch_list): - assert not self._created_with_new_api, ('Partial runs don\'t work with ' - 'C API') if target_list: raise RuntimeError('partial_run() requires empty target_list.') with errors.raise_exception_on_not_ok_status() as status: - return tf_session.TF_PRun(session, handle, feed_dict, fetch_list, - status) + if self._created_with_new_api: + return tf_session.TF_SessionPRun_wrapper(session, handle, feed_dict, + fetch_list, status) + else: + return tf_session.TF_PRun(session, handle, feed_dict, fetch_list, + status) if handle is None: return self._do_call(_run_fn, self._session, feeds, fetches, targets, diff --git a/tensorflow/python/client/session_partial_run_test.py b/tensorflow/python/client/session_partial_run_test.py index 9e0eca2089e..33b90e6156f 100644 --- a/tensorflow/python/client/session_partial_run_test.py +++ b/tensorflow/python/client/session_partial_run_test.py @@ -33,25 +33,15 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import googletest from tensorflow.python.training import server_lib -ops._USE_C_API = True # NOTE(mrry): Dummy shape registration for ops used in the tests, since they # don't have C++ op registrations on which to attach C++ shape fns. ops.RegisterShape('ConstructionFails')(common_shapes.unknown_shape) -class PartialRunTest(test_util.TensorFlowTestCase): +class PartialRunTestMethods(object): - def setUp(self): - # Partial runs don't work with C API - ops._USE_C_API = False - super(PartialRunTest, self).setUp() - - def tearDown(self): - ops._USE_C_API = True - super(PartialRunTest, self).tearDown() - - def runTestPartialRun(self, sess): + def RunTestPartialRun(self, sess): a = array_ops.placeholder(dtypes.float32, shape=[]) b = array_ops.placeholder(dtypes.float32, shape=[]) c = array_ops.placeholder(dtypes.float32, shape=[]) @@ -73,7 +63,7 @@ class PartialRunTest(test_util.TensorFlowTestCase): res = sess.partial_run(h2, r2, feed_dict={c: temp}) self.assertEqual(162, res) - def runTestPartialRunIncomplete(self, sess): + def RunTestPartialRunIncomplete(self, sess): a = array_ops.placeholder(dtypes.float32, shape=[]) b = array_ops.placeholder(dtypes.float32, shape=[]) c = array_ops.placeholder(dtypes.float32, shape=[]) @@ -84,7 +74,7 @@ class PartialRunTest(test_util.TensorFlowTestCase): res = sess.partial_run(h, r1, feed_dict={a: 1, b: 2}) self.assertEqual(3, res) - def runTestConcurrentPartialRun(self, sess): + def RunTestConcurrentPartialRun(self, sess): a = array_ops.placeholder(dtypes.float32, shape=[]) b = array_ops.placeholder(dtypes.float32, shape=[]) c = array_ops.placeholder(dtypes.float32, shape=[]) @@ -101,7 +91,7 @@ class PartialRunTest(test_util.TensorFlowTestCase): res = sess.partial_run(h2, r2, feed_dict={c: 7}) self.assertEqual(462, res) - def runTestManyPartialRun(self, sess): + def RunTestManyPartialRun(self, sess): steps = 200 inputs = [] outputs = [] @@ -123,7 +113,7 @@ class PartialRunTest(test_util.TensorFlowTestCase): self.assertEqual(steps, len(res)) self.assertEqual(2.0, res[-1]) - def runTestRunAndPartialRun(self, sess): + def RunTestRunAndPartialRun(self, sess): a = constant_op.constant(2.0, dtypes.float32) b = a * 2 c = b * 3 @@ -132,7 +122,7 @@ class PartialRunTest(test_util.TensorFlowTestCase): r2 = sess.partial_run(h, [b, c]) self.assertEqual(r1, r2) - def runTestPartialRunMissingPlaceholderFeedException(self, sess): + def RunTestPartialRunMissingPlaceholderFeedException(self, sess): x = array_ops.placeholder(dtypes.float32, shape=()) fetches = [x * 2, x * 3] handle = sess.partial_run_setup(fetches=fetches, feeds=[]) @@ -140,7 +130,7 @@ class PartialRunTest(test_util.TensorFlowTestCase): 'You must feed a value for placeholder'): sess.partial_run(handle, fetches[0]) - def runTestPartialRunUnspecifiedFeed(self, sess): + def RunTestPartialRunUnspecifiedFeed(self, sess): a = array_ops.placeholder(dtypes.float32, shape=[]) b = array_ops.placeholder(dtypes.float32, shape=[]) c = array_ops.placeholder(dtypes.float32, shape=[]) @@ -151,7 +141,7 @@ class PartialRunTest(test_util.TensorFlowTestCase): 'was not specified in partial_run_setup.$'): sess.partial_run(h, r1, feed_dict={a: 1, b: 2, c: 3}) - def runTestPartialRunUnspecifiedFetch(self, sess): + def RunTestPartialRunUnspecifiedFetch(self, sess): a = array_ops.placeholder(dtypes.float32, shape=[]) b = array_ops.placeholder(dtypes.float32, shape=[]) c = array_ops.placeholder(dtypes.float32, shape=[]) @@ -163,7 +153,7 @@ class PartialRunTest(test_util.TensorFlowTestCase): 'was not specified in partial_run_setup.$'): sess.partial_run(h, r2, feed_dict={a: 1, c: 3}) - def runTestPartialRunAlreadyFed(self, sess): + def RunTestPartialRunAlreadyFed(self, sess): a = array_ops.placeholder(dtypes.float32, shape=[]) b = array_ops.placeholder(dtypes.float32, shape=[]) c = array_ops.placeholder(dtypes.float32, shape=[]) @@ -176,7 +166,7 @@ class PartialRunTest(test_util.TensorFlowTestCase): 'has already been fed.$'): sess.partial_run(h, r2, feed_dict={a: 1, c: 3}) - def runTestPartialRunAlreadyFetched(self, sess): + def RunTestPartialRunAlreadyFetched(self, sess): a = array_ops.placeholder(dtypes.float32, shape=[]) b = array_ops.placeholder(dtypes.float32, shape=[]) c = array_ops.placeholder(dtypes.float32, shape=[]) @@ -189,7 +179,7 @@ class PartialRunTest(test_util.TensorFlowTestCase): 'has already been fetched.$'): sess.partial_run(h, r1, feed_dict={c: 3}) - def runTestPartialRunEmptyFetches(self, sess): + def RunTestPartialRunEmptyFetches(self, sess): a = array_ops.placeholder(dtypes.float32) b = a * 2.0 @@ -207,82 +197,109 @@ class PartialRunTest(test_util.TensorFlowTestCase): sess.partial_run_setup(fetches=[], feeds=[x]) def testPartialRunDirect(self): - self.runTestPartialRun(session.Session()) + self.RunTestPartialRun(session.Session()) def testPartialRunIncompleteDirect(self): - self.runTestPartialRunIncomplete(session.Session()) + self.RunTestPartialRunIncomplete(session.Session()) def testConcurrentPartialRunDirect(self): - self.runTestConcurrentPartialRun(session.Session()) + self.RunTestConcurrentPartialRun(session.Session()) def testManyPartialRunDirect(self): - self.runTestManyPartialRun(session.Session()) + self.RunTestManyPartialRun(session.Session()) def testRunAndPartialRunDirect(self): - self.runTestRunAndPartialRun(session.Session()) + self.RunTestRunAndPartialRun(session.Session()) def testPartialRunMissingPlaceholderFeedExceptionDirect(self): - self.runTestPartialRunMissingPlaceholderFeedException(session.Session()) + self.RunTestPartialRunMissingPlaceholderFeedException(session.Session()) def testPartialRunUnspecifiedFeedDirect(self): - self.runTestPartialRunUnspecifiedFeed(session.Session()) + self.RunTestPartialRunUnspecifiedFeed(session.Session()) def testPartialRunUnspecifiedFetchDirect(self): - self.runTestPartialRunUnspecifiedFetch(session.Session()) + self.RunTestPartialRunUnspecifiedFetch(session.Session()) def testPartialRunAlreadyFedDirect(self): - self.runTestPartialRunAlreadyFed(session.Session()) + self.RunTestPartialRunAlreadyFed(session.Session()) def testPartialRunAlreadyFetchedDirect(self): - self.runTestPartialRunAlreadyFetched(session.Session()) + self.RunTestPartialRunAlreadyFetched(session.Session()) def testPartialRunEmptyFetchesDirect(self): - self.runTestPartialRunEmptyFetches(session.Session()) + self.RunTestPartialRunEmptyFetches(session.Session()) def testPartialRunDist(self): server = server_lib.Server.create_local_server() - self.runTestPartialRun(session.Session(server.target)) + self.RunTestPartialRun(session.Session(server.target)) def testPartialRunIncompleteDist(self): server = server_lib.Server.create_local_server() - self.runTestPartialRunIncomplete(session.Session(server.target)) + self.RunTestPartialRunIncomplete(session.Session(server.target)) def testConcurrentPartialRunDist(self): server = server_lib.Server.create_local_server() - self.runTestConcurrentPartialRun(session.Session(server.target)) + self.RunTestConcurrentPartialRun(session.Session(server.target)) def testManyPartialRunDist(self): server = server_lib.Server.create_local_server() - self.runTestManyPartialRun(session.Session(server.target)) + self.RunTestManyPartialRun(session.Session(server.target)) def testRunAndPartialRunDist(self): server = server_lib.Server.create_local_server() - self.runTestRunAndPartialRun(session.Session(server.target)) + self.RunTestRunAndPartialRun(session.Session(server.target)) def testPartialRunMissingPlaceholderFeedExceptionDist(self): server = server_lib.Server.create_local_server() - self.runTestPartialRunMissingPlaceholderFeedException( + self.RunTestPartialRunMissingPlaceholderFeedException( session.Session(server.target)) def testPartialRunUnspecifiedFeedDist(self): server = server_lib.Server.create_local_server() - self.runTestPartialRunUnspecifiedFeed(session.Session(server.target)) + self.RunTestPartialRunUnspecifiedFeed(session.Session(server.target)) def testPartialRunUnspecifiedFetchDist(self): server = server_lib.Server.create_local_server() - self.runTestPartialRunUnspecifiedFetch(session.Session(server.target)) + self.RunTestPartialRunUnspecifiedFetch(session.Session(server.target)) def testPartialRunAlreadyFedDist(self): server = server_lib.Server.create_local_server() - self.runTestPartialRunAlreadyFed(session.Session(server.target)) + self.RunTestPartialRunAlreadyFed(session.Session(server.target)) def testPartialRunAlreadyFetchedDist(self): server = server_lib.Server.create_local_server() - self.runTestPartialRunAlreadyFetched(session.Session(server.target)) + self.RunTestPartialRunAlreadyFetched(session.Session(server.target)) def testPartialRunEmptyFetchesDist(self): server = server_lib.Server.create_local_server() - self.runTestPartialRunEmptyFetches(session.Session(server.target)) + self.RunTestPartialRunEmptyFetches(session.Session(server.target)) + + +class PartialRunTest(PartialRunTestMethods, test_util.TensorFlowTestCase): + """Test case that invokes test methods with _USE_C_API=False.""" + + def setUp(self): + self.prev_use_c_api = ops._USE_C_API + ops._USE_C_API = False + super(PartialRunTest, self).setUp() + + def tearDown(self): + ops._USE_C_API = self.prev_use_c_api + super(PartialRunTest, self).tearDown() + + +class PartialRunWithCApiTest(PartialRunTestMethods, + test_util.TensorFlowTestCase): + """Test case that invokes test methods with _USE_C_API=True.""" + + def setUp(self): + self.prev_use_c_api = ops._USE_C_API + ops._USE_C_API = True + super(PartialRunWithCApiTest, self).setUp() + + def tearDown(self): + ops._USE_C_API = self.prev_use_c_api + super(PartialRunWithCApiTest, self).tearDown() if __name__ == '__main__': diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py index a1f98059cd8..ef8f25c873d 100644 --- a/tensorflow/python/client/session_test.py +++ b/tensorflow/python/client/session_test.py @@ -140,7 +140,6 @@ class SessionTest(test_util.TensorFlowTestCase): with self.assertRaisesOpError(lambda e: e.op == a.op): a.eval() - @test_util.disable_c_api # Partial runs don't work with C API def testErrorCodeWithNoNodeDef(self): with session.Session() as s: a = array_ops.placeholder(dtypes.float32, shape=[]) @@ -1525,7 +1524,7 @@ class SessionTest(test_util.TensorFlowTestCase): sess.run(enqueue_op) self.assertEqual(sess.run(q.size()), num_epochs * 2) - @test_util.disable_c_api # Partial runs don't work with C API + @test_util.disable_c_api # set_device does not work with C API def testRegisterFetchAndFeedConversionFunctions(self): class SquaredTensor(object): def __init__(self, tensor): diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 4a62a96935e..284c98639d7 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -63,6 +63,11 @@ tensorflow::ImportNumpy(); // Constants used by TensorHandle (get_session_handle). %constant const char* TENSOR_HANDLE_KEY = tensorflow::SessionState::kTensorHandleResourceTypeName; +// Convert TF_OperationName output to unicode python string +%typemap(out) const char* TF_OperationName { + $result = PyUnicode_FromString($1); +} + //////////////////////////////////////////////////////////////////////////////// // BEGIN TYPEMAPS FOR tensorflow::TF_Run_wrapper() //////////////////////////////////////////////////////////////////////////////// @@ -112,6 +117,8 @@ tensorflow::ImportNumpy(); tensorflow::PyObjectVector temp) { $1 = &temp; } +// TODO(iga): move this and the corresponding typemap(argout) to +// tf_sessionrun_wrapper.i once we get rid of this code for DeprecatedSession. %typemap(in, numinputs=0) char** out_handle ( char* temp) { $1 = &temp; @@ -142,7 +149,7 @@ tensorflow::ImportNumpy(); %#else $result = PyUnicode_FromStringAndSize( %#endif - *$1, strlen(*$1)); + *$1, *$1 == nullptr ? 0 : strlen(*$1)); delete[] *$1; } @@ -163,29 +170,37 @@ tensorflow::ImportNumpy(); // Helper function to convert a Python list of Tensors to a C++ vector of // TF_Outputs. // -// Caller should have already checked that `py_tensor_list` is a list (this -// isn't done in this function to allow for function-specific error messages) -void PyTensorListToVector(PyObject* py_tensor_list, - std::vector* vec) { +// Returns true if successful. Otherwise, returns false and sets error_msg. +bool PyTensorListToVector(PyObject* py_tensor_list, + std::vector* vec, + string* error_msg) { + if (!PyList_Check(py_tensor_list)) { + *error_msg = "expected Python list."; + return false; + } size_t size = PyList_Size(py_tensor_list); for (int i = 0; i < size; ++i) { PyObject* item = PyList_GetItem(py_tensor_list, i); TF_Output* input_ptr; - SWIG_ConvertPtr(item, reinterpret_cast(&input_ptr), - SWIGTYPE_p_TF_Output, 0); + if (!SWIG_IsOK(SWIG_ConvertPtr(item, reinterpret_cast(&input_ptr), + SWIGTYPE_p_TF_Output, 0))) { + *error_msg = "expected Python list of wrapped TF_Output objects. " + "Found python list of something else."; + return false; + } vec->push_back(*input_ptr); } + return true; } %} // Converts input Python list of wrapped TF_Outputs into a single array %typemap(in) (const TF_Output* inputs, int num_inputs) (std::vector inputs) { - if (!PyList_Check($input)) { - SWIG_exception_fail( - SWIG_TypeError, "$symname: expected Python list of wrapped TF_Outputs"); + string error_msg; + if (!PyTensorListToVector($input, &inputs, &error_msg)) { + SWIG_exception_fail(SWIG_TypeError, ("$symname: " + error_msg).c_str()); } - PyTensorListToVector($input, &inputs); $1 = inputs.data(); $2 = inputs.size(); } @@ -211,6 +226,17 @@ void PyTensorListToVector(PyObject* py_tensor_list, // PyArray_Return, maybe others). %noexception TF_SessionRun_wrapper; +// We use TF_SessionPRunSetup_wrapper instead of TF_SessionPRunSetup +%ignore TF_SessionPRunSetup; +%unignore TF_SessionPRunSetup_wrapper; +// See comment for "%noexception TF_SessionRun_wrapper;" +%noexception TF_SessionPRunSetup_wrapper; + +// We use TF_SessionPRun_wrapper instead of TF_SessionPRun +%ignore TF_SessionPRun; +%unignore TF_SessionPRun_wrapper; +// See comment for "%noexception TF_SessionRun_wrapper;" +%noexception TF_SessionPRun_wrapper; %rename("_TF_SetTarget") TF_SetTarget; %rename("_TF_SetConfig") TF_SetConfig; diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc index 86088d0ab49..7ebb1a7fe4c 100644 --- a/tensorflow/python/client/tf_session_helper.cc +++ b/tensorflow/python/client/tf_session_helper.cc @@ -594,14 +594,6 @@ void TF_PRunSetup_wrapper(TF_DeprecatedSession* session, const_cast(output_names.data()), output_names.size(), const_cast(target_nodes.data()), target_nodes.size(), out_handle, out_status); - // TF_PRunSetup leaves out_handle undefined if it fails, but SWIG will call - // free(out_handle) on the returned handle regardless. Thus, must make sure it - // is valid. - if (TF_GetCode(out_status) != TF_OK) { - char* tmp = new char[1]; - tmp[0] = '\0'; - *out_handle = tmp; - } Py_END_ALLOW_THREADS; } @@ -623,7 +615,7 @@ void TF_Reset_wrapper(const TF_SessionOptions* opt, out_status); } -void TF_SessionRun_wrapper_helper(TF_Session* session, +void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle, const TF_Buffer* run_options, const std::vector& inputs, const std::vector& input_ndarrays, @@ -678,10 +670,16 @@ void TF_SessionRun_wrapper_helper(TF_Session* session, // Call TF_SessionRun() (and release GIL during execution) Py_BEGIN_ALLOW_THREADS; - TF_SessionRun(session, run_options, inputs.data(), input_vals.data(), - inputs.size(), outputs.data(), output_vals.data(), - outputs.size(), targets.data(), targets.size(), run_metadata, - out_status); + if (handle == nullptr) { + TF_SessionRun(session, run_options, inputs.data(), input_vals.data(), + inputs.size(), outputs.data(), output_vals.data(), + outputs.size(), targets.data(), targets.size(), run_metadata, + out_status); + } else { + TF_SessionPRun(session, handle, inputs.data(), input_vals.data(), + inputs.size(), outputs.data(), output_vals.data(), + outputs.size(), targets.data(), targets.size(), out_status); + } Py_END_ALLOW_THREADS; // Create scoped containers for output tensors @@ -716,9 +714,9 @@ void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options, const std::vector& targets, TF_Buffer* run_metadata, TF_Status* out_status, std::vector* py_outputs) { - TF_SessionRun_wrapper_helper(session, run_options, inputs, input_ndarrays, - outputs, targets, run_metadata, out_status, - py_outputs); + TF_SessionRun_wrapper_helper(session, nullptr, run_options, inputs, + input_ndarrays, outputs, targets, run_metadata, + out_status, py_outputs); // Release any unused ndarray references (see memory management comment in // TF_SessionRun_wrapper_helper) ClearDecrefCache(); @@ -737,4 +735,35 @@ string EqualGraphDefWrapper(const string& actual, const string& expected) { return EqualGraphDef(actual_def, expected_def, &diff) ? "" : diff; } +void TF_SessionPRunSetup_wrapper(TF_Session* session, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& targets, + const char** out_handle, + TF_Status* out_status) { + // Call TF_SessionPRunSetup() (and release GIL during execution) + Py_BEGIN_ALLOW_THREADS; + TF_SessionPRunSetup(session, inputs.data(), inputs.size(), outputs.data(), + outputs.size(), targets.data(), targets.size(), + out_handle, out_status); + Py_END_ALLOW_THREADS; +} + +void TF_SessionPRun_wrapper(TF_Session* session, const char* handle, + const std::vector& inputs, + const std::vector& input_ndarrays, + const std::vector& outputs, + TF_Status* out_status, + std::vector* py_outputs) { + const std::vector targets; + TF_SessionRun_wrapper_helper(session, handle, + nullptr, // run_options + inputs, input_ndarrays, outputs, targets, + nullptr, // run_metadata + out_status, py_outputs); + // Release any unused ndarray references (see memory management comment in + // TF_SessionRun_wrapper_helper) + ClearDecrefCache(); +} + } // namespace tensorflow diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h index 727e8ade52f..9937b6aeeb3 100644 --- a/tensorflow/python/client/tf_session_helper.h +++ b/tensorflow/python/client/tf_session_helper.h @@ -81,8 +81,6 @@ void TF_Run_wrapper(TF_DeprecatedSession* session, const TF_Buffer* run_options, // // On failure, out_status contains a tensorflow::Status with an error // message. -// -// NOTE: This is EXPERIMENTAL and subject to change. void TF_PRunSetup_wrapper(TF_DeprecatedSession* session, const NameVector& input_names, const NameVector& output_names, @@ -101,8 +99,6 @@ void TF_PRunSetup_wrapper(TF_DeprecatedSession* session, // // On failure, out_status contains a tensorflow::Status with an error // message. -// -// NOTE: This is EXPERIMENTAL and subject to change. void TF_PRun_wrapper(TF_DeprecatedSession* session, const char* handle, PyObject* feed_dict, const NameVector& output_names, TF_Status* out_status, PyObjectVector* out_values); @@ -128,6 +124,40 @@ void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options, const std::vector& targets, TF_Buffer* run_metadata, TF_Status* out_status, std::vector* py_outputs); + +// Set up the graph with the intended feeds (inputs) and fetches (output) for +// a sequence of partial run calls. +// +// On success, returns a handle that can be used for subsequent PRun calls. The +// handle is owned by the caller and should be deleted with TF_DeletePRunHandle +// when it is no longer needed. +// +// On failure, out_status contains a tensorflow::Status with an error +// message. +void TF_SessionPRunSetup_wrapper(TF_Session* session, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& targets, + const char** out_handle, + TF_Status* out_status); + +// Continue to run the graph with additional feeds and fetches. The +// execution state is uniquely identified by the handle. +// +// On success, `py_outputs` is populated with a numpy ndarray for each output +// (the caller must decref these ndarrays, although this will likely be handled +// by the Python gc). `session`, `handle`, `out_status`, and `py_outputs` must +// be non-null. `py_outputs` should be empty. +// +// On failure, out_status contains a tensorflow::Status with an error +// message. +void TF_SessionPRun_wrapper(TF_Session* session, const char* handle, + const std::vector& inputs, + const std::vector& input_ndarrays, + const std::vector& outputs, + TF_Status* out_status, + std::vector* py_outputs); + } // namespace tensorflow #endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_ diff --git a/tensorflow/python/client/tf_sessionrun_wrapper.i b/tensorflow/python/client/tf_sessionrun_wrapper.i index 289792fef26..473bc3ccc53 100644 --- a/tensorflow/python/client/tf_sessionrun_wrapper.i +++ b/tensorflow/python/client/tf_sessionrun_wrapper.i @@ -73,13 +73,17 @@ tensorflow::ImportNumpy(); // $input is a Python list of wrapped TF_Outputs %typemap(in) (const std::vector& outputs) (std::vector outputs) { - if (!PyList_Check($input)) { - SWIG_exception_fail(SWIG_TypeError, "$symname: expected list"); + string error_msg; + if (!PyTensorListToVector($input, &outputs, &error_msg)) { + SWIG_exception_fail(SWIG_TypeError, ("$symname: " + error_msg).c_str()); } - PyTensorListToVector($input, &outputs); $1 = &outputs; } +// Apply the typemap above to inputs as well +%typemap(in) (const std::vector& inputs) = + (const std::vector& outputs); + // Create temporary py_outputs_vec variable to store return value %typemap(in, numinputs=0) (std::vector* py_outputs) (std::vector py_outputs_vec) { diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index 39446b6ca27..54ae5047ca7 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -367,6 +367,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":debug_data", + "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", "//third_party/py/numpy", diff --git a/tensorflow/python/debug/__init__.py b/tensorflow/python/debug/__init__.py index 750d21f80d3..147e42d878f 100644 --- a/tensorflow/python/debug/__init__.py +++ b/tensorflow/python/debug/__init__.py @@ -21,6 +21,7 @@ See the @{$python/tfdbg} guide. @@watch_graph_with_blacklists @@DebugTensorDatum @@DebugDumpDir +@@load_tensor_from_event @@load_tensor_from_event_file @@has_inf_or_nan @@DumpingDebugHook @@ -40,6 +41,7 @@ from __future__ import print_function from tensorflow.python.debug.lib.debug_data import DebugDumpDir from tensorflow.python.debug.lib.debug_data import DebugTensorDatum from tensorflow.python.debug.lib.debug_data import has_inf_or_nan +from tensorflow.python.debug.lib.debug_data import load_tensor_from_event from tensorflow.python.debug.lib.debug_data import load_tensor_from_event_file from tensorflow.python.debug.lib.debug_utils import add_debug_tensor_watch diff --git a/tensorflow/python/debug/cli/curses_ui.py b/tensorflow/python/debug/cli/curses_ui.py index 6a571c097ee..498e346393b 100644 --- a/tensorflow/python/debug/cli/curses_ui.py +++ b/tensorflow/python/debug/cli/curses_ui.py @@ -513,6 +513,18 @@ class CursesUI(base_ui.BaseUI): def get_help(self): return self._command_handler_registry.get_help() + def _addstr(self, *args): + try: + self._stdscr.addstr(*args) + except curses.error: + pass + + def _refresh_pad(self, pad, *args): + try: + pad.refresh(*args) + except curses.error: + pass + def _screen_create_command_textbox(self, existing_command=None): """Create command textbox on screen. @@ -522,8 +534,8 @@ class CursesUI(base_ui.BaseUI): """ # Display the tfdbg prompt. - self._stdscr.addstr(self._max_y - self._command_textbox_height, 0, - self.CLI_PROMPT, curses.A_BOLD) + self._addstr(self._max_y - self._command_textbox_height, 0, + self.CLI_PROMPT, curses.A_BOLD) self._stdscr.refresh() self._command_window.clear() @@ -948,7 +960,7 @@ class CursesUI(base_ui.BaseUI): color_pair = (self._default_color_pair if color is None else self._color_pairs[color]) - self._stdscr.addstr(row, 0, line, color_pair | attr) + self._addstr(row, 0, line, color_pair | attr) self._screen_refresh() def _screen_new_output_pad(self, rows, cols): @@ -1235,10 +1247,9 @@ class CursesUI(base_ui.BaseUI): def _screen_scroll_output_pad(self, pad, viewport_top, viewport_left, screen_location_top, screen_location_left, screen_location_bottom, screen_location_right): - pad.refresh(viewport_top, viewport_left, screen_location_top, - screen_location_left, screen_location_bottom, - screen_location_right) - + self._refresh_pad(pad, viewport_top, viewport_left, screen_location_top, + screen_location_left, screen_location_bottom, + screen_location_right) self._scroll_bar = ScrollBar( self._max_x - 2, 3, @@ -1249,9 +1260,9 @@ class CursesUI(base_ui.BaseUI): (scroll_pad, _, _) = self._display_lines( self._scroll_bar.layout(), self._output_num_rows - 1) - scroll_pad.refresh( - 0, 0, self._output_top_row + 1, self._max_x - 2, - self._output_num_rows + 1, self._max_x - 1) + self._refresh_pad(scroll_pad, 0, 0, self._output_top_row + 1, + self._max_x - 2, self._output_num_rows + 1, + self._max_x - 1) def _scroll_output(self, direction, line_index=None): """Scroll the output pad. @@ -1332,15 +1343,14 @@ class CursesUI(base_ui.BaseUI): def _screen_render_nav_bar(self): if self._nav_bar_pad: - self._nav_bar_pad.refresh(0, 0, self._nav_bar_row, 0, - self._output_pad_screen_location.top, - self._max_x) + self._refresh_pad(self._nav_bar_pad, 0, 0, self._nav_bar_row, 0, + self._output_pad_screen_location.top, self._max_x) def _screen_render_menu_pad(self): if self._main_menu_pad: - self._main_menu_pad.refresh(0, 0, self._output_pad_screen_location.top, 0, - self._output_pad_screen_location.top, - self._max_x) + self._refresh_pad( + self._main_menu_pad, 0, 0, self._output_pad_screen_location.top, 0, + self._output_pad_screen_location.top, self._max_x) def _compile_ui_status_summary(self): """Compile status summary about this Curses UI instance. diff --git a/tensorflow/python/debug/lib/debug_data.py b/tensorflow/python/debug/lib/debug_data.py index 0cdf1891272..24214c0dddd 100644 --- a/tensorflow/python/debug/lib/debug_data.py +++ b/tensorflow/python/debug/lib/debug_data.py @@ -22,6 +22,7 @@ import collections import glob import json import os +import platform import numpy as np import six @@ -40,11 +41,19 @@ METADATA_FILE_PREFIX = "_tfdbg_" CORE_METADATA_TAG = "core_metadata_" GRAPH_FILE_TAG = "graph_" DEVICE_TAG = "device_" +HASH_TAG = "hash" FETCHES_INFO_FILE_TAG = "fetches_info_" FEED_KEYS_INFO_FILE_TAG = "feed_keys_info_" +def _glob(glob_pattern): + if platform.system() == "Windows": + return glob.glob(glob_pattern) + else: + return gfile.Glob(glob_pattern) + + class InconvertibleTensorProto(object): """Represents a TensorProto that cannot be converted to np.ndarray.""" @@ -679,7 +688,7 @@ class DebugDumpDir(object): def _load_all_device_dumps(self, partition_graphs, validate): """Load the dump data for all devices.""" - device_dirs = glob.glob(os.path.join( + device_dirs = _glob(os.path.join( self._dump_root, METADATA_FILE_PREFIX + DEVICE_TAG + "*")) self._device_names = [] @@ -762,7 +771,7 @@ class DebugDumpDir(object): self._t0 = min(t0s) if t0s else None def _load_core_metadata(self): - core_metadata_files = glob.glob(os.path.join( + core_metadata_files = _glob(os.path.join( self._dump_root, METADATA_FILE_PREFIX + CORE_METADATA_TAG + "*")) for core_metadata_file in core_metadata_files: with gfile.Open(core_metadata_file, "rb") as f: @@ -772,7 +781,7 @@ class DebugDumpDir(object): extract_core_metadata_from_event_proto(event)) def _load_fetches_info(self): - fetches_info_files = glob.glob(os.path.join( + fetches_info_files = _glob(os.path.join( self._dump_root, METADATA_FILE_PREFIX + FETCHES_INFO_FILE_TAG + "*")) self._run_fetches_info = [] for fetches_info_file in fetches_info_files: @@ -780,7 +789,7 @@ class DebugDumpDir(object): _load_log_message_from_event_file(fetches_info_file)) def _load_feeds_info(self): - feeds_info_files = glob.glob(os.path.join( + feeds_info_files = _glob(os.path.join( self._dump_root, METADATA_FILE_PREFIX + FEED_KEYS_INFO_FILE_TAG + "*")) self._run_feed_keys_info = [] for feeds_info_file in feeds_info_files: diff --git a/tensorflow/python/debug/lib/debug_data_test.py b/tensorflow/python/debug/lib/debug_data_test.py index 70dc8c11500..eff70b662bd 100644 --- a/tensorflow/python/debug/lib/debug_data_test.py +++ b/tensorflow/python/debug/lib/debug_data_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import os +import platform import shutil import tempfile @@ -27,7 +28,9 @@ from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import tensor_pb2 from tensorflow.python.debug.lib import debug_data from tensorflow.python.framework import test_util +from tensorflow.python.platform import gfile from tensorflow.python.platform import googletest +from tensorflow.python.platform import test class DeviceNamePathConversionTest(test_util.TensorFlowTestCase): @@ -339,6 +342,38 @@ class DebugDumpDirTest(test_util.TensorFlowTestCase): self.assertIsNone(dump_dir.t0) self.assertEqual([], dump_dir.dumped_tensor_data) + def testDebugDumpDir_usesGfileGlob(self): + if platform.system() == "Windows": + self.skipTest("gfile.Glob is not used on Windows.") + + self._makeDataDirWithMultipleDevicesAndDuplicateNodeNames() + + def fake_gfile_glob(glob_pattern): + del glob_pattern + return [] + + with test.mock.patch.object( + gfile, "Glob", side_effect=fake_gfile_glob, autospec=True) as fake: + debug_data.DebugDumpDir(self._dump_root) + expected_calls = [ + test.mock.call(os.path.join( + self._dump_root, + (debug_data.METADATA_FILE_PREFIX + + debug_data.CORE_METADATA_TAG + "*"))), + test.mock.call(os.path.join( + self._dump_root, + (debug_data.METADATA_FILE_PREFIX + + debug_data.FETCHES_INFO_FILE_TAG + "*"))), + test.mock.call(os.path.join( + self._dump_root, + (debug_data.METADATA_FILE_PREFIX + + debug_data.FEED_KEYS_INFO_FILE_TAG + "*"))), + test.mock.call(os.path.join( + self._dump_root, + (debug_data.METADATA_FILE_PREFIX + + debug_data.DEVICE_TAG + "*")))] + fake.assert_has_calls(expected_calls, any_order=True) + class GetNodeNameAndOutputSlotTest(test_util.TensorFlowTestCase): diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py index 013a43a8b08..b7a943d2fb9 100644 --- a/tensorflow/python/estimator/canned/dnn.py +++ b/tensorflow/python/estimator/canned/dnn.py @@ -188,6 +188,8 @@ class DNNClassifier(estimator.Estimator): name. Both features' `value` must be a `SparseTensor`. - if `column` is a `_DenseColumn`, a feature with `key=column.name` whose `value` is a `Tensor`. + + Loss is calculated by using softmax cross entropy. """ def __init__(self, @@ -319,6 +321,8 @@ class DNNRegressor(estimator.Estimator): name. Both features' `value` must be a `SparseTensor`. - if `column` is a `_DenseColumn`, a feature with `key=column.name` whose `value` is a `Tensor`. + + Loss is calculated by using mean squared error. """ def __init__(self, diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py index 935f6564eb5..03bcf7ae571 100644 --- a/tensorflow/python/estimator/canned/dnn_linear_combined.py +++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py @@ -295,6 +295,7 @@ class DNNLinearCombinedClassifier(estimator.Estimator): - if `column` is a `_DenseColumn`, a feature with `key=column.name` whose `value` is a `Tensor`. + Loss is calculated by using softmax cross entropy. """ def __init__(self, @@ -453,6 +454,7 @@ class DNNLinearCombinedRegressor(estimator.Estimator): - if `column` is a `_DenseColumn`, a feature with `key=column.name` whose `value` is a `Tensor`. + Loss is calculated by using mean squared error. """ def __init__(self, diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py index fd929b260bd..552b1bdf01e 100644 --- a/tensorflow/python/estimator/canned/linear.py +++ b/tensorflow/python/estimator/canned/linear.py @@ -151,6 +151,8 @@ class LinearClassifier(estimator.Estimator): Both features' `value` must be a `SparseTensor`. - if `column` is a `RealValuedColumn`, a feature with `key=column.name` whose `value` is a `Tensor`. + + Loss is calculated by using softmax cross entropy. """ def __init__(self, @@ -260,6 +262,8 @@ class LinearRegressor(estimator.Estimator): key=weight column name, value=a `SparseTensor`} - if isinstance(column, `RealValuedColumn`): key=column.name, value=a `Tensor` + + Loss is calculated by using mean squared error. """ def __init__(self, diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 293aa752531..ab49f36e5e8 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -603,12 +603,15 @@ class Estimator(object): if not (estimator_spec.scaffold.saver or ops.get_collection(ops.GraphKeys.SAVERS)): - ops.add_to_collection(ops.GraphKeys.SAVERS, - training.Saver( - sharded=True, - max_to_keep=self._config.keep_checkpoint_max, - defer_build=True, - save_relative_paths=True)) + ops.add_to_collection( + ops.GraphKeys.SAVERS, + training.Saver( + sharded=True, + max_to_keep=self._config.keep_checkpoint_max, + keep_checkpoint_every_n_hours=( + self._config.keep_checkpoint_every_n_hours), + defer_build=True, + save_relative_paths=True)) chief_hooks = [] if (self._config.save_checkpoints_secs or @@ -862,7 +865,8 @@ def _write_dict_to_summary(output_dir, value.simple_value = int(dictionary[key]) else: logging.warn( - 'Skipping summary for %s, must be a float, np.float32, np.int64, np.int32 or int.', + 'Skipping summary for %s, must be a float, np.float32, np.int64, ' + 'np.int32 or int.', key) summary_writer.add_summary(summary_proto, current_global_step) summary_writer.flush() diff --git a/tensorflow/python/estimator/inputs/numpy_io.py b/tensorflow/python/estimator/inputs/numpy_io.py index b31c5492d86..c9f37f06e83 100644 --- a/tensorflow/python/estimator/inputs/numpy_io.py +++ b/tensorflow/python/estimator/inputs/numpy_io.py @@ -55,6 +55,7 @@ def numpy_input_fn(x, of numpy arrays. The dict `features` has the same keys as the `x`. Example: + ```python age = np.arange(4) * 1.0 height = np.arange(32, 36) diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py index 025e2136206..5043e7285fe 100644 --- a/tensorflow/python/framework/importer.py +++ b/tensorflow/python/framework/importer.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import contextlib import copy @@ -311,9 +312,10 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, compute_shapes=False, compute_device=False, op_def=op_def) - # Maps from a node to the op it is colocated with, if colocation + # Maps from a node to the ops it is colocated with, if colocation # is specified in the attributes. - colocation_pairs = {} + colocation_pairs = collections.defaultdict(list) + # 2. Add inputs to the operations. for node in graph_def.node: op = name_to_op[node.name] @@ -339,7 +341,7 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, 'loc:@' + original_op.name)) if op_to_bind_to != node.name: # Keep track of this mapping for a later phase. - colocation_pairs[op] = original_op + colocation_pairs[op].append(original_op) # Don't apply this op's device function, # the colocation constraint will ensure # the proper device gets assigned at runtime. @@ -474,13 +476,24 @@ def import_graph_def(graph_def, input_map=None, return_elements=None, # The following loop populates the device field of ops that are # colocated with another op. This is implied by the colocation # attribute, but we propagate the device field for completeness. - for op, coloc_op in colocation_pairs.items(): - # If the colocation op has no device, even after a device - # application, there's nothing to do here. - if not coloc_op.device: - continue - coloc_device = pydev.DeviceSpec.from_string(coloc_op.device) - op._set_device(coloc_device) # pylint: disable=protected-access + for op, coloc_op_list in colocation_pairs.items(): + coloc_device = None + # Find any device in the list of colocated ops that have a + # device, if it exists. We assume that if multiple ops + # have devices, they refer to the same device. Otherwise, a + # runtime error will occur since the colocation property + # cannot be guaranteed. + # + # One possible improvement is to try to check for compatibility + # of all devices in this list at import time here, which would + # require implementing a compatibility function for device specs + # in python. + for coloc_op in coloc_op_list: + if coloc_op.device: + coloc_device = pydev.DeviceSpec.from_string(coloc_op.device) + break + if coloc_device: + op._set_device(coloc_device) # pylint: disable=protected-access # Treat unused input mappings as an error, because they are likely to be # due to a typo. diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index 7fdbcfd8561..5a683dc733e 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -682,6 +682,42 @@ class ImportGraphDefTest(test.TestCase): key: '_class' value { list { s: 'loc:@imported_graph/A' } } } }""", b.graph.as_graph_def()) + def testMultipleColocationWithDeviceFn(self): + original_graph_def = self._MakeGraphDef(""" + node { name: 'A' op: 'None'} + node { name: 'B' op: 'None'} + node { name: 'C' op: 'None' attr { + key: '_class' + value { list { s: 'loc:@A' s: 'loc:@B' } } + } }""") + + # A device function that places "B" on a device, and "A" is empty. + # + # B and C should contain "/device:B". A will not right now. But + # because of the colocation property, at runtime it would be + # placed with B and C. + def CustomDeviceFn(op): + if "B" in op.name: + return "/device:B:0" + return "" + + with ops.Graph().as_default(): + with ops.device(CustomDeviceFn): + c, = importer.import_graph_def( + original_graph_def, return_elements=["C"], name="imported_graph") + + self.assertProtoEqualsVersion(""" + node { name: 'imported_graph/A' op: 'None' } + node { name: 'imported_graph/B' op: 'None' device: "/device:B:0" } + node { name: 'imported_graph/C' op: 'None' device: "/device:B:0" + attr { + key: '_class' value { + list { s: 'loc:@imported_graph/A' + s: 'loc:@imported_graph/B' } + } + } + }""", c.graph.as_graph_def()) + def testNamePrefixColocationAttrsMultipleImport(self): original_graph_def = self._MakeGraphDef(""" node { name: 'A' op: 'None' } diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 46417c23246..ccd1099f80c 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -1350,7 +1350,13 @@ class Operation(object): @property def name(self): """The full name of this operation.""" - return self._node_def.name + if _USE_C_API: + # TODO(iga): Remove this assert after converting to C API by default. + # Just being a bit paranoid here. + assert self._node_def.name == c_api.TF_OperationName(self._c_op) + return c_api.TF_OperationName(self._c_op) + else: + return self._node_def.name @property def _id(self): diff --git a/tensorflow/python/grappler/cost_analyzer.cc b/tensorflow/python/grappler/cost_analyzer.cc index 29976b79495..88bf900dca6 100644 --- a/tensorflow/python/grappler/cost_analyzer.cc +++ b/tensorflow/python/grappler/cost_analyzer.cc @@ -30,11 +30,11 @@ CostAnalyzer::CostAnalyzer(const GrapplerItem& item, Cluster* cluster, analytical_estimator_(cluster, false), suffix_(suffix) {} -Status CostAnalyzer::GenerateReport(std::ostream& os) { +Status CostAnalyzer::GenerateReport(std::ostream& os, bool per_node_report) { GatherCosts(); PreprocessCosts(); AnalyzeCosts(); - PrintAnalysis(os); + PrintAnalysis(os, per_node_report); return Status::OK(); } @@ -158,7 +158,7 @@ void CostAnalyzer::AnalyzeCosts() { } } -void CostAnalyzer::PrintAnalysis(std::ostream& os) const { +void CostAnalyzer::PrintAnalysis(std::ostream& os, bool per_node_report) const { os << std::endl; os << std::left << std::setw(50) << "Total time measured in ns (serialized): " << std::right @@ -225,6 +225,11 @@ void CostAnalyzer::PrintAnalysis(std::ostream& os) const { os << std::endl; } os << std::endl; + + if (per_node_report) { + os << "Below is the per-node report:" << std::endl; + os << op_perf_.DebugString(); + } } } // end namespace grappler diff --git a/tensorflow/python/grappler/cost_analyzer.h b/tensorflow/python/grappler/cost_analyzer.h index 3700bf5fb37..0e860e0fee9 100644 --- a/tensorflow/python/grappler/cost_analyzer.h +++ b/tensorflow/python/grappler/cost_analyzer.h @@ -50,7 +50,7 @@ class CostAnalyzer { public: explicit CostAnalyzer(const GrapplerItem& item, Cluster* cluster, const string& suffix); - Status GenerateReport(std::ostream& os); + Status GenerateReport(std::ostream& os, bool per_node_report); private: void PredictCosts(CostEstimator* cost_estimator, CostGraphDef* cost_graph, @@ -59,7 +59,7 @@ class CostAnalyzer { void PreprocessCosts(); void AnalyzeCosts(); void SortOpsByTime(std::map ops); - void PrintAnalysis(std::ostream& os) const; + void PrintAnalysis(std::ostream& os, bool per_node_report) const; const GrapplerItem* item_; MeasuringCostEstimator measure_estimator_; diff --git a/tensorflow/python/grappler/cost_analyzer.i b/tensorflow/python/grappler/cost_analyzer.i index a51d8673c99..6066b6131ff 100644 --- a/tensorflow/python/grappler/cost_analyzer.i +++ b/tensorflow/python/grappler/cost_analyzer.i @@ -42,8 +42,10 @@ limitations under the License. %} %{ -string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph) { +string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool +per_node_report) { tensorflow::grappler::ItemConfig cfg; + cfg.apply_optimizations = false; std::unique_ptr item = tensorflow::grappler::GrapplerItemFromMetaGraphDef("metagraph", metagraph, cfg); @@ -53,16 +55,20 @@ string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph) { int num_cpu_cores = tensorflow::grappler::GetNumAvailableLogicalCPUCores(); int num_gpus = tensorflow::grappler::GetNumAvailableGPUs(); tensorflow::grappler::SingleMachine cluster(timeout_s, num_cpu_cores, num_gpus); + cluster.SetNumWarmupSteps(10); + cluster.AllowSoftPlacement(true); + cluster.DisableDetailedStats(false); TF_CHECK_OK(cluster.Provision()); string suffix; tensorflow::grappler::CostAnalyzer analyzer(*item, &cluster, suffix); std::stringstream os; - analyzer.GenerateReport(os); + analyzer.GenerateReport(os, per_node_report); return os.str(); } %} -string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph); +string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool +per_node_report); diff --git a/tensorflow/python/grappler/cost_analyzer.py b/tensorflow/python/grappler/cost_analyzer.py index d16614c7c75..75c21e57271 100644 --- a/tensorflow/python/grappler/cost_analyzer.py +++ b/tensorflow/python/grappler/cost_analyzer.py @@ -22,8 +22,19 @@ from tensorflow.python import pywrap_tensorflow as tf_wrap from tensorflow.python.framework import errors -def GenerateCostReport(metagraph): - """Analyze the cost of each TensorFlow operation in the provided metagraph.""" +def GenerateCostReport(metagraph, per_node_report=False): + """Analyze the cost of each TensorFlow op and node in the provided metagraph. + + Args: + metagraph: An TensorFlow MetaGraphDef. + per_node_report: by default the report contains stats aggregated on a per op + type basis, setting per_node_report to True adds results for each + individual node to the report. + + Returns: + A string of cost report. + """ with errors.raise_exception_on_not_ok_status(): - ret_from_swig = tf_wrap.GenerateCostReport(metagraph.SerializeToString()) + ret_from_swig = tf_wrap.GenerateCostReport(metagraph.SerializeToString(), + per_node_report) return ret_from_swig diff --git a/tensorflow/python/grappler/cost_analyzer_tool.py b/tensorflow/python/grappler/cost_analyzer_tool.py new file mode 100644 index 00000000000..146bb4311cb --- /dev/null +++ b/tensorflow/python/grappler/cost_analyzer_tool.py @@ -0,0 +1,69 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""A tool for cost analysis.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +from google.protobuf import text_format + +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.grappler import cost_analyzer +from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.platform import app +from tensorflow.python.platform import gfile + + +def main(_): + with gfile.GFile(FLAGS.input) as input_file: + metagraph = meta_graph_pb2.MetaGraphDef() + metagraph.ParseFromString(input_file.read()) + + if FLAGS.rewriter_config is not None: + rewriter_config = rewriter_config_pb2.RewriterConfig() + text_format.Merge(FLAGS.rewriter_config, rewriter_config) + optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph) + metagraph.graph_def.CopyFrom(optimized_graph) + + report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report) + print(report) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--input", type=str, default=None, help="Input .meta file path.") + parser.add_argument( + "--rewriter_config", + type=str, + default=None, + help="Configuration for the grappler optimizers, described as a " + "RewriterConfig protocol buffer. Usage example 1: " + "--rewriter_config='optimize_tensor_layout: true " + "disable_model_pruning: true'. Usage example 2: " + "--rewriter_config='optimizers: \"constfold\" optimizers: \"layout\"'") + parser.add_argument( + "--per_node_report", + action="store_true", + help="Generate per-node report. By default the report contains stats " + "aggregated on a per op type basis, per_node_report adds results " + "for each individual node to the report.") + FLAGS, unparsed = parser.parse_known_args() + app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/grappler/memory_optimizer_test.py b/tensorflow/python/grappler/memory_optimizer_test.py index 581f17c2ca2..4db8fa72451 100644 --- a/tensorflow/python/grappler/memory_optimizer_test.py +++ b/tensorflow/python/grappler/memory_optimizer_test.py @@ -18,16 +18,23 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed from tensorflow.python.grappler import tf_optimizer from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import training as train -class MemoryOptimizerTest(test.TestCase): +class MemoryOptimizerSwapTest(test.TestCase): """Tests the Grappler memory optimizer.""" def testNoSwapping(self): @@ -85,5 +92,51 @@ class MemoryOptimizerTest(test.TestCase): self.assertEqual('c', node.input[1]) +class MemoryOptimizerRecomputeTest(test.TestCase): + + def _RunGraphWithConfig(self, config, batch_size=14, image_dim=12): + """Run a simple layered graph with conv, an intermediate op, and a ReLU.""" + graph = ops.Graph() + with graph.as_default(): + random_seed.set_random_seed(1) + current_activation = variable_scope.get_variable( + name='start', shape=[batch_size, image_dim, image_dim, 5]) + conv_filter = variable_scope.get_variable( + name='filter', shape=[5, 5, 5, 5]) + for layer_number in range(10): + with variable_scope.variable_scope('layer_{}'.format(layer_number)): + after_conv = nn.conv2d(current_activation, conv_filter, [1, 1, 1, 1], + 'SAME') + current_activation = 2. * after_conv + current_activation = nn.relu(current_activation) + loss = math_ops.reduce_mean(current_activation) + optimizer = train.AdamOptimizer(0.001) + train_op = optimizer.minimize(loss) + init_op = variables.global_variables_initializer() + with session.Session(config=config, graph=graph) as sess: + sess.run(init_op) + sess.run(train_op) + sess.run(train_op) + return sess.run(loss) + + def _GetMemoryOptimizerConfig(self): + rewrite_options = rewriter_config_pb2.RewriterConfig( + memory_optimization=rewriter_config_pb2.RewriterConfig.HEURISTICS) + graph_options = config_pb2.GraphOptions(rewrite_options=rewrite_options) + return config_pb2.ConfigProto(graph_options=graph_options) + + def testRecomputationRewritingNoErrors(self): + """Tests that there are no errors when we request a memory optimizer pass. + + Does not test that the memory optimizer actually runs. See + core/grappler/optimizers/memory_optimizer_test.cc for a functional test of + the graph rewriting. + """ + original_loss = self._RunGraphWithConfig(config_pb2.ConfigProto()) + memory_optimized_loss = self._RunGraphWithConfig( + config=self._GetMemoryOptimizerConfig()) + self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-4) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py index 8ba9d0efff7..3298092fbea 100644 --- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py +++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py @@ -37,18 +37,21 @@ def ConfigsToTest(): Tuple (input_size, filter_size, out_size, stride, padding), the depthwise convolution parameters. """ - input_sizes = [[4, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], [4, 35, 35, 2], - [4, 147, 147, 2], [3, 299, 299, 3], [5, 183, 183, 1]] - filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [3, 1, 48, 4], [5, 5, 2, 1], - [3, 3, 2, 8], [2, 2, 3, 8], [5, 5, 1, 2]] - out_sizes = [[4, 5, 5, 96], [4, 8, 8, 84], [4, 17, 17, 192], [4, 35, 35, 2], - [4, 49, 49, 16], [3, 150, 150, 24], [5, 92, 92, 2]] - strides = [1, 1, 1, 1, 3, 2, 2] + input_sizes = [[4, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], [4, 9, 27, 8], + [4, 31, 31, 7], [4, 35, 35, 2], [4, 147, 147, 2], + [3, 299, 299, 3], [5, 183, 183, 1]] + filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [3, 1, 48, 4], [3, 3, 8, 1], + [3, 3, 7, 1], [5, 5, 2, 1], [3, 3, 2, 8], [2, 2, 3, + 8], [5, 5, 1, 2]] + out_sizes = [[4, 5, 5, 96], [4, 8, 8, 84], [4, 17, 17, 192], [4, 9, 27, 8], + [4, 31, 31, 7], [4, 35, 35, 2], [4, 49, 49, 16], + [3, 150, 150, 24], [5, 92, 92, 2]] + strides = [1, 1, 1, 1, 1, 1, 3, 2, 2] # pylint: disable=invalid-name VALID = "VALID" SAME = "SAME" # pylint: enable=invalid-name - paddings = [SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME] + paddings = [SAME, SAME, SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME] for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides, paddings): yield i, f, o, s, p diff --git a/tensorflow/python/kernel_tests/reader_ops_test.py b/tensorflow/python/kernel_tests/reader_ops_test.py index 12932219fc3..c7e1d88360f 100644 --- a/tensorflow/python/kernel_tests/reader_ops_test.py +++ b/tensorflow/python/kernel_tests/reader_ops_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import collections import gzip import os +import shutil import threading import zlib @@ -36,6 +37,8 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.util import compat +prefix_path = "tensorflow/core/lib" + # pylint: disable=invalid-name TFRecordCompressionType = tf_record.TFRecordCompressionType # pylint: enable=invalid-name @@ -858,48 +861,50 @@ class AsyncReaderTest(test.TestCase): output.append(sess.run(args)) -# TODO(jhseu): Restore after fixing. -#class LMDBReaderTest(test.TestCase): -# -# def setUp(self): -# super(LMDBReaderTest, self).setUp() -# -# def testReadFromFile(self): -# with self.test_session() as sess: -# reader = io_ops.LMDBReader(name="test_read_from_file") -# path = os.path.join("tensorflow", "core", "lib", "lmdb", "testdata", -# "data.mdb") -# queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) -# key, value = reader.read(queue) -# -# queue.enqueue([path]).run() -# queue.close().run() -# for i in range(10): -# k, v = sess.run([key, value]) -# self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(i))) -# self.assertAllEqual(compat.as_bytes(v), compat.as_bytes(str(chr(ord('a') + i)))) -# -# with self.assertRaisesOpError("is closed and has insufficient elements " -# "\\(requested 1, current size 0\\)"): -# k, v = sess.run([key, value]) -# -# def testReadFromFolder(self): -# with self.test_session() as sess: -# reader = io_ops.LMDBReader(name="test_read_from_folder") -# path = os.path.join("tensorflow", "core", "lib", "lmdb", "testdata") -# queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) -# key, value = reader.read(queue) -# -# queue.enqueue([path]).run() -# queue.close().run() -# for i in range(10): -# k, v = sess.run([key, value]) -# self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(i))) -# self.assertAllEqual(compat.as_bytes(v), compat.as_bytes(str(chr(ord('a') + i)))) -# -# with self.assertRaisesOpError("is closed and has insufficient elements " -# "\\(requested 1, current size 0\\)"): -# k, v = sess.run([key, value]) +class LMDBReaderTest(test.TestCase): + + def setUp(self): + super(LMDBReaderTest, self).setUp() + # Copy database out because we need the path to be writable to use locks. + path = os.path.join(prefix_path, "lmdb", "testdata", "data.mdb") + self.db_path = os.path.join(self.get_temp_dir(), "data.mdb") + shutil.copy(path, self.db_path) + + def testReadFromFile(self): + with self.test_session() as sess: + reader = io_ops.LMDBReader(name="test_read_from_file") + queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) + key, value = reader.read(queue) + + queue.enqueue([self.db_path]).run() + queue.close().run() + for i in range(10): + k, v = sess.run([key, value]) + self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(i))) + self.assertAllEqual( + compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + i)))) + + with self.assertRaisesOpError("is closed and has insufficient elements " + "\\(requested 1, current size 0\\)"): + k, v = sess.run([key, value]) + + def testReadFromFolder(self): + with self.test_session() as sess: + reader = io_ops.LMDBReader(name="test_read_from_folder") + queue = data_flow_ops.FIFOQueue(99, [dtypes.string], shapes=()) + key, value = reader.read(queue) + + queue.enqueue([self.db_path]).run() + queue.close().run() + for i in range(10): + k, v = sess.run([key, value]) + self.assertAllEqual(compat.as_bytes(k), compat.as_bytes(str(i))) + self.assertAllEqual( + compat.as_bytes(v), compat.as_bytes(str(chr(ord("a") + i)))) + + with self.assertRaisesOpError("is closed and has insufficient elements " + "\\(requested 1, current size 0\\)"): + k, v = sess.run([key, value]) if __name__ == "__main__": diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py index fdf1b134b9c..63c7280b3d7 100644 --- a/tensorflow/python/layers/convolutional.py +++ b/tensorflow/python/layers/convolutional.py @@ -388,7 +388,7 @@ class Conv2D(_Conv): filters: Integer, the dimensionality of the output space (i.e. the number of filters in the convolution). kernel_size: An integer or tuple/list of 2 integers, specifying the - width and height of the 2D convolution window. + height and width of the 2D convolution window. Can be a single integer to specify the same value for all spatial dimensions. strides: An integer or tuple/list of 2 integers, @@ -489,7 +489,7 @@ def conv2d(inputs, filters: Integer, the dimensionality of the output space (i.e. the number of filters in the convolution). kernel_size: An integer or tuple/list of 2 integers, specifying the - width and height of the 2D convolution window. + height and width of the 2D convolution window. Can be a single integer to specify the same value for all spatial dimensions. strides: An integer or tuple/list of 2 integers, diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index 780d1c2b8e0..ad0f202f959 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -123,6 +123,10 @@ class BatchNormalization(base.Layer): if self.fused and renorm: raise ValueError( 'Batch renorm is currently not supported with fused batch norm.') + if self.fused and (beta_regularizer is not None or + gamma_regularizer is not None): + raise ValueError('Regularizers are not currently ' + 'supported for fused batch norm.') if renorm: renorm_clipping = renorm_clipping or {} keys = ['rmax', 'rmin', 'dmax'] @@ -153,7 +157,12 @@ class BatchNormalization(base.Layer): ' is out of range for input with rank ' + str(ndim)) if self.fused is None: - self.fused = not self.renorm and ndim == 4 and axis in [1, 3] + # Currently fused batch norm doesn't support renorm and beta/gamma + # regularizer; and only supports an input tensor of rank 4 and a channel + # dimension on axis 1 and 3. + self.fused = not self.renorm and ndim == 4 and axis in [ + 1, 3 + ] and self.beta_regularizer is None and self.gamma_regularizer is None if self.fused: if axis == 1: diff --git a/tensorflow/python/layers/normalization_test.py b/tensorflow/python/layers/normalization_test.py index fa6c9c4a5db..64bebb1021c 100644 --- a/tensorflow/python/layers/normalization_test.py +++ b/tensorflow/python/layers/normalization_test.py @@ -143,45 +143,46 @@ class BNTest(test.TestCase): self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) def test4DInputAxis1(self): - epsilon = 1e-3 - bn = normalization_layers.BatchNormalization( - axis=1, epsilon=epsilon, momentum=0.9) - inputs = variables.Variable( - np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32) - training = array_ops.placeholder(dtype='bool') - outputs = bn.apply(inputs, training=training) + if test.is_gpu_available(cuda_only=True): + epsilon = 1e-3 + bn = normalization_layers.BatchNormalization( + axis=1, epsilon=epsilon, momentum=0.9) + inputs = variables.Variable( + np.random.random((5, 4, 3, 6)) + 100, dtype=dtypes.float32) + training = array_ops.placeholder(dtype='bool') + outputs = bn.apply(inputs, training=training) - with self.test_session() as sess: - # Test training with placeholder learning phase. - sess.run(variables.global_variables_initializer()) - np_gamma, np_beta = sess.run([bn.gamma, bn.beta]) - np_gamma = np.reshape(np_gamma, (1, 4, 1, 1)) - np_beta = np.reshape(np_beta, (1, 4, 1, 1)) - for _ in range(100): - np_output, _, _ = sess.run([outputs] + bn.updates, - feed_dict={training: True}) - # Verify that the axis is normalized during training. + with self.test_session(use_gpu=True) as sess: + # Test training with placeholder learning phase. + sess.run(variables.global_variables_initializer()) + np_gamma, np_beta = sess.run([bn.gamma, bn.beta]) + np_gamma = np.reshape(np_gamma, (1, 4, 1, 1)) + np_beta = np.reshape(np_beta, (1, 4, 1, 1)) + for _ in range(100): + np_output, _, _ = sess.run( + [outputs] + bn.updates, feed_dict={training: True}) + # Verify that the axis is normalized during training. + normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta + self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) + self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) + + # Verify that the statistics are updated during training. + moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance]) + np_inputs = sess.run(inputs) + mean = np.mean(np_inputs, axis=(0, 2, 3)) + std = np.std(np_inputs, axis=(0, 2, 3)) + variance = np.square(std) + self.assertAllClose(mean, moving_mean, atol=1e-2) + self.assertAllClose(variance, moving_var, atol=1e-2) + + # Test inference with placeholder learning phase. + np_output = sess.run(outputs, feed_dict={training: False}) + + # Verify that the axis is normalized during inference. normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) - # Verify that the statistics are updated during training. - moving_mean, moving_var = sess.run([bn.moving_mean, bn.moving_variance]) - np_inputs = sess.run(inputs) - mean = np.mean(np_inputs, axis=(0, 2, 3)) - std = np.std(np_inputs, axis=(0, 2, 3)) - variance = np.square(std) - self.assertAllClose(mean, moving_mean, atol=1e-2) - self.assertAllClose(variance, moving_var, atol=1e-2) - - # Test inference with placeholder learning phase. - np_output = sess.run(outputs, feed_dict={training: False}) - - # Verify that the axis is normalized during inference. - normed_np_output = ((np_output - epsilon) * np_gamma) + np_beta - self.assertAlmostEqual(np.mean(normed_np_output), 0., places=1) - self.assertAlmostEqual(np.std(normed_np_output), 1., places=1) - def test4DInputAxis2(self): epsilon = 1e-3 bn = normalization_layers.BatchNormalization( diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 86a59ff9e30..a8f596c7a3c 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -514,6 +514,10 @@ def slice(input_, begin, size, name=None): words, `begin[i]` is the offset into the 'i'th dimension of `input` that you want to slice from. + Note that @{tf.Tensor.__getitem__} is typically a more pythonic way to + perform slices, as it allows you to write `foo[3:7, :-2]` instead of + `tf.slice([3, 0], [4, foo.get_shape()[1]-2])`. + `begin` is zero-based; `size` is one-based. If `size[i]` is -1, all remaining elements in dimension i are included in the slice. In other words, this is equivalent to setting: diff --git a/tensorflow/python/ops/bitwise_ops.py b/tensorflow/python/ops/bitwise_ops.py new file mode 100644 index 00000000000..cbabc3ed9ba --- /dev/null +++ b/tensorflow/python/ops/bitwise_ops.py @@ -0,0 +1,40 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Operations for manipulating the binary representations of integers. + +@@bitwise_and +@@bitwise_or +@@bitwise_xor +@@invert +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.python.ops.gen_bitwise_ops import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +ops.NotDifferentiable("BitwiseAnd") +ops.NotDifferentiable("BitwiseOr") +ops.NotDifferentiable("BitwiseXor") +ops.NotDifferentiable("Invert") + +remove_undocumented(__name__) diff --git a/tensorflow/python/ops/bitwise_ops_test.py b/tensorflow/python/ops/bitwise_ops_test.py new file mode 100644 index 00000000000..904cf99a5ab --- /dev/null +++ b/tensorflow/python/ops/bitwise_ops_test.py @@ -0,0 +1,74 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for bitwise operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import bitwise_ops +from tensorflow.python.platform import googletest + + +class BitwiseOpTest(test_util.TensorFlowTestCase): + + def __init__(self, method_name="runTest"): + super(BitwiseOpTest, self).__init__(method_name) + + def testBinaryOps(self): + dtype_list = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, + dtypes.uint8, dtypes.uint16] + + with self.test_session(use_gpu=True) as sess: + for dtype in dtype_list: + lhs = constant_op.constant([0, 5, 3, 14], dtype=dtype) + rhs = constant_op.constant([5, 0, 7, 11], dtype=dtype) + and_result, or_result, xor_result = sess.run( + [bitwise_ops.bitwise_and(lhs, rhs), + bitwise_ops.bitwise_or(lhs, rhs), + bitwise_ops.bitwise_xor(lhs, rhs)]) + self.assertAllEqual(and_result, [0, 0, 3, 10]) + self.assertAllEqual(or_result, [5, 5, 7, 15]) + self.assertAllEqual(xor_result, [5, 5, 4, 5]) + + def testInvertOp(self): + dtype_list = [dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64, + dtypes.uint8, dtypes.uint16] + inputs = [0, 5, 3, 14] + with self.test_session(use_gpu=True) as sess: + for dtype in dtype_list: + # Because of issues with negative numbers, let's test this indirectly. + # 1. invert(a) and a = 0 + # 2. invert(a) or a = invert(0) + input_tensor = constant_op.constant(inputs, dtype=dtype) + not_a_and_a, not_a_or_a, not_0 = sess.run( + [bitwise_ops.bitwise_and( + input_tensor, bitwise_ops.invert(input_tensor)), + bitwise_ops.bitwise_or( + input_tensor, bitwise_ops.invert(input_tensor)), + bitwise_ops.invert(constant_op.constant(0, dtype=dtype))]) + self.assertAllEqual(not_a_and_a, [0, 0, 0, 0]) + self.assertAllEqual(not_a_or_a, [not_0] * 4) + # For unsigned dtypes let's also check the result directly. + if dtype.is_unsigned: + inverted = sess.run(bitwise_ops.invert(input_tensor)) + expected = [dtype.max - x for x in inputs] + self.assertAllEqual(inverted, expected) + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/ops/conv2d_benchmark.py b/tensorflow/python/ops/conv2d_benchmark.py new file mode 100644 index 00000000000..c6352b2a98c --- /dev/null +++ b/tensorflow/python/ops/conv2d_benchmark.py @@ -0,0 +1,141 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Benchmark for Conv2D op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import time + +from tensorflow.python.client import session as session_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +def build_graph(device, input_shape, filter_shape, strides, padding, num_iters): + """builds a graph containing a sequence of conv2d operations. + + Args: + device: String, the device to run on. + input_shape: Shape of the input tensor. + filter_shape: Shape of the filter tensor. + strides: A list of ints. 1-D of length 4. The stride of sliding + window for each dimension of input. + padding: A string from: "SAME", "VALID". The type of padding + algorithm to use. + num_iters: number of iterations to run conv2d. + + Returns: + An array of tensors to run() + """ + with ops.device("/%s:0" % device): + inp = variables.Variable(random_ops.truncated_normal(input_shape)) + filt = variables.Variable(random_ops.truncated_normal(filter_shape)) + + outputs = [] + conv2d_op = nn_ops.conv2d(inp, filt, strides, padding, data_format="NHWC") + outputs.append(conv2d_op) + for _ in range(1, num_iters): + with ops.control_dependencies([conv2d_op]): + conv2d_op = nn_ops.conv2d( + inp, filt, strides, padding, data_format="NHWC") + outputs.append(conv2d_op) + return control_flow_ops.group(*outputs) + + +class Conv2DBenchmark(test.Benchmark): + """Benchmark conv2d!""" + + def _run_graph(self, device, input_shape, filter_shape, strides, padding, + num_iters): + """runs the graph and print its execution time. + + Args: + device: String, the device to run on. + input_shape: Shape of the input tensor. + filter_shape: Shape of the filter tensor. + strides: A list of ints. 1-D of length 4. The stride of sliding + window for each dimension of input. + padding: A string from: "SAME", "VALID". The type of padding + algorithm to use. num_iters: Number of iterations to run the + benchmark. + num_iters: number of iterations to run conv2d. + + Returns: + The duration of the run in seconds. + """ + graph = ops.Graph() + with graph.as_default(): + outputs = build_graph(device, input_shape, filter_shape, strides, padding, + num_iters) + with session_lib.Session(graph=graph) as session: + variables.global_variables_initializer().run() + # warmup runs + session.run(outputs) + + start_time = time.time() + session.run(outputs) + duration = (time.time() - start_time) / num_iters + + print("%s inputshape:%s filtershape:%s strides:%s padding:%s " + "%d iters: %.8f sec" % + (device, str(input_shape).replace(" ", ""), + str(filter_shape).replace(" ", ""), + str(strides).replace(" ", ""), padding, num_iters, duration)) + + name_template = ( + "conv2d_{device}_input_shape_{inputshape}_filter_shape_{filtershape}_" + "strides_{strides}_padding_{padding}") + + self.report_benchmark( + name=name_template.format( + device=device, + inputshape=str(input_shape).replace(" ", ""), + filtershape=str(filter_shape).replace(" ", ""), + strides=str(strides).replace(" ", ""), + padding=padding).replace(" ", ""), + iters=num_iters, + wall_time=duration / num_iters) + + return duration + + def benchmark_conv2d(self): + print("conv2d benchmark:") + + h = 1000 + w = 1000 + fh = 3 + fw = 3 + input_shapes = [] + filter_shapes = [] + for b, c in itertools.product([4, 16, 32], [i for i in range(3, 16)]): + input_shapes += [[b, h, w, c]] + filter_shapes += [[fh, fw, c, b]] + strides = [[1, 2, 2, 1]] + paddings = ["VALID", "SAME"] + for ishape, fshape in zip(input_shapes, filter_shapes): + for stride in strides: + for padding in paddings: + self._run_graph("gpu", ishape, fshape, stride, padding, 80) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/io_ops.py b/tensorflow/python/ops/io_ops.py index 0b1a77969a0..8f5d0e5cd43 100644 --- a/tensorflow/python/ops/io_ops.py +++ b/tensorflow/python/ops/io_ops.py @@ -26,6 +26,7 @@ See the @{$python/io_ops} guide. @@WholeFileReader @@IdentityReader @@TFRecordReader +@@LMDBReader @@FixedLengthRecordReader @@decode_csv @@decode_raw diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py index 3ae9316fee2..dade0535892 100644 --- a/tensorflow/python/ops/lookup_ops.py +++ b/tensorflow/python/ops/lookup_ops.py @@ -691,8 +691,8 @@ class IdTableWithHashBuckets(LookupInterface): - emerson -> 0 - lake -> 1 - palmer -> 2 - - -> bucket id between 3 and 3 + num_oov_buckets, calculated by: - hash() % num_oov_buckets + vocab_size + - -> bucket id between 3 and 3 + num_oov_buckets - 1, calculated + by: hash() % num_oov_buckets + vocab_size If input_tensor is ["emerson", "lake", "palmer", "king", "crimson"], the lookup result is [0, 1, 2, 4, 7] @@ -870,7 +870,8 @@ def index_table_from_file(vocabulary_file=None, Any lookup of an out-of-vocabulary token will return a bucket ID based on its hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the `default_value`. - The bucket ID range is `[vocabulary size, vocabulary size + num_oov_buckets]`. + The bucket ID range is + `[vocabulary size, vocabulary size + num_oov_buckets - 1]`. The underlying table must be initialized by calling `tf.tables_initializer.run()` or `table.init.run()` once. @@ -977,7 +978,7 @@ def index_table_from_tensor(vocabulary_list, Any lookup of an out-of-vocabulary token will return a bucket ID based on its hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the `default_value`. - The bucket ID range is `[mapping size, mapping size + num_oov_buckets]`. + The bucket ID range is `[mapping size, mapping size + num_oov_buckets - 1]`. The underlying table must be initialized by calling `tf.tables_initializer.run()` or `table.init.run()` once. diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py index c00949a322a..bffd3275ee1 100644 --- a/tensorflow/python/ops/rnn_cell_impl.py +++ b/tensorflow/python/ops/rnn_cell_impl.py @@ -345,6 +345,8 @@ class BasicLSTMCell(RNNCell): Args: num_units: int, The number of units in the LSTM cell. forget_bias: float, The bias added to forget gates (see above). + Must set to `0.0` manually when restoring from CudnnLSTM-trained + checkpoints. state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. If False, they are concatenated along the column axis. The latter behavior will soon be deprecated. @@ -444,7 +446,8 @@ class LSTMCell(RNNCell): Use a variable_scope partitioner instead. forget_bias: Biases of the forget gate are initialized by default to 1 in order to reduce the scale of forgetting at the beginning of - the training. + the training. Must set it manually to `0.0` when restoring from + CudnnLSTM trained checkpoints. state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. If False, they are concatenated along the column axis. This latter behavior will soon be deprecated. diff --git a/tensorflow/python/ops/summary_ops.py b/tensorflow/python/ops/summary_ops.py index 4ad0862dcc7..3d19bc25261 100644 --- a/tensorflow/python/ops/summary_ops.py +++ b/tensorflow/python/ops/summary_ops.py @@ -131,7 +131,6 @@ def _tensor_summary_v2( # pylint: disable=invalid-name val = gen_logging_ops._tensor_summary_v2( tensor=tensor, tag=tag, - description="", name=scope, serialized_summary_metadata=serialized_summary_metadata) summary_op_util.collect(val, collections, [ops.GraphKeys.SUMMARIES]) diff --git a/tensorflow/python/ops/transpose_benchmark.py b/tensorflow/python/ops/transpose_benchmark.py index 6bd3fe5e5a0..cddefacf2ef 100644 --- a/tensorflow/python/ops/transpose_benchmark.py +++ b/tensorflow/python/ops/transpose_benchmark.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -32,7 +32,7 @@ from tensorflow.python.platform import test def build_graph(device, input_shape, perm, datatype, num_iters): - """Build a graph containing a sequence of conv2d operations. + """builds a graph containing a sequence of conv2d operations. Args: device: String, the device to run on. @@ -50,10 +50,12 @@ def build_graph(device, input_shape, perm, datatype, num_iters): t = constant_op.constant(inp, shape=input_shape) outputs = [] - outputs.append(array_ops.transpose(t, perm)) - for i in range(1, num_iters): - with ops.control_dependencies([outputs[i - 1]]): - outputs.append(array_ops.transpose(t, perm)) + transpose_op = array_ops.transpose(t, perm) + outputs.append(transpose_op) + for _ in range(1, num_iters): + with ops.control_dependencies([transpose_op]): + transpose_op = array_ops.transpose(t, perm) + outputs.append(transpose_op) return control_flow_ops.group(*outputs) @@ -61,7 +63,7 @@ class TransposeBenchmark(test.Benchmark): """Benchmark transpose!""" def _run_graph(self, device, input_shape, perm, num_iters, datatype): - """Run the graph and print its execution time. + """runs the graph and print its execution time. Args: device: String, the device to run on. @@ -82,9 +84,11 @@ class TransposeBenchmark(test.Benchmark): session.run(outputs) start_time = time.time() session.run(outputs) + duration = (time.time() - start_time) / num_iters - throughput = np.prod(np.array( - input_shape)) * datatype().itemsize * 2 / duration / 1e9 + throughput = np.prod( + np.array(input_shape)) * datatype().itemsize * 2 / duration / 1e9 + print("%s %s inputshape:%s perm:%s %d %.6fsec, %.4fGB/s." % (device, str(datatype), str(input_shape).replace(" ", ""), str(perm).replace(" ", ""), num_iters, duration, throughput)) @@ -108,19 +112,19 @@ class TransposeBenchmark(test.Benchmark): datatypes = [np.complex128, np.float64, np.float32, np.float16, np.int8] - small_shapes = [[2, 20, 20, 20, 16], [2, 16, 20, 20, 20]] * 2 + [[ - 2, 100, 100, 16 - ], [2, 16, 100, 100]] * 2 + [[2, 5000, 16], [2, 16, 5000]] * 2 - small_perms = [[0, 4, 1, 2, 3], [0, 2, 3, 4, 1]] + [[4, 1, 2, 3, 0]] * 2 + [ - [0, 3, 1, 2], [0, 2, 3, 1] - ] + [[3, 1, 2, 0]] * 2 + [[0, 2, 1]] * 2 + [[2, 1, 0]] * 2 + small_shapes = [[2, 20, 20, 20, 16], [2, 16, 20, 20, 20]] * 2 + small_shapes += [[2, 100, 100, 16], [2, 16, 100, 100]] * 2 + small_shapes += [[2, 5000, 16], [2, 16, 5000]] * 2 + small_perms = [[0, 4, 1, 2, 3], [0, 2, 3, 4, 1]] + [[4, 1, 2, 3, 0]] * 2 + small_perms += [[0, 3, 1, 2], [0, 2, 3, 1]] + [[3, 1, 2, 0]] * 2 + small_perms += [[0, 2, 1]] * 2 + [[2, 1, 0]] * 2 - large_shapes = [[2, 100, 100, 100, 32], [2, 100, 100, 100, 64]] * 2 + [[ - 2, 1000, 1000, 32 - ], [2, 1000, 1000, 64]] * 2 + [[2, 1000000, 32], [2, 1000000, 64]] * 2 - large_perms = [[0, 4, 1, 2, 3], [0, 2, 3, 4, 1]] + [[4, 1, 2, 3, 0]] * 2 + [ - [0, 3, 1, 2], [0, 2, 3, 1] - ] + [[3, 1, 2, 0]] * 2 + [[0, 2, 1]] * 2 + [[2, 1, 0]] * 2 + large_shapes = [[2, 100, 100, 100, 32], [2, 100, 100, 100, 64]] * 2 + large_shapes += [[2, 1000, 1000, 32], [2, 1000, 1000, 64]] * 2 + large_shapes += [[2, 1000000, 32], [2, 1000000, 64]] * 2 + large_perms = [[0, 4, 1, 2, 3], [0, 2, 3, 4, 1]] + [[4, 1, 2, 3, 0]] * 2 + large_perms += [[0, 3, 1, 2], [0, 2, 3, 1]] + [[3, 1, 2, 0]] * 2 + large_perms += [[0, 2, 1]] * 2 + [[2, 1, 0]] * 2 huge_shapes = [[2, 100, 100, 100, 128], [2, 1000, 1000, 128], [2, 1000000, 128]] * 2 @@ -143,5 +147,23 @@ class TransposeBenchmark(test.Benchmark): for ishape, perm in zip(huge_shapes, huge_perms): self._run_graph("gpu", ishape, perm, num_iters, datatype) + small_dim_large_shapes = [[2, 1000000, 3], [2, 3, 1000000], [2, 1000000, 8], + [2, 8, 1000000]] + small_dim_small_shapes = [[2, 5000, 3], [2, 3, 5000], [2, 5000, 8], + [2, 8, 5000]] + small_dim_perms = [[0, 2, 1]] * 4 + + num_iters = 320 + small_dim_large_shape_datatypes = [np.float64, np.float32, np.int8] + for datatype in small_dim_large_shape_datatypes: + for ishape, perm in zip(small_dim_large_shapes, small_dim_perms): + self._run_graph("gpu", ishape, perm, num_iters, datatype) + + small_dim_small_shape_datatypes = [np.complex128, np.float16] + for datatype in small_dim_small_shape_datatypes: + for ishape, perm in zip(small_dim_small_shapes, small_dim_perms): + self._run_graph("gpu", ishape, perm, num_iters, datatype) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index e6d71a48dfd..5c988bf95a8 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -287,16 +287,12 @@ class SavedModelBuilder(object): # The graph almost certainly previously contained at least one Saver, and # possibly several (e.g. one for loading a pretrained embedding, and another # for the model weights). However, a *new* Saver was just created that - # includes all of the variables. In the context of the SavedModel, this - # new Saver is the only one that needs to be retained. The associated - # checkpoint produced in add_meta_graph_and_variables() contains all of the - # variable values. Thus, any preexisting Savers are redundant and useless - # at best, but worse may break downstream graph-processing tools, and can be - # confusing during debugging. It is therefore safe and wise to set - # `clear_extraneous_savers` to `True`, since it removes both the extraneous - # SaverDefs and their associated Save/Restore Ops from the graph. - meta_graph_def = saver.export_meta_graph(clear_devices=clear_devices, - clear_extraneous_savers=True) + # includes all of the variables. Removing the preexisting ones was the + # motivation for the clear_extraneous_savers option, but it turns out that + # there are edge cases where that option breaks the graph. Until that is + # resolved, we just leave the option set to False for now. + # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. + meta_graph_def = saver.export_meta_graph(clear_devices=clear_devices) # Tag the meta graph def and add it to the SavedModel. self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) @@ -378,16 +374,12 @@ class SavedModelBuilder(object): # The graph almost certainly previously contained at least one Saver, and # possibly several (e.g. one for loading a pretrained embedding, and another # for the model weights). However, a *new* Saver was just created that - # includes all of the variables. In the context of the SavedModel, this - # new Saver is the only one that needs to be retained. The associated - # checkpoint that was saved just above contains all of the variable values. - # Thus, any preexisting Savers are redundant and useless at best, but worse - # may break downstream graph-processing tools, and can be confusing during - # debugging. It is therefore safe and wise to set `clear_extraneous_savers` - # to `True`, since it removes both the extraneous SaverDefs and their - # associated Save/Restore Ops from the graph. - meta_graph_def = saver.export_meta_graph(clear_devices=clear_devices, - clear_extraneous_savers=True) + # includes all of the variables. Removing the preexisting ones was the + # motivation for the clear_extraneous_savers option, but it turns out that + # there are edge cases where that option breaks the graph. Until that is + # resolved, we just leave the option set to False for now. + # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. + meta_graph_def = saver.export_meta_graph(clear_devices=clear_devices) # Tag the meta graph def and add it to the SavedModel. self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py index 32526521749..5ff954fd9f8 100644 --- a/tensorflow/python/saved_model/loader_impl.py +++ b/tensorflow/python/saved_model/loader_impl.py @@ -206,8 +206,11 @@ def load(sess, tags, export_dir, **saver_kwargs): break if not found_match: - raise RuntimeError("MetaGraphDef associated with tags " + str(tags).strip( - "[]") + " could not be found in SavedModel") + raise RuntimeError( + "MetaGraphDef associated with tags " + str(tags).strip("[]") + + " could not be found in SavedModel. To inspect available tag-sets in" + " the SavedModel, please use the SavedModel CLI: `saved_model_cli`" + ) # Build a saver by importing the meta graph def to load. saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs) diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 0eb9f49fed0..fcd6bc39547 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -39,7 +39,6 @@ from tensorflow.python.saved_model import loader from tensorflow.python.saved_model import main_op from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants -from tensorflow.python.training import saver as tf_saver from tensorflow.python.training import saver_test_utils from tensorflow.python.util import compat @@ -810,66 +809,6 @@ class SavedModelTest(test.TestCase): self.assertEqual( 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) - def testClearExtraneousSavers(self): - export_dir = os.path.join(test.get_temp_dir(), - "test_clear_extraneous_savers") - builder = saved_model_builder.SavedModelBuilder(export_dir) - - # Create a variable and a Saver. - with ops.Graph().as_default() as graph: - with session.Session( - target="", - config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: - self._init_and_validate_variable(sess, "v", 42) - - # Add two Savers, which should be removed in - # add_meta_graph_and_variables() in favor of the locally added one. - saver1 = tf_saver.Saver() - graph.add_to_collection(ops.GraphKeys.SAVERS, saver1) - saver2 = tf_saver.Saver() - graph.add_to_collection(ops.GraphKeys.SAVERS, saver2) - - # Confirm there are two SaverDefs. - savers = graph.get_collection(ops.GraphKeys.SAVERS) - self.assertEqual(2, len(savers)) - - # Confirm there are two Save and two Restore ops. - save_op_names = set([x.name for x in graph.get_operations() - if x.type == "SaveV2"]) - self.assertSetEqual(set(["save/SaveV2", "save_1/SaveV2"]), - save_op_names) - - restore_op_names = set([x.name for x in graph.get_operations() - if x.type == "RestoreV2"]) - self.assertSetEqual(set(["save/RestoreV2", "save_1/RestoreV2"]), - restore_op_names) - - # The SavedModel builder adds its own Saver' for a total of three. - builder.add_meta_graph_and_variables( - sess, [tag_constants.TRAINING], clear_devices=True) - - # Save the SavedModel to disk. - builder.save() - - # Restore the graph. - with ops.Graph().as_default() as graph: - with self.test_session(graph=graph) as sess: - loader.load(sess, [tag_constants.TRAINING], export_dir) - self.assertEqual( - 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) - - # Confirm that the reloaded graph has only one SaverDef. - savers = ops.get_collection(ops.GraphKeys.SAVERS) - self.assertEqual(1, len(savers)) - - # The reloaded graph should have exactly one Save and one Restore op. - save_op_names = set([x.name for x in graph.get_operations() - if x.type == "SaveV2"]) - self.assertSetEqual(set(["save_2/SaveV2"]), save_op_names) - restore_op_names = set([x.name for x in graph.get_operations() - if x.type == "RestoreV2"]) - self.assertSetEqual(set(["save_2/RestoreV2"]), restore_op_names) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py index 7ff01a51f3d..f3600793a62 100644 --- a/tensorflow/python/summary/summary.py +++ b/tensorflow/python/summary/summary.py @@ -20,6 +20,7 @@ See the @{$python/summary} guide. @@FileWriter @@FileWriterCache @@tensor_summary +@@_tensor_summary_v2 @@scalar @@histogram @@audio @@ -28,6 +29,7 @@ See the @{$python/summary} guide. @@merge @@merge_all @@get_summary_description +@@PluginAsset @@get_plugin_asset @@get_all_plugin_assets """ diff --git a/tensorflow/python/summary/text_summary.py b/tensorflow/python/summary/text_summary.py index 52bc913b2ad..2132dc6eb8a 100644 --- a/tensorflow/python/summary/text_summary.py +++ b/tensorflow/python/summary/text_summary.py @@ -23,13 +23,26 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from collections import namedtuple import json +from tensorflow.core.framework import summary_pb2 from tensorflow.python.framework import dtypes -from tensorflow.python.ops.summary_ops import tensor_summary +from tensorflow.python.ops.summary_ops import _tensor_summary_v2 from tensorflow.python.summary import plugin_asset +from tensorflow.python.util import deprecation + +PLUGIN_NAME = "text" + +# Contains event-related data specific to the text plugin. +_TextPluginData = namedtuple("_TextPluginData", []) +@deprecation.deprecated_args( + "2017-06-13", + "collections is deprecated. Instead of using collections to associate " + "plugins to events, add a PluginData field to the SummaryMetadata of a " + "Value proto.", "collections") def text_summary(name, tensor, collections=None): """Summarizes textual data. @@ -60,9 +73,16 @@ def text_summary(name, tensor, collections=None): raise ValueError("Expected tensor %s to have dtype string, got %s" % (tensor.name, tensor.dtype)) - t_summary = tensor_summary(name, tensor, collections=collections) - text_assets = plugin_asset.get_plugin_asset(TextSummaryPluginAsset) - text_assets.register_tensor(t_summary.op.name) + summary_metadata = summary_pb2.SummaryMetadata() + text_plugin_data = _TextPluginData() + data_dict = text_plugin_data._asdict() # pylint: disable=protected-access + summary_metadata.plugin_data.add( + plugin_name=PLUGIN_NAME, content=json.dumps(data_dict)) + t_summary = _tensor_summary_v2( + name=name, + tensor=tensor, + summary_metadata=summary_metadata, + collections=collections) return t_summary diff --git a/tensorflow/python/summary/text_summary_test.py b/tensorflow/python/summary/text_summary_test.py index 31009702ca4..4d357918f6c 100644 --- a/tensorflow/python/summary/text_summary_test.py +++ b/tensorflow/python/summary/text_summary_test.py @@ -17,7 +17,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import ops as framework_ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest @@ -43,16 +42,11 @@ class TextPluginTest(test_util.TensorFlowTestCase): # The API accepts vectors. arr = array_ops.constant(["one", "two", "three"]) summ = text_summary.text_summary("foo", arr) - self.assertEqual(summ.op.type, "TensorSummary") + self.assertEqual(summ.op.type, "TensorSummaryV2") # the API accepts scalars summ = text_summary.text_summary("foo", array_ops.constant("one")) - self.assertEqual(summ.op.type, "TensorSummary") - - def testTextSummaryCollections(self): - text_summary.text_summary("bar", array_ops.constant("2"), collections=[]) - summaries = framework_ops.get_collection(framework_ops.GraphKeys.SUMMARIES) - self.assertEqual(len(summaries), 0) + self.assertEqual(summ.op.type, "TensorSummaryV2") if __name__ == "__main__": diff --git a/tensorflow/python/summary/writer/writer.py b/tensorflow/python/summary/writer/writer.py index 05f97fb2841..8ce49d623d8 100644 --- a/tensorflow/python/summary/writer/writer.py +++ b/tensorflow/python/summary/writer/writer.py @@ -86,6 +86,14 @@ class SummaryToEventTransformer(object): meta_graph.create_meta_graph_def(graph_def=graph_def or maybe_graph_as_def)) + # This set contains tags of Summary Values that have been encountered + # already. The motivation here is that the SummaryWriter only keeps the + # metadata property (which is a SummaryMetadata proto) of the first Summary + # Value encountered for each tag. The SummaryWriter strips away the + # SummaryMetadata for all subsequent Summary Values with tags seen + # previously. This saves space. + self._seen_summary_tags = set() + def add_summary(self, summary, global_step=None): """Adds a `Summary` protocol buffer to the event file. @@ -108,6 +116,24 @@ class SummaryToEventTransformer(object): summ = summary_pb2.Summary() summ.ParseFromString(summary) summary = summ + + # We strip metadata from values with tags that we have seen before in order + # to save space - we just store the metadata on the first value with a + # specific tag. + for value in summary.value: + if not value.metadata: + continue + + if value.tag in self._seen_summary_tags: + # This tag has been encountered before. Strip the metadata. + value.ClearField("metadata") + continue + + # We encounter a value with a tag we have not encountered previously. And + # it has metadata. Remember to strip metadata from future values with this + # tag string. + self._seen_summary_tags.add(value.tag) + event = event_pb2.Event(summary=summary) self._add_event(event, global_step) diff --git a/tensorflow/python/summary/writer/writer_test.py b/tensorflow/python/summary/writer/writer_test.py index 8c34eb82e35..3d27b11cb9f 100644 --- a/tensorflow/python/summary/writer/writer_test.py +++ b/tensorflow/python/summary/writer/writer_test.py @@ -317,6 +317,63 @@ class SummaryWriterTestCase(test.TestCase): # We should be done. self.assertRaises(StopIteration, lambda: next(rr)) + def testPluginMetadataStrippedFromSubsequentEvents(self): + test_dir = self._CleanTestDir("basics") + sw = writer.FileWriter(test_dir) + + sw.add_session_log(event_pb2.SessionLog(status=SessionLog.START), 1) + + # We add 2 summaries with the same tags. They both have metadata. The writer + # should strip the metadata from the second one. + value = summary_pb2.Summary.Value(tag="foo", simple_value=10.0) + value.metadata.plugin_data.add(plugin_name="bar", content="... content ...") + sw.add_summary(summary_pb2.Summary(value=[value]), 10) + value = summary_pb2.Summary.Value(tag="foo", simple_value=10.0) + value.metadata.plugin_data.add(plugin_name="bar", content="... content ...") + sw.add_summary(summary_pb2.Summary(value=[value]), 10) + + sw.close() + rr = self._EventsReader(test_dir) + + # The first event should list the file_version. + ev = next(rr) + self._assertRecent(ev.wall_time) + self.assertEquals("brain.Event:2", ev.file_version) + + # The next event should be the START message. + ev = next(rr) + self._assertRecent(ev.wall_time) + self.assertEquals(1, ev.step) + self.assertEquals(SessionLog.START, ev.session_log.status) + + # This is the first event with tag foo. It should contain SummaryMetadata. + ev = next(rr) + self.assertProtoEquals(""" + value { + tag: "foo" + simple_value: 10.0 + metadata { + plugin_data { + plugin_name: "bar" + content: "... content ..." + } + } + } + """, ev.summary) + + # This is the second event with tag foo. It should lack SummaryMetadata + # because the file writer should have stripped it. + ev = next(rr) + self.assertProtoEquals(""" + value { + tag: "foo" + simple_value: 10.0 + } + """, ev.summary) + + # We should be done. + self.assertRaises(StopIteration, lambda: next(rr)) + def testFileWriterWithSuffix(self): test_dir = self._CleanTestDir("test_suffix") sw = writer.FileWriter(test_dir, filename_suffix="_test_suffix") diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py index d52cf9a4367..ddf04e21e61 100644 --- a/tensorflow/python/training/checkpoint_utils.py +++ b/tensorflow/python/training/checkpoint_utils.py @@ -20,6 +20,7 @@ from __future__ import print_function import six +from tensorflow.python import pywrap_tensorflow from tensorflow.python.ops import io_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope as vs @@ -27,7 +28,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver -from tensorflow.python.training import training as train + __all__ = [ "load_checkpoint", "load_variable", "list_variables", "init_from_checkpoint" @@ -55,7 +56,7 @@ def load_checkpoint(ckpt_dir_or_file): if filename is None: raise ValueError("Couldn't find 'checkpoint' file or checkpoints in " "given directory %s" % ckpt_dir_or_file) - return train.NewCheckpointReader(filename) + return pywrap_tensorflow.NewCheckpointReader(filename) def load_variable(ckpt_dir_or_file, name): diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py index f4ac3c97587..e2a7b28e2bc 100644 --- a/tensorflow/python/training/training.py +++ b/tensorflow/python/training/training.py @@ -85,6 +85,10 @@ See the @{$python/train} guide. @@create_global_step @@assert_global_step @@write_graph +@@load_checkpoint +@@load_variable +@@list_variables +@@init_from_checkpoint """ # Optimizers. @@ -142,6 +146,11 @@ from tensorflow.python.training.basic_session_run_hooks import GlobalStepWaiterH from tensorflow.python.training.basic_session_run_hooks import FinalOpsHook from tensorflow.python.training.basic_session_run_hooks import FeedFnHook from tensorflow.python.training.basic_loops import basic_train_loop +from tensorflow.python.training.checkpoint_utils import init_from_checkpoint +from tensorflow.python.training.checkpoint_utils import list_variables +from tensorflow.python.training.checkpoint_utils import load_checkpoint +from tensorflow.python.training.checkpoint_utils import load_variable + from tensorflow.python.training.device_setter import replica_device_setter from tensorflow.python.training.monitored_session import Scaffold from tensorflow.python.training.monitored_session import MonitoredTrainingSession diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 34a683ab575..062e3e59cac 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -2432,17 +2432,16 @@ bool CudnnSupport::DoTransformTensor(Stream* stream, dnn::DataType input_type, const DeviceMemoryBase& input_data, const dnn::BatchDescriptor& output_desc, - dnn::DataType output_type, + dnn::DataType output_type, float scale, DeviceMemoryBase* output_data) { mutex_lock lock{dnn_handle_mutex_}; - float alpha = 1.0f; float beta = 0.0f; ScopedTensorDescriptor input_tensor_desc( parent_, input_desc, ToCudnnDataType(input_type, input_desc.layout())); ScopedTensorDescriptor output_tensor_desc( parent_, output_desc, ToCudnnDataType(output_type, output_desc.layout())); cudnnStatus_t status = wrap::cudnnTransformTensor( - parent_, ToHandle(dnn_handle_), &alpha, input_tensor_desc.handle(), + parent_, ToHandle(dnn_handle_), &scale, input_tensor_desc.handle(), input_data.opaque(), &beta, output_tensor_desc.handle(), output_data->opaque()); if (status != CUDNN_STATUS_SUCCESS) { diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index cc37c8bb9f3..16fa656ef60 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -455,7 +455,7 @@ class CudnnSupport : public dnn::DnnSupport { dnn::DataType input_type, const DeviceMemoryBase& input_data, const dnn::BatchDescriptor& output_desc, - dnn::DataType output_type, + dnn::DataType output_type, float scale, DeviceMemoryBase* output_data) override; private: diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 8c8ac8662d1..f12aa6d38b5 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -1980,13 +1980,14 @@ class DnnSupport { // input_data: the device memory region that contains the input tensor. // output_desc: specifies the shape and the data layout of the output tensor. // output_type: the data type of the output tensor. + // scale: an element-wise scaling factor to apply. // output_data: the device memory region that contains the output tensor. virtual bool DoTransformTensor(Stream* stream, const dnn::BatchDescriptor& input_desc, dnn::DataType input_type, const DeviceMemoryBase& input_data, const dnn::BatchDescriptor& output_desc, - dnn::DataType output_type, + dnn::DataType output_type, float scale, DeviceMemoryBase* output_data) { return false; } diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index fff9accae32..cbf8c11ef10 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -4406,15 +4406,16 @@ Stream &Stream::ThenTransformTensor(const dnn::BatchDescriptor &input_desc, dnn::DataType input_type, const DeviceMemoryBase &input_data, const dnn::BatchDescriptor &output_desc, - dnn::DataType output_type, + dnn::DataType output_type, float scale, DeviceMemoryBase *output_data) { VLOG_CALL(PARAM(input_desc), PARAM(input_type), PARAM(input_data), - PARAM(output_desc), PARAM(output_type), PARAM(output_data)); + PARAM(output_desc), PARAM(output_type), PARAM(scale), + PARAM(output_data)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoTransformTensor(this, input_desc, input_type, input_data, output_desc, output_type, - output_data)); + scale, output_data)); } else { SetErrorAndLogNoDnnSupport(); } diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index b07d3021c92..2ab3f44af51 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -1655,7 +1655,7 @@ class Stream { dnn::DataType input_type, const DeviceMemoryBase &input_data, const dnn::BatchDescriptor &output_desc, - dnn::DataType output_type, + dnn::DataType output_type, float scale, DeviceMemoryBase *output_data); // The templated version of the above ThenTransformTensor. Useful when the diff --git a/tensorflow/tensorboard/BUILD b/tensorflow/tensorboard/BUILD deleted file mode 100644 index bbd4251731e..00000000000 --- a/tensorflow/tensorboard/BUILD +++ /dev/null @@ -1,84 +0,0 @@ -# Description: -# TensorBoard, a dashboard for investigating TensorFlow - -package(default_visibility = [":internal"]) - -licenses(["notice"]) # Apache 2.0 - -package_group( - name = "internal", - packages = [ - "//learning/brain/tensorboard/...", - "//learning/vis/...", - "//tensorflow/...", - "//tensorflow/tensorboard/...", - ], -) - -py_binary( - name = "tensorboard", - srcs = ["main.py"], - data = [":assets"], - main = "main.py", - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/tensorboard/backend:application", - "//tensorflow/tensorboard/backend/event_processing:event_file_inspector", - "//tensorflow/tensorboard/plugins/audio:audio_plugin", - "//tensorflow/tensorboard/plugins/distributions:distributions_plugin", - "//tensorflow/tensorboard/plugins/graphs:graphs_plugin", - "//tensorflow/tensorboard/plugins/histograms:histograms_plugin", - "//tensorflow/tensorboard/plugins/images:images_plugin", - "//tensorflow/tensorboard/plugins/projector:projector_plugin", - "//tensorflow/tensorboard/plugins/scalars:scalars_plugin", - "//tensorflow/tensorboard/plugins/text:text_plugin", - "@org_pocoo_werkzeug//:werkzeug", - ], -) - -py_library( - name = "expect_tensorflow_installed", - # This is a dummy rule used as a TensorFlow dependency in open-source. - # We expect TensorFlow to already be installed on the system, e.g. via - # `pip install tensorflow` -) - -py_library( - name = "expect_numpy_installed", - # This is a dummy rule used as a numpy dependency in open-source. - # We expect numpy to already be installed on the system, e.g. via - # `pip install numpy` -) - -filegroup( - name = "assets", - srcs = [ - "TAG", - "//tensorflow/tensorboard/components:index.html", - "//tensorflow/tensorboard/components:trace_viewer_index.html", - ], -) - -filegroup( - name = "ts_web_library_default_typings", - srcs = [ - # Ordering probably matters. - "@com_microsoft_typescript//:lib.es6.d.ts", - "@io_angular_clutz//:src/resources/closure.lib.d.ts", - "//tensorflow/tensorboard/defs:clutz.d.ts", - ], - visibility = ["//visibility:public"], -) - -filegroup( - name = "all_files", - srcs = glob( - ["**"], - exclude = [ - "METADATA", - "OWNERS", - "tensorboard.google.bzl", - ], - ), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/DEVELOPMENT.md b/tensorflow/tensorboard/DEVELOPMENT.md deleted file mode 100644 index 79d534a26c9..00000000000 --- a/tensorflow/tensorboard/DEVELOPMENT.md +++ /dev/null @@ -1,25 +0,0 @@ -# How to Develop TensorBoard - -## Launching a Development Instance - -Run the following to launch a demo of TensorBoard in raw sources mode: - -```sh -bazel run third_party/tensorflow/tensorboard/components/tf_tensorboard:demo -``` - -Now you can navigate to and play with -the demo TensorBoard instance. This will have live source reloading. - -This demo TensorBoard will have a small amount of demo data generated by -[generate_testdata.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/scripts/generate_testdata.py). -You can use [serialize_tensorboard.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/scripts/serialize_tensorboard.py) -to create a realistic demo directory from your own data files. - -## Launching TensorBoard Proper - -Running TensorBoard automatically asks Bazel to create a vulcanized HTML binary: - -```sh -bazel run //tensorflow/tensorboard:tensorboard -- --logdir=/path/to/logs -``` diff --git a/tensorflow/tensorboard/README.md b/tensorflow/tensorboard/README.md deleted file mode 100644 index a9ab4d3bd2a..00000000000 --- a/tensorflow/tensorboard/README.md +++ /dev/null @@ -1,366 +0,0 @@ -# TensorBoard - -TensorBoard is a suite of web applications for inspecting and understanding your -TensorFlow runs and graphs. - -This README gives an overview of key concepts in TensorBoard, as well as how to -interpret the visualizations TensorBoard provides. For an in-depth example of -using TensorBoard, see the tutorial: [TensorBoard: Visualizing -Learning](https://www.tensorflow.org/get_started/summaries_and_tensorboard). -For in-depth information on the Graph Visualizer, see this tutorial: [TensorBoard: Graph Visualization](https://www.tensorflow.org/get_started/graph_viz). - -You may also want to watch -[this video tutorial](https://www.youtube.com/watch?v=eBbEDRsCmv4) that walks -through setting up and using TensorBoard. - -# Usage - -Before running TensorBoard, make sure you have generated summary data in a log -directory by creating a summary writer: - -``` python -# sess.graph contains the graph definition; that enables the Graph Visualizer. - -file_writer = tf.summary.FileWriter('/path/to/logs', sess.graph) -``` - -For more details, see [the TensorBoard tutorial](https://www.tensorflow.org/get_started/summaries_and_tensorboard). -Once you have event files, run TensorBoard and provide the log directory. If -you're using a precompiled TensorFlow package (e.g. you installed via pip), run: - -``` -tensorboard --logdir=path/to/logs -``` - -Or, if you are building from source: - -``` -bazel build tensorflow/tensorboard:tensorboard -./bazel-bin/tensorflow/tensorboard/tensorboard --logdir=path/to/logs -``` - -This should print that TensorBoard has started. Next, connect to -http://localhost:6006. - -TensorBoard requires a `logdir` to read logs from. For info on configuring -TensorBoard, run `tensorboard --help`. - -TensorBoard can be used in Google Chrome or Firefox. Other browsers might -work, but there may be bugs or performance issues. - -# Key Concepts - -### Summary Ops: How TensorBoard gets data from TensorFlow - -The first step in using TensorBoard is acquiring data from your TensorFlow run. -For this, you need [summary ops](https://www.tensorflow.org/api_docs/python/tf/summary). -Summary ops are ops, like -[`tf.matmul`](https://www.tensorflow.org/versions/r1.2/api_docs/python/tf/matmul) -or -[`tf.nn.relu`](https://www.tensorflow.org/versions/master/api_docs/python/tf/nn/relu), -which means they take in tensors, produce tensors, and are evaluated from within -a TensorFlow graph. However, summary ops have a twist: the Tensors they produce -contain serialized protobufs, which are written to disk and sent to TensorBoard. -To visualize the summary data in TensorBoard, you should evaluate the summary -op, retrieve the result, and then write that result to disk using a -summary.FileWriter. A full explanation, with examples, is in [the -tutorial](https://www.tensorflow.org/get_started/summaries_and_tensorboard). - -The supported summary ops include: -* tf.summary.scalar -* tf.summary.image -* tf.summary.audio -* tf.summary.text -* tf.summary.histogram - -### Tags: Giving names to data - -When you make a summary op, you will also give it a `tag`. The tag is basically -a name for the data recorded by that op, and will be used to organize the data -in the frontend. The scalar and histogram dashboards organize data by tag, and -group the tags into folders according to a directory/like/hierarchy. If you have -a lot of tags, we recommend grouping them with slashes. - -### Event Files & LogDirs: How TensorBoard loads the data - -`summary.FileWriters` take summary data from TensorFlow, and then write them to a -specified directory, known as the `logdir`. Specifically, the data is written to -an append-only record dump that will have "tfevents" in the filename. -TensorBoard reads data from a full directory, and organizes it into the history -of a single TensorFlow execution. - -Why does it read the whole directory, rather than an individual file? You might -have been using -[supervisor.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/training/supervisor.py) -to run your model, in which case if TensorFlow crashes, the supervisor will -restart it from a checkpoint. When it restarts, it will start writing to a new -events file, and TensorBoard will stitch the various event files together to -produce a consistent history of what happened. - -### Runs: Comparing different executions of your model - -You may want to visually compare multiple executions of your model; for example, -suppose you've changed the hyperparameters and want to see if it's converging -faster. TensorBoard enables this through different "runs". When TensorBoard is -passed a `logdir` at startup, it recursively walks the directory tree rooted at -`logdir` looking for subdirectories that contain tfevents data. Every time it -encounters such a subdirectory, it loads it as a new `run`, and the frontend -will organize the data accordingly. - -For example, here is a well-organized TensorBoard log directory, with two runs, -"run1" and "run2". - -``` -/some/path/mnist_experiments/ -/some/path/mnist_experiments/run1/ -/some/path/mnist_experiments/run1/events.out.tfevents.1456525581.name -/some/path/mnist_experiments/run1/events.out.tfevents.1456525585.name -/some/path/mnist_experiments/run2/ -/some/path/mnist_experiments/run2/events.out.tfevents.1456525385.name -/tensorboard --logdir=/some/path/mnist_experiments -``` - -You may also pass a comma separated list of log directories, and TensorBoard -will watch each directory. You can also assign names to individual log -directories by putting a colon between the name and the path, as in - -``` -tensorboard --logdir=name1:/path/to/logs/1,name2:/path/to/logs/2 -``` - -# The Visualizations - -### Scalar Dashboard - -TensorBoard's Scalar Dashboard visualizes scalar statistics that vary over time; -for example, you might want to track the model's loss or learning rate. As -described in *Key Concepts*, you can compare multiple runs, and the data is -organized by tag. The line charts have the following interactions: - -* Clicking on the small blue icon in the lower-left corner of each chart will -expand the chart - -* Dragging a rectangular region on the chart will zoom in - -* Double clicking on the chart will zoom out - -* Mousing over the chart will produce crosshairs, with data values recorded in -the run-selector on the left. - -Additionally, you can create new folders to organize tags by writing regular -expressions in the box in the top-left of the dashboard. - -### Histogram Dashboard - -The HistogramDashboard displays how the statistical distribution of a Tensor -has varied over time. It visualizes data recorded via `tf.summary.histogram`. -Each chart shows temporal "slices" of data, where each slice is a histogram of -the tensor at a given step. It's organized with the oldest timestep in the back, -and the most recent timestep in front. By changing the Histogram Mode from -"offset" to "overlay", the perspective will rotate so that every histogram slice -is rendered as a line and overlaid with one another. - -### Distribution Dashboard - -The Distribution Dashboard is another way of visualizing histogram data from -`tf.summary.histogram`. It shows some high-level statistics on a distribution. -Each line on the chart represents a percentile in the distribution over the -data: for example, the bottom line shows how the minimum value has changed over -time, and the line in the middle shows how the median has changed. Reading from -top to bottom, the lines have the following meaning: `[maximum, 93%, 84%, 69%, -50%, 31%, 16%, 7%, minimum]` - -These percentiles can also be viewed as standard deviation boundaries on a -normal distribution: `[maximum, μ+1.5σ, μ+σ, μ+0.5σ, μ, μ-0.5σ, μ-σ, μ-1.5σ, -minimum]` so that the colored regions, read from inside to outside, have widths -`[σ, 2σ, 3σ]` respectively. - - -### Image Dashboard - -The Image Dashboard can display pngs that were saved via a `tf.summary.image`. -The dashboard is set up so that each row corresponds to a different tag, and -each column corresponds to a run. Since the image dashboard supports arbitrary -pngs, you can use this to embed custom visualizations (e.g. matplotlib -scatterplots) into TensorBoard. This dashboard always shows you the latest image -for each tag. - -### Audio Dashboard - -The Audio Dashboard can embed playable audio widgets for audio saved via a -`tf.summary.audio`. The dashboard is set up so that each row corresponds to a -different tag, and each column corresponds to a run. This dashboard always -embeds the latest audio for each tag. - -### Graph Explorer - -The Graph Explorer can visualize a TensorBoard graph, enabling inspection of the -TensorFlow model. To get best use of the graph visualizer, you should use name -scopes to hierarchically group the ops in your graph - otherwise, the graph may -be difficult to decipher. For more information, including examples, see [the -graph visualizer tutorial](https://www.tensorflow.org/get_started/graph_viz). - -### Embedding Projector - -The Embedding Projector allows you to visualize high-dimensional data; for -example, you may view your input data after it has been embedded in a high- -dimensional space by your model. The embedding projector reads data from your -model checkpoint file, and may be configured with additional metadata, like -a vocabulary file or sprite images. For more details, see [the embedding -projector tutorial](https://www.tensorflow.org/get_started/embedding_viz). - -### Text Dashboard - -The Text Dashboard displays text snippets saved via `tf.summary.text`. Markdown -features including hyperlinks, lists, and tables are all supported. - -# Frequently Asked Questions - -### My TensorBoard isn't showing any data! What's wrong? - -The first thing to do is ensure that TensorBoard is properly loading data from -the correct directory. Launch `tensorboard --logdir=DIRECTORY_PATH --debug` and -look for output of the form - -`INFO:tensorflow:TensorBoard path_to_run is: {'DIRECTORY_PATH': None}` - -Verify that the DIRECTORY_PATH TensorBoard is looking at is the path you expect. -(Note: There's a known issue where TensorBoard [does not handle paths starting -in ~ properly](https://github.com/tensorflow/tensorflow/issues/1587)). - -If you're loading from the proper path, make sure that event files are present. -TensorBoard will recursively walk its logdir, it's fine if the data is nested -under a subdirectory. Try running the command: - -`find DIRECTORY_PATH | grep tfevents` - -If you have at least one result, then TensorBoard should be able to load data. - -Finally, let's make sure that the event files actually have data. Run -tensorboard in inspector mode to inspect the contents of your event files. - -`tensorboard --inspect --logdir=DIRECTORY_PATH` - -If after running this procedure, it's still not working, please file an [issue -on GitHub](https://github.com/tensorflow/tensorflow/issues). It will be much -easier for us to debug it if you provide an event file that isn't working. - -### TensorBoard is showing only some of my data, or isn't properly updating! - -This issue usually comes about because of how TensorBoard iterates through the -`tfevents` files: it progresses through the events file in timestamp order, and -only reads one file at a time. Let's suppose we have files with timestamps `a` -and `b`, where `a self._path) - if next_paths: - return min(next_paths) - else: - return None - - def _HasOOOWrite(self, path): - """Returns whether the path has had an out-of-order write.""" - # Check the sizes of each path before the current one. - size = tf.gfile.Stat(path).length - old_size = self._finalized_sizes.get(path, None) - if size != old_size: - if old_size is None: - tf.logging.error('File %s created after file %s even though it\'s ' - 'lexicographically earlier', path, self._path) - else: - tf.logging.error('File %s updated even though the current file is %s', - path, self._path) - return True - else: - return False - - -class DirectoryDeletedError(Exception): - """Thrown by Load() when the directory is *permanently* gone. - - We distinguish this from temporary errors so that other code can decide to - drop all of our data only when a directory has been intentionally deleted, - as opposed to due to transient filesystem errors. - """ - pass diff --git a/tensorflow/tensorboard/backend/event_processing/directory_watcher_test.py b/tensorflow/tensorboard/backend/event_processing/directory_watcher_test.py deleted file mode 100644 index d44f74a8a43..00000000000 --- a/tensorflow/tensorboard/backend/event_processing/directory_watcher_test.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Tests for directory_watcher.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import shutil - -import tensorflow as tf - -from tensorflow.tensorboard.backend.event_processing import directory_watcher -from tensorflow.tensorboard.backend.event_processing import io_wrapper - - -class _ByteLoader(object): - """A loader that loads individual bytes from a file.""" - - def __init__(self, path): - self._f = open(path) - self.bytes_read = 0 - - def Load(self): - while True: - self._f.seek(self.bytes_read) - byte = self._f.read(1) - if byte: - self.bytes_read += 1 - yield byte - else: - return - - -class DirectoryWatcherTest(tf.test.TestCase): - - def setUp(self): - # Put everything in a directory so it's easier to delete. - self._directory = os.path.join(self.get_temp_dir(), 'monitor_dir') - os.mkdir(self._directory) - self._watcher = directory_watcher.DirectoryWatcher(self._directory, - _ByteLoader) - self.stubs = tf.test.StubOutForTesting() - - def tearDown(self): - self.stubs.CleanUp() - try: - shutil.rmtree(self._directory) - except OSError: - # Some tests delete the directory. - pass - - def _WriteToFile(self, filename, data): - path = os.path.join(self._directory, filename) - with open(path, 'a') as f: - f.write(data) - - def _LoadAllEvents(self): - """Loads all events in the watcher.""" - for _ in self._watcher.Load(): - pass - - def assertWatcherYields(self, values): - self.assertEqual(list(self._watcher.Load()), values) - - def testRaisesWithBadArguments(self): - with self.assertRaises(ValueError): - directory_watcher.DirectoryWatcher(None, lambda x: None) - with self.assertRaises(ValueError): - directory_watcher.DirectoryWatcher('dir', None) - - def testEmptyDirectory(self): - self.assertWatcherYields([]) - - def testSingleWrite(self): - self._WriteToFile('a', 'abc') - self.assertWatcherYields(['a', 'b', 'c']) - self.assertFalse(self._watcher.OutOfOrderWritesDetected()) - - def testMultipleWrites(self): - self._WriteToFile('a', 'abc') - self.assertWatcherYields(['a', 'b', 'c']) - self._WriteToFile('a', 'xyz') - self.assertWatcherYields(['x', 'y', 'z']) - self.assertFalse(self._watcher.OutOfOrderWritesDetected()) - - def testMultipleLoads(self): - self._WriteToFile('a', 'a') - self._watcher.Load() - self._watcher.Load() - self.assertWatcherYields(['a']) - self.assertFalse(self._watcher.OutOfOrderWritesDetected()) - - def testMultipleFilesAtOnce(self): - self._WriteToFile('b', 'b') - self._WriteToFile('a', 'a') - self.assertWatcherYields(['a', 'b']) - self.assertFalse(self._watcher.OutOfOrderWritesDetected()) - - def testFinishesLoadingFileWhenSwitchingToNewFile(self): - self._WriteToFile('a', 'a') - # Empty the iterator. - self.assertEquals(['a'], list(self._watcher.Load())) - self._WriteToFile('a', 'b') - self._WriteToFile('b', 'c') - # The watcher should finish its current file before starting a new one. - self.assertWatcherYields(['b', 'c']) - self.assertFalse(self._watcher.OutOfOrderWritesDetected()) - - def testIntermediateEmptyFiles(self): - self._WriteToFile('a', 'a') - self._WriteToFile('b', '') - self._WriteToFile('c', 'c') - self.assertWatcherYields(['a', 'c']) - self.assertFalse(self._watcher.OutOfOrderWritesDetected()) - - def testPathFilter(self): - self._watcher = directory_watcher.DirectoryWatcher( - self._directory, _ByteLoader, - lambda path: 'do_not_watch_me' not in path) - - self._WriteToFile('a', 'a') - self._WriteToFile('do_not_watch_me', 'b') - self._WriteToFile('c', 'c') - self.assertWatcherYields(['a', 'c']) - self.assertFalse(self._watcher.OutOfOrderWritesDetected()) - - def testDetectsNewOldFiles(self): - self._WriteToFile('b', 'a') - self._LoadAllEvents() - self._WriteToFile('a', 'a') - self._LoadAllEvents() - self.assertTrue(self._watcher.OutOfOrderWritesDetected()) - - def testIgnoresNewerFiles(self): - self._WriteToFile('a', 'a') - self._LoadAllEvents() - self._WriteToFile('q', 'a') - self._LoadAllEvents() - self.assertFalse(self._watcher.OutOfOrderWritesDetected()) - - def testDetectsChangingOldFiles(self): - self._WriteToFile('a', 'a') - self._WriteToFile('b', 'a') - self._LoadAllEvents() - self._WriteToFile('a', 'c') - self._LoadAllEvents() - self.assertTrue(self._watcher.OutOfOrderWritesDetected()) - - def testDoesntCrashWhenFileIsDeleted(self): - self._WriteToFile('a', 'a') - self._LoadAllEvents() - os.remove(os.path.join(self._directory, 'a')) - self._WriteToFile('b', 'b') - self.assertWatcherYields(['b']) - - def testRaisesRightErrorWhenDirectoryIsDeleted(self): - self._WriteToFile('a', 'a') - self._LoadAllEvents() - shutil.rmtree(self._directory) - with self.assertRaises(directory_watcher.DirectoryDeletedError): - self._LoadAllEvents() - - def testDoesntRaiseDirectoryDeletedErrorIfOutageIsTransient(self): - self._WriteToFile('a', 'a') - self._LoadAllEvents() - shutil.rmtree(self._directory) - - # Fake a single transient I/O error. - def FakeFactory(original): - - def Fake(*args, **kwargs): - if FakeFactory.has_been_called: - original(*args, **kwargs) - else: - raise OSError('lp0 temporarily on fire') - - return Fake - - FakeFactory.has_been_called = False - - for stub_name in ['ListDirectoryAbsolute', 'ListRecursively']: - self.stubs.Set(io_wrapper, stub_name, - FakeFactory(getattr(io_wrapper, stub_name))) - for stub_name in ['IsDirectory', 'Exists', 'Stat']: - self.stubs.Set(tf.gfile, stub_name, - FakeFactory(getattr(tf.gfile, stub_name))) - - with self.assertRaises((IOError, OSError)): - self._LoadAllEvents() - - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow/tensorboard/backend/event_processing/event_accumulator.py b/tensorflow/tensorboard/backend/event_processing/event_accumulator.py deleted file mode 100644 index 1562f0f8339..00000000000 --- a/tensorflow/tensorboard/backend/event_processing/event_accumulator.py +++ /dev/null @@ -1,851 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Takes a generator of values, and accumulates them for a frontend.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import os -import re -import threading - -import numpy as np -import tensorflow as tf - -from tensorflow.tensorboard.backend.event_processing import directory_watcher -from tensorflow.tensorboard.backend.event_processing import event_file_loader -from tensorflow.tensorboard.backend.event_processing import plugin_asset_util -from tensorflow.tensorboard.backend.event_processing import reservoir - -namedtuple = collections.namedtuple -ScalarEvent = namedtuple('ScalarEvent', ['wall_time', 'step', 'value']) - -HealthPillEvent = namedtuple('HealthPillEvent', [ - 'wall_time', 'step', 'device_name', 'node_name', 'output_slot', 'dtype', - 'shape', 'value']) - -CompressedHistogramEvent = namedtuple('CompressedHistogramEvent', - ['wall_time', 'step', - 'compressed_histogram_values']) - -CompressedHistogramValue = namedtuple('CompressedHistogramValue', - ['basis_point', 'value']) - -HistogramEvent = namedtuple('HistogramEvent', - ['wall_time', 'step', 'histogram_value']) - -HistogramValue = namedtuple('HistogramValue', ['min', 'max', 'num', 'sum', - 'sum_squares', 'bucket_limit', - 'bucket']) - -ImageEvent = namedtuple('ImageEvent', ['wall_time', 'step', - 'encoded_image_string', 'width', - 'height']) - -AudioEvent = namedtuple('AudioEvent', ['wall_time', 'step', - 'encoded_audio_string', 'content_type', - 'sample_rate', 'length_frames']) - -TensorEvent = namedtuple('TensorEvent', ['wall_time', 'step', 'tensor_proto']) - -## Different types of summary events handled by the event_accumulator -SUMMARY_TYPES = { - 'simple_value': '_ProcessScalar', - 'histo': '_ProcessHistogram', - 'image': '_ProcessImage', - 'audio': '_ProcessAudio', - 'tensor': '_ProcessTensor', -} - -## The tagTypes below are just arbitrary strings chosen to pass the type -## information of the tag from the backend to the frontend -COMPRESSED_HISTOGRAMS = 'distributions' -HISTOGRAMS = 'histograms' -IMAGES = 'images' -AUDIO = 'audio' -SCALARS = 'scalars' -TENSORS = 'tensors' -HEALTH_PILLS = 'health_pills' -GRAPH = 'graph' -META_GRAPH = 'meta_graph' -RUN_METADATA = 'run_metadata' - -## Normal CDF for std_devs: (-Inf, -1.5, -1, -0.5, 0, 0.5, 1, 1.5, Inf) -## naturally gives bands around median of width 1 std dev, 2 std dev, 3 std dev, -## and then the long tail. -NORMAL_HISTOGRAM_BPS = (0, 668, 1587, 3085, 5000, 6915, 8413, 9332, 10000) - -DEFAULT_SIZE_GUIDANCE = { - COMPRESSED_HISTOGRAMS: 500, - IMAGES: 4, - AUDIO: 4, - SCALARS: 10000, - # We store this many health pills per op. - HEALTH_PILLS: 100, - HISTOGRAMS: 1, - TENSORS: 10, -} - -STORE_EVERYTHING_SIZE_GUIDANCE = { - COMPRESSED_HISTOGRAMS: 0, - IMAGES: 0, - AUDIO: 0, - SCALARS: 0, - HEALTH_PILLS: 0, - HISTOGRAMS: 0, - TENSORS: 0, -} - -# The tag that values containing health pills have. Health pill data is stored -# in tensors. In order to distinguish health pill values from scalar values, we -# rely on how health pill values have this special tag value. -HEALTH_PILL_EVENT_TAG_PREFIX = '__health_pill__/' - - -def IsTensorFlowEventsFile(path): - """Check the path name to see if it is probably a TF Events file. - - Args: - path: A file path to check if it is an event file. - - Raises: - ValueError: If the path is an empty string. - - Returns: - If path is formatted like a TensorFlowEventsFile. - """ - if not path: - raise ValueError('Path must be a nonempty string') - return 'tfevents' in tf.compat.as_str_any(os.path.basename(path)) - - -class EventAccumulator(object): - """An `EventAccumulator` takes an event generator, and accumulates the values. - - The `EventAccumulator` is intended to provide a convenient Python interface - for loading Event data written during a TensorFlow run. TensorFlow writes out - `Event` protobuf objects, which have a timestamp and step number, and often - contain a `Summary`. Summaries can have different kinds of data like an image, - a scalar value, or a histogram. The Summaries also have a tag, which we use to - organize logically related data. The `EventAccumulator` supports retrieving - the `Event` and `Summary` data by its tag. - - Calling `Tags()` gets a map from `tagType` (e.g. `'images'`, - `'compressedHistograms'`, `'scalars'`, etc) to the associated tags for those - data types. Then, various functional endpoints (eg - `Accumulator.Scalars(tag)`) allow for the retrieval of all data - associated with that tag. - - The `Reload()` method synchronously loads all of the data written so far. - - Histograms, audio, and images are very large, so storing all of them is not - recommended. - @@Tensors - """ - - def __init__(self, - path, - size_guidance=DEFAULT_SIZE_GUIDANCE, - compression_bps=NORMAL_HISTOGRAM_BPS, - purge_orphaned_data=True): - """Construct the `EventAccumulator`. - - Args: - path: A file path to a directory containing tf events files, or a single - tf events file. The accumulator will load events from this path. - size_guidance: Information on how much data the EventAccumulator should - store in memory. The DEFAULT_SIZE_GUIDANCE tries not to store too much - so as to avoid OOMing the client. The size_guidance should be a map - from a `tagType` string to an integer representing the number of - items to keep per tag for items of that `tagType`. If the size is 0, - all events are stored. - compression_bps: Information on how the `EventAccumulator` should compress - histogram data for the `CompressedHistograms` tag (for details see - `ProcessCompressedHistogram`). - purge_orphaned_data: Whether to discard any events that were "orphaned" by - a TensorFlow restart. - """ - sizes = {} - for key in DEFAULT_SIZE_GUIDANCE: - if key in size_guidance: - sizes[key] = size_guidance[key] - else: - sizes[key] = DEFAULT_SIZE_GUIDANCE[key] - - self._first_event_timestamp = None - self._scalars = reservoir.Reservoir(size=sizes[SCALARS]) - - # Unlike the other reservoir, the reservoir for health pills is keyed by the - # name of the op instead of the tag. This lets us efficiently obtain the - # health pills per node. - self._health_pills = reservoir.Reservoir(size=sizes[HEALTH_PILLS]) - - self._graph = None - self._graph_from_metagraph = False - self._meta_graph = None - self._tagged_metadata = {} - self._histograms = reservoir.Reservoir(size=sizes[HISTOGRAMS]) - self._compressed_histograms = reservoir.Reservoir( - size=sizes[COMPRESSED_HISTOGRAMS], always_keep_last=False) - self._images = reservoir.Reservoir(size=sizes[IMAGES]) - self._audio = reservoir.Reservoir(size=sizes[AUDIO]) - self._tensors = reservoir.Reservoir(size=sizes[TENSORS]) - - self._generator_mutex = threading.Lock() - self.path = path - self._generator = _GeneratorFromPath(path) - - self._compression_bps = compression_bps - self.purge_orphaned_data = purge_orphaned_data - - self.most_recent_step = -1 - self.most_recent_wall_time = -1 - self.file_version = None - - # The attributes that get built up by the accumulator - self.accumulated_attrs = ('_scalars', '_histograms', - '_compressed_histograms', '_images', '_audio') - self._tensor_summaries = {} - - def Reload(self): - """Loads all events added since the last call to `Reload`. - - If `Reload` was never called, loads all events in the file. - - Returns: - The `EventAccumulator`. - """ - with self._generator_mutex: - for event in self._generator.Load(): - self._ProcessEvent(event) - return self - - def PluginAssets(self, plugin_name): - """Return a list of all plugin assets for the given plugin. - - Args: - plugin_name: The string name of a plugin to retrieve assets for. - - Returns: - A list of string plugin asset names, or empty list if none are available. - If the plugin was not registered, an empty list is returned. - """ - return plugin_asset_util.ListAssets(self.path, plugin_name) - - def RetrievePluginAsset(self, plugin_name, asset_name): - """Return the contents of a given plugin asset. - - Args: - plugin_name: The string name of a plugin. - asset_name: The string name of an asset. - - Returns: - The string contents of the plugin asset. - - Raises: - KeyError: If the asset is not available. - """ - return plugin_asset_util.RetrieveAsset(self.path, plugin_name, asset_name) - - def FirstEventTimestamp(self): - """Returns the timestamp in seconds of the first event. - - If the first event has been loaded (either by this method or by `Reload`, - this returns immediately. Otherwise, it will load in the first event. Note - that this means that calling `Reload` will cause this to block until - `Reload` has finished. - - Returns: - The timestamp in seconds of the first event that was loaded. - - Raises: - ValueError: If no events have been loaded and there were no events found - on disk. - """ - if self._first_event_timestamp is not None: - return self._first_event_timestamp - with self._generator_mutex: - try: - event = next(self._generator.Load()) - self._ProcessEvent(event) - return self._first_event_timestamp - - except StopIteration: - raise ValueError('No event timestamp could be found') - - def _ProcessEvent(self, event): - """Called whenever an event is loaded.""" - if self._first_event_timestamp is None: - self._first_event_timestamp = event.wall_time - - if event.HasField('file_version'): - new_file_version = _ParseFileVersion(event.file_version) - if self.file_version and self.file_version != new_file_version: - ## This should not happen. - tf.logging.warn(('Found new file_version for event.proto. This will ' - 'affect purging logic for TensorFlow restarts. ' - 'Old: {0} New: {1}').format(self.file_version, - new_file_version)) - self.file_version = new_file_version - - self._MaybePurgeOrphanedData(event) - - ## Process the event. - # GraphDef and MetaGraphDef are handled in a special way: - # If no graph_def Event is available, but a meta_graph_def is, and it - # contains a graph_def, then use the meta_graph_def.graph_def as our graph. - # If a graph_def Event is available, always prefer it to the graph_def - # inside the meta_graph_def. - if event.HasField('graph_def'): - if self._graph is not None: - tf.logging.warn( - ('Found more than one graph event per run, or there was ' - 'a metagraph containing a graph_def, as well as one or ' - 'more graph events. Overwriting the graph with the ' - 'newest event.')) - self._graph = event.graph_def - self._graph_from_metagraph = False - elif event.HasField('meta_graph_def'): - if self._meta_graph is not None: - tf.logging.warn(('Found more than one metagraph event per run. ' - 'Overwriting the metagraph with the newest event.')) - self._meta_graph = event.meta_graph_def - if self._graph is None or self._graph_from_metagraph: - # We may have a graph_def in the metagraph. If so, and no - # graph_def is directly available, use this one instead. - meta_graph = tf.MetaGraphDef() - meta_graph.ParseFromString(self._meta_graph) - if meta_graph.graph_def: - if self._graph is not None: - tf.logging.warn( - ('Found multiple metagraphs containing graph_defs,' - 'but did not find any graph events. Overwriting the ' - 'graph with the newest metagraph version.')) - self._graph_from_metagraph = True - self._graph = meta_graph.graph_def.SerializeToString() - elif event.HasField('tagged_run_metadata'): - tag = event.tagged_run_metadata.tag - if tag in self._tagged_metadata: - tf.logging.warn('Found more than one "run metadata" event with tag ' + - tag + '. Overwriting it with the newest event.') - self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata - elif event.HasField('summary'): - for value in event.summary.value: - if (value.HasField('tensor') and - value.tag.startswith(HEALTH_PILL_EVENT_TAG_PREFIX)): - self._ProcessHealthPillSummary(value, event) - else: - for summary_type, summary_func in SUMMARY_TYPES.items(): - if value.HasField(summary_type): - datum = getattr(value, summary_type) - tag = value.node_name if summary_type == 'tensor' else value.tag - getattr(self, summary_func)(tag, event.wall_time, event.step, - datum) - - def _ProcessHealthPillSummary(self, value, event): - """Process summaries containing health pills. - - These summaries are distinguished by the fact that they have a Tensor field - and have a special tag value. - - This method emits ERROR-level messages to the logs if it encounters Tensor - summaries that it cannot process. - - Args: - value: A tf.Summary.Value with a Tensor field. - event: The tf.Event containing that value. - """ - elements = tf.make_ndarray(value.tensor) - - # The node_name property of the value object is actually a watch key: a - # combination of node name, output slot, and a suffix. We capture the - # actual node name and the output slot with a regular expression. - match = re.match(r'^(.*):(\d+):DebugNumericSummary$', value.node_name) - if not match: - tf.logging.log_first_n( - tf.logging.ERROR, - 'Unsupported watch key %s for health pills; skipping this sequence.', - 1, value.node_name) - return - - node_name = match.group(1) - output_slot = int(match.group(2)) - device_name = value.tag[len(HEALTH_PILL_EVENT_TAG_PREFIX):] - self._ProcessHealthPill(event.wall_time, event.step, device_name, node_name, - output_slot, elements) - - def Tags(self): - """Return all tags found in the value stream. - - Returns: - A `{tagType: ['list', 'of', 'tags']}` dictionary. - """ - return { - IMAGES: self._images.Keys(), - AUDIO: self._audio.Keys(), - HISTOGRAMS: self._histograms.Keys(), - SCALARS: self._scalars.Keys(), - COMPRESSED_HISTOGRAMS: self._compressed_histograms.Keys(), - TENSORS: self._tensors.Keys(), - # Use a heuristic: if the metagraph is available, but - # graph is not, then we assume the metagraph contains the graph. - GRAPH: self._graph is not None, - META_GRAPH: self._meta_graph is not None, - RUN_METADATA: list(self._tagged_metadata.keys()) - } - - def Scalars(self, tag): - """Given a summary tag, return all associated `ScalarEvent`s. - - Args: - tag: A string tag associated with the events. - - Raises: - KeyError: If the tag is not found. - - Returns: - An array of `ScalarEvent`s. - """ - return self._scalars.Items(tag) - - def HealthPills(self, node_name): - """Returns all health pill values for a certain node. - - Args: - node_name: The name of the node to obtain health pills for. - - Raises: - KeyError: If the node name is not found. - - Returns: - An array of `HealthPillEvent`s. - """ - return self._health_pills.Items(node_name) - - def GetOpsWithHealthPills(self): - """Determines which ops have at least 1 health pill event. - - Returns: - A list of names of ops with at least 1 health pill event. - """ - return self._health_pills.Keys() - - def Graph(self): - """Return the graph definition, if there is one. - - If the graph is stored directly, return that. If no graph is stored - directly but a metagraph is stored containing a graph, return that. - - Raises: - ValueError: If there is no graph for this run. - - Returns: - The `graph_def` proto. - """ - graph = tf.GraphDef() - if self._graph is not None: - graph.ParseFromString(self._graph) - return graph - raise ValueError('There is no graph in this EventAccumulator') - - def MetaGraph(self): - """Return the metagraph definition, if there is one. - - Raises: - ValueError: If there is no metagraph for this run. - - Returns: - The `meta_graph_def` proto. - """ - if self._meta_graph is None: - raise ValueError('There is no metagraph in this EventAccumulator') - meta_graph = tf.MetaGraphDef() - meta_graph.ParseFromString(self._meta_graph) - return meta_graph - - def RunMetadata(self, tag): - """Given a tag, return the associated session.run() metadata. - - Args: - tag: A string tag associated with the event. - - Raises: - ValueError: If the tag is not found. - - Returns: - The metadata in form of `RunMetadata` proto. - """ - if tag not in self._tagged_metadata: - raise ValueError('There is no run metadata with this tag name') - - run_metadata = tf.RunMetadata() - run_metadata.ParseFromString(self._tagged_metadata[tag]) - return run_metadata - - def Histograms(self, tag): - """Given a summary tag, return all associated histograms. - - Args: - tag: A string tag associated with the events. - - Raises: - KeyError: If the tag is not found. - - Returns: - An array of `HistogramEvent`s. - """ - return self._histograms.Items(tag) - - def CompressedHistograms(self, tag): - """Given a summary tag, return all associated compressed histograms. - - Args: - tag: A string tag associated with the events. - - Raises: - KeyError: If the tag is not found. - - Returns: - An array of `CompressedHistogramEvent`s. - """ - return self._compressed_histograms.Items(tag) - - def Images(self, tag): - """Given a summary tag, return all associated images. - - Args: - tag: A string tag associated with the events. - - Raises: - KeyError: If the tag is not found. - - Returns: - An array of `ImageEvent`s. - """ - return self._images.Items(tag) - - def Audio(self, tag): - """Given a summary tag, return all associated audio. - - Args: - tag: A string tag associated with the events. - - Raises: - KeyError: If the tag is not found. - - Returns: - An array of `AudioEvent`s. - """ - return self._audio.Items(tag) - - def Tensors(self, tag): - """Given a summary tag, return all associated tensors. - - Args: - tag: A string tag associated with the events. - - Raises: - KeyError: If the tag is not found. - - Returns: - An array of `TensorEvent`s. - """ - return self._tensors.Items(tag) - - def _MaybePurgeOrphanedData(self, event): - """Maybe purge orphaned data due to a TensorFlow crash. - - When TensorFlow crashes at step T+O and restarts at step T, any events - written after step T are now "orphaned" and will be at best misleading if - they are included in TensorBoard. - - This logic attempts to determine if there is orphaned data, and purge it - if it is found. - - Args: - event: The event to use as a reference, to determine if a purge is needed. - """ - if not self.purge_orphaned_data: - return - ## Check if the event happened after a crash, and purge expired tags. - if self.file_version and self.file_version >= 2: - ## If the file_version is recent enough, use the SessionLog enum - ## to check for restarts. - self._CheckForRestartAndMaybePurge(event) - else: - ## If there is no file version, default to old logic of checking for - ## out of order steps. - self._CheckForOutOfOrderStepAndMaybePurge(event) - - def _CheckForRestartAndMaybePurge(self, event): - """Check and discard expired events using SessionLog.START. - - Check for a SessionLog.START event and purge all previously seen events - with larger steps, because they are out of date. Because of supervisor - threading, it is possible that this logic will cause the first few event - messages to be discarded since supervisor threading does not guarantee - that the START message is deterministically written first. - - This method is preferred over _CheckForOutOfOrderStepAndMaybePurge which - can inadvertently discard events due to supervisor threading. - - Args: - event: The event to use as reference. If the event is a START event, all - previously seen events with a greater event.step will be purged. - """ - if event.HasField( - 'session_log') and event.session_log.status == tf.SessionLog.START: - self._Purge(event, by_tags=False) - - def _CheckForOutOfOrderStepAndMaybePurge(self, event): - """Check for out-of-order event.step and discard expired events for tags. - - Check if the event is out of order relative to the global most recent step. - If it is, purge outdated summaries for tags that the event contains. - - Args: - event: The event to use as reference. If the event is out-of-order, all - events with the same tags, but with a greater event.step will be purged. - """ - if event.step < self.most_recent_step and event.HasField('summary'): - self._Purge(event, by_tags=True) - else: - self.most_recent_step = event.step - self.most_recent_wall_time = event.wall_time - - def _ConvertHistogramProtoToTuple(self, histo): - return HistogramValue(min=histo.min, - max=histo.max, - num=histo.num, - sum=histo.sum, - sum_squares=histo.sum_squares, - bucket_limit=list(histo.bucket_limit), - bucket=list(histo.bucket)) - - def _ProcessHistogram(self, tag, wall_time, step, histo): - """Processes a proto histogram by adding it to accumulated state.""" - histo = self._ConvertHistogramProtoToTuple(histo) - histo_ev = HistogramEvent(wall_time, step, histo) - self._histograms.AddItem(tag, histo_ev) - self._compressed_histograms.AddItem( - tag, histo_ev, lambda x: _CompressHistogram(x, self._compression_bps)) - - def _ProcessImage(self, tag, wall_time, step, image): - """Processes an image by adding it to accumulated state.""" - event = ImageEvent(wall_time=wall_time, - step=step, - encoded_image_string=image.encoded_image_string, - width=image.width, - height=image.height) - self._images.AddItem(tag, event) - - def _ProcessAudio(self, tag, wall_time, step, audio): - """Processes a audio by adding it to accumulated state.""" - event = AudioEvent(wall_time=wall_time, - step=step, - encoded_audio_string=audio.encoded_audio_string, - content_type=audio.content_type, - sample_rate=audio.sample_rate, - length_frames=audio.length_frames) - self._audio.AddItem(tag, event) - - def _ProcessScalar(self, tag, wall_time, step, scalar): - """Processes a simple value by adding it to accumulated state.""" - sv = ScalarEvent(wall_time=wall_time, step=step, value=scalar) - self._scalars.AddItem(tag, sv) - - def _ProcessTensor(self, tag, wall_time, step, tensor): - tv = TensorEvent(wall_time=wall_time, step=step, tensor_proto=tensor) - self._tensors.AddItem(tag, tv) - - def _ProcessHealthPill(self, wall_time, step, device_name, node_name, - output_slot, elements): - """Processes a health pill value by adding it to accumulated state. - - Args: - wall_time: The time at which the health pill was created. Provided by the - debugger. - step: The step at which the health pill was created. Provided by the - debugger. - device_name: The name of the node's device. - node_name: The name of the node for this health pill. - output_slot: The output slot for this health pill. - elements: An ND array of 20 floats. The elements of the health pill. - """ - # Key by the node name for fast retrieval of health pills by node name. The - # array is cast to a list so that it is JSON-able. The debugger data plugin - # serves a JSON response. - self._health_pills.AddItem(node_name, - HealthPillEvent( - wall_time=wall_time, - step=step, - device_name=device_name, - node_name=node_name, - output_slot=output_slot, - dtype=repr(tf.as_dtype(elements[12])), - shape=list(elements[14:]), - value=list(elements))) - - def _Purge(self, event, by_tags): - """Purge all events that have occurred after the given event.step. - - If by_tags is True, purge all events that occurred after the given - event.step, but only for the tags that the event has. Non-sequential - event.steps suggest that a TensorFlow restart occurred, and we discard - the out-of-order events to display a consistent view in TensorBoard. - - Discarding by tags is the safer method, when we are unsure whether a restart - has occurred, given that threading in supervisor can cause events of - different tags to arrive with unsynchronized step values. - - If by_tags is False, then purge all events with event.step greater than the - given event.step. This can be used when we are certain that a TensorFlow - restart has occurred and these events can be discarded. - - Args: - event: The event to use as reference for the purge. All events with - the same tags, but with a greater event.step will be purged. - by_tags: Bool to dictate whether to discard all out-of-order events or - only those that are associated with the given reference event. - """ - ## Keep data in reservoirs that has a step less than event.step - _NotExpired = lambda x: x.step < event.step - - if by_tags: - - def _ExpiredPerTag(value): - return [getattr(self, x).FilterItems(_NotExpired, value.tag) - for x in self.accumulated_attrs] - - expired_per_tags = [_ExpiredPerTag(value) - for value in event.summary.value] - expired_per_type = [sum(x) for x in zip(*expired_per_tags)] - else: - expired_per_type = [getattr(self, x).FilterItems(_NotExpired) - for x in self.accumulated_attrs] - - if sum(expired_per_type) > 0: - purge_msg = _GetPurgeMessage(self.most_recent_step, - self.most_recent_wall_time, event.step, - event.wall_time, *expired_per_type) - tf.logging.warn(purge_msg) - - -def _GetPurgeMessage(most_recent_step, most_recent_wall_time, event_step, - event_wall_time, num_expired_scalars, num_expired_histos, - num_expired_comp_histos, num_expired_images, - num_expired_audio): - """Return the string message associated with TensorBoard purges.""" - return ('Detected out of order event.step likely caused by ' - 'a TensorFlow restart. Purging expired events from Tensorboard' - ' display between the previous step: {} (timestamp: {}) and ' - 'current step: {} (timestamp: {}). Removing {} scalars, {} ' - 'histograms, {} compressed histograms, {} images, ' - 'and {} audio.').format(most_recent_step, most_recent_wall_time, - event_step, event_wall_time, - num_expired_scalars, num_expired_histos, - num_expired_comp_histos, num_expired_images, - num_expired_audio) - - -def _GeneratorFromPath(path): - """Create an event generator for file or directory at given path string.""" - if not path: - raise ValueError('path must be a valid string') - if IsTensorFlowEventsFile(path): - return event_file_loader.EventFileLoader(path) - else: - return directory_watcher.DirectoryWatcher( - path, event_file_loader.EventFileLoader, IsTensorFlowEventsFile) - - -def _ParseFileVersion(file_version): - """Convert the string file_version in event.proto into a float. - - Args: - file_version: String file_version from event.proto - - Returns: - Version number as a float. - """ - tokens = file_version.split('brain.Event:') - try: - return float(tokens[-1]) - except ValueError: - ## This should never happen according to the definition of file_version - ## specified in event.proto. - tf.logging.warn( - ('Invalid event.proto file_version. Defaulting to use of ' - 'out-of-order event.step logic for purging expired events.')) - return -1 - - -def _CompressHistogram(histo_ev, bps): - """Creates fixed size histogram by adding compression to accumulated state. - - This routine transforms a histogram at a particular step by linearly - interpolating its variable number of buckets to represent their cumulative - weight at a constant number of compression points. This significantly reduces - the size of the histogram and makes it suitable for a two-dimensional area - plot where the output of this routine constitutes the ranges for a single x - coordinate. - - Args: - histo_ev: A HistogramEvent namedtuple. - bps: Compression points represented in basis points, 1/100ths of a percent. - - Returns: - CompressedHistogramEvent namedtuple. - """ - # See also: Histogram::Percentile() in core/lib/histogram/histogram.cc - histo = histo_ev.histogram_value - if not histo.num: - return CompressedHistogramEvent( - histo_ev.wall_time, - histo_ev.step, - [CompressedHistogramValue(b, 0.0) for b in bps]) - bucket = np.array(histo.bucket) - weights = (bucket * bps[-1] / (bucket.sum() or 1.0)).cumsum() - values = [] - j = 0 - while j < len(bps): - i = np.searchsorted(weights, bps[j], side='right') - while i < len(weights): - cumsum = weights[i] - cumsum_prev = weights[i - 1] if i > 0 else 0.0 - if cumsum == cumsum_prev: # prevent remap divide by zero - i += 1 - continue - if not i or not cumsum_prev: - lhs = histo.min - else: - lhs = max(histo.bucket_limit[i - 1], histo.min) - rhs = min(histo.bucket_limit[i], histo.max) - weight = _Remap(bps[j], cumsum_prev, cumsum, lhs, rhs) - values.append(CompressedHistogramValue(bps[j], weight)) - j += 1 - break - else: - break - while j < len(bps): - values.append(CompressedHistogramValue(bps[j], histo.max)) - j += 1 - return CompressedHistogramEvent(histo_ev.wall_time, histo_ev.step, values) - - -def _Remap(x, x0, x1, y0, y1): - """Linearly map from [x0, x1] unto [y0, y1].""" - return y0 + (x - x0) * float(y1 - y0) / (x1 - x0) diff --git a/tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py b/tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py deleted file mode 100644 index 4ce766f4204..00000000000 --- a/tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py +++ /dev/null @@ -1,976 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os - -import numpy as np -import six -from six.moves import xrange # pylint: disable=redefined-builtin -import tensorflow as tf - -from tensorflow.tensorboard.backend.event_processing import event_accumulator as ea - - -class _EventGenerator(object): - """Class that can add_events and then yield them back. - - Satisfies the EventGenerator API required for the EventAccumulator. - Satisfies the EventWriter API required to create a SummaryWriter. - - Has additional convenience methods for adding test events. - """ - - def __init__(self, testcase, zero_out_timestamps=False): - self._testcase = testcase - self.items = [] - self.zero_out_timestamps = zero_out_timestamps - - def Load(self): - while self.items: - yield self.items.pop(0) - - def AddScalar(self, tag, wall_time=0, step=0, value=0): - event = tf.Event( - wall_time=wall_time, - step=step, - summary=tf.Summary( - value=[tf.Summary.Value(tag=tag, simple_value=value)])) - self.AddEvent(event) - - def AddHealthPill(self, wall_time, step, device_name, op_name, output_slot, - elements): - event = tf.Event(step=step, wall_time=wall_time) - value = event.summary.value.add( - tag=ea.HEALTH_PILL_EVENT_TAG_PREFIX + device_name, - node_name='%s:%d:DebugNumericSummary' % (op_name, output_slot)) - value.tensor.tensor_shape.dim.add(size=len(elements)) - value.tensor.dtype = 2 # DT_DOUBLE - value.tensor.tensor_content = np.array(elements, dtype=np.float64).tobytes() - self.AddEvent(event) - - def AddHistogram(self, - tag, - wall_time=0, - step=0, - hmin=1, - hmax=2, - hnum=3, - hsum=4, - hsum_squares=5, - hbucket_limit=None, - hbucket=None): - histo = tf.HistogramProto( - min=hmin, - max=hmax, - num=hnum, - sum=hsum, - sum_squares=hsum_squares, - bucket_limit=hbucket_limit, - bucket=hbucket) - event = tf.Event( - wall_time=wall_time, - step=step, - summary=tf.Summary(value=[tf.Summary.Value(tag=tag, histo=histo)])) - self.AddEvent(event) - - def AddImage(self, - tag, - wall_time=0, - step=0, - encoded_image_string=b'imgstr', - width=150, - height=100): - image = tf.Summary.Image( - encoded_image_string=encoded_image_string, width=width, height=height) - event = tf.Event( - wall_time=wall_time, - step=step, - summary=tf.Summary(value=[tf.Summary.Value(tag=tag, image=image)])) - self.AddEvent(event) - - def AddAudio(self, - tag, - wall_time=0, - step=0, - encoded_audio_string=b'sndstr', - content_type='audio/wav', - sample_rate=44100, - length_frames=22050): - audio = tf.Summary.Audio( - encoded_audio_string=encoded_audio_string, - content_type=content_type, - sample_rate=sample_rate, - length_frames=length_frames) - event = tf.Event( - wall_time=wall_time, - step=step, - summary=tf.Summary(value=[tf.Summary.Value(tag=tag, audio=audio)])) - self.AddEvent(event) - - def AddEvent(self, event): - if self.zero_out_timestamps: - event.wall_time = 0 - self.items.append(event) - - def add_event(self, event): # pylint: disable=invalid-name - """Match the EventWriter API.""" - self.AddEvent(event) - - def get_logdir(self): # pylint: disable=invalid-name - """Return a temp directory for asset writing.""" - return self._testcase.get_temp_dir() - - -class EventAccumulatorTest(tf.test.TestCase): - - def assertTagsEqual(self, actual, expected): - """Utility method for checking the return value of the Tags() call. - - It fills out the `expected` arg with the default (empty) values for every - tag type, so that the author needs only specify the non-empty values they - are interested in testing. - - Args: - actual: The actual Accumulator tags response. - expected: The expected tags response (empty fields may be omitted) - """ - - empty_tags = { - ea.IMAGES: [], - ea.AUDIO: [], - ea.SCALARS: [], - ea.HISTOGRAMS: [], - ea.COMPRESSED_HISTOGRAMS: [], - ea.GRAPH: False, - ea.META_GRAPH: False, - ea.RUN_METADATA: [], - ea.TENSORS: [], - } - - # Verifies that there are no unexpected keys in the actual response. - # If this line fails, likely you added a new tag type, and need to update - # the empty_tags dictionary above. - self.assertItemsEqual(actual.keys(), empty_tags.keys()) - - for key in actual: - expected_value = expected.get(key, empty_tags[key]) - if isinstance(expected_value, list): - self.assertItemsEqual(actual[key], expected_value) - else: - self.assertEqual(actual[key], expected_value) - - -class MockingEventAccumulatorTest(EventAccumulatorTest): - - def setUp(self): - super(MockingEventAccumulatorTest, self).setUp() - self.stubs = tf.test.StubOutForTesting() - self._real_constructor = ea.EventAccumulator - self._real_generator = ea._GeneratorFromPath - - def _FakeAccumulatorConstructor(generator, *args, **kwargs): - ea._GeneratorFromPath = lambda x: generator - return self._real_constructor(generator, *args, **kwargs) - - ea.EventAccumulator = _FakeAccumulatorConstructor - - def tearDown(self): - self.stubs.CleanUp() - ea.EventAccumulator = self._real_constructor - ea._GeneratorFromPath = self._real_generator - - def testEmptyAccumulator(self): - gen = _EventGenerator(self) - x = ea.EventAccumulator(gen) - x.Reload() - self.assertTagsEqual(x.Tags(), {}) - - def testTags(self): - """Tags should be found in EventAccumulator after adding some events.""" - gen = _EventGenerator(self) - gen.AddScalar('s1') - gen.AddScalar('s2') - gen.AddHistogram('hst1') - gen.AddHistogram('hst2') - gen.AddImage('im1') - gen.AddImage('im2') - gen.AddAudio('snd1') - gen.AddAudio('snd2') - acc = ea.EventAccumulator(gen) - acc.Reload() - self.assertTagsEqual(acc.Tags(), { - ea.IMAGES: ['im1', 'im2'], - ea.AUDIO: ['snd1', 'snd2'], - ea.SCALARS: ['s1', 's2'], - ea.HISTOGRAMS: ['hst1', 'hst2'], - ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'], - }) - - def testReload(self): - """EventAccumulator contains suitable tags after calling Reload.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - acc.Reload() - self.assertTagsEqual(acc.Tags(), {}) - gen.AddScalar('s1') - gen.AddScalar('s2') - gen.AddHistogram('hst1') - gen.AddHistogram('hst2') - gen.AddImage('im1') - gen.AddImage('im2') - gen.AddAudio('snd1') - gen.AddAudio('snd2') - acc.Reload() - self.assertTagsEqual(acc.Tags(), { - ea.IMAGES: ['im1', 'im2'], - ea.AUDIO: ['snd1', 'snd2'], - ea.SCALARS: ['s1', 's2'], - ea.HISTOGRAMS: ['hst1', 'hst2'], - ea.COMPRESSED_HISTOGRAMS: ['hst1', 'hst2'], - }) - - def testScalars(self): - """Tests whether EventAccumulator contains scalars after adding them.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - s1 = ea.ScalarEvent(wall_time=1, step=10, value=32) - s2 = ea.ScalarEvent(wall_time=2, step=12, value=64) - gen.AddScalar('s1', wall_time=1, step=10, value=32) - gen.AddScalar('s2', wall_time=2, step=12, value=64) - acc.Reload() - self.assertEqual(acc.Scalars('s1'), [s1]) - self.assertEqual(acc.Scalars('s2'), [s2]) - - def _compareHealthPills(self, expected_event, gotten_event): - """Compares 2 health pills. - - Args: - expected_event: The expected HealthPillEvent. - gotten_event: The gotten HealthPillEvent. - """ - self.assertEqual(expected_event.wall_time, gotten_event.wall_time) - self.assertEqual(expected_event.step, gotten_event.step) - self.assertEqual(expected_event.device_name, gotten_event.device_name) - self.assertEqual(expected_event.node_name, gotten_event.node_name) - self.assertEqual(expected_event.output_slot, gotten_event.output_slot) - self.assertEqual(len(expected_event.value), len(gotten_event.value)) - for i, expected_value in enumerate(expected_event.value): - self.assertEqual(expected_value, gotten_event.value[i]) - - def testHealthPills(self): - """HealthPills should be properly inserted into EventAccumulator.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - health_pill_elements_1 = list(range(1, 13)) + [ - float(1), 2.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0] - gen.AddHealthPill(13371337, 41, '/job:localhost/replica:0/task:0/cpu:0', - 'Add', 0, health_pill_elements_1) - health_pill_elements_2 = list(range(42, 54)) + [ - float(2), 2.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0] - gen.AddHealthPill(13381338, 42, '/job:localhost/replica:0/task:0/gpu:0', - 'Add', 1, health_pill_elements_2) - acc.Reload() - - # Retrieve the health pills for each node name. - gotten_events = acc.HealthPills('Add') - self.assertEquals(2, len(gotten_events)) - self._compareHealthPills( - ea.HealthPillEvent( - wall_time=13371337, - step=41, - device_name='/job:localhost/replica:0/task:0/cpu:0', - node_name='Add', - output_slot=0, - dtype='tf.float32', - shape=[1, 2], - value=health_pill_elements_1), gotten_events[0]) - self._compareHealthPills( - ea.HealthPillEvent( - wall_time=13381338, - device_name='/job:localhost/replica:0/task:0/gpu:0', - step=42, - node_name='Add', - output_slot=1, - dtype='tf.float64', - shape=[3, 4], - value=health_pill_elements_2), gotten_events[1]) - - def testGetOpsWithHealthPills(self): - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - health_pill_elements_1 = list(range(1, 13)) + [ - float(1), 2.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0] - gen.AddHealthPill(13371337, 41, '/job:localhost/replica:0/task:0/cpu:0', - 'Add', 0, health_pill_elements_1) - health_pill_elements_2 = list(range(42, 54)) + [ - float(2), 2.0, 1.0, 2.0, 0.0, 0.0, 0.0, 0.0] - gen.AddHealthPill(13381338, 42, '/job:localhost/replica:0/task:0/cpu:0', - 'MatMul', 1, health_pill_elements_2) - acc.Reload() - self.assertItemsEqual(['Add', 'MatMul'], acc.GetOpsWithHealthPills()) - - def testHistograms(self): - """Tests whether histograms are inserted into EventAccumulator.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - - val1 = ea.HistogramValue( - min=1, - max=2, - num=3, - sum=4, - sum_squares=5, - bucket_limit=[1, 2, 3], - bucket=[0, 3, 0]) - val2 = ea.HistogramValue( - min=-2, - max=3, - num=4, - sum=5, - sum_squares=6, - bucket_limit=[2, 3, 4], - bucket=[1, 3, 0]) - - hst1 = ea.HistogramEvent(wall_time=1, step=10, histogram_value=val1) - hst2 = ea.HistogramEvent(wall_time=2, step=12, histogram_value=val2) - gen.AddHistogram( - 'hst1', - wall_time=1, - step=10, - hmin=1, - hmax=2, - hnum=3, - hsum=4, - hsum_squares=5, - hbucket_limit=[1, 2, 3], - hbucket=[0, 3, 0]) - gen.AddHistogram( - 'hst2', - wall_time=2, - step=12, - hmin=-2, - hmax=3, - hnum=4, - hsum=5, - hsum_squares=6, - hbucket_limit=[2, 3, 4], - hbucket=[1, 3, 0]) - acc.Reload() - self.assertEqual(acc.Histograms('hst1'), [hst1]) - self.assertEqual(acc.Histograms('hst2'), [hst2]) - - def testCompressedHistograms(self): - """Tests compressed histograms inserted into EventAccumulator.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen, compression_bps=(0, 2500, 5000, 7500, 10000)) - - gen.AddHistogram( - 'hst1', - wall_time=1, - step=10, - hmin=1, - hmax=2, - hnum=3, - hsum=4, - hsum_squares=5, - hbucket_limit=[1, 2, 3], - hbucket=[0, 3, 0]) - gen.AddHistogram( - 'hst2', - wall_time=2, - step=12, - hmin=-2, - hmax=3, - hnum=4, - hsum=5, - hsum_squares=6, - hbucket_limit=[2, 3, 4], - hbucket=[1, 3, 0]) - acc.Reload() - - # Create the expected values after compressing hst1 - expected_vals1 = [ - ea.CompressedHistogramValue(bp, val) - for bp, val in [(0, 1.0), (2500, 1.25), (5000, 1.5), (7500, 1.75 - ), (10000, 2.0)] - ] - expected_cmphst1 = ea.CompressedHistogramEvent( - wall_time=1, step=10, compressed_histogram_values=expected_vals1) - self.assertEqual(acc.CompressedHistograms('hst1'), [expected_cmphst1]) - - # Create the expected values after compressing hst2 - expected_vals2 = [ - ea.CompressedHistogramValue(bp, val) - for bp, val in [(0, -2), - (2500, 2), - (5000, 2 + 1 / 3), - (7500, 2 + 2 / 3), - (10000, 3)] - ] - expected_cmphst2 = ea.CompressedHistogramEvent( - wall_time=2, step=12, compressed_histogram_values=expected_vals2) - self.assertEqual(acc.CompressedHistograms('hst2'), [expected_cmphst2]) - - def testCompressedHistogramsWithEmptyHistogram(self): - """Tests that empty histograms compressed properly in EventAccumulator.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen, compression_bps=(0, 2500, 5000, 7500, 10000)) - - gen.AddHistogram( - 'hst1', - wall_time=1, - step=10, - hmin=None, - hmax=None, - hnum=0, - hsum=0, - hsum_squares=0, - hbucket_limit=[1, 2, 3], - hbucket=[0, 0, 0]) - acc.Reload() - - # Create the expected values after compressing hst1 - expected_vals1 = [ - ea.CompressedHistogramValue(bp, val) - for bp, val in [(0, 0.0), (2500, 0), (5000, 0), (7500, 0), (10000, 0)] - ] - expected_cmphst1 = ea.CompressedHistogramEvent( - wall_time=1, step=10, compressed_histogram_values=expected_vals1) - self.assertEqual(acc.CompressedHistograms('hst1'), [expected_cmphst1]) - - def testCompressHistogram_uglyHistogram(self): - bps = (0, 668, 1587, 3085, 5000, 6915, 8413, 9332, 10000) - histogram_values = ea.HistogramValue( - min=0.0, - max=1.0, - num=960.0, - sum=64.0, - sum_squares=64.0, - bucket_limit=[ - 0.0, 1e-12, 0.917246389039776, 1.0089710279437536, - 1.7976931348623157e+308 - ], - bucket=[0.0, 896.0, 0.0, 64.0, 0.0]) - histogram_event = ea.HistogramEvent(0, 0, histogram_values) - compressed_event = ea._CompressHistogram(histogram_event, bps) - vals = compressed_event.compressed_histogram_values - self.assertEquals(tuple(v.basis_point for v in vals), bps) - self.assertAlmostEqual(vals[0].value, 0.0) - self.assertAlmostEqual(vals[1].value, 7.157142857142856e-14) - self.assertAlmostEqual(vals[2].value, 1.7003571428571426e-13) - self.assertAlmostEqual(vals[3].value, 3.305357142857143e-13) - self.assertAlmostEqual(vals[4].value, 5.357142857142857e-13) - self.assertAlmostEqual(vals[5].value, 7.408928571428571e-13) - self.assertAlmostEqual(vals[6].value, 9.013928571428571e-13) - self.assertAlmostEqual(vals[7].value, 9.998571428571429e-13) - self.assertAlmostEqual(vals[8].value, 1.0) - - def testImages(self): - """Tests 2 images inserted/accessed in EventAccumulator.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - im1 = ea.ImageEvent( - wall_time=1, - step=10, - encoded_image_string=b'big', - width=400, - height=300) - im2 = ea.ImageEvent( - wall_time=2, - step=12, - encoded_image_string=b'small', - width=40, - height=30) - gen.AddImage( - 'im1', - wall_time=1, - step=10, - encoded_image_string=b'big', - width=400, - height=300) - gen.AddImage( - 'im2', - wall_time=2, - step=12, - encoded_image_string=b'small', - width=40, - height=30) - acc.Reload() - self.assertEqual(acc.Images('im1'), [im1]) - self.assertEqual(acc.Images('im2'), [im2]) - - def testAudio(self): - """Tests 2 audio events inserted/accessed in EventAccumulator.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - snd1 = ea.AudioEvent( - wall_time=1, - step=10, - encoded_audio_string=b'big', - content_type='audio/wav', - sample_rate=44100, - length_frames=441000) - snd2 = ea.AudioEvent( - wall_time=2, - step=12, - encoded_audio_string=b'small', - content_type='audio/wav', - sample_rate=44100, - length_frames=44100) - gen.AddAudio( - 'snd1', - wall_time=1, - step=10, - encoded_audio_string=b'big', - content_type='audio/wav', - sample_rate=44100, - length_frames=441000) - gen.AddAudio( - 'snd2', - wall_time=2, - step=12, - encoded_audio_string=b'small', - content_type='audio/wav', - sample_rate=44100, - length_frames=44100) - acc.Reload() - self.assertEqual(acc.Audio('snd1'), [snd1]) - self.assertEqual(acc.Audio('snd2'), [snd2]) - - def testKeyError(self): - """KeyError should be raised when accessing non-existing keys.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - acc.Reload() - with self.assertRaises(KeyError): - acc.Scalars('s1') - with self.assertRaises(KeyError): - acc.Scalars('hst1') - with self.assertRaises(KeyError): - acc.Scalars('im1') - with self.assertRaises(KeyError): - acc.Histograms('s1') - with self.assertRaises(KeyError): - acc.Histograms('im1') - with self.assertRaises(KeyError): - acc.Images('s1') - with self.assertRaises(KeyError): - acc.Images('hst1') - with self.assertRaises(KeyError): - acc.Audio('s1') - with self.assertRaises(KeyError): - acc.Audio('hst1') - - def testNonValueEvents(self): - """Non-value events in the generator don't cause early exits.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - gen.AddScalar('s1', wall_time=1, step=10, value=20) - gen.AddEvent(tf.Event(wall_time=2, step=20, file_version='nots2')) - gen.AddScalar('s3', wall_time=3, step=100, value=1) - gen.AddHistogram('hst1') - gen.AddImage('im1') - gen.AddAudio('snd1') - - acc.Reload() - self.assertTagsEqual(acc.Tags(), { - ea.IMAGES: ['im1'], - ea.AUDIO: ['snd1'], - ea.SCALARS: ['s1', 's3'], - ea.HISTOGRAMS: ['hst1'], - ea.COMPRESSED_HISTOGRAMS: ['hst1'], - }) - - def testExpiredDataDiscardedAfterRestartForFileVersionLessThan2(self): - """Tests that events are discarded after a restart is detected. - - If a step value is observed to be lower than what was previously seen, - this should force a discard of all previous items with the same tag - that are outdated. - - Only file versions < 2 use this out-of-order discard logic. Later versions - discard events based on the step value of SessionLog.START. - """ - warnings = [] - self.stubs.Set(tf.logging, 'warn', warnings.append) - - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - - gen.AddEvent(tf.Event(wall_time=0, step=0, file_version='brain.Event:1')) - gen.AddScalar('s1', wall_time=1, step=100, value=20) - gen.AddScalar('s1', wall_time=1, step=200, value=20) - gen.AddScalar('s1', wall_time=1, step=300, value=20) - acc.Reload() - ## Check that number of items are what they should be - self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200, 300]) - - gen.AddScalar('s1', wall_time=1, step=101, value=20) - gen.AddScalar('s1', wall_time=1, step=201, value=20) - gen.AddScalar('s1', wall_time=1, step=301, value=20) - acc.Reload() - ## Check that we have discarded 200 and 300 from s1 - self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 101, 201, 301]) - - def testOrphanedDataNotDiscardedIfFlagUnset(self): - """Tests that events are not discarded if purge_orphaned_data is false. - """ - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen, purge_orphaned_data=False) - - gen.AddEvent(tf.Event(wall_time=0, step=0, file_version='brain.Event:1')) - gen.AddScalar('s1', wall_time=1, step=100, value=20) - gen.AddScalar('s1', wall_time=1, step=200, value=20) - gen.AddScalar('s1', wall_time=1, step=300, value=20) - acc.Reload() - ## Check that number of items are what they should be - self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200, 300]) - - gen.AddScalar('s1', wall_time=1, step=101, value=20) - gen.AddScalar('s1', wall_time=1, step=201, value=20) - gen.AddScalar('s1', wall_time=1, step=301, value=20) - acc.Reload() - ## Check that we have discarded 200 and 300 from s1 - self.assertEqual([x.step for x in acc.Scalars('s1')], - [100, 200, 300, 101, 201, 301]) - - def testEventsDiscardedPerTagAfterRestartForFileVersionLessThan2(self): - """Tests that event discards after restart, only affect the misordered tag. - - If a step value is observed to be lower than what was previously seen, - this should force a discard of all previous items that are outdated, but - only for the out of order tag. Other tags should remain unaffected. - - Only file versions < 2 use this out-of-order discard logic. Later versions - discard events based on the step value of SessionLog.START. - """ - warnings = [] - self.stubs.Set(tf.logging, 'warn', warnings.append) - - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - - gen.AddEvent(tf.Event(wall_time=0, step=0, file_version='brain.Event:1')) - gen.AddScalar('s1', wall_time=1, step=100, value=20) - gen.AddScalar('s1', wall_time=1, step=200, value=20) - gen.AddScalar('s1', wall_time=1, step=300, value=20) - gen.AddScalar('s1', wall_time=1, step=101, value=20) - gen.AddScalar('s1', wall_time=1, step=201, value=20) - gen.AddScalar('s1', wall_time=1, step=301, value=20) - - gen.AddScalar('s2', wall_time=1, step=101, value=20) - gen.AddScalar('s2', wall_time=1, step=201, value=20) - gen.AddScalar('s2', wall_time=1, step=301, value=20) - - acc.Reload() - ## Check that we have discarded 200 and 300 - self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 101, 201, 301]) - - ## Check that s1 discards do not affect s2 - ## i.e. check that only events from the out of order tag are discarded - self.assertEqual([x.step for x in acc.Scalars('s2')], [101, 201, 301]) - - def testOnlySummaryEventsTriggerDiscards(self): - """Test that file version event does not trigger data purge.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - gen.AddScalar('s1', wall_time=1, step=100, value=20) - ev1 = tf.Event(wall_time=2, step=0, file_version='brain.Event:1') - graph_bytes = tf.GraphDef().SerializeToString() - ev2 = tf.Event(wall_time=3, step=0, graph_def=graph_bytes) - gen.AddEvent(ev1) - gen.AddEvent(ev2) - acc.Reload() - self.assertEqual([x.step for x in acc.Scalars('s1')], [100]) - - def testSessionLogStartMessageDiscardsExpiredEvents(self): - """Test that SessionLog.START message discards expired events. - - This discard logic is preferred over the out-of-order step discard logic, - but this logic can only be used for event protos which have the SessionLog - enum, which was introduced to event.proto for file_version >= brain.Event:2. - """ - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - gen.AddEvent(tf.Event(wall_time=0, step=1, file_version='brain.Event:2')) - - gen.AddScalar('s1', wall_time=1, step=100, value=20) - gen.AddScalar('s1', wall_time=1, step=200, value=20) - gen.AddScalar('s1', wall_time=1, step=300, value=20) - gen.AddScalar('s1', wall_time=1, step=400, value=20) - - gen.AddScalar('s2', wall_time=1, step=202, value=20) - gen.AddScalar('s2', wall_time=1, step=203, value=20) - - slog = tf.SessionLog(status=tf.SessionLog.START) - gen.AddEvent(tf.Event(wall_time=2, step=201, session_log=slog)) - acc.Reload() - self.assertEqual([x.step for x in acc.Scalars('s1')], [100, 200]) - self.assertEqual([x.step for x in acc.Scalars('s2')], []) - - def testFirstEventTimestamp(self): - """Test that FirstEventTimestamp() returns wall_time of the first event.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - gen.AddEvent(tf.Event(wall_time=10, step=20, file_version='brain.Event:2')) - gen.AddScalar('s1', wall_time=30, step=40, value=20) - self.assertEqual(acc.FirstEventTimestamp(), 10) - - def testReloadPopulatesFirstEventTimestamp(self): - """Test that Reload() means FirstEventTimestamp() won't load events.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - gen.AddEvent(tf.Event(wall_time=1, step=2, file_version='brain.Event:2')) - - acc.Reload() - - def _Die(*args, **kwargs): # pylint: disable=unused-argument - raise RuntimeError('Load() should not be called') - - self.stubs.Set(gen, 'Load', _Die) - self.assertEqual(acc.FirstEventTimestamp(), 1) - - def testFirstEventTimestampLoadsEvent(self): - """Test that FirstEventTimestamp() doesn't discard the loaded event.""" - gen = _EventGenerator(self) - acc = ea.EventAccumulator(gen) - gen.AddEvent(tf.Event(wall_time=1, step=2, file_version='brain.Event:2')) - - self.assertEqual(acc.FirstEventTimestamp(), 1) - acc.Reload() - self.assertEqual(acc.file_version, 2.0) - - def testTFSummaryScalar(self): - """Verify processing of tf.summary.scalar.""" - event_sink = _EventGenerator(self, zero_out_timestamps=True) - writer = tf.summary.FileWriter(self.get_temp_dir()) - writer.event_writer = event_sink - with self.test_session() as sess: - ipt = tf.placeholder(tf.float32) - tf.summary.scalar('scalar1', ipt) - tf.summary.scalar('scalar2', ipt * ipt) - merged = tf.summary.merge_all() - writer.add_graph(sess.graph) - for i in xrange(10): - summ = sess.run(merged, feed_dict={ipt: i}) - writer.add_summary(summ, global_step=i) - - accumulator = ea.EventAccumulator(event_sink) - accumulator.Reload() - - seq1 = [ea.ScalarEvent(wall_time=0, step=i, value=i) for i in xrange(10)] - seq2 = [ - ea.ScalarEvent( - wall_time=0, step=i, value=i * i) for i in xrange(10) - ] - - self.assertTagsEqual(accumulator.Tags(), { - ea.SCALARS: ['scalar1', 'scalar2'], - ea.GRAPH: True, - ea.META_GRAPH: False, - }) - - self.assertEqual(accumulator.Scalars('scalar1'), seq1) - self.assertEqual(accumulator.Scalars('scalar2'), seq2) - first_value = accumulator.Scalars('scalar1')[0].value - self.assertTrue(isinstance(first_value, float)) - - def testTFSummaryImage(self): - """Verify processing of tf.summary.image.""" - event_sink = _EventGenerator(self, zero_out_timestamps=True) - writer = tf.summary.FileWriter(self.get_temp_dir()) - writer.event_writer = event_sink - with self.test_session() as sess: - ipt = tf.ones([10, 4, 4, 3], tf.uint8) - # This is an interesting example, because the old tf.image_summary op - # would throw an error here, because it would be tag reuse. - # Using the tf node name instead allows argument re-use to the image - # summary. - with tf.name_scope('1'): - tf.summary.image('images', ipt, max_outputs=1) - with tf.name_scope('2'): - tf.summary.image('images', ipt, max_outputs=2) - with tf.name_scope('3'): - tf.summary.image('images', ipt, max_outputs=3) - merged = tf.summary.merge_all() - writer.add_graph(sess.graph) - for i in xrange(10): - summ = sess.run(merged) - writer.add_summary(summ, global_step=i) - - accumulator = ea.EventAccumulator(event_sink) - accumulator.Reload() - - tags = [ - u'1/images/image', u'2/images/image/0', u'2/images/image/1', - u'3/images/image/0', u'3/images/image/1', u'3/images/image/2' - ] - - self.assertTagsEqual(accumulator.Tags(), { - ea.IMAGES: tags, - ea.GRAPH: True, - ea.META_GRAPH: False, - }) - - def testTFSummaryTensor(self): - """Verify processing of tf.summary.tensor.""" - event_sink = _EventGenerator(self, zero_out_timestamps=True) - writer = tf.summary.FileWriter(self.get_temp_dir()) - writer.event_writer = event_sink - with self.test_session() as sess: - tf.summary.tensor_summary('scalar', tf.constant(1.0)) - tf.summary.tensor_summary('vector', tf.constant([1.0, 2.0, 3.0])) - tf.summary.tensor_summary('string', tf.constant(six.b('foobar'))) - merged = tf.summary.merge_all() - summ = sess.run(merged) - writer.add_summary(summ, 0) - - accumulator = ea.EventAccumulator(event_sink) - accumulator.Reload() - - self.assertTagsEqual(accumulator.Tags(), { - ea.TENSORS: ['scalar', 'vector', 'string'], - }) - - scalar_proto = accumulator.Tensors('scalar')[0].tensor_proto - scalar = tf.make_ndarray(scalar_proto) - vector_proto = accumulator.Tensors('vector')[0].tensor_proto - vector = tf.make_ndarray(vector_proto) - string_proto = accumulator.Tensors('string')[0].tensor_proto - string = tf.make_ndarray(string_proto) - - self.assertTrue(np.array_equal(scalar, 1.0)) - self.assertTrue(np.array_equal(vector, [1.0, 2.0, 3.0])) - self.assertTrue(np.array_equal(string, six.b('foobar'))) - - -class RealisticEventAccumulatorTest(EventAccumulatorTest): - - def setUp(self): - super(RealisticEventAccumulatorTest, self).setUp() - - def testScalarsRealistically(self): - """Test accumulator by writing values and then reading them.""" - - def FakeScalarSummary(tag, value): - value = tf.Summary.Value(tag=tag, simple_value=value) - summary = tf.Summary(value=[value]) - return summary - - directory = os.path.join(self.get_temp_dir(), 'values_dir') - if tf.gfile.IsDirectory(directory): - tf.gfile.DeleteRecursively(directory) - tf.gfile.MkDir(directory) - - writer = tf.summary.FileWriter(directory, max_queue=100) - - with tf.Graph().as_default() as graph: - _ = tf.constant([2.0, 1.0]) - # Add a graph to the summary writer. - writer.add_graph(graph) - meta_graph_def = tf.train.export_meta_graph(graph_def=graph.as_graph_def( - add_shapes=True)) - writer.add_meta_graph(meta_graph_def) - - run_metadata = tf.RunMetadata() - device_stats = run_metadata.step_stats.dev_stats.add() - device_stats.device = 'test device' - writer.add_run_metadata(run_metadata, 'test run') - - # Write a bunch of events using the writer. - for i in xrange(30): - summ_id = FakeScalarSummary('id', i) - summ_sq = FakeScalarSummary('sq', i * i) - writer.add_summary(summ_id, i * 5) - writer.add_summary(summ_sq, i * 5) - writer.flush() - - # Verify that we can load those events properly - acc = ea.EventAccumulator(directory) - acc.Reload() - self.assertTagsEqual(acc.Tags(), { - ea.SCALARS: ['id', 'sq'], - ea.GRAPH: True, - ea.META_GRAPH: True, - ea.RUN_METADATA: ['test run'], - }) - id_events = acc.Scalars('id') - sq_events = acc.Scalars('sq') - self.assertEqual(30, len(id_events)) - self.assertEqual(30, len(sq_events)) - for i in xrange(30): - self.assertEqual(i * 5, id_events[i].step) - self.assertEqual(i * 5, sq_events[i].step) - self.assertEqual(i, id_events[i].value) - self.assertEqual(i * i, sq_events[i].value) - - # Write a few more events to test incremental reloading - for i in xrange(30, 40): - summ_id = FakeScalarSummary('id', i) - summ_sq = FakeScalarSummary('sq', i * i) - writer.add_summary(summ_id, i * 5) - writer.add_summary(summ_sq, i * 5) - writer.flush() - - # Verify we can now see all of the data - acc.Reload() - id_events = acc.Scalars('id') - sq_events = acc.Scalars('sq') - self.assertEqual(40, len(id_events)) - self.assertEqual(40, len(sq_events)) - for i in xrange(40): - self.assertEqual(i * 5, id_events[i].step) - self.assertEqual(i * 5, sq_events[i].step) - self.assertEqual(i, id_events[i].value) - self.assertEqual(i * i, sq_events[i].value) - self.assertProtoEquals(graph.as_graph_def(add_shapes=True), acc.Graph()) - self.assertProtoEquals(meta_graph_def, acc.MetaGraph()) - - def testGraphFromMetaGraphBecomesAvailable(self): - """Test accumulator by writing values and then reading them.""" - - directory = os.path.join(self.get_temp_dir(), 'metagraph_test_values_dir') - if tf.gfile.IsDirectory(directory): - tf.gfile.DeleteRecursively(directory) - tf.gfile.MkDir(directory) - - writer = tf.summary.FileWriter(directory, max_queue=100) - - with tf.Graph().as_default() as graph: - _ = tf.constant([2.0, 1.0]) - # Add a graph to the summary writer. - meta_graph_def = tf.train.export_meta_graph(graph_def=graph.as_graph_def( - add_shapes=True)) - writer.add_meta_graph(meta_graph_def) - - writer.flush() - - # Verify that we can load those events properly - acc = ea.EventAccumulator(directory) - acc.Reload() - self.assertTagsEqual(acc.Tags(), { - ea.GRAPH: True, - ea.META_GRAPH: True, - }) - self.assertProtoEquals(graph.as_graph_def(add_shapes=True), acc.Graph()) - self.assertProtoEquals(meta_graph_def, acc.MetaGraph()) - - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow/tensorboard/backend/event_processing/event_file_inspector.py b/tensorflow/tensorboard/backend/event_processing/event_file_inspector.py deleted file mode 100644 index e120dd2ab16..00000000000 --- a/tensorflow/tensorboard/backend/event_processing/event_file_inspector.py +++ /dev/null @@ -1,427 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Logic for TensorBoard inspector to help humans investigate event files. - -Example usages: -tensorboard --inspect --event_file=myevents.out -tensorboard --inspect --event_file=myevents.out --tag=loss -tensorboard --inspect --logdir=mylogdir -tensorboard --inspect --logdir=mylogdir --tag=loss - - -This script runs over a logdir and creates an InspectionUnit for every -subdirectory with event files. If running over an event file, it creates only -one InspectionUnit. One block of output is printed to console for each -InspectionUnit. - -The primary content of an InspectionUnit is the dict field_to_obs that maps -fields (e.g. "scalar", "histogram", "session_log:start", etc.) to a list of -Observations for the field. Observations correspond one-to-one with Events in an -event file but contain less information because they only store what is -necessary to generate the final console output. - -The final output is rendered to console by applying some aggregating function -to the lists of Observations. Different functions are applied depending on the -type of field. For instance, for "scalar" fields, the inspector shows aggregate -statistics. For other fields like "session_log:start", all observed steps are -printed in order to aid debugging. - - -[1] Query a logdir or an event file for its logged tags and summary statistics -using --logdir or --event_file. - -[[event_file]] contains these tags: -histograms - binary/Sign/Activations - binary/nn_tanh/act/Activations - binary/nn_tanh/biases - binary/nn_tanh/biases:gradient - binary/nn_tanh/weights - binary/nn_tanh/weights:gradient -images - input_images/image/0 - input_images/image/1 - input_images/image/2 -scalars - Learning Rate - Total Cost - Total Cost (raw) - -Debug output aggregated over all tags: -graph - first_step 0 - last_step 0 - max_step 0 - min_step 0 - num_steps 1 - outoforder_steps [] -histograms - first_step 491 - last_step 659823 - max_step 659823 - min_step 491 - num_steps 993 - outoforder_steps [] -images - -scalars - first_step 0 - last_step 659823 - max_step 659823 - min_step 0 - num_steps 1985 - outoforder_steps [] -sessionlog:checkpoint - first_step 7129 - last_step 657167 - max_step 657167 - min_step 7129 - num_steps 99 - outoforder_steps [] -sessionlog:start - outoforder_steps [] - steps [0L] -sessionlog:stop - - - -[2] Drill down into a particular tag using --tag. - -Debug output for binary/Sign/Activations: -histograms - first_step 491 - last_step 659823 - max_step 659823 - min_step 491 - num_steps 993 - outoforder_steps [] -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import itertools -import os - -import tensorflow as tf - -from tensorflow.tensorboard.backend.event_processing import event_accumulator -from tensorflow.tensorboard.backend.event_processing import event_file_loader -from tensorflow.tensorboard.backend.event_processing import event_multiplexer - -FLAGS = tf.flags.FLAGS - - -# Map of field names within summary.proto to the user-facing names that this -# script outputs. -SUMMARY_TYPE_TO_FIELD = {'simple_value': 'scalars', - 'histo': 'histograms', - 'image': 'images', - 'audio': 'audio'} -for summary_type in event_accumulator.SUMMARY_TYPES: - if summary_type not in SUMMARY_TYPE_TO_FIELD: - SUMMARY_TYPE_TO_FIELD[summary_type] = summary_type - -# Types of summaries that we may want to query for by tag. -TAG_FIELDS = list(SUMMARY_TYPE_TO_FIELD.values()) - -# Summaries that we want to see every instance of. -LONG_FIELDS = ['sessionlog:start', 'sessionlog:stop'] - -# Summaries that we only want an abridged digest of, since they would -# take too much screen real estate otherwise. -SHORT_FIELDS = ['graph', 'sessionlog:checkpoint'] + TAG_FIELDS - -# All summary types that we can inspect. -TRACKED_FIELDS = SHORT_FIELDS + LONG_FIELDS - -# An `Observation` contains the data within each Event file that the inspector -# cares about. The inspector accumulates Observations as it processes events. -Observation = collections.namedtuple('Observation', ['step', 'wall_time', - 'tag']) - -# An InspectionUnit is created for each organizational structure in the event -# files visible in the final terminal output. For instance, one InspectionUnit -# is created for each subdirectory in logdir. When asked to inspect a single -# event file, there may only be one InspectionUnit. - -# The InspectionUnit contains the `name` of the organizational unit that will be -# printed to console, a `generator` that yields `Event` protos, and a mapping -# from string fields to `Observations` that the inspector creates. -InspectionUnit = collections.namedtuple('InspectionUnit', ['name', 'generator', - 'field_to_obs']) - -PRINT_SEPARATOR = '=' * 70 + '\n' - - -def get_field_to_observations_map(generator, query_for_tag=''): - """Return a field to `Observations` dict for the event generator. - - Args: - generator: A generator over event protos. - query_for_tag: A string that if specified, only create observations for - events with this tag name. - - Returns: - A dict mapping keys in `TRACKED_FIELDS` to an `Observation` list. - """ - - def increment(stat, event, tag=''): - assert stat in TRACKED_FIELDS - field_to_obs[stat].append(Observation(step=event.step, - wall_time=event.wall_time, - tag=tag)._asdict()) - - field_to_obs = dict([(t, []) for t in TRACKED_FIELDS]) - - for event in generator: - ## Process the event - if event.HasField('graph_def') and (not query_for_tag): - increment('graph', event) - if event.HasField('session_log') and (not query_for_tag): - status = event.session_log.status - if status == tf.SessionLog.START: - increment('sessionlog:start', event) - elif status == tf.SessionLog.STOP: - increment('sessionlog:stop', event) - elif status == tf.SessionLog.CHECKPOINT: - increment('sessionlog:checkpoint', event) - elif event.HasField('summary'): - for value in event.summary.value: - if query_for_tag and value.tag != query_for_tag: - continue - - for proto_name, display_name in SUMMARY_TYPE_TO_FIELD.items(): - if value.HasField(proto_name): - increment(display_name, event, value.tag) - return field_to_obs - - -def get_unique_tags(field_to_obs): - """Returns a dictionary of tags that a user could query over. - - Args: - field_to_obs: Dict that maps string field to `Observation` list. - - Returns: - A dict that maps keys in `TAG_FIELDS` to a list of string tags present in - the event files. If the dict does not have any observations of the type, - maps to an empty list so that we can render this to console. - """ - return {field: sorted(set([x.get('tag', '') for x in observations])) - for field, observations in field_to_obs.items() - if field in TAG_FIELDS} - - -def print_dict(d, show_missing=True): - """Prints a shallow dict to console. - - Args: - d: Dict to print. - show_missing: Whether to show keys with empty values. - """ - for k, v in sorted(d.items()): - if (not v) and show_missing: - # No instances of the key, so print missing symbol. - print('{} -'.format(k)) - elif isinstance(v, list): - # Value is a list, so print each item of the list. - print(k) - for item in v: - print(' {}'.format(item)) - elif isinstance(v, dict): - # Value is a dict, so print each (key, value) pair of the dict. - print(k) - for kk, vv in sorted(v.items()): - print(' {:<20} {}'.format(kk, vv)) - - -def get_dict_to_print(field_to_obs): - """Transform the field-to-obs mapping into a printable dictionary. - - Args: - field_to_obs: Dict that maps string field to `Observation` list. - - Returns: - A dict with the keys and values to print to console. - """ - - def compressed_steps(steps): - return {'num_steps': len(set(steps)), - 'min_step': min(steps), - 'max_step': max(steps), - 'last_step': steps[-1], - 'first_step': steps[0], - 'outoforder_steps': get_out_of_order(steps)} - - def full_steps(steps): - return {'steps': steps, 'outoforder_steps': get_out_of_order(steps)} - - output = {} - for field, observations in field_to_obs.items(): - if not observations: - output[field] = None - continue - - steps = [x['step'] for x in observations] - if field in SHORT_FIELDS: - output[field] = compressed_steps(steps) - if field in LONG_FIELDS: - output[field] = full_steps(steps) - - return output - - -def get_out_of_order(list_of_numbers): - """Returns elements that break the monotonically non-decreasing trend. - - This is used to find instances of global step values that are "out-of-order", - which may trigger TensorBoard event discarding logic. - - Args: - list_of_numbers: A list of numbers. - - Returns: - A list of tuples in which each tuple are two elements are adjacent, but the - second element is lower than the first. - """ - # TODO(cassandrax): Consider changing this to only check for out-of-order - # steps within a particular tag. - result = [] - for i in range(len(list_of_numbers)): - if i == 0: - continue - if list_of_numbers[i] < list_of_numbers[i - 1]: - result.append((list_of_numbers[i - 1], list_of_numbers[i])) - return result - - -def generators_from_logdir(logdir): - """Returns a list of event generators for subdirectories with event files. - - The number of generators returned should equal the number of directories - within logdir that contain event files. If only logdir contains event files, - returns a list of length one. - - Args: - logdir: A log directory that contains event files. - - Returns: - List of event generators for each subdirectory with event files. - """ - subdirs = event_multiplexer.GetLogdirSubdirectories(logdir) - generators = [ - itertools.chain(*[ - generator_from_event_file(os.path.join(subdir, f)) - for f in tf.gfile.ListDirectory(subdir) - if event_accumulator.IsTensorFlowEventsFile(os.path.join(subdir, f)) - ]) for subdir in subdirs - ] - return generators - - -def generator_from_event_file(event_file): - """Returns a generator that yields events from an event file.""" - return event_file_loader.EventFileLoader(event_file).Load() - - -def get_inspection_units(logdir='', event_file='', tag=''): - """Returns a list of InspectionUnit objects given either logdir or event_file. - - If logdir is given, the number of InspectionUnits should equal the - number of directories or subdirectories that contain event files. - - If event_file is given, the number of InspectionUnits should be 1. - - Args: - logdir: A log directory that contains event files. - event_file: Or, a particular event file path. - tag: An optional tag name to query for. - - Returns: - A list of InspectionUnit objects. - """ - if logdir: - subdirs = event_multiplexer.GetLogdirSubdirectories(logdir) - inspection_units = [] - for subdir in subdirs: - generator = itertools.chain(*[ - generator_from_event_file(os.path.join(subdir, f)) - for f in tf.gfile.ListDirectory(subdir) - if event_accumulator.IsTensorFlowEventsFile(os.path.join(subdir, f)) - ]) - inspection_units.append(InspectionUnit( - name=subdir, - generator=generator, - field_to_obs=get_field_to_observations_map(generator, tag))) - if inspection_units: - print('Found event files in:\n{}\n'.format('\n'.join( - [u.name for u in inspection_units]))) - elif event_accumulator.IsTensorFlowEventsFile(logdir): - print( - 'It seems that {} may be an event file instead of a logdir. If this ' - 'is the case, use --event_file instead of --logdir to pass ' - 'it in.'.format(logdir)) - else: - print('No event files found within logdir {}'.format(logdir)) - return inspection_units - elif event_file: - generator = generator_from_event_file(event_file) - return [InspectionUnit( - name=event_file, - generator=generator, - field_to_obs=get_field_to_observations_map(generator, tag))] - - -def inspect(logdir='', event_file='', tag=''): - """Main function for inspector that prints out a digest of event files. - - Args: - logdir: A log directory that contains event files. - event_file: Or, a particular event file path. - tag: An optional tag name to query for. - - Raises: - ValueError: If neither logdir and event_file are given, or both are given. - """ - if logdir and event_file: - raise ValueError( - 'Must specify either --logdir or --event_file, but not both.') - if not (logdir or event_file): - raise ValueError('Must specify either --logdir or --event_file.') - - print(PRINT_SEPARATOR + - 'Processing event files... (this can take a few minutes)\n' + - PRINT_SEPARATOR) - inspection_units = get_inspection_units(logdir, event_file, tag) - - for unit in inspection_units: - if tag: - print('Event statistics for tag {} in {}:'.format(tag, unit.name)) - else: - # If the user is not inspecting a particular tag, also print the list of - # all available tags that they can query. - print('These tags are in {}:'.format(unit.name)) - print_dict(get_unique_tags(unit.field_to_obs)) - print(PRINT_SEPARATOR) - print('Event statistics for {}:'.format(unit.name)) - - print_dict(get_dict_to_print(unit.field_to_obs), show_missing=(not tag)) - print(PRINT_SEPARATOR) - - -if __name__ == '__main__': - tf.app.run() diff --git a/tensorflow/tensorboard/backend/event_processing/event_file_inspector_test.py b/tensorflow/tensorboard/backend/event_processing/event_file_inspector_test.py deleted file mode 100644 index 084043d5110..00000000000 --- a/tensorflow/tensorboard/backend/event_processing/event_file_inspector_test.py +++ /dev/null @@ -1,189 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import shutil - -import tensorflow as tf - -from tensorflow.tensorboard.backend.event_processing import event_file_inspector as efi - - -class EventFileInspectorTest(tf.test.TestCase): - - def setUp(self): - self.logdir = os.path.join(self.get_temp_dir(), 'tfevents') - self._MakeDirectoryIfNotExists(self.logdir) - - def tearDown(self): - shutil.rmtree(self.logdir) - - def _MakeDirectoryIfNotExists(self, path): - if not os.path.exists(path): - os.mkdir(path) - - def _WriteScalarSummaries(self, data, subdirs=('',)): - # Writes data to a tempfile in subdirs, and returns generator for the data. - # If subdirs is given, writes data identically to all subdirectories. - for subdir_ in subdirs: - subdir = os.path.join(self.logdir, subdir_) - self._MakeDirectoryIfNotExists(subdir) - - sw = tf.summary.FileWriter(subdir) - for datum in data: - summary = tf.Summary() - if 'simple_value' in datum: - summary.value.add(tag=datum['tag'], - simple_value=datum['simple_value']) - sw.add_summary(summary, global_step=datum['step']) - elif 'histo' in datum: - summary.value.add(tag=datum['tag'], histo=tf.HistogramProto()) - sw.add_summary(summary, global_step=datum['step']) - elif 'session_log' in datum: - sw.add_session_log(datum['session_log'], global_step=datum['step']) - sw.close() - - def testEmptyLogdir(self): - # Nothing was written to logdir - units = efi.get_inspection_units(self.logdir) - self.assertEqual([], units) - - def testGetAvailableTags(self): - data = [{'tag': 'c', 'histo': 2, 'step': 10}, - {'tag': 'c', 'histo': 2, 'step': 11}, - {'tag': 'c', 'histo': 2, 'step': 9}, - {'tag': 'b', 'simple_value': 2, 'step': 20}, - {'tag': 'b', 'simple_value': 2, 'step': 15}, - {'tag': 'a', 'simple_value': 2, 'step': 3}] - self._WriteScalarSummaries(data) - units = efi.get_inspection_units(self.logdir) - tags = efi.get_unique_tags(units[0].field_to_obs) - self.assertEqual(['a', 'b'], tags['scalars']) - self.assertEqual(['c'], tags['histograms']) - - def testInspectAll(self): - data = [{'tag': 'c', 'histo': 2, 'step': 10}, - {'tag': 'c', 'histo': 2, 'step': 11}, - {'tag': 'c', 'histo': 2, 'step': 9}, - {'tag': 'b', 'simple_value': 2, 'step': 20}, - {'tag': 'b', 'simple_value': 2, 'step': 15}, - {'tag': 'a', 'simple_value': 2, 'step': 3}] - self._WriteScalarSummaries(data) - units = efi.get_inspection_units(self.logdir) - printable = efi.get_dict_to_print(units[0].field_to_obs) - self.assertEqual(printable['histograms']['max_step'], 11) - self.assertEqual(printable['histograms']['min_step'], 9) - self.assertEqual(printable['histograms']['num_steps'], 3) - self.assertEqual(printable['histograms']['last_step'], 9) - self.assertEqual(printable['histograms']['first_step'], 10) - self.assertEqual(printable['histograms']['outoforder_steps'], [(11, 9)]) - - self.assertEqual(printable['scalars']['max_step'], 20) - self.assertEqual(printable['scalars']['min_step'], 3) - self.assertEqual(printable['scalars']['num_steps'], 3) - self.assertEqual(printable['scalars']['last_step'], 3) - self.assertEqual(printable['scalars']['first_step'], 20) - self.assertEqual(printable['scalars']['outoforder_steps'], [(20, 15), - (15, 3)]) - - def testInspectTag(self): - data = [{'tag': 'c', 'histo': 2, 'step': 10}, - {'tag': 'c', 'histo': 2, 'step': 11}, - {'tag': 'c', 'histo': 2, 'step': 9}, - {'tag': 'b', 'histo': 2, 'step': 20}, - {'tag': 'b', 'simple_value': 2, 'step': 15}, - {'tag': 'a', 'simple_value': 2, 'step': 3}] - self._WriteScalarSummaries(data) - units = efi.get_inspection_units(self.logdir, tag='c') - printable = efi.get_dict_to_print(units[0].field_to_obs) - self.assertEqual(printable['histograms']['max_step'], 11) - self.assertEqual(printable['histograms']['min_step'], 9) - self.assertEqual(printable['histograms']['num_steps'], 3) - self.assertEqual(printable['histograms']['last_step'], 9) - self.assertEqual(printable['histograms']['first_step'], 10) - self.assertEqual(printable['histograms']['outoforder_steps'], [(11, 9)]) - self.assertEqual(printable['scalars'], None) - - def testSessionLogSummaries(self): - data = [ - { - 'session_log': tf.SessionLog(status=tf.SessionLog.START), - 'step': 0 - }, - { - 'session_log': tf.SessionLog(status=tf.SessionLog.CHECKPOINT), - 'step': 1 - }, - { - 'session_log': tf.SessionLog(status=tf.SessionLog.CHECKPOINT), - 'step': 2 - }, - { - 'session_log': tf.SessionLog(status=tf.SessionLog.CHECKPOINT), - 'step': 3 - }, - { - 'session_log': tf.SessionLog(status=tf.SessionLog.STOP), - 'step': 4 - }, - { - 'session_log': tf.SessionLog(status=tf.SessionLog.START), - 'step': 5 - }, - { - 'session_log': tf.SessionLog(status=tf.SessionLog.STOP), - 'step': 6 - }, - ] - - self._WriteScalarSummaries(data) - units = efi.get_inspection_units(self.logdir) - self.assertEqual(1, len(units)) - printable = efi.get_dict_to_print(units[0].field_to_obs) - self.assertEqual(printable['sessionlog:start']['steps'], [0, 5]) - self.assertEqual(printable['sessionlog:stop']['steps'], [4, 6]) - self.assertEqual(printable['sessionlog:checkpoint']['num_steps'], 3) - - def testInspectAllWithNestedLogdirs(self): - data = [{'tag': 'c', 'simple_value': 2, 'step': 10}, - {'tag': 'c', 'simple_value': 2, 'step': 11}, - {'tag': 'c', 'simple_value': 2, 'step': 9}, - {'tag': 'b', 'simple_value': 2, 'step': 20}, - {'tag': 'b', 'simple_value': 2, 'step': 15}, - {'tag': 'a', 'simple_value': 2, 'step': 3}] - - subdirs = ['eval', 'train'] - self._WriteScalarSummaries(data, subdirs=subdirs) - units = efi.get_inspection_units(self.logdir) - self.assertEqual(2, len(units)) - directory_names = [os.path.join(self.logdir, name) for name in subdirs] - self.assertEqual(directory_names, sorted([unit.name for unit in units])) - - for unit in units: - printable = efi.get_dict_to_print(unit.field_to_obs)['scalars'] - self.assertEqual(printable['max_step'], 20) - self.assertEqual(printable['min_step'], 3) - self.assertEqual(printable['num_steps'], 6) - self.assertEqual(printable['last_step'], 3) - self.assertEqual(printable['first_step'], 10) - self.assertEqual(printable['outoforder_steps'], [(11, 9), (20, 15), - (15, 3)]) - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow/tensorboard/backend/event_processing/event_file_loader.py b/tensorflow/tensorboard/backend/event_processing/event_file_loader.py deleted file mode 100644 index 896142daaf4..00000000000 --- a/tensorflow/tensorboard/backend/event_processing/event_file_loader.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Functionality for loading events from a record file.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - - -class EventFileLoader(object): - """An EventLoader is an iterator that yields Event protos.""" - - def __init__(self, file_path): - if file_path is None: - raise ValueError('A file path is required') - file_path = tf.resource_loader.readahead_file_path(file_path) - tf.logging.debug('Opening a record reader pointing at %s', file_path) - with tf.errors.raise_exception_on_not_ok_status() as status: - self._reader = tf.pywrap_tensorflow.PyRecordReader_New( - tf.compat.as_bytes(file_path), 0, tf.compat.as_bytes(''), status) - # Store it for logging purposes. - self._file_path = file_path - if not self._reader: - raise IOError('Failed to open a record reader pointing to %s' % file_path) - - def Load(self): - """Loads all new values from disk. - - Calling Load multiple times in a row will not 'drop' events as long as the - return value is not iterated over. - - Yields: - All values that were written to disk that have not been yielded yet. - """ - while True: - try: - with tf.errors.raise_exception_on_not_ok_status() as status: - self._reader.GetNext(status) - except (tf.errors.DataLossError, tf.errors.OutOfRangeError): - # We ignore partial read exceptions, because a record may be truncated. - # PyRecordReader holds the offset prior to the failed read, so retrying - # will succeed. - break - event = tf.Event() - event.ParseFromString(self._reader.record()) - yield event - tf.logging.debug('No more events in %s', self._file_path) - - -def main(argv): - if len(argv) != 2: - print('Usage: event_file_loader ') - return 1 - loader = EventFileLoader(argv[1]) - for event in loader.Load(): - print(event) - - -if __name__ == '__main__': - tf.app.run() diff --git a/tensorflow/tensorboard/backend/event_processing/event_file_loader_test.py b/tensorflow/tensorboard/backend/event_processing/event_file_loader_test.py deleted file mode 100644 index 210a7bc52ed..00000000000 --- a/tensorflow/tensorboard/backend/event_processing/event_file_loader_test.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Tests for event_file_loader.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import tempfile - -import tensorflow as tf - - -from tensorflow.tensorboard.backend.event_processing import event_file_loader - - -class EventFileLoaderTest(tf.test.TestCase): - # A record containing a simple event. - RECORD = (b'\x18\x00\x00\x00\x00\x00\x00\x00\xa3\x7fK"\t\x00\x00\xc0%\xddu' - b'\xd5A\x1a\rbrain.Event:1\xec\xf32\x8d') - - def _WriteToFile(self, filename, data): - with open(filename, 'ab') as f: - f.write(data) - - def _LoaderForTestFile(self, filename): - return event_file_loader.EventFileLoader( - os.path.join(self.get_temp_dir(), filename)) - - def testEmptyEventFile(self): - filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name - self._WriteToFile(filename, b'') - loader = self._LoaderForTestFile(filename) - self.assertEqual(len(list(loader.Load())), 0) - - def testSingleWrite(self): - filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name - self._WriteToFile(filename, EventFileLoaderTest.RECORD) - loader = self._LoaderForTestFile(filename) - events = list(loader.Load()) - self.assertEqual(len(events), 1) - self.assertEqual(events[0].wall_time, 1440183447.0) - self.assertEqual(len(list(loader.Load())), 0) - - def testMultipleWrites(self): - filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name - self._WriteToFile(filename, EventFileLoaderTest.RECORD) - loader = self._LoaderForTestFile(filename) - self.assertEqual(len(list(loader.Load())), 1) - self._WriteToFile(filename, EventFileLoaderTest.RECORD) - self.assertEqual(len(list(loader.Load())), 1) - - def testMultipleLoads(self): - filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name - self._WriteToFile(filename, EventFileLoaderTest.RECORD) - loader = self._LoaderForTestFile(filename) - loader.Load() - loader.Load() - self.assertEqual(len(list(loader.Load())), 1) - - def testMultipleWritesAtOnce(self): - filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name - self._WriteToFile(filename, EventFileLoaderTest.RECORD) - self._WriteToFile(filename, EventFileLoaderTest.RECORD) - loader = self._LoaderForTestFile(filename) - self.assertEqual(len(list(loader.Load())), 2) - - def testMultipleWritesWithBadWrite(self): - filename = tempfile.NamedTemporaryFile(dir=self.get_temp_dir()).name - self._WriteToFile(filename, EventFileLoaderTest.RECORD) - self._WriteToFile(filename, EventFileLoaderTest.RECORD) - # Test that we ignore partial record writes at the end of the file. - self._WriteToFile(filename, b'123') - loader = self._LoaderForTestFile(filename) - self.assertEqual(len(list(loader.Load())), 2) - - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow/tensorboard/backend/event_processing/event_multiplexer.py b/tensorflow/tensorboard/backend/event_processing/event_multiplexer.py deleted file mode 100644 index e4b8814c929..00000000000 --- a/tensorflow/tensorboard/backend/event_processing/event_multiplexer.py +++ /dev/null @@ -1,475 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Provides an interface for working with multiple event files.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import threading - -import six -import tensorflow as tf - -from tensorflow.tensorboard.backend.event_processing import directory_watcher -from tensorflow.tensorboard.backend.event_processing import event_accumulator -from tensorflow.tensorboard.backend.event_processing import io_wrapper - - -class EventMultiplexer(object): - """An `EventMultiplexer` manages access to multiple `EventAccumulator`s. - - Each `EventAccumulator` is associated with a `run`, which is a self-contained - TensorFlow execution. The `EventMultiplexer` provides methods for extracting - information about events from multiple `run`s. - - Example usage for loading specific runs from files: - - ```python - x = EventMultiplexer({'run1': 'path/to/run1', 'run2': 'path/to/run2'}) - x.Reload() - ``` - - Example usage for loading a directory where each subdirectory is a run - - ```python - (eg:) /parent/directory/path/ - /parent/directory/path/run1/ - /parent/directory/path/run1/events.out.tfevents.1001 - /parent/directory/path/run1/events.out.tfevents.1002 - - /parent/directory/path/run2/ - /parent/directory/path/run2/events.out.tfevents.9232 - - /parent/directory/path/run3/ - /parent/directory/path/run3/events.out.tfevents.9232 - x = EventMultiplexer().AddRunsFromDirectory('/parent/directory/path') - (which is equivalent to:) - x = EventMultiplexer({'run1': '/parent/directory/path/run1', 'run2':...} - ``` - - If you would like to watch `/parent/directory/path`, wait for it to be created - (if necessary) and then periodically pick up new runs, use - `AutoloadingMultiplexer` - @@Tensors - """ - - def __init__(self, - run_path_map=None, - size_guidance=event_accumulator.DEFAULT_SIZE_GUIDANCE, - purge_orphaned_data=True): - """Constructor for the `EventMultiplexer`. - - Args: - run_path_map: Dict `{run: path}` which specifies the - name of a run, and the path to find the associated events. If it is - None, then the EventMultiplexer initializes without any runs. - size_guidance: A dictionary mapping from `tagType` to the number of items - to store for each tag of that type. See - `event_accumulator.EventAccumulator` for details. - purge_orphaned_data: Whether to discard any events that were "orphaned" by - a TensorFlow restart. - """ - tf.logging.info('Event Multiplexer initializing.') - self._accumulators_mutex = threading.Lock() - self._accumulators = {} - self._paths = {} - self._reload_called = False - self._size_guidance = size_guidance - self.purge_orphaned_data = purge_orphaned_data - if run_path_map is not None: - tf.logging.info('Event Multplexer doing initialization load for %s', - run_path_map) - for (run, path) in six.iteritems(run_path_map): - self.AddRun(path, run) - tf.logging.info('Event Multiplexer done initializing') - - def AddRun(self, path, name=None): - """Add a run to the multiplexer. - - If the name is not specified, it is the same as the path. - - If a run by that name exists, and we are already watching the right path, - do nothing. If we are watching a different path, replace the event - accumulator. - - If `Reload` has been called, it will `Reload` the newly created - accumulators. - - Args: - path: Path to the event files (or event directory) for given run. - name: Name of the run to add. If not provided, is set to path. - - Returns: - The `EventMultiplexer`. - """ - if name is None or name is '': - name = path - accumulator = None - with self._accumulators_mutex: - if name not in self._accumulators or self._paths[name] != path: - if name in self._paths and self._paths[name] != path: - # TODO(danmane) - Make it impossible to overwrite an old path with - # a new path (just give the new path a distinct name) - tf.logging.warning('Conflict for name %s: old path %s, new path %s', - name, self._paths[name], path) - tf.logging.info('Constructing EventAccumulator for %s', path) - accumulator = event_accumulator.EventAccumulator( - path, - size_guidance=self._size_guidance, - purge_orphaned_data=self.purge_orphaned_data) - self._accumulators[name] = accumulator - self._paths[name] = path - if accumulator: - if self._reload_called: - accumulator.Reload() - return self - - def AddRunsFromDirectory(self, path, name=None): - """Load runs from a directory; recursively walks subdirectories. - - If path doesn't exist, no-op. This ensures that it is safe to call - `AddRunsFromDirectory` multiple times, even before the directory is made. - - If path is a directory, load event files in the directory (if any exist) and - recursively call AddRunsFromDirectory on any subdirectories. This mean you - can call AddRunsFromDirectory at the root of a tree of event logs and - TensorBoard will load them all. - - If the `EventMultiplexer` is already loaded this will cause - the newly created accumulators to `Reload()`. - Args: - path: A string path to a directory to load runs from. - name: Optionally, what name to apply to the runs. If name is provided - and the directory contains run subdirectories, the name of each subrun - is the concatenation of the parent name and the subdirectory name. If - name is provided and the directory contains event files, then a run - is added called "name" and with the events from the path. - - Raises: - ValueError: If the path exists and isn't a directory. - - Returns: - The `EventMultiplexer`. - """ - tf.logging.info('Starting AddRunsFromDirectory: %s', path) - for subdir in GetLogdirSubdirectories(path): - tf.logging.info('Adding events from directory %s', subdir) - rpath = os.path.relpath(subdir, path) - subname = os.path.join(name, rpath) if name else rpath - self.AddRun(subdir, name=subname) - tf.logging.info('Done with AddRunsFromDirectory: %s', path) - return self - - def Reload(self): - """Call `Reload` on every `EventAccumulator`.""" - tf.logging.info('Beginning EventMultiplexer.Reload()') - self._reload_called = True - # Build a list so we're safe even if the list of accumulators is modified - # even while we're reloading. - with self._accumulators_mutex: - items = list(self._accumulators.items()) - - names_to_delete = set() - for name, accumulator in items: - try: - accumulator.Reload() - except (OSError, IOError) as e: - tf.logging.error("Unable to reload accumulator '%s': %s", name, e) - except directory_watcher.DirectoryDeletedError: - names_to_delete.add(name) - - with self._accumulators_mutex: - for name in names_to_delete: - tf.logging.warning("Deleting accumulator '%s'", name) - del self._accumulators[name] - tf.logging.info('Finished with EventMultiplexer.Reload()') - return self - - def PluginAssets(self, plugin_name): - """Get index of runs and assets for a given plugin. - - Args: - plugin_name: Name of the plugin we are checking for. - - Returns: - A dictionary that maps from run_name to a list of plugin - assets for that run. - """ - with self._accumulators_mutex: - # To avoid nested locks, we construct a copy of the run-accumulator map - items = list(six.iteritems(self._accumulators)) - - return {run: accum.PluginAssets(plugin_name) for run, accum in items} - - def RetrievePluginAsset(self, run, plugin_name, asset_name): - """Return the contents for a specific plugin asset from a run. - - Args: - run: The string name of the run. - plugin_name: The string name of a plugin. - asset_name: The string name of an asset. - - Returns: - The string contents of the plugin asset. - - Raises: - KeyError: If the asset is not available. - """ - accumulator = self._GetAccumulator(run) - return accumulator.RetrievePluginAsset(plugin_name, asset_name) - - def FirstEventTimestamp(self, run): - """Return the timestamp of the first event of the given run. - - This may perform I/O if no events have been loaded yet for the run. - - Args: - run: A string name of the run for which the timestamp is retrieved. - - Returns: - The wall_time of the first event of the run, which will typically be - seconds since the epoch. - - Raises: - KeyError: If the run is not found. - ValueError: If the run has no events loaded and there are no events on - disk to load. - """ - accumulator = self._GetAccumulator(run) - return accumulator.FirstEventTimestamp() - - def Scalars(self, run, tag): - """Retrieve the scalar events associated with a run and tag. - - Args: - run: A string name of the run for which values are retrieved. - tag: A string name of the tag for which values are retrieved. - - Raises: - KeyError: If the run is not found, or the tag is not available for - the given run. - - Returns: - An array of `event_accumulator.ScalarEvents`. - """ - accumulator = self._GetAccumulator(run) - return accumulator.Scalars(tag) - - def HealthPills(self, run, node_name): - """Retrieve the health pill events associated with a run and node name. - - Args: - run: A string name of the run for which health pills are retrieved. - node_name: A string name of the node for which health pills are retrieved. - - Raises: - KeyError: If the run is not found, or the node name is not available for - the given run. - - Returns: - An array of `event_accumulator.HealthPillEvents`. - """ - accumulator = self._GetAccumulator(run) - return accumulator.HealthPills(node_name) - - def GetOpsWithHealthPills(self, run): - """Determines which ops have at least 1 health pill event for a given run. - - Args: - run: The name of the run. - - Raises: - KeyError: If the run is not found, or the node name is not available for - the given run. - - Returns: - The list of names of ops with health pill events. - """ - return self._GetAccumulator(run).GetOpsWithHealthPills() - - def Graph(self, run): - """Retrieve the graph associated with the provided run. - - Args: - run: A string name of a run to load the graph for. - - Raises: - KeyError: If the run is not found. - ValueError: If the run does not have an associated graph. - - Returns: - The `GraphDef` protobuf data structure. - """ - accumulator = self._GetAccumulator(run) - return accumulator.Graph() - - def MetaGraph(self, run): - """Retrieve the metagraph associated with the provided run. - - Args: - run: A string name of a run to load the graph for. - - Raises: - KeyError: If the run is not found. - ValueError: If the run does not have an associated graph. - - Returns: - The `MetaGraphDef` protobuf data structure. - """ - accumulator = self._GetAccumulator(run) - return accumulator.MetaGraph() - - def RunMetadata(self, run, tag): - """Get the session.run() metadata associated with a TensorFlow run and tag. - - Args: - run: A string name of a TensorFlow run. - tag: A string name of the tag associated with a particular session.run(). - - Raises: - KeyError: If the run is not found, or the tag is not available for the - given run. - - Returns: - The metadata in the form of `RunMetadata` protobuf data structure. - """ - accumulator = self._GetAccumulator(run) - return accumulator.RunMetadata(tag) - - def Histograms(self, run, tag): - """Retrieve the histogram events associated with a run and tag. - - Args: - run: A string name of the run for which values are retrieved. - tag: A string name of the tag for which values are retrieved. - - Raises: - KeyError: If the run is not found, or the tag is not available for - the given run. - - Returns: - An array of `event_accumulator.HistogramEvents`. - """ - accumulator = self._GetAccumulator(run) - return accumulator.Histograms(tag) - - def CompressedHistograms(self, run, tag): - """Retrieve the compressed histogram events associated with a run and tag. - - Args: - run: A string name of the run for which values are retrieved. - tag: A string name of the tag for which values are retrieved. - - Raises: - KeyError: If the run is not found, or the tag is not available for - the given run. - - Returns: - An array of `event_accumulator.CompressedHistogramEvents`. - """ - accumulator = self._GetAccumulator(run) - return accumulator.CompressedHistograms(tag) - - def Images(self, run, tag): - """Retrieve the image events associated with a run and tag. - - Args: - run: A string name of the run for which values are retrieved. - tag: A string name of the tag for which values are retrieved. - - Raises: - KeyError: If the run is not found, or the tag is not available for - the given run. - - Returns: - An array of `event_accumulator.ImageEvents`. - """ - accumulator = self._GetAccumulator(run) - return accumulator.Images(tag) - - def Audio(self, run, tag): - """Retrieve the audio events associated with a run and tag. - - Args: - run: A string name of the run for which values are retrieved. - tag: A string name of the tag for which values are retrieved. - - Raises: - KeyError: If the run is not found, or the tag is not available for - the given run. - - Returns: - An array of `event_accumulator.AudioEvents`. - """ - accumulator = self._GetAccumulator(run) - return accumulator.Audio(tag) - - def Tensors(self, run, tag): - """Retrieve the tensor events associated with a run and tag. - - Args: - run: A string name of the run for which values are retrieved. - tag: A string name of the tag for which values are retrieved. - - Raises: - KeyError: If the run is not found, or the tag is not available for - the given run. - - Returns: - An array of `event_accumulator.TensorEvent`s. - """ - accumulator = self._GetAccumulator(run) - return accumulator.Tensors(tag) - - def Runs(self): - """Return all the run names in the `EventMultiplexer`. - - Returns: - ``` - {runName: { images: [tag1, tag2, tag3], - scalarValues: [tagA, tagB, tagC], - histograms: [tagX, tagY, tagZ], - compressedHistograms: [tagX, tagY, tagZ], - graph: true, meta_graph: true}} - ``` - """ - with self._accumulators_mutex: - # To avoid nested locks, we construct a copy of the run-accumulator map - items = list(six.iteritems(self._accumulators)) - return {run_name: accumulator.Tags() for run_name, accumulator in items} - - def RunPaths(self): - """Returns a dict mapping run names to event file paths.""" - return self._paths - - def _GetAccumulator(self, run): - with self._accumulators_mutex: - return self._accumulators[run] - - -def GetLogdirSubdirectories(path): - """Returns subdirectories with event files on path.""" - if tf.gfile.Exists(path) and not tf.gfile.IsDirectory(path): - raise ValueError('GetLogdirSubdirectories: path exists and is not a ' - 'directory, %s' % path) - - # ListRecursively just yields nothing if the path doesn't exist. - return ( - subdir - for (subdir, files) in io_wrapper.ListRecursively(path) - if list(filter(event_accumulator.IsTensorFlowEventsFile, files)) - ) diff --git a/tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py b/tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py deleted file mode 100644 index ea536dfaad6..00000000000 --- a/tensorflow/tensorboard/backend/event_processing/event_multiplexer_test.py +++ /dev/null @@ -1,360 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools -import os -import os.path -import shutil - -import tensorflow as tf - -from tensorflow.tensorboard.backend.event_processing import event_accumulator -from tensorflow.tensorboard.backend.event_processing import event_multiplexer - - -def _AddEvents(path): - if not tf.gfile.IsDirectory(path): - tf.gfile.MakeDirs(path) - fpath = os.path.join(path, 'hypothetical.tfevents.out') - with tf.gfile.GFile(fpath, 'w') as f: - f.write('') - return fpath - - -def _CreateCleanDirectory(path): - if tf.gfile.IsDirectory(path): - tf.gfile.DeleteRecursively(path) - tf.gfile.MkDir(path) - - -class _FakeAccumulator(object): - - def __init__(self, path, health_pill_mapping=None): - """Constructs a fake accumulator with some fake events. - - Args: - path: The path for the run that this accumulator is for. - health_pill_mapping: An optional mapping from Op to health pill strings. - """ - self._path = path - self.reload_called = False - self._node_names_to_health_pills = health_pill_mapping or {} - - def Tags(self): - return {event_accumulator.IMAGES: ['im1', 'im2'], - event_accumulator.AUDIO: ['snd1', 'snd2'], - event_accumulator.HISTOGRAMS: ['hst1', 'hst2'], - event_accumulator.COMPRESSED_HISTOGRAMS: ['cmphst1', 'cmphst2'], - event_accumulator.SCALARS: ['sv1', 'sv2']} - - def FirstEventTimestamp(self): - return 0 - - def _TagHelper(self, tag_name, enum): - if tag_name not in self.Tags()[enum]: - raise KeyError - return ['%s/%s' % (self._path, tag_name)] - - def Scalars(self, tag_name): - return self._TagHelper(tag_name, event_accumulator.SCALARS) - - def HealthPills(self, node_name): - if node_name not in self._node_names_to_health_pills: - raise KeyError - health_pills = self._node_names_to_health_pills[node_name] - return [self._path + '/' + health_pill for health_pill in health_pills] - - def GetOpsWithHealthPills(self): - return self._node_names_to_health_pills.keys() - - def Histograms(self, tag_name): - return self._TagHelper(tag_name, event_accumulator.HISTOGRAMS) - - def CompressedHistograms(self, tag_name): - return self._TagHelper(tag_name, event_accumulator.COMPRESSED_HISTOGRAMS) - - def Images(self, tag_name): - return self._TagHelper(tag_name, event_accumulator.IMAGES) - - def Audio(self, tag_name): - return self._TagHelper(tag_name, event_accumulator.AUDIO) - - def Tensors(self, tag_name): - return self._TagHelper(tag_name, event_accumulator.TENSORS) - - def Reload(self): - self.reload_called = True - - -def _GetFakeAccumulator(path, - size_guidance=None, - compression_bps=None, - purge_orphaned_data=None, - health_pill_mapping=None): - del size_guidance, compression_bps, purge_orphaned_data # Unused. - return _FakeAccumulator(path, health_pill_mapping=health_pill_mapping) - - -class EventMultiplexerTest(tf.test.TestCase): - - def setUp(self): - super(EventMultiplexerTest, self).setUp() - self.stubs = tf.test.StubOutForTesting() - - self.stubs.Set(event_accumulator, 'EventAccumulator', _GetFakeAccumulator) - - def tearDown(self): - self.stubs.CleanUp() - - def testEmptyLoader(self): - """Tests empty EventMultiplexer creation.""" - x = event_multiplexer.EventMultiplexer() - self.assertEqual(x.Runs(), {}) - - def testRunNamesRespected(self): - """Tests two EventAccumulators inserted/accessed in EventMultiplexer.""" - x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) - self.assertItemsEqual(sorted(x.Runs().keys()), ['run1', 'run2']) - self.assertEqual(x._GetAccumulator('run1')._path, 'path1') - self.assertEqual(x._GetAccumulator('run2')._path, 'path2') - - def testReload(self): - """EventAccumulators should Reload after EventMultiplexer call it.""" - x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) - self.assertFalse(x._GetAccumulator('run1').reload_called) - self.assertFalse(x._GetAccumulator('run2').reload_called) - x.Reload() - self.assertTrue(x._GetAccumulator('run1').reload_called) - self.assertTrue(x._GetAccumulator('run2').reload_called) - - def testScalars(self): - """Tests Scalars function returns suitable values.""" - x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) - - run1_actual = x.Scalars('run1', 'sv1') - run1_expected = ['path1/sv1'] - - self.assertEqual(run1_expected, run1_actual) - - def testHealthPills(self): - """Tests HealthPills() returns events associated with run1/Add.""" - self.stubs.Set(event_accumulator, 'EventAccumulator', - functools.partial( - _GetFakeAccumulator, - health_pill_mapping={'Add': ['hp1', 'hp2']})) - x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) - self.assertEqual(['path1/hp1', 'path1/hp2'], x.HealthPills('run1', 'Add')) - - def testGetOpsWithHealthPillsWhenHealthPillsAreNotAvailable(self): - # The event accumulator lacks health pills for the run. - x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) - self.assertItemsEqual([], x.GetOpsWithHealthPills('run1')) - - def testGetOpsWithHealthPillsWhenHealthPillsAreAvailable(self): - # The event accumulator has health pills for the run. - self.stubs.Set(event_accumulator, 'EventAccumulator', - functools.partial( - _GetFakeAccumulator, - health_pill_mapping={'Add': ['hp1', 'hp2']})) - x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) - self.assertItemsEqual(['Add'], x.GetOpsWithHealthPills('run1')) - - def testExceptions(self): - """KeyError should be raised when accessing non-existing keys.""" - x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) - with self.assertRaises(KeyError): - x.Scalars('sv1', 'xxx') - - def testInitialization(self): - """Tests EventMultiplexer is created properly with its params.""" - x = event_multiplexer.EventMultiplexer() - self.assertEqual(x.Runs(), {}) - x = event_multiplexer.EventMultiplexer({'run1': 'path1', 'run2': 'path2'}) - self.assertItemsEqual(x.Runs(), ['run1', 'run2']) - self.assertEqual(x._GetAccumulator('run1')._path, 'path1') - self.assertEqual(x._GetAccumulator('run2')._path, 'path2') - - def testAddRunsFromDirectory(self): - """Tests AddRunsFromDirectory function. - - Tests the following scenarios: - - When the directory does not exist. - - When the directory is empty. - - When the directory has empty subdirectory. - - Contains proper EventAccumulators after adding events. - """ - x = event_multiplexer.EventMultiplexer() - tmpdir = self.get_temp_dir() - join = os.path.join - fakedir = join(tmpdir, 'fake_accumulator_directory') - realdir = join(tmpdir, 'real_accumulator_directory') - self.assertEqual(x.Runs(), {}) - x.AddRunsFromDirectory(fakedir) - self.assertEqual(x.Runs(), {}, 'loading fakedir had no effect') - - _CreateCleanDirectory(realdir) - x.AddRunsFromDirectory(realdir) - self.assertEqual(x.Runs(), {}, 'loading empty directory had no effect') - - path1 = join(realdir, 'path1') - tf.gfile.MkDir(path1) - x.AddRunsFromDirectory(realdir) - self.assertEqual(x.Runs(), {}, 'creating empty subdirectory had no effect') - - _AddEvents(path1) - x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(x.Runs(), ['path1'], 'loaded run: path1') - loader1 = x._GetAccumulator('path1') - self.assertEqual(loader1._path, path1, 'has the correct path') - - path2 = join(realdir, 'path2') - _AddEvents(path2) - x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(x.Runs(), ['path1', 'path2']) - self.assertEqual( - x._GetAccumulator('path1'), loader1, 'loader1 not regenerated') - - path2_2 = join(path2, 'path2') - _AddEvents(path2_2) - x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(x.Runs(), ['path1', 'path2', 'path2/path2']) - self.assertEqual( - x._GetAccumulator('path2/path2')._path, path2_2, 'loader2 path correct') - - def testAddRunsFromDirectoryThatContainsEvents(self): - x = event_multiplexer.EventMultiplexer() - tmpdir = self.get_temp_dir() - join = os.path.join - realdir = join(tmpdir, 'event_containing_directory') - - _CreateCleanDirectory(realdir) - - self.assertEqual(x.Runs(), {}) - - _AddEvents(realdir) - x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(x.Runs(), ['.']) - - subdir = join(realdir, 'subdir') - _AddEvents(subdir) - x.AddRunsFromDirectory(realdir) - self.assertItemsEqual(x.Runs(), ['.', 'subdir']) - - def testAddRunsFromDirectoryWithRunNames(self): - x = event_multiplexer.EventMultiplexer() - tmpdir = self.get_temp_dir() - join = os.path.join - realdir = join(tmpdir, 'event_containing_directory') - - _CreateCleanDirectory(realdir) - - self.assertEqual(x.Runs(), {}) - - _AddEvents(realdir) - x.AddRunsFromDirectory(realdir, 'foo') - self.assertItemsEqual(x.Runs(), ['foo/.']) - - subdir = join(realdir, 'subdir') - _AddEvents(subdir) - x.AddRunsFromDirectory(realdir, 'foo') - self.assertItemsEqual(x.Runs(), ['foo/.', 'foo/subdir']) - - def testAddRunsFromDirectoryWalksTree(self): - x = event_multiplexer.EventMultiplexer() - tmpdir = self.get_temp_dir() - join = os.path.join - realdir = join(tmpdir, 'event_containing_directory') - - _CreateCleanDirectory(realdir) - _AddEvents(realdir) - sub = join(realdir, 'subdirectory') - sub1 = join(sub, '1') - sub2 = join(sub, '2') - sub1_1 = join(sub1, '1') - _AddEvents(sub1) - _AddEvents(sub2) - _AddEvents(sub1_1) - x.AddRunsFromDirectory(realdir) - - self.assertItemsEqual(x.Runs(), ['.', 'subdirectory/1', 'subdirectory/2', - 'subdirectory/1/1']) - - def testAddRunsFromDirectoryThrowsException(self): - x = event_multiplexer.EventMultiplexer() - tmpdir = self.get_temp_dir() - - filepath = _AddEvents(tmpdir) - with self.assertRaises(ValueError): - x.AddRunsFromDirectory(filepath) - - def testAddRun(self): - x = event_multiplexer.EventMultiplexer() - x.AddRun('run1_path', 'run1') - run1 = x._GetAccumulator('run1') - self.assertEqual(sorted(x.Runs().keys()), ['run1']) - self.assertEqual(run1._path, 'run1_path') - - x.AddRun('run1_path', 'run1') - self.assertEqual(run1, x._GetAccumulator('run1'), 'loader not recreated') - - x.AddRun('run2_path', 'run1') - new_run1 = x._GetAccumulator('run1') - self.assertEqual(new_run1._path, 'run2_path') - self.assertNotEqual(run1, new_run1) - - x.AddRun('runName3') - self.assertItemsEqual(sorted(x.Runs().keys()), ['run1', 'runName3']) - self.assertEqual(x._GetAccumulator('runName3')._path, 'runName3') - - def testAddRunMaintainsLoading(self): - x = event_multiplexer.EventMultiplexer() - x.Reload() - x.AddRun('run1') - x.AddRun('run2') - self.assertTrue(x._GetAccumulator('run1').reload_called) - self.assertTrue(x._GetAccumulator('run2').reload_called) - - -class EventMultiplexerWithRealAccumulatorTest(tf.test.TestCase): - - def testDeletingDirectoryRemovesRun(self): - x = event_multiplexer.EventMultiplexer() - tmpdir = self.get_temp_dir() - join = os.path.join - run1_dir = join(tmpdir, 'run1') - run2_dir = join(tmpdir, 'run2') - run3_dir = join(tmpdir, 'run3') - - for dirname in [run1_dir, run2_dir, run3_dir]: - _AddEvents(dirname) - - x.AddRun(run1_dir, 'run1') - x.AddRun(run2_dir, 'run2') - x.AddRun(run3_dir, 'run3') - - x.Reload() - - # Delete the directory, then reload. - shutil.rmtree(run2_dir) - x.Reload() - self.assertNotIn('run2', x.Runs().keys()) - - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow/tensorboard/backend/event_processing/io_wrapper.py b/tensorflow/tensorboard/backend/event_processing/io_wrapper.py deleted file mode 100644 index c185f26a4fd..00000000000 --- a/tensorflow/tensorboard/backend/event_processing/io_wrapper.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""IO helper functions.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os - -import tensorflow as tf - - -def IsGCSPath(path): - return path.startswith("gs://") - - -def ListDirectoryAbsolute(directory): - """Yields all files in the given directory. The paths are absolute.""" - return (os.path.join(directory, path) - for path in tf.gfile.ListDirectory(directory)) - - -def ListRecursively(top): - """Walks a directory tree, yielding (dir_path, file_paths) tuples. - - For each of `top` and its subdirectories, yields a tuple containing the path - to the directory and the path to each of the contained files. Note that - unlike os.Walk()/tf.gfile.Walk(), this does not list subdirectories and the - file paths are all absolute. - - If the directory does not exist, this yields nothing. - - Args: - top: A path to a directory.. - Yields: - A list of (dir_path, file_paths) tuples. - """ - for dir_path, _, filenames in tf.gfile.Walk(top): - yield (dir_path, (os.path.join(dir_path, filename) - for filename in filenames)) diff --git a/tensorflow/tensorboard/backend/event_processing/plugin_asset_util.py b/tensorflow/tensorboard/backend/event_processing/plugin_asset_util.py deleted file mode 100644 index 5fb71284244..00000000000 --- a/tensorflow/tensorboard/backend/event_processing/plugin_asset_util.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Load plugin assets from disk.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os.path - - -import tensorflow as tf - - -_PLUGINS_DIR = "plugins" - - -def _IsDirectory(parent, item): - """Helper that returns if parent/item is a directory.""" - return tf.gfile.IsDirectory(os.path.join(parent, item)) - - -def PluginDirectory(logdir, plugin_name): - """Returns the plugin directory for plugin_name.""" - return os.path.join(logdir, _PLUGINS_DIR, plugin_name) - - -def ListPlugins(logdir): - """List all the plugins that have registered assets in logdir. - - If the plugins_dir does not exist, it returns an empty list. This maintains - compatibility with old directories that have no plugins written. - - Args: - logdir: A directory that was created by a TensorFlow events writer. - - Returns: - a list of plugin names, as strings - """ - plugins_dir = os.path.join(logdir, _PLUGINS_DIR) - if not tf.gfile.IsDirectory(plugins_dir): - return [] - entries = tf.gfile.ListDirectory(plugins_dir) - return [x for x in entries if _IsDirectory(plugins_dir, x)] - - -def ListAssets(logdir, plugin_name): - """List all the assets that are available for given plugin in a logdir. - - Args: - logdir: A directory that was created by a TensorFlow summary.FileWriter. - plugin_name: A string name of a plugin to list assets for. - - Returns: - A string list of available plugin assets. If the plugin subdirectory does - not exist (either because the logdir doesn't exist, or because the plugin - didn't register) an empty list is returned. - """ - plugin_dir = PluginDirectory(logdir, plugin_name) - if not tf.gfile.IsDirectory(plugin_dir): - return [] - entries = tf.gfile.ListDirectory(plugin_dir) - return [x for x in entries if not _IsDirectory(plugin_dir, x)] - - -def RetrieveAsset(logdir, plugin_name, asset_name): - """Retrieve a particular plugin asset from a logdir. - - Args: - logdir: A directory that was created by a TensorFlow summary.FileWriter. - plugin_name: The plugin we want an asset from. - asset_name: The name of the requested asset. - - Returns: - string contents of the plugin asset. - - Raises: - KeyError: if the asset does not exist. - """ - - asset_path = os.path.join(PluginDirectory(logdir, plugin_name), asset_name) - try: - with tf.gfile.Open(asset_path, "r") as f: - return f.read() - except tf.errors.NotFoundError: - raise KeyError("Asset path %s not found" % asset_path) - except tf.errors.OpError as e: - raise KeyError("Couldn't read asset path: %s, OpError %s" % (asset_path, e)) diff --git a/tensorflow/tensorboard/backend/event_processing/reservoir.py b/tensorflow/tensorboard/backend/event_processing/reservoir.py deleted file mode 100644 index 0a1252e6352..00000000000 --- a/tensorflow/tensorboard/backend/event_processing/reservoir.py +++ /dev/null @@ -1,253 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""A key-value[] store that implements reservoir sampling on the values.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import random -import threading - - -class Reservoir(object): - """A map-to-arrays container, with deterministic Reservoir Sampling. - - Items are added with an associated key. Items may be retrieved by key, and - a list of keys can also be retrieved. If size is not zero, then it dictates - the maximum number of items that will be stored with each key. Once there are - more items for a given key, they are replaced via reservoir sampling, such - that each item has an equal probability of being included in the sample. - - Deterministic means that for any given seed and bucket size, the sequence of - values that are kept for any given tag will always be the same, and that this - is independent of any insertions on other tags. That is: - - >>> separate_reservoir = reservoir.Reservoir(10) - >>> interleaved_reservoir = reservoir.Reservoir(10) - >>> for i in xrange(100): - >>> separate_reservoir.AddItem('key1', i) - >>> for i in xrange(100): - >>> separate_reservoir.AddItem('key2', i) - >>> for i in xrange(100): - >>> interleaved_reservoir.AddItem('key1', i) - >>> interleaved_reservoir.AddItem('key2', i) - - separate_reservoir and interleaved_reservoir will be in identical states. - - See: https://en.wikipedia.org/wiki/Reservoir_sampling - - Adding items has amortized O(1) runtime. - - """ - - def __init__(self, size, seed=0, always_keep_last=True): - """Creates a new reservoir. - - Args: - size: The number of values to keep in the reservoir for each tag. If 0, - all values will be kept. - seed: The seed of the random number generator to use when sampling. - Different values for |seed| will produce different samples from the same - input items. - always_keep_last: Whether to always keep the latest seen item in the - end of the reservoir. Defaults to True. - - Raises: - ValueError: If size is negative or not an integer. - """ - if size < 0 or size != round(size): - raise ValueError('size must be nonegative integer, was %s' % size) - self._buckets = collections.defaultdict( - lambda: _ReservoirBucket(size, random.Random(seed), always_keep_last)) - # _mutex guards the keys - creating new keys, retrieving by key, etc - # the internal items are guarded by the ReservoirBuckets' internal mutexes - self._mutex = threading.Lock() - - def Keys(self): - """Return all the keys in the reservoir. - - Returns: - ['list', 'of', 'keys'] in the Reservoir. - """ - with self._mutex: - return list(self._buckets.keys()) - - def Items(self, key): - """Return items associated with given key. - - Args: - key: The key for which we are finding associated items. - - Raises: - KeyError: If the key is not found in the reservoir. - - Returns: - [list, of, items] associated with that key. - """ - with self._mutex: - if key not in self._buckets: - raise KeyError('Key %s was not found in Reservoir' % key) - bucket = self._buckets[key] - return bucket.Items() - - def AddItem(self, key, item, f=lambda x: x): - """Add a new item to the Reservoir with the given tag. - - If the reservoir has not yet reached full size, the new item is guaranteed - to be added. If the reservoir is full, then behavior depends on the - always_keep_last boolean. - - If always_keep_last was set to true, the new item is guaranteed to be added - to the reservoir, and either the previous last item will be replaced, or - (with low probability) an older item will be replaced. - - If always_keep_last was set to false, then the new item will replace an - old item with low probability. - - If f is provided, it will be applied to transform item (lazily, iff item is - going to be included in the reservoir). - - Args: - key: The key to store the item under. - item: The item to add to the reservoir. - f: An optional function to transform the item prior to addition. - """ - with self._mutex: - bucket = self._buckets[key] - bucket.AddItem(item, f) - - def FilterItems(self, filterFn, key=None): - """Filter items within a Reservoir, using a filtering function. - - Args: - filterFn: A function that returns True for the items to be kept. - key: An optional bucket key to filter. If not specified, will filter all - all buckets. - - Returns: - The number of items removed. - """ - with self._mutex: - if key: - if key in self._buckets: - return self._buckets[key].FilterItems(filterFn) - else: - return 0 - else: - return sum(bucket.FilterItems(filterFn) - for bucket in self._buckets.values()) - - -class _ReservoirBucket(object): - """A container for items from a stream, that implements reservoir sampling. - - It always stores the most recent item as its final item. - """ - - def __init__(self, _max_size, _random=None, always_keep_last=True): - """Create the _ReservoirBucket. - - Args: - _max_size: The maximum size the reservoir bucket may grow to. If size is - zero, the bucket has unbounded size. - _random: The random number generator to use. If not specified, defaults to - random.Random(0). - always_keep_last: Whether the latest seen item should always be included - in the end of the bucket. - - Raises: - ValueError: if the size is not a nonnegative integer. - """ - if _max_size < 0 or _max_size != round(_max_size): - raise ValueError('_max_size must be nonegative int, was %s' % _max_size) - self.items = [] - # This mutex protects the internal items, ensuring that calls to Items and - # AddItem are thread-safe - self._mutex = threading.Lock() - self._max_size = _max_size - self._num_items_seen = 0 - if _random is not None: - self._random = _random - else: - self._random = random.Random(0) - self.always_keep_last = always_keep_last - - def AddItem(self, item, f=lambda x: x): - """Add an item to the ReservoirBucket, replacing an old item if necessary. - - The new item is guaranteed to be added to the bucket, and to be the last - element in the bucket. If the bucket has reached capacity, then an old item - will be replaced. With probability (_max_size/_num_items_seen) a random item - in the bucket will be popped out and the new item will be appended - to the end. With probability (1 - _max_size/_num_items_seen) - the last item in the bucket will be replaced. - - Since the O(n) replacements occur with O(1/_num_items_seen) likelihood, - the amortized runtime is O(1). - - Args: - item: The item to add to the bucket. - f: A function to transform item before addition, if it will be kept in - the reservoir. - """ - with self._mutex: - if len(self.items) < self._max_size or self._max_size == 0: - self.items.append(f(item)) - else: - r = self._random.randint(0, self._num_items_seen) - if r < self._max_size: - self.items.pop(r) - self.items.append(f(item)) - elif self.always_keep_last: - self.items[-1] = f(item) - self._num_items_seen += 1 - - def FilterItems(self, filterFn): - """Filter items in a ReservoirBucket, using a filtering function. - - Filtering items from the reservoir bucket must update the - internal state variable self._num_items_seen, which is used for determining - the rate of replacement in reservoir sampling. Ideally, self._num_items_seen - would contain the exact number of items that have ever seen by the - ReservoirBucket and satisfy filterFn. However, the ReservoirBucket does not - have access to all items seen -- it only has access to the subset of items - that have survived sampling (self.items). Therefore, we estimate - self._num_items_seen by scaling it by the same ratio as the ratio of items - not removed from self.items. - - Args: - filterFn: A function that returns True for items to be kept. - - Returns: - The number of items removed from the bucket. - """ - with self._mutex: - size_before = len(self.items) - self.items = list(filter(filterFn, self.items)) - size_diff = size_before - len(self.items) - - # Estimate a correction the number of items seen - prop_remaining = len(self.items) / float( - size_before) if size_before > 0 else 0 - self._num_items_seen = int(round(self._num_items_seen * prop_remaining)) - return size_diff - - def Items(self): - """Get all the items in the bucket.""" - with self._mutex: - return list(self.items) diff --git a/tensorflow/tensorboard/backend/event_processing/reservoir_test.py b/tensorflow/tensorboard/backend/event_processing/reservoir_test.py deleted file mode 100644 index df4757e2454..00000000000 --- a/tensorflow/tensorboard/backend/event_processing/reservoir_test.py +++ /dev/null @@ -1,279 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from six.moves import xrange # pylint: disable=redefined-builtin -import tensorflow as tf - -from tensorflow.tensorboard.backend.event_processing import reservoir - - -class ReservoirTest(tf.test.TestCase): - - def testEmptyReservoir(self): - r = reservoir.Reservoir(1) - self.assertFalse(r.Keys()) - - def testRespectsSize(self): - r = reservoir.Reservoir(42) - self.assertEqual(r._buckets['meaning of life']._max_size, 42) - - def testItemsAndKeys(self): - r = reservoir.Reservoir(42) - r.AddItem('foo', 4) - r.AddItem('bar', 9) - r.AddItem('foo', 19) - self.assertItemsEqual(r.Keys(), ['foo', 'bar']) - self.assertEqual(r.Items('foo'), [4, 19]) - self.assertEqual(r.Items('bar'), [9]) - - def testExceptions(self): - with self.assertRaises(ValueError): - reservoir.Reservoir(-1) - with self.assertRaises(ValueError): - reservoir.Reservoir(13.3) - - r = reservoir.Reservoir(12) - with self.assertRaises(KeyError): - r.Items('missing key') - - def testDeterminism(self): - """Tests that the reservoir is deterministic.""" - key = 'key' - r1 = reservoir.Reservoir(10) - r2 = reservoir.Reservoir(10) - for i in xrange(100): - r1.AddItem('key', i) - r2.AddItem('key', i) - - self.assertEqual(r1.Items(key), r2.Items(key)) - - def testBucketDeterminism(self): - """Tests that reservoirs are deterministic at a bucket level. - - This means that only the order elements are added within a bucket matters. - """ - separate_reservoir = reservoir.Reservoir(10) - interleaved_reservoir = reservoir.Reservoir(10) - for i in xrange(100): - separate_reservoir.AddItem('key1', i) - for i in xrange(100): - separate_reservoir.AddItem('key2', i) - for i in xrange(100): - interleaved_reservoir.AddItem('key1', i) - interleaved_reservoir.AddItem('key2', i) - - for key in ['key1', 'key2']: - self.assertEqual( - separate_reservoir.Items(key), interleaved_reservoir.Items(key)) - - def testUsesSeed(self): - """Tests that reservoirs with different seeds keep different samples.""" - key = 'key' - r1 = reservoir.Reservoir(10, seed=0) - r2 = reservoir.Reservoir(10, seed=1) - for i in xrange(100): - r1.AddItem('key', i) - r2.AddItem('key', i) - self.assertNotEqual(r1.Items(key), r2.Items(key)) - - def testFilterItemsByKey(self): - r = reservoir.Reservoir(100, seed=0) - for i in xrange(10): - r.AddItem('key1', i) - r.AddItem('key2', i) - - self.assertEqual(len(r.Items('key1')), 10) - self.assertEqual(len(r.Items('key2')), 10) - - self.assertEqual(r.FilterItems(lambda x: x <= 7, 'key2'), 2) - self.assertEqual(len(r.Items('key2')), 8) - self.assertEqual(len(r.Items('key1')), 10) - - self.assertEqual(r.FilterItems(lambda x: x <= 3, 'key1'), 6) - self.assertEqual(len(r.Items('key1')), 4) - self.assertEqual(len(r.Items('key2')), 8) - - -class ReservoirBucketTest(tf.test.TestCase): - - def testEmptyBucket(self): - b = reservoir._ReservoirBucket(1) - self.assertFalse(b.Items()) - - def testFillToSize(self): - b = reservoir._ReservoirBucket(100) - for i in xrange(100): - b.AddItem(i) - self.assertEqual(b.Items(), list(xrange(100))) - self.assertEqual(b._num_items_seen, 100) - - def testDoesntOverfill(self): - b = reservoir._ReservoirBucket(10) - for i in xrange(1000): - b.AddItem(i) - self.assertEqual(len(b.Items()), 10) - self.assertEqual(b._num_items_seen, 1000) - - def testMaintainsOrder(self): - b = reservoir._ReservoirBucket(100) - for i in xrange(10000): - b.AddItem(i) - items = b.Items() - prev = -1 - for item in items: - self.assertTrue(item > prev) - prev = item - - def testKeepsLatestItem(self): - b = reservoir._ReservoirBucket(5) - for i in xrange(100): - b.AddItem(i) - last = b.Items()[-1] - self.assertEqual(last, i) - - def testSizeOneBucket(self): - b = reservoir._ReservoirBucket(1) - for i in xrange(20): - b.AddItem(i) - self.assertEqual(b.Items(), [i]) - self.assertEqual(b._num_items_seen, 20) - - def testSizeZeroBucket(self): - b = reservoir._ReservoirBucket(0) - for i in xrange(20): - b.AddItem(i) - self.assertEqual(b.Items(), list(range(i + 1))) - self.assertEqual(b._num_items_seen, 20) - - def testSizeRequirement(self): - with self.assertRaises(ValueError): - reservoir._ReservoirBucket(-1) - with self.assertRaises(ValueError): - reservoir._ReservoirBucket(10.3) - - def testRemovesItems(self): - b = reservoir._ReservoirBucket(100) - for i in xrange(10): - b.AddItem(i) - self.assertEqual(len(b.Items()), 10) - self.assertEqual(b._num_items_seen, 10) - self.assertEqual(b.FilterItems(lambda x: x <= 7), 2) - self.assertEqual(len(b.Items()), 8) - self.assertEqual(b._num_items_seen, 8) - - def testRemovesItemsWhenItemsAreReplaced(self): - b = reservoir._ReservoirBucket(100) - for i in xrange(10000): - b.AddItem(i) - self.assertEqual(b._num_items_seen, 10000) - - # Remove items - num_removed = b.FilterItems(lambda x: x <= 7) - self.assertGreater(num_removed, 92) - self.assertEqual([], [item for item in b.Items() if item > 7]) - self.assertEqual(b._num_items_seen, - int(round(10000 * (1 - float(num_removed) / 100)))) - - def testLazyFunctionEvaluationAndAlwaysKeepLast(self): - - class FakeRandom(object): - - def randint(self, a, b): # pylint:disable=unused-argument - return 999 - - class Incrementer(object): - - def __init__(self): - self.n = 0 - - def increment_and_double(self, x): - self.n += 1 - return x * 2 - - # We've mocked the randomness generator, so that once it is full, the last - # item will never get durable reservoir inclusion. Since always_keep_last is - # false, the function should only get invoked 100 times while filling up - # the reservoir. This laziness property is an essential performance - # optimization. - b = reservoir._ReservoirBucket(100, FakeRandom(), always_keep_last=False) - incrementer = Incrementer() - for i in xrange(1000): - b.AddItem(i, incrementer.increment_and_double) - self.assertEqual(incrementer.n, 100) - self.assertEqual(b.Items(), [x * 2 for x in xrange(100)]) - - # This time, we will always keep the last item, meaning that the function - # should get invoked once for every item we add. - b = reservoir._ReservoirBucket(100, FakeRandom(), always_keep_last=True) - incrementer = Incrementer() - - for i in xrange(1000): - b.AddItem(i, incrementer.increment_and_double) - self.assertEqual(incrementer.n, 1000) - self.assertEqual(b.Items(), [x * 2 for x in xrange(99)] + [999 * 2]) - - -class ReservoirBucketStatisticalDistributionTest(tf.test.TestCase): - - def setUp(self): - self.total = 1000000 - self.samples = 10000 - self.n_buckets = 100 - self.total_per_bucket = self.total // self.n_buckets - self.assertEqual(self.total % self.n_buckets, 0, 'total must be evenly ' - 'divisible by the number of buckets') - self.assertTrue(self.total > self.samples, 'need to have more items ' - 'than samples') - - def AssertBinomialQuantity(self, measured): - p = 1.0 * self.n_buckets / self.samples - mean = p * self.samples - variance = p * (1 - p) * self.samples - error = measured - mean - # Given that the buckets were actually binomially distributed, this - # fails with probability ~2E-9 - passed = error * error <= 36.0 * variance - self.assertTrue(passed, 'found a bucket with measured %d ' - 'too far from expected %d' % (measured, mean)) - - def testBucketReservoirSamplingViaStatisticalProperties(self): - # Not related to a 'ReservoirBucket', but instead number of buckets we put - # samples into for testing the shape of the distribution - b = reservoir._ReservoirBucket(_max_size=self.samples) - # add one extra item because we always keep the most recent item, which - # would skew the distribution; we can just slice it off the end instead. - for i in xrange(self.total + 1): - b.AddItem(i) - - divbins = [0] * self.n_buckets - modbins = [0] * self.n_buckets - # Slice off the last item when we iterate. - for item in b.Items()[0:-1]: - divbins[item // self.total_per_bucket] += 1 - modbins[item % self.n_buckets] += 1 - - for bucket_index in xrange(self.n_buckets): - divbin = divbins[bucket_index] - modbin = modbins[bucket_index] - self.AssertBinomialQuantity(divbin) - self.AssertBinomialQuantity(modbin) - - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow/tensorboard/backend/http_util.py b/tensorflow/tensorboard/backend/http_util.py deleted file mode 100644 index 81a06a5f14c..00000000000 --- a/tensorflow/tensorboard/backend/http_util.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""TensorBoard HTTP utilities.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - -import gzip -import json -import re -import time -import wsgiref.handlers - -import six -import tensorflow as tf -from werkzeug import wrappers - -from tensorflow.tensorboard.backend import json_util - - -_EXTRACT_MIMETYPE_PATTERN = re.compile(r'^[^;\s]*') -_EXTRACT_CHARSET_PATTERN = re.compile(r'charset=([-_0-9A-Za-z]+)') - -# Allows *, gzip or x-gzip, but forbid gzip;q=0 -# https://tools.ietf.org/html/rfc7231#section-5.3.4 -_ALLOWS_GZIP_PATTERN = re.compile( - r'(?:^|,|\s)(?:(?:x-)?gzip|\*)(?!;q=0)(?:\s|,|$)') - -_TEXTUAL_MIMETYPES = set([ - 'application/javascript', - 'application/json', - 'application/json+protobuf', - 'image/svg+xml', - 'text/css', - 'text/csv', - 'text/html', - 'text/plain', - 'text/tab-separated-values', - 'text/x-protobuf', -]) - -_JSON_MIMETYPES = set([ - 'application/json', - 'application/json+protobuf', -]) - - -def Respond(request, - content, - content_type, - code=200, - expires=0, - content_encoding=None, - encoding='utf-8'): - """Construct a werkzeug Response. - - Responses are transmitted to the browser with compression if: a) the browser - supports it; b) it's sane to compress the content_type in question; and c) - the content isn't already compressed, as indicated by the content_encoding - parameter. - - Browser and proxy caching is completely disabled by default. If the expires - parameter is greater than zero then the response will be able to be cached by - the browser for that many seconds; however, proxies are still forbidden from - caching so that developers can bypass the cache with Ctrl+Shift+R. - - For textual content that isn't JSON, the encoding parameter is used as the - transmission charset which is automatically appended to the Content-Type - header. That is unless of course the content_type parameter contains a - charset parameter. If the two disagree, the characters in content will be - transcoded to the latter. - - If content_type declares a JSON media type, then content MAY be a dict, list, - tuple, or set, in which case this function has an implicit composition with - json_util.Cleanse and json.dumps. The encoding parameter is used to decode - byte strings within the JSON object; therefore transmitting binary data - within JSON is not permitted. JSON is transmitted as ASCII unless the - content_type parameter explicitly defines a charset parameter, in which case - the serialized JSON bytes will use that instead of escape sequences. - - Args: - request: A werkzeug Request object. Used mostly to check the - Accept-Encoding header. - content: Payload data as byte string, unicode string, or maybe JSON. - content_type: Media type and optionally an output charset. - code: Numeric HTTP status code to use. - expires: Second duration for browser caching. - content_encoding: Encoding if content is already encoded, e.g. 'gzip'. - encoding: Input charset if content parameter has byte strings. - - Returns: - A werkzeug Response object (a WSGI application). - """ - - mimetype = _EXTRACT_MIMETYPE_PATTERN.search(content_type).group(0) - charset_match = _EXTRACT_CHARSET_PATTERN.search(content_type) - charset = charset_match.group(1) if charset_match else encoding - textual = charset_match or mimetype in _TEXTUAL_MIMETYPES - if mimetype in _JSON_MIMETYPES and (isinstance(content, dict) or - isinstance(content, list) or - isinstance(content, set) or - isinstance(content, tuple)): - content = json.dumps(json_util.Cleanse(content, encoding), - ensure_ascii=not charset_match) - if charset != encoding: - content = tf.compat.as_text(content, encoding) - content = tf.compat.as_bytes(content, charset) - if textual and not charset_match and mimetype not in _JSON_MIMETYPES: - content_type += '; charset=' + charset - if (not content_encoding and textual and - _ALLOWS_GZIP_PATTERN.search(request.headers.get('Accept-Encoding', ''))): - out = six.BytesIO() - f = gzip.GzipFile(fileobj=out, mode='wb', compresslevel=3) - f.write(content) - f.close() - content = out.getvalue() - content_encoding = 'gzip' - if request.method == 'HEAD': - content = '' - headers = [] - - headers.append(('Content-Length', str(len(content)))) - if content_encoding: - headers.append(('Content-Encoding', content_encoding)) - if expires > 0: - e = wsgiref.handlers.format_date_time(time.time() + float(expires)) - headers.append(('Expires', e)) - headers.append(('Cache-Control', 'private, max-age=%d' % expires)) - else: - headers.append(('Expires', '0')) - headers.append(('Cache-Control', 'no-cache, must-revalidate')) - - return wrappers.Response( - response=content, status=code, headers=headers, content_type=content_type) diff --git a/tensorflow/tensorboard/backend/http_util_test.py b/tensorflow/tensorboard/backend/http_util_test.py deleted file mode 100644 index 6b0c8d3403b..00000000000 --- a/tensorflow/tensorboard/backend/http_util_test.py +++ /dev/null @@ -1,156 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests HTTP utilities.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals - -import gzip - -import six -import tensorflow as tf -from werkzeug import test as wtest -from werkzeug import wrappers -from tensorflow.tensorboard.backend import http_util - - -class RespondTest(tf.test.TestCase): - - def testHelloWorld(self): - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, 'hello world', 'text/html') - self.assertEqual(r.status_code, 200) - self.assertEqual(r.response[0], six.b('hello world')) - - def testHeadRequest_doesNotWrite(self): - builder = wtest.EnvironBuilder(method='HEAD') - env = builder.get_environ() - request = wrappers.Request(env) - r = http_util.Respond(request, 'hello world', 'text/html') - self.assertEqual(r.status_code, 200) - self.assertEqual(r.response[0], six.b('')) - - def testPlainText_appendsUtf8ToContentType(self): - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, 'hello', 'text/plain') - h = r.headers - self.assertEqual(h.get('Content-Type'), 'text/plain; charset=utf-8') - - def testContentLength_isInBytes(self): - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, '爱', 'text/plain') - self.assertEqual(r.headers.get('Content-Length'), '3') - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, '爱'.encode('utf-8'), 'text/plain') - self.assertEqual(r.headers.get('Content-Length'), '3') - - def testResponseCharsetTranscoding(self): - bean = '要依法治国是赞美那些谁是公义的和惩罚恶人。 - 韩非' - - # input is unicode string, output is gbk string - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, bean, 'text/plain; charset=gbk') - self.assertEqual(r.response[0], bean.encode('gbk')) - - # input is utf-8 string, output is gbk string - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, bean.encode('utf-8'), 'text/plain; charset=gbk') - self.assertEqual(r.response[0], bean.encode('gbk')) - - # input is object with unicode strings, output is gbk json - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, {'red': bean}, 'application/json; charset=gbk') - self.assertEqual(r.response[0], b'{"red": "' + bean.encode('gbk') + b'"}') - - # input is object with utf-8 strings, output is gbk json - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond( - q, {'red': bean.encode('utf-8')}, 'application/json; charset=gbk') - self.assertEqual(r.response[0], b'{"red": "' + bean.encode('gbk') + b'"}') - - # input is object with gbk strings, output is gbk json - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond( - q, {'red': bean.encode('gbk')}, - 'application/json; charset=gbk', - encoding='gbk') - self.assertEqual(r.response[0], b'{"red": "' + bean.encode('gbk') + b'"}') - - def testAcceptGzip_compressesResponse(self): - fall_of_hyperion_canto1_stanza1 = '\n'.join([ - 'Fanatics have their dreams, wherewith they weave', - 'A paradise for a sect; the savage too', - 'From forth the loftiest fashion of his sleep', - 'Guesses at Heaven; pity these have not', - 'Trac\'d upon vellum or wild Indian leaf', - 'The shadows of melodious utterance.', - 'But bare of laurel they live, dream, and die;', - 'For Poesy alone can tell her dreams,', - 'With the fine spell of words alone can save', - 'Imagination from the sable charm', - 'And dumb enchantment. Who alive can say,', - '\'Thou art no Poet may\'st not tell thy dreams?\'', - 'Since every man whose soul is not a clod', - 'Hath visions, and would speak, if he had loved', - 'And been well nurtured in his mother tongue.', - 'Whether the dream now purpos\'d to rehearse', - 'Be poet\'s or fanatic\'s will be known', - 'When this warm scribe my hand is in the grave.', - ]) - - e1 = wtest.EnvironBuilder(headers={'Accept-Encoding': '*'}).get_environ() - any_encoding = wrappers.Request(e1) - - r = http_util.Respond( - any_encoding, fall_of_hyperion_canto1_stanza1, 'text/plain') - self.assertEqual(r.headers.get('Content-Encoding'), 'gzip') - - self.assertEqual( - _gunzip(r.response[0]), fall_of_hyperion_canto1_stanza1.encode('utf-8')) - - e2 = wtest.EnvironBuilder(headers={'Accept-Encoding': 'gzip'}).get_environ() - gzip_encoding = wrappers.Request(e2) - - r = http_util.Respond( - gzip_encoding, fall_of_hyperion_canto1_stanza1, 'text/plain') - self.assertEqual(r.headers.get('Content-Encoding'), 'gzip') - self.assertEqual( - _gunzip(r.response[0]), fall_of_hyperion_canto1_stanza1.encode('utf-8')) - - r = http_util.Respond( - any_encoding, fall_of_hyperion_canto1_stanza1, 'image/png') - self.assertEqual( - r.response[0], fall_of_hyperion_canto1_stanza1.encode('utf-8')) - - def testJson_getsAutoSerialized(self): - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, [1, 2, 3], 'application/json') - self.assertEqual(r.response[0], b'[1, 2, 3]') - - def testExpires_setsCruiseControl(self): - q = wrappers.Request(wtest.EnvironBuilder().get_environ()) - r = http_util.Respond(q, 'hello world', 'text/html', expires=60) - self.assertEqual(r.headers.get('Cache-Control'), 'private, max-age=60') - - -def _gunzip(bs): - return gzip.GzipFile('', 'rb', 9, six.BytesIO(bs)).read() - - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow/tensorboard/backend/json_util.py b/tensorflow/tensorboard/backend/json_util.py deleted file mode 100644 index ab8f34a2fb9..00000000000 --- a/tensorflow/tensorboard/backend/json_util.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""A module providing a function for serializing JSON values with Infinity. - -Python provides no way to override how json.dumps serializes -Infinity/-Infinity/NaN; if allow_nan is true, it encodes them as -Infinity/-Infinity/NaN, in violation of the JSON spec and in violation of what -JSON.parse accepts. If it's false, it throws a ValueError, Neither subclassing -JSONEncoder nor passing a function in the |default| keyword argument overrides -this. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import math - -import tensorflow as tf - - -_INFINITY = float('inf') -_NEGATIVE_INFINITY = float('-inf') - - -def Cleanse(obj, encoding='utf-8'): - """Makes Python object appropriate for JSON serialization. - - - Replaces instances of Infinity/-Infinity/NaN with strings. - - Turns byte strings into unicode strings. - - Turns sets into sorted lists. - - Turns tuples into lists. - - Args: - obj: Python data structure. - encoding: Charset used to decode byte strings. - - Returns: - Unicode JSON data structure. - """ - if isinstance(obj, int): - return obj - elif isinstance(obj, float): - if obj == _INFINITY: - return 'Infinity' - elif obj == _NEGATIVE_INFINITY: - return '-Infinity' - elif math.isnan(obj): - return 'NaN' - else: - return obj - elif isinstance(obj, bytes): - return tf.compat.as_text(obj, encoding) - elif isinstance(obj, list) or isinstance(obj, tuple): - return [Cleanse(i, encoding) for i in obj] - elif isinstance(obj, set): - return [Cleanse(i, encoding) for i in sorted(obj)] - elif isinstance(obj, dict): - return {Cleanse(k, encoding): Cleanse(v, encoding) for k, v in obj.items()} - else: - return obj diff --git a/tensorflow/tensorboard/backend/json_util_test.py b/tensorflow/tensorboard/backend/json_util_test.py deleted file mode 100644 index 22e815564e4..00000000000 --- a/tensorflow/tensorboard/backend/json_util_test.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - - -from tensorflow.tensorboard.backend import json_util - -_INFINITY = float('inf') - - -class FloatWrapperTest(tf.test.TestCase): - - def _assertWrapsAs(self, to_wrap, expected): - """Asserts that |to_wrap| becomes |expected| when wrapped.""" - actual = json_util.Cleanse(to_wrap) - for a, e in zip(actual, expected): - self.assertEqual(e, a) - - def testWrapsPrimitives(self): - self._assertWrapsAs(_INFINITY, 'Infinity') - self._assertWrapsAs(-_INFINITY, '-Infinity') - self._assertWrapsAs(float('nan'), 'NaN') - - def testWrapsObjectValues(self): - self._assertWrapsAs({'x': _INFINITY}, {'x': 'Infinity'}) - - def testWrapsObjectKeys(self): - self._assertWrapsAs({_INFINITY: 'foo'}, {'Infinity': 'foo'}) - - def testWrapsInListsAndTuples(self): - self._assertWrapsAs([_INFINITY], ['Infinity']) - # map() returns a list even if the argument is a tuple. - self._assertWrapsAs((_INFINITY,), ['Infinity',]) - - def testWrapsRecursively(self): - self._assertWrapsAs({'x': [_INFINITY]}, {'x': ['Infinity']}) - - def testTuple_turnsIntoList(self): - self.assertEqual(json_util.Cleanse(('a', 'b')), ['a', 'b']) - - def testSet_turnsIntoSortedList(self): - self.assertEqual(json_util.Cleanse(set(['b', 'a'])), ['a', 'b']) - - def testByteString_turnsIntoUnicodeString(self): - self.assertEqual(json_util.Cleanse(b'\xc2\xa3'), u'\u00a3') # is # sterling - - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow/tensorboard/backend/process_graph.py b/tensorflow/tensorboard/backend/process_graph.py deleted file mode 100644 index 2b314d79cb1..00000000000 --- a/tensorflow/tensorboard/backend/process_graph.py +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Graph post-processing logic. Used by both TensorBoard and mldash.""" - - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -import tensorflow as tf - - -def prepare_graph_for_ui(graph, limit_attr_size=1024, - large_attrs_key='_too_large_attrs'): - """Prepares (modifies in-place) the graph to be served to the front-end. - - For now, it supports filtering out attributes that are - too large to be shown in the graph UI. - - Args: - graph: The GraphDef proto message. - limit_attr_size: Maximum allowed size in bytes, before the attribute - is considered large. Default is 1024 (1KB). Must be > 0 or None. - If None, there will be no filtering. - large_attrs_key: The attribute key that will be used for storing attributes - that are too large. Default is '_too_large_attrs'. Must be != None if - `limit_attr_size` is != None. - - Raises: - ValueError: If `large_attrs_key is None` while `limit_attr_size != None`. - ValueError: If `limit_attr_size` is defined, but <= 0. - """ - # Check input for validity. - if limit_attr_size is not None: - if large_attrs_key is None: - raise ValueError('large_attrs_key must be != None when limit_attr_size' - '!= None.') - - if limit_attr_size <= 0: - raise ValueError('limit_attr_size must be > 0, but is %d' % - limit_attr_size) - - # Filter only if a limit size is defined. - if limit_attr_size is not None: - for node in graph.node: - # Go through all the attributes and filter out ones bigger than the - # limit. - keys = list(node.attr.keys()) - for key in keys: - size = node.attr[key].ByteSize() - if size > limit_attr_size or size < 0: - del node.attr[key] - # Add the attribute key to the list of "too large" attributes. - # This is used in the info card in the graph UI to show the user - # that some attributes are too large to be shown. - node.attr[large_attrs_key].list.s.append(tf.compat.as_bytes(key)) diff --git a/tensorflow/tensorboard/components/BUILD b/tensorflow/tensorboard/components/BUILD deleted file mode 100644 index 2d7613dbfdc..00000000000 --- a/tensorflow/tensorboard/components/BUILD +++ /dev/null @@ -1,47 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") -load("//tensorflow/tensorboard/defs:vulcanize.bzl", "tensorboard_html_binary") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tensorboard", - srcs = [ - "analytics.html", - "tensorboard.html", - ], - path = "/", - deps = ["//tensorflow/tensorboard/components/tf_tensorboard"], -) - -tensorboard_html_binary( - name = "index", - input_path = "/tensorboard.html", - output_path = "/index.html", - deps = [":tensorboard"], -) - -ts_web_library( - name = "trace_viewer", - srcs = [ - "trace_viewer.html", - ], - path = "/", - deps = [ - "//tensorflow/tensorboard/components/tf_trace_viewer", - ], -) - -tensorboard_html_binary( - name = "trace_viewer_index", - input_path = "/trace_viewer.html", - output_path = "/trace_viewer_index.html", - deps = [":trace_viewer"], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/analytics.html b/tensorflow/tensorboard/components/analytics.html deleted file mode 100644 index d319f576fc1..00000000000 --- a/tensorflow/tensorboard/components/analytics.html +++ /dev/null @@ -1,18 +0,0 @@ - - - diff --git a/tensorflow/tensorboard/components/tensorboard.html b/tensorflow/tensorboard/components/tensorboard.html deleted file mode 100644 index afaf396614f..00000000000 --- a/tensorflow/tensorboard/components/tensorboard.html +++ /dev/null @@ -1,26 +0,0 @@ - - - - -TensorBoard - - - - - - diff --git a/tensorflow/tensorboard/components/tf_audio_dashboard/BUILD b/tensorflow/tensorboard/components/tf_audio_dashboard/BUILD deleted file mode 100644 index 3bc754063c7..00000000000 --- a/tensorflow/tensorboard/components/tf_audio_dashboard/BUILD +++ /dev/null @@ -1,50 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_audio_dashboard", - srcs = [ - "tf-audio-dashboard.html", - "tf-audio-grid.html", - "tf-audio-loader.html", - ], - path = "/tf-audio-dashboard", - deps = [ - "//tensorflow/tensorboard/components/tf_backend", - "//tensorflow/tensorboard/components/tf_dashboard_common", - "//tensorflow/tensorboard/components/tf_imports:d3", - "//tensorflow/tensorboard/components/tf_imports:lodash", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "@org_polymer_paper_icon_button", - "@org_polymer_paper_slider", - "@org_polymer_paper_spinner", - "@org_polymer_paper_styles", - ], -) - -ts_web_library( - name = "index", - srcs = [ - "demo/index.html", - "index.html", - ], - path = "/tf-audio-dashboard", - deps = [ - ":tf_audio_dashboard", - "//tensorflow/tensorboard/components/tf_imports:d3", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - "//tensorflow/tensorboard/demo:demo_data", - "@org_polymer_iron_component_page", - "@org_polymer_iron_demo_helpers", - "@org_polymer_paper_styles", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_audio_dashboard/demo/index.html b/tensorflow/tensorboard/components/tf_audio_dashboard/demo/index.html deleted file mode 100644 index a1d7e968e8f..00000000000 --- a/tensorflow/tensorboard/components/tf_audio_dashboard/demo/index.html +++ /dev/null @@ -1,67 +0,0 @@ - - - - - - - - -Audio Dashboard Demo - - - - diff --git a/tensorflow/tensorboard/components/tf_audio_dashboard/index.html b/tensorflow/tensorboard/components/tf_audio_dashboard/index.html deleted file mode 100644 index 157f1692658..00000000000 --- a/tensorflow/tensorboard/components/tf_audio_dashboard/index.html +++ /dev/null @@ -1,25 +0,0 @@ - - - - -tf-audio-dashboard - - - - - diff --git a/tensorflow/tensorboard/components/tf_audio_dashboard/test/BUILD b/tensorflow/tensorboard/components/tf_audio_dashboard/test/BUILD deleted file mode 100644 index 3d50e5d2caa..00000000000 --- a/tensorflow/tensorboard/components/tf_audio_dashboard/test/BUILD +++ /dev/null @@ -1,33 +0,0 @@ -package( - default_testonly = True, - default_visibility = ["//tensorflow/tensorboard:internal"], -) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "test", - srcs = [ - "audioDashboardTests.ts", - "tests.html", - ], - path = "/tf-audio-dashboard/test", - deps = [ - "//tensorflow/tensorboard/components/tf_audio_dashboard", - "//tensorflow/tensorboard/components/tf_backend", - "//tensorflow/tensorboard/components/tf_imports:d3", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "//tensorflow/tensorboard/components/tf_imports:web_component_tester", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - "//tensorflow/tensorboard/demo:demo_data", - ], -) - -filegroup( - name = "all_files", - testonly = 0, - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_audio_dashboard/test/audioDashboardTests.ts b/tensorflow/tensorboard/components/tf_audio_dashboard/test/audioDashboardTests.ts deleted file mode 100644 index 6ccd9bede66..00000000000 --- a/tensorflow/tensorboard/components/tf_audio_dashboard/test/audioDashboardTests.ts +++ /dev/null @@ -1,49 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import * as backend_backend from '../../tf-backend/backend'; -import {createRouter, setRouter} from '../../tf-backend/router'; - -// TODO(dandelion): Fix me. -declare function fixture(id: string): any; -declare function stub(x, y: any): void; - -describe('audio dashboard tests', () => { - let audioDash; - let reloadCount = 0; - beforeEach(() => { - audioDash = fixture('testElementFixture'); - const router = createRouter('/data', true); - setRouter(router); - const backend = new backend_backend.Backend(); - audioDash.backend = backend; - stub('tf-audio-loader', { - reload: () => { reloadCount++; }, - }); - }); - - it('calling reload on dashboard reloads the audio-loaders', (done) => { - audioDash.backendReload().then(() => { - reloadCount = 0; - const loaders = - [].slice.call(audioDash.getElementsByTagName('tf-audio-loader')); - audioDash.frontendReload(); - setTimeout(() => { - chai.assert.isTrue(reloadCount >= 2); - done(); - }); - }); - }); -}); diff --git a/tensorflow/tensorboard/components/tf_audio_dashboard/test/tests.html b/tensorflow/tensorboard/components/tf_audio_dashboard/test/tests.html deleted file mode 100644 index 891e8bf0c29..00000000000 --- a/tensorflow/tensorboard/components/tf_audio_dashboard/test/tests.html +++ /dev/null @@ -1,38 +0,0 @@ - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_audio_dashboard/tf-audio-dashboard.html b/tensorflow/tensorboard/components/tf_audio_dashboard/tf-audio-dashboard.html deleted file mode 100644 index 7caea7130d0..00000000000 --- a/tensorflow/tensorboard/components/tf_audio_dashboard/tf-audio-dashboard.html +++ /dev/null @@ -1,94 +0,0 @@ - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_audio_dashboard/tf-audio-grid.html b/tensorflow/tensorboard/components/tf_audio_dashboard/tf-audio-grid.html deleted file mode 100644 index c71d8bdd4bf..00000000000 --- a/tensorflow/tensorboard/components/tf_audio_dashboard/tf-audio-grid.html +++ /dev/null @@ -1,183 +0,0 @@ - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_audio_dashboard/tf-audio-loader.html b/tensorflow/tensorboard/components/tf_audio_dashboard/tf-audio-loader.html deleted file mode 100644 index 71539537d0e..00000000000 --- a/tensorflow/tensorboard/components/tf_audio_dashboard/tf-audio-loader.html +++ /dev/null @@ -1,237 +0,0 @@ - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_backend/BUILD b/tensorflow/tensorboard/components/tf_backend/BUILD deleted file mode 100644 index 50fc267dc4d..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/BUILD +++ /dev/null @@ -1,45 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_backend", - srcs = [ - "backend.ts", - "behavior.ts", - "requestManager.ts", - "router.ts", - "runsStore.ts", - "tf-backend.html", - "urlPathHelpers.ts", - ], - path = "/tf-backend", - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/tensorboard/components/tf_imports:d3", - "//tensorflow/tensorboard/components/tf_imports:lodash", - "//tensorflow/tensorboard/components/tf_imports:plottable", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "//tensorflow/tensorboard/components/vz_sorting", - ], -) - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [":tf_backend"], - destdir = "tf-backend", - deps = [ - "//tensorflow/tensorboard/components/tf_imports_google:lib", - "//tensorflow/tensorboard/components/vz_sorting:legacy", - "//third_party/javascript/polymer/v1/polymer:lib", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_backend/backend.ts b/tensorflow/tensorboard/components/tf_backend/backend.ts deleted file mode 100644 index 023414b6b75..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/backend.ts +++ /dev/null @@ -1,608 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {compareTagNames} from '../vz-sorting/sorting'; -import {RequestManager} from './requestManager'; -import {getRouter} from './router'; -import {demoify, queryEncoder} from './urlPathHelpers'; - -export interface RunEnumeration { - histograms: string[]; - compressedHistogramTuples: string[]; - scalars: string[]; - images: string[]; - audio: string[]; - graph: boolean; - run_metadata: string[]; -} - -export interface LogdirResponse { logdir: string; } - -export interface RunsResponse { [runName: string]: RunEnumeration; } - -export type RunToTag = { - [run: string]: string[]; -}; - -export interface Datum { - wall_time: Date; - step: number; -} - -export type ScalarDatum = Datum & Scalar; -export interface Scalar { scalar: number; } - -export interface Text { text: string; } -export type TextDatum = Datum & Text; - -export type HistogramDatum = Datum & Histogram; -export interface Histogram { - min: number; - max: number; - nItems?: number; - sum?: number; - sumSquares?: number; - bucketRightEdges: number[]; - bucketCounts: number[]; -} - -export interface HistogramBin { - x: number; - dx: number; - y: number; -} -export type HistogramSeriesDatum = HistogramSeries & Datum; -export interface HistogramSeries { bins: HistogramBin[]; } - -export type ImageDatum = Datum & Image; -export interface Image { - width: number; - height: number; - url: string; -} - -export type AudioDatum = Datum & Audio; -export interface Audio { - content_type: string; - url: string; -} - -// A health pill encapsulates an overview of tensor element values. The value -// field is a list of 12 numbers that shed light on the status of the tensor. -export interface HealthPill { - device_name: string; - node_name: string; - output_slot: number; - dtype: string; - shape: number[]; - value: number[]; -} - -// When updating this type, keep it consistent with the HealthPill interface -// in tf_graph_common/lib/scene/scene.ts. -export type HealthPillDatum = Datum & HealthPill; -// A health pill response is a mapping from node name to a list of health pill -// data entries. -export interface HealthPillsResponse { [key: string]: HealthPillDatum[]; } - -// An object that encapsulates an alert issued by the debugger. This alert is -// sent by debugging libraries after bad values (NaN, +/- Inf) are encountered. -export interface DebuggerNumericsAlertReport { - device_name: string; - tensor_name: string; - first_timestamp: number; - nan_event_count: number; - neg_inf_event_count: number; - pos_inf_event_count: number; -} -// A DebuggerNumericsAlertReportResponse contains alerts issued by the debugger -// in ascending order of timestamp. This helps the user identify for instance -// when bad values first appeared in the model. -export type DebuggerNumericsAlertReportResponse = DebuggerNumericsAlertReport[]; - -export const TYPES = [ - 'scalar', 'histogram', 'compressedHistogram', 'graph', 'image', 'audio', - 'runMetadata', 'text' -]; -/** - * The Backend class provides a convenient and typed interface to the backend. - * - * It provides methods corresponding to the different data sources on the - * TensorBoard backend. These methods return a promise containing the data - * from the backend. This class does some post-processing on the data; for - * example, converting data elements tuples into js objects so that they can - * be accessed in a more convenient and clearly-documented fashion. - */ -export class Backend { - public requestManager: RequestManager; - - /** - * Construct a Backend instance. - * @param requestManager The RequestManager, overwritable so you may - * manually clear request queue, etc. Defaults to a new RequestManager. - */ - constructor(requestManager?: RequestManager) { - this.requestManager = requestManager || new RequestManager(); - } - - /** - * Returns a promise for requesting the logdir string. - */ - public logdir(): Promise { - return this.requestManager.request(getRouter().logdir()); - } - - /** - * Returns a listing of all the available data in the TensorBoard backend. - */ - public runs(): Promise { - return this.requestManager.request(getRouter().runs()); - } - - /** - * Return a promise showing the Run-to-Tag mapping for scalar data. - */ - public scalarTags(): Promise { - return this.requestManager.request( - getRouter().pluginRoute('scalars', '/tags')); - } - - /** - * Return a promise showing the Run-to-Tag mapping for histogram data. - */ - public histogramTags(): Promise { - return this.requestManager.request( - getRouter().pluginRoute('histograms', '/tags')); - } - - /** - * Return a promise showing the Run-to-Tag mapping for image data. - */ - public imageTags(): Promise { - return this.requestManager.request( - getRouter().pluginRoute('images', '/tags')); - } - - /** - * Return a promise showing the Run-to-Tag mapping for audio data. - */ - public audioTags(): Promise { - return this.requestManager.request( - getRouter().pluginRoute('audio', '/tags')); - } - - /** - * Return a promise showing the Run-to-Tag mapping for compressedHistogram - * data. - */ - public compressedHistogramTags(): Promise { - return this.requestManager.request( - getRouter().pluginRoute('distributions', '/tags')); - } - - /** - * Returns a promise showing the Run-to-Tag mapping for profile data. - */ - public profileTags(): Promise { - let url = getRouter().pluginRoute('profile', '/tags'); - if (getRouter().isDemoMode()) { - url += '.json'; - } - return this.requestManager.request(url); - } - - /** - * Return a promise showing list of runs that contain graphs. - */ - public graphRuns(): Promise { - return this.requestManager.request( - getRouter().pluginRoute('graphs', '/runs')); - } - - /** - * Return a promise showing the Run-to-Tag mapping for run_metadata objects. - */ - public runMetadataTags(): Promise { - return this.requestManager.request( - getRouter().pluginRoute('graphs', '/run_metadata_tags')); - } - - - /** - * Returns a promise showing the Run-to-Tag mapping for text data. - */ - public textRuns(): Promise { - return this.requestManager.request(getRouter().textRuns()); - } - - - /** - * Returns a promise containing TextDatums for given run and tag. - */ - public text(tag: string, run: string): Promise { - const url = getRouter().text(tag, run); - // tslint:disable-next-line:no-any it's convenient and harmless here - return this.requestManager.request(url).then(map((x: any) => { - x.wall_time = timeToDate(x.wall_time); - return x; - })); - } - - /** - * Return a URL to fetch a graph (cf. method 'graph'). - */ - public graphUrl(run: string, limitAttrSize?: number, largeAttrsKey?: string): - string { - const demoMode = getRouter().isDemoMode(); - const base = getRouter().pluginRoute('graphs', '/graph'); - const optional = (p) => (p != null && !demoMode || undefined) && p; - const parameters = { - 'run': run, - 'limit_attr_size': optional(limitAttrSize), - 'large_attrs_key': optional(largeAttrsKey), - }; - const extension = demoMode ? '.pbtxt' : ''; - return base + queryEncoder(parameters) + extension; - } - - public graph(run: string, limitAttrSize?: number, largeAttrsKey?: string): - Promise { - const url = this.graphUrl(run, limitAttrSize, largeAttrsKey); - return this.requestManager.request(url); - } - - /** - * Return a promise containing ScalarDatums for given run and tag. - */ - public scalar(tag: string, run: string): Promise> { - let p: Promise[]>; - const url = getRouter().pluginRunTagRoute('scalars', '/scalars')(tag, run); - p = this.requestManager.request(url); - return p.then(map(detupler(createScalar))); - } - - /** - * Returns a promise for requesting the health pills for a list of nodes. This - * route is used by the debugger plugin. - */ - public healthPills(nodeNames: string[], step?: number): - Promise { - const postData = { - 'node_names': JSON.stringify(nodeNames), - - // Events files with debugger data fall under this special run. - 'run': '__debugger_data__', - }; - if (step !== undefined) { - // The user requested health pills for a specific step. This request - // might be slow since the backend reads events sequentially from disk. - postData['step'] = step; - } - return this.requestManager.request(getRouter().healthPills(), postData); - } - - /** - * Returns a promise for alerts for bad values (detected by the debugger). - * This route is used by the debugger plugin. - */ - public debuggerNumericsAlerts(): - Promise { - return this.requestManager.request( - getRouter().pluginRoute('debugger', '/numerics_alert_report')); - } - - /** - * Return a promise containing HistogramDatums for given run and tag. - */ - public histogram(tag: string, run: string): - Promise> { - let p: Promise[]>; - const url = - getRouter().pluginRunTagRoute('histograms', '/histograms')(tag, run); - p = this.requestManager.request(url); - return p.then(map(detupler(createHistogram))).then(function(histos) { - // Get the minimum and maximum values across all histograms so that the - // visualization is aligned for all timesteps. - const min = d3.min(histos, d => d.min); - const max = d3.max(histos, d => d.max); - - return histos.map(function(histo, i) { - return { - wall_time: histo.wall_time, - step: histo.step, - bins: convertBins(histo, min, max) - }; - }); - }); - } - - /** - * Return a promise containing ImageDatums for given run and tag. - */ - public image(tag: string, run: string): Promise> { - const url = (getRouter().pluginRunTagRoute('images', '/images')(tag, run)); - let p: Promise; - p = this.requestManager.request(url); - return p.then(map(this.createImage.bind(this))); - } - - /** - * Return a promise containing AudioDatums for given run and tag. - */ - public audio(tag: string, run: string): Promise> { - const url = (getRouter().pluginRunTagRoute('audio', '/audio')(tag, run)); - let p: Promise; - p = this.requestManager.request(url); - return p.then(map(this.createAudio.bind(this))); - } - - /** - * Returns a promise containing profile data for given run and tag. - */ - public profile(tag: string, run: string): Promise { - let url = (getRouter().pluginRunTagRoute('profile', '/data')(tag, run)); - if (getRouter().isDemoMode()) { - url += '.json'; - } - return this.requestManager.request(url); - } - - /** - * Returns the url for the RunMetadata for the given run/tag. - */ - public runMetadataUrl(tag: string, run: string): string { - return getRouter().pluginRunTagRoute('graphs', '/run_metadata')(tag, run); - } - - /** - * Returns a promise to load the string RunMetadata for given run/tag. - */ - public runMetadata(tag: string, run: string): Promise { - const url = this.runMetadataUrl(tag, run); - return this.requestManager.request(url); - } - - /** - * Get compressedHistogram data. - * Unlike other methods, don't bother reprocessing this data into a nicer - * format. This is because we will deprecate this route. - */ - private compressedHistogram(tag: string, run: string): - Promise> { - const url = (getRouter().pluginRunTagRoute( - 'distributions', '/distributions')(tag, run)); - let p: Promise[]>; - p = this.requestManager.request(url); - return p.then(map(detupler((x) => x))); - } - - private createImage(x: ImageMetadata): Image&Datum { - const pluginRoute = getRouter().pluginRoute('images', '/individualImage'); - - let query = x.query; - if (pluginRoute.indexOf('?') > -1) { - // The route already has GET parameters. Append our parameters to them. - query = '&' + query; - } else { - // The route lacks GET parameters. We append them. - query = '?' + query; - } - - if (getRouter().isDemoMode()) { - query = demoify(query); - } - - let individualImageUrl = pluginRoute + query; - // Include wall_time just to disambiguate the URL and force the browser - // to reload the image when the URL changes. The backend doesn't care - // about the value. - individualImageUrl += - getRouter().isDemoMode() ? '.png' : '&ts=' + x.wall_time; - - return { - width: x.width, - height: x.height, - wall_time: timeToDate(x.wall_time), - step: x.step, - url: individualImageUrl, - }; - } - - private createAudio(x: AudioMetadata): Audio&Datum { - const pluginRoute = getRouter().pluginRoute('audio', '/individualAudio'); - - let query = x.query; - if (pluginRoute.indexOf('?') > -1) { - // The route already has GET parameters. Append our parameters to them. - query = '&' + query; - } else { - // The route lacks GET parameters. We append them. - query = '?' + query; - } - - if (getRouter().isDemoMode()) { - query = demoify(query); - } - - let individualAudioUrl = pluginRoute + query; - // Include wall_time just to disambiguate the URL and force the browser - // to reload the audio when the URL changes. The backend doesn't care - // about the value. - individualAudioUrl += - getRouter().isDemoMode() ? '.wav' : '&ts=' + x.wall_time; - - return { - content_type: x.content_type, - wall_time: timeToDate(x.wall_time), - step: x.step, - url: individualAudioUrl, - }; - } -} - -/** Given a RunToTag, return sorted array of all runs */ -export function getRuns(r: RunToTag): string[] { - return _.keys(r).sort(compareTagNames); -} - -/** Given a RunToTag, return array of all tags (sorted + dedup'd) */ -export function getTags(r: RunToTag): string[] { - return _.union.apply(null, _.values(r)).sort(compareTagNames); -} - -/** - * Given a RunToTag and an array of runs, return every tag that appears for - * at least one run. - * Sorted, deduplicated. - */ -export function filterTags(r: RunToTag, runs: string[]): string[] { - let result = []; - runs.forEach((x) => result = result.concat(r[x])); - return _.uniq(result).sort(compareTagNames); -} - -function timeToDate(x: number): Date { - return new Date(x * 1000); -}; - -/** Just a curryable map to make things cute and tidy. */ -function map(f: (x: T) => U): (arr: T[]) => U[] { - return function(arr: T[]): U[] { - return arr.map(f); - }; -}; - -/** - * This is a higher order function that takes a function that transforms a - * T into a G, and returns a function that takes TupleDatas and converts - * them into the intersection of a G and a Datum. - */ -function detupler(xform: (x: T) => G): (t: TupleData) => Datum & G { - return function(x: TupleData): Datum & G { - // Create a G, assert it has type - let obj = xform(x[2]); - // ... patch in the properties of datum - obj.wall_time = timeToDate(x[0]); - obj.step = x[1]; - return obj; - }; -}; - -function createScalar(x: number): Scalar { - return {scalar: x}; -} - -function createHistogram(x: HistogramTuple): Histogram { - return { - min: x[0], - max: x[1], - nItems: x[2], - sum: x[3], - sumSquares: x[4], - bucketRightEdges: x[5], - bucketCounts: x[6], - }; -} - -/** - * Takes histogram data as stored by tensorboard backend and converts it to - * the standard d3 histogram data format to make it more compatible and easier - * to visualize. When visualizing histograms, having the left edge and width - * makes things quite a bit easier. The bins are also converted to have an - * uniform width, what makes the visualization easier to understand. - * - * @param histogram A histogram from tensorboard backend. - * @param min The leftmost edge. The binning will start on it. - * @param max The rightmost edge. The binning will end on it. - * @param numBins The number of bins of the converted data. The default of 30 - * is a sensible default, using more starts to get artifacts because the event - * data is stored in buckets, and you start being able to see the aliased - * borders between each bucket. - * @return A histogram bin. Each bin has an x (left edge), a dx (width), - * and a y (count). - * - * If given rightedges are inclusive, then these left edges (x) are exclusive. - */ -export function convertBins( - histogram: Histogram, min: number, max: number, numBins = 30) { - if (histogram.bucketRightEdges.length !== histogram.bucketCounts.length) { - throw(new Error('Edges and counts are of different lengths.')); - } - - if (max === min) { - // Create bins even if all the data has a single value. - max = min * 1.1 + 1; - min = min / 1.1 - 1; - } - const binWidth = (max - min) / numBins; - let bucketLeft = min; // Use the min as the starting point for the bins. - let bucketPos = 0; - return d3.range(min, max, binWidth).map((binLeft) => { - const binRight = binLeft + binWidth; - - // Take the count of each existing bucket, multiply it by the proportion - // of overlap with the new bin, then sum and store as the count for the - // new bin. If no overlap, will add to zero, if 100% overlap, will include - // the full count into new bin. - let binY = 0; - while (bucketPos < histogram.bucketRightEdges.length) { - // Clip the right edge because right-most edge can be infinite-sized. - const bucketRight = Math.min(max, histogram.bucketRightEdges[bucketPos]); - - const intersect = - Math.min(bucketRight, binRight) - Math.max(bucketLeft, binLeft); - const count = (intersect / (bucketRight - bucketLeft)) * - histogram.bucketCounts[bucketPos]; - - binY += intersect > 0 ? count : 0; - - // If bucketRight is bigger than binRight, than this bin is finished and - // there is data for the next bin, so don't increment bucketPos. - if (bucketRight > binRight) { - break; - } - bucketLeft = Math.max(min, bucketRight); - bucketPos++; - } - - return {x: binLeft, dx: binWidth, y: binY}; - }); -} - -/** - * The following interfaces (TupleData, HistogramTuple, - * CompressedHistogramTuple, ImageMetadata, and AudioMetadata) describe how - * the data is sent over from the backend. - */ -type TupleData = [number, number, T]; // wall_time, step - -// Min, Max, nItems, Sum, Sum_Squares, right edges of buckets, nItems in -// buckets -type HistogramTuple = - [number, number, number, number, number, number[], number[]]; -type CompressedHistogramTuple = [number, number][]; // percentile, value -interface ImageMetadata { - width: number; - height: number; - wall_time: number; - step: number; - query: string; -} -interface AudioMetadata { - content_type: string; - wall_time: number; - step: number; - query: string; -} diff --git a/tensorflow/tensorboard/components/tf_backend/behavior.ts b/tensorflow/tensorboard/components/tf_backend/behavior.ts deleted file mode 100644 index 8df791eface..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/behavior.ts +++ /dev/null @@ -1,148 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -import {getRuns, getTags, TYPES} from './backend'; - -/** @polymerBehavior */ -export const BackendBehavior = { - properties: { - /** *** Required properties *** */ - /** Data type. One of Backend.TYPES */ - dataType: { - type: String, - observer: '_throwErrorOnUnrecognizedType', - }, - - /** Backend for data loading. */ - backend: { - type: Object, - }, - - /** Should it automatically load when configured ready? Default true. */ - autoLoad: { - type: Boolean, - value: true, - }, - - /** *** Component-provided properties *** */ - /** Every tag available for data type (sorted, dedpulicated) */ - tags: { - type: Array, - readOnly: true, - notify: true, - }, - - /** Every run available for data type (sorted) */ - runs: { - type: Array, - readOnly: true, - notify: true, - }, - - /** Mapping from runs to tags for the data type */ - run2tag: { - type: Object, - readOnly: true, - notify: true, - }, - - /** Promise provider for the data. Useful for passing to subcomponents */ - dataProvider: - {type: Function, computed: '_getDataProvider(dataType, backend)'}, - - /** Has the dashboard loaded yet? */ - loadState: { - type: String, - value: 'noload', // [noload, pending, loaded, failure] - readOnly: true, - }, - - /** - * True if dashboard has loaded, and no tags were found. - * Persists through subsequent reloads (ie. still true while - * next load is pending) so warning won't flash away every reload - * when there is no data. - */ - dataNotFound: { - type: Boolean, - value: false, - readOnly: true, - } - - }, - observers: ['_do_autoLoad(dataType, backend, autoLoad)'], - /** - * Reloading works in two steps: - * Backend reload, which gets metadata on available runs, tags, etc from - * the backend. - * Frontend reload, which loads new data for each chart or visual display. - * Backend reload logic is provided by this behavior. The frontend reload - * logic should be provided elsewhere, since it is component-specific. - * To keep things simple and consistent, we do the backend reload first, - * and the frontend reload afterwards. - */ - reload() { - return this.backendReload().then((x) => { - return this.frontendReload(); - }); - }, - /** - * Load data from backend and then set run2tag, tags, runs, and loadState. - * Returns a promise that resolves/rejects when data is loaded. - */ - backendReload() { - if (this.dataType == null) { - throw new Error('BackendBehavior: Need a dataType to reload.'); - } - if (this.backend == null) { - throw new Error('BackendBehavior: Need a backend to reload.'); - } - const runsRoute = (this.backend[this.dataType + 'Runs'] || - this.backend[this.dataType + 'Tags']) - .bind(this.backend); - this._setLoadState('pending'); - return runsRoute().then( - (x) => { - this._setLoadState('loaded'); - if (_.isEqual(x, this.run2tag)) { - // If x and run2tag are equal, let's avoid updating everything - // since that can needlessly trigger run changes, reloads, etc - return x; - } - this._setRun2tag(x); - const tags = getTags(x); - this._setDataNotFound(tags.length === 0); - this._setTags(tags); - this._setRuns(getRuns(x)); - return x; - }, - (fail) => { - this._setLoadState('failure'); - return fail; - }); - }, - _do_autoLoad(type, backend, autoLoad) { - if (autoLoad) { - this.reload(); - } - }, - _getDataProvider(dataType, backend) { - return this.backend[this.dataType].bind(this.backend); - }, - _throwErrorOnUnrecognizedType(dataType) { - if (TYPES.indexOf(dataType) === -1) { - throw new Error('BackendBehavior: Unknown dataType ' + dataType); - } - }, -}; diff --git a/tensorflow/tensorboard/components/tf_backend/requestManager.ts b/tensorflow/tensorboard/components/tf_backend/requestManager.ts deleted file mode 100644 index 0fa198416e8..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/requestManager.ts +++ /dev/null @@ -1,177 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -interface ResolveReject { - resolve: Function; - reject: Function; -} -/** - * Manages many fetch requests. Launches up to nSimultaneousRequests - * simultaneously, and maintains a LIFO queue of requests to process when - * more urls are requested than can be handled at once. The queue can be - * cleared. - * - * When a request is made, a Promise is returned which resolves with the - * parsed JSON result from the request. - */ -export class RequestCancellationError extends Error { - public name = 'RequestCancellationError'; -} - -export class RequestNetworkError extends Error { - public name: string; - public req: XMLHttpRequest; - public url: string; - - constructor(req: XMLHttpRequest, url) { - super(); - this.message = `RequestNetworkError: ${req.status} at ${url}`; - this.name = 'RequestNetworkError'; - this.req = req; - this.url = url; - } -} - -export class RequestManager { - private _queue: ResolveReject[]; - private _maxRetries: number; - private _nActiveRequests: number; - private _nSimultaneousRequests: number; - - constructor(nSimultaneousRequests = 10, maxRetries = 3) { - this._queue = []; - this._nActiveRequests = 0; - this._nSimultaneousRequests = nSimultaneousRequests; - this._maxRetries = maxRetries; - } - - /** - * Gives a promise that loads assets from given url (respects queuing). If - * postData is provided, this request will use POST, not GET. This is an - * object mapping POST keys to string values. - */ - public request(url: string, postData?: {[key: string]: string}): - Promise { - const promise = - new Promise((resolve, reject) => { - const resolver = {resolve: resolve, reject: reject}; - this._queue.push(resolver); - this.launchRequests(); - }) - .then(() => { - return this.promiseWithRetries(url, this._maxRetries, postData); - }) - .then( - (response) => { - // Success - Let's free space for another active - // request, and launch it - this._nActiveRequests--; - this.launchRequests(); - return response; - }, - (rejection) => { - if (rejection.name === 'RequestNetworkError') { - // If we failed due to network error, we should - // decrement - // _nActiveRequests because this request was - // active - this._nActiveRequests--; - this.launchRequests(); - } - return Promise.reject(rejection); - }); - return promise; - } - - public clearQueue() { - while (this._queue.length > 0) { - this._queue.pop().reject( - new RequestCancellationError('Request cancelled by clearQueue')); - } - } - - /* Return number of currently pending requests */ - public activeRequests(): number { - return this._nActiveRequests; - } - - /* Return total number of outstanding requests (includes queue) */ - public outstandingRequests(): number { - return this._nActiveRequests + this._queue.length; - } - - private launchRequests() { - while (this._nActiveRequests < this._nSimultaneousRequests && - this._queue.length > 0) { - this._nActiveRequests++; - this._queue.pop().resolve(); - } - } - - /** - * Try to request a given URL using overwritable _promiseFromUrl method. - * If the request fails for any reason, we will retry up to maxRetries - * times. In practice, this will help us paper over transient network issues - * like '502 Bad Gateway'. - * By default, Chrome displays network errors in console, so - * the user will be able to tell when the requests are failing. I think this - * is a feature, if the request failures and retries are causing any - * pain to users, they can see it and file issues. - */ - private promiseWithRetries( - url: string, maxRetries: number, postData?: {[key: string]: string}) { - var success = (x) => x; - var failure = (x) => { - if (maxRetries > 0) { - return this.promiseWithRetries(url, maxRetries - 1, postData); - } else { - return Promise.reject(x); - } - }; - return this._promiseFromUrl(url, postData).then(success, failure); - } - - /* Actually get promise from url using XMLHttpRequest */ - protected _promiseFromUrl(url: string, postData?: {[key: string]: string}) { - return new Promise((resolve, reject) => { - let req = new XMLHttpRequest(); - req.open(postData ? 'POST' : 'GET', url); - - let formData; - if (postData) { - // We are to make a POST request. - formData = new FormData(); - for (let postKey in postData) { - if (postKey) { - // The linter requires 'for in' loops to be filtered by an if - // condition. - formData.append(postKey, postData[postKey]); - } - } - } - req.onload = function() { - if (req.status === 200) { - resolve(JSON.parse(req.responseText)); - } else { - reject(new RequestNetworkError(req, url)); - } - }; - req.onerror = function() { - reject(new RequestNetworkError(req, url)); - }; - req.send(formData); - }); - } -} diff --git a/tensorflow/tensorboard/components/tf_backend/router.ts b/tensorflow/tensorboard/components/tf_backend/router.ts deleted file mode 100644 index 598546004e1..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/router.ts +++ /dev/null @@ -1,97 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {demoify, queryEncoder} from './urlPathHelpers' - -export type RunTagUrlFn = (tag: string, run: string) => string; - -export interface Router { - logdir: () => string; - runs: () => string; - isDemoMode: () => boolean; - textRuns: () => string; - text: RunTagUrlFn; - healthPills: () => string; - pluginRoute: (pluginName: string, route: string) => string; - pluginRunTagRoute: (pluginName: string, route: string) => RunTagUrlFn; -} -; - -/** - * Create a router for communicating with the TensorBoard backend. You - * can pass this to `setRouter` to make it the global router. - * - * @param dataDir {string} The base prefix for finding data on server. - * @param demoMode {boolean} Whether to modify urls for filesystem demo usage. - */ -export function createRouter(dataDir = 'data', demoMode = false): Router { - var clean = demoMode ? demoify : (x) => x; - if (dataDir[dataDir.length - 1] === '/') { - dataDir = dataDir.slice(0, dataDir.length - 1); - } - function standardRoute(route: string, demoExtension = '.json'): - ((tag: string, run: string) => string) { - return function(tag: string, run: string): string { - var url = - dataDir + '/' + route + clean(queryEncoder({tag: tag, run: run})); - if (demoMode) { - url += demoExtension; - } - return url; - }; - } - function pluginRoute(pluginName: string, route: string): string { - return `${dataDir}/plugin/${pluginName}${route}`; - } - function pluginRunTagRoute(pluginName: string, route: string): - ((tag: string, run: string) => string) { - const base = pluginRoute(pluginName, route); - return (tag, run) => base + clean(queryEncoder({tag, run})); - } - return { - logdir: () => dataDir + '/logdir', - runs: () => dataDir + '/runs' + (demoMode ? '.json' : ''), - isDemoMode: () => demoMode, - healthPills: () => dataDir + '/plugin/debugger/health_pills', - textRuns: () => dataDir + '/plugin/text/runs' + (demoMode ? '.json' : ''), - text: standardRoute('plugin/text/text'), - pluginRoute, - pluginRunTagRoute, - }; -}; - -let _router: Router = createRouter(); - -/** - * @return {Router} the global router - */ -export function getRouter(): Router { - return _router; -} - -/** - * Set the global router, to be returned by future calls to `getRouter`. - * You may wish to invoke this if you are running a demo server with a - * custom path prefix, or if you have customized the TensorBoard backend - * to use a different path. - * - * @param {Router} router the new global router - */ -export function setRouter(router: Router): void { - if (router == null) { - throw new Error('Router required, but got: ' + router); - } - _router = router; -} diff --git a/tensorflow/tensorboard/components/tf_backend/runsStore.ts b/tensorflow/tensorboard/components/tf_backend/runsStore.ts deleted file mode 100644 index bcaff994ce8..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/runsStore.ts +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -import {RequestManager} from './requestManager'; -import {getRouter} from './router'; - -let runs: string[] = []; - -export type Listener = () => void; -const listeners = new Set(); - -const requestManager = new RequestManager(1 /* simultaneous request */); - -/** - * Register a listener (nullary function) to be called when new runs are - * available. - */ -export function addListener(listener: Listener): void { - listeners.add(listener); -} - -/** - * Remove a listener registered with `addListener`. - */ -export function removeListener(listener: Listener): void { - listeners.delete(listener); -} - -/** - * Asynchronously load or reload the runs data. Listeners will be - * invoked if this causes the runs data to change. - * - * @see addListener - * @return {Promise} a promise that resolves when the runs have - * loaded - */ -export function fetchRuns(): Promise { - const url = getRouter().runs(); - return requestManager.request(url).then(newRuns => { - if (!_.isEqual(runs, newRuns)) { - runs = newRuns; - listeners.forEach(listener => { - listener(); - }); - } - }); -} - -/** - * Get the current list of runs. If no data is available, this will be - * an empty array (i.e., there is no distinction between "no runs" and - * "no runs yet"). - */ -export function getRuns(): string[] { - return runs.slice(); -} diff --git a/tensorflow/tensorboard/components/tf_backend/test/BUILD b/tensorflow/tensorboard/components/tf_backend/test/BUILD deleted file mode 100644 index da70f8a9daa..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/test/BUILD +++ /dev/null @@ -1,32 +0,0 @@ -package( - default_testonly = True, - default_visibility = ["//tensorflow/tensorboard:internal"], -) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "test", - srcs = [ - "tests.html", - "backendTests.ts", - "behaviorTests.ts", - "requestManagerTests.ts", - ] + glob(["data/**"]), - path = "/tf-backend/test", - deps = [ - "//tensorflow/tensorboard/components/tf_backend", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "//tensorflow/tensorboard/components/tf_imports:web_component_tester", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - ], -) - -filegroup( - name = "all_files", - testonly = 0, - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_backend/test/backendTests.ts b/tensorflow/tensorboard/components/tf_backend/test/backendTests.ts deleted file mode 100644 index 029c8359125..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/test/backendTests.ts +++ /dev/null @@ -1,294 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -import {Backend, convertBins, filterTags, getRuns, getTags, RunToTag, TYPES} from '../backend'; -import {RequestManager} from '../requestManager'; -import {createRouter, setRouter} from '../router'; -import {BAD_CHARACTERS, demoify, queryEncoder} from '../urlPathHelpers'; - -describe('urlPathHelpers', () => { - it('demoify works as expected', () => { - const demoified = demoify(BAD_CHARACTERS); - let allClean = ''; - for (let i = 0; i < BAD_CHARACTERS.length; i++) { - allClean += '_'; - } - chai.assert.equal(demoified, allClean, 'cleaning the BAD_CHARACTERS works'); - chai.assert.equal(demoify('foozod'), 'foozod', 'doesnt change safe string'); - chai.assert.equal(demoify('foo zod (2)'), 'foo_zod__2_', 'simple case'); - }); - - it('queryEncoder works with demoify on spaces and parens', () => { - const params = {foo: 'something with spaces and (parens)'}; - const actual = demoify(queryEncoder(params)); - const expected = '_foo_something_with_spaces_and__28parens_29'; - chai.assert.equal(actual, expected); - }); -}); - -function assertIsDatum(x) { - chai.assert.isNumber(x.step); - chai.assert.instanceOf(x.wall_time, Date); -} - -describe('backend tests', () => { - let backend: Backend; - let rm: RequestManager; - const base = 'data'; - const demoRouter = createRouter(base, /*demoMode=*/true); - beforeEach(() => { - // Construct a demo Backend (third param is true) - setRouter(demoRouter); - backend = new Backend(); - rm = new RequestManager(); - }); - - it('runs are loaded properly', (done) => { - const runsResponse = backend.runs(); - const actualRuns = rm.request(demoRouter.runs()); - Promise.all([runsResponse, actualRuns]).then((values) => { - chai.assert.deepEqual(values[0], values[1]); - done(); - }); - }); - - it('scalars are loaded properly', (done) => { - backend.scalar('cross_entropy (1)', 'run1').then((s) => { - // just check the data got reformatted properly - const aScalar = s[s.length - 1]; - assertIsDatum(aScalar); - chai.assert.isNumber(aScalar.scalar); - // verify date conversion works - chai.assert.equal(aScalar.wall_time.valueOf(), 40000); - done(); - }); - }); - - it('histograms are loaded properly', (done) => { - backend.histogram('histo1', 'run1').then((histos) => { - const histo = histos[0]; - assertIsDatum(histo); - chai.assert.instanceOf(histo.bins, Array); - done(); - }); - }); - - it('all registered types have handlers', () => { - TYPES.forEach((t: string) => { - chai.assert.isDefined(backend[t], t); - chai.assert.isDefined(backend[t + 'Runs'], t + 'Runs'); - }); - }); - - it('images are loaded properly', (done) => { - backend.image('im1', 'run1').then((images) => { - const image = images[0]; - assertIsDatum(image); - chai.assert.isNumber(image.width); - chai.assert.isNumber(image.height); - done(); - }); - }); - - it('audio is loaded properly', (done) => { - backend.audio('audio1', 'run1').then((audioClips) => { - const audio = audioClips[0]; - assertIsDatum(audio); - chai.assert.equal(audio.content_type, 'audio/wav'); - done(); - }); - }); - - it('trailing slash removed from base route', () => { - const r = createRouter('foo/'); - chai.assert.equal(r.runs(), 'foo/runs'); - }); - - it('run helper methods work', (done) => { - const scalar = {run1: ['cross_entropy (1)'], fake_run_no_data: ['scalar2']}; - const image = {run1: ['im1'], fake_run_no_data: ['im1', 'im2']}; - const audio = {run1: ['audio1'], fake_run_no_data: ['audio1', 'audio2']}; - const runMetadata = {run1: ['step99'], fake_run_no_data: ['step99']}; - const graph = ['fake_run_no_data']; - let count = 0; - function next() { - count++; - if (count === 4) { - done(); - } - } - backend.scalarTags().then((x) => { - chai.assert.deepEqual(x, scalar); - next(); - }); - backend.imageTags().then((x) => { - chai.assert.deepEqual(x, image); - next(); - }); - backend.audioTags().then((x) => { - chai.assert.deepEqual(x, audio); - next(); - }); - backend.runMetadataTags().then((x) => { - chai.assert.deepEqual(x, runMetadata); - next(); - }); - backend.graphRuns().then((x) => { - chai.assert.deepEqual(x, graph); - next(); - }); - }); - - it('runToTag helpers work', () => { - const r2t: RunToTag = { - run1: ['foo', 'bar', 'zod'], - run2: ['zod', 'zoink'], - a: ['foo', 'zod'] - }; - const empty1: RunToTag = {}; - const empty2: RunToTag = {run1: [], run2: []}; - chai.assert.deepEqual(getRuns(r2t), ['a', 'run1', 'run2']); - chai.assert.deepEqual(getTags(r2t), ['bar', 'foo', 'zod', 'zoink']); - chai.assert.deepEqual(filterTags(r2t, ['run1', 'run2']), getTags(r2t)); - chai.assert.deepEqual(filterTags(r2t, ['run1']), ['bar', 'foo', 'zod']); - chai.assert.deepEqual( - filterTags(r2t, ['run2', 'a']), ['foo', 'zod', 'zoink']); - - chai.assert.deepEqual(getRuns(empty1), []); - chai.assert.deepEqual(getTags(empty1), []); - - chai.assert.deepEqual(getRuns(empty2), ['run1', 'run2']); - chai.assert.deepEqual(getTags(empty2), []); - }); -}); - -describe('Verify that the histogram format conversion works.', () => { - - function assertHistogramEquality(h1, h2) { - h1.forEach((b1, i) => { - const b2 = h2[i]; - chai.assert.closeTo(b1.x, b2.x, 1e-10); - chai.assert.closeTo(b1.dx, b2.dx, 1e-10); - chai.assert.closeTo(b1.y, b2.y, 1e-10); - }); - } - - it('Throws and error if the inputs are of different lengths', () => { - chai.assert.throws(() => { - convertBins( - {bucketRightEdges: [0], bucketCounts: [1, 2], min: 1, max: 2}, 1, 2, - 2); - }, 'Edges and counts are of different lengths.'); - }); - - it('Handles data with no bins', () => { - chai.assert.deepEqual( - convertBins( - {bucketRightEdges: [], bucketCounts: [], min: 0, max: 0}, 0, 0, 0), - []); - }); - - it('Handles data with one bin', () => { - const counts = [1]; - const rightEdges = [1.21e-12]; - const histogram = [{x: 1.1e-12, dx: 1.21e-12 - 1.1e-12, y: 1}]; - const newHistogram = convertBins( - { - bucketRightEdges: rightEdges, - bucketCounts: counts, - min: 1.1e-12, - max: 1.21e-12 - }, - 1.1e-12, 1.21e-12, 1); - assertHistogramEquality(newHistogram, histogram); - }); - - it('Handles data with two bins.', () => { - const counts = [1, 2]; - const rightEdges = [1.1e-12, 1.21e-12]; - const histogram = [ - {x: 1.0e-12, dx: 1.05e-13, y: 1.09090909090909}, - {x: 1.105e-12, dx: 1.05e-13, y: 1.9090909090909} - ]; - const newHistogram = convertBins( - { - bucketRightEdges: rightEdges, - bucketCounts: counts, - min: 1.0e-12, - max: 1.21e-12 - }, - 1.0e-12, 1.21e-12, 2); - assertHistogramEquality(newHistogram, histogram); - }); - - it('Handles a domain that crosses zero, but doesn\'t include zero as ' + - 'an edge.', - () => { - const counts = [1, 2]; - const rightEdges = [-1.0e-12, 1.0e-12]; - const histogram = [ - {x: -1.1e-12, dx: 1.05e-12, y: 1.95}, - {x: -0.5e-13, dx: 1.05e-12, y: 1.05} - ]; - const newHistogram = convertBins( - { - bucketRightEdges: rightEdges, - bucketCounts: counts, - min: -1.1e-12, - max: 1.0e-12 - }, - -1.1e-12, 1.0e-12, 2); - assertHistogramEquality(newHistogram, histogram); - }); - - it('Handles a histogram of all zeros', () => { - const h = { - min: 0, - max: 0, - nItems: 51200, - sum: 0, - sumSquares: 0, - bucketRightEdges: [0, 1e-12, 1.7976931348623157e+308], - bucketCounts: [0, 51200, 0], - wall_time: '2017-01-25T02:30:11.257Z', - step: 0 - }; - const newHistogram = convertBins(h, 0, 0, 5); - const expectedHistogram = [ - {x: -1, dx: 0.4, y: 0}, {x: -0.6, dx: 0.4, y: 0}, - {x: -0.2, dx: 0.4, y: 51200}, {x: 0.2, dx: 0.4, y: 0}, - {x: 0.6, dx: 0.4, y: 0} - ]; - assertHistogramEquality(newHistogram, expectedHistogram); - }); - - it('Handles a right-most right edge that extends to very large number.', - () => { - const counts = [1, 2, 3]; - const rightEdges = [0, 1.0e-12, 1.0e14]; - const histogram = [ - {x: -1.0e-12, dx: 0.7e-12, y: 0.7}, {x: -0.3e-12, dx: 0.7e-12, y: 1.1}, - {x: 0.4e-12, dx: 0.7e-12, y: 4.2} - ]; - const newHistogram = convertBins( - { - bucketRightEdges: rightEdges, - bucketCounts: counts, - min: -1.0e-12, - max: 1.1e-12 - }, - -1.0e-12, 1.1e-12, 3); - assertHistogramEquality(newHistogram, histogram); - }); -}); diff --git a/tensorflow/tensorboard/components/tf_backend/test/behaviorTests.ts b/tensorflow/tensorboard/components/tf_backend/test/behaviorTests.ts deleted file mode 100644 index 6bf328140e2..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/test/behaviorTests.ts +++ /dev/null @@ -1,165 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {Backend, getRuns, getTags, RunToTag} from '../backend' -import {BackendBehavior} from '../behavior' - -declare function fixture(id: string): void; - -window.addEventListener('WebComponentsReady', function() { - Polymer({ - is: 'test-element', - behaviors: [BackendBehavior], - frontendReload: function() { - // no-op - }, - }); -}); - -describe('data-behavior', function() { - let testElement; - let resolve; - let reject; - const fakeBackend = { - scalarTags() { - return new Promise((_resolve, _reject) => { - resolve = (x) => _resolve(x); - reject = (x) => _reject(x); - }); - }, - scalar(x) { - return this; - }, - }; - beforeEach(function() { - testElement = fixture('testElementFixture'); - testElement.autoLoad = false; - testElement.backend = fakeBackend; - testElement.dataType = 'scalar'; - }); - - it('load states work as expected', function(done) { - chai.assert.equal(testElement.loadState, 'noload'); - var reloaded = testElement.reload(); - chai.assert.equal(testElement.loadState, 'pending'); - resolve(); - reloaded - .then(function() { - chai.assert.equal(testElement.loadState, 'loaded'); - var reloaded2 = testElement.reload(); - chai.assert.equal(testElement.loadState, 'pending'); - reject(); - return reloaded2; - }) - .then(function() { - chai.assert.equal(testElement.loadState, 'failure'); - done(); - }); - }); - - it('data provider set appropriately', function() { - chai.assert.deepEqual(testElement.dataProvider(), testElement.backend); - }); - - it('loads data as expected', function(done) { - var r2t: RunToTag = { - run1: ['foo', 'bar', 'zod'], - run2: ['zoink', 'zow'], - run3: ['.'], - }; - var tags = getTags(r2t); - var runs = getRuns(r2t); - testElement.backend = fakeBackend; - testElement.dataType = 'scalar'; - testElement.reload().then(function(x) { - chai.assert.deepEqual(testElement.run2tag, r2t); - chai.assert.deepEqual(testElement.runs, runs); - chai.assert.deepEqual(testElement.tags, tags); - done(); - }); - resolve(r2t); - }); - - it('errors thrown on bad data types', function() { - testElement.backend = undefined; - chai.assert.throws(function() { - testElement.dataType = 'foo'; - }); - testElement.dataType = 'scalar'; - testElement.dataType = 'graph'; - testElement.dataType = 'histogram'; - }); - - it('dataNotFound flag works', function(done) { - chai.assert.isFalse(testElement.dataNotFound, 'initially false'); - var next = testElement.reload(); - chai.assert.isFalse(testElement.dataNotFound, 'still false while pending'); - resolve({foo: [], bar: []}); - next.then(() => { - chai.assert.isTrue(testElement.dataNotFound, 'true on empty data'); - var last = testElement.reload(); - chai.assert.isTrue(testElement.dataNotFound, 'still true while pending'); - resolve({foo: ['bar'], bar: ['zod']}); - last.then(() => { - chai.assert.isFalse( - testElement.dataNotFound, 'false now that we have data'); - done(); - }); - }); - }); - - it('reloads as soon as setup, if autoReload is true', function(done) { - var r2t = {foo: [], bar: []}; - var fakeBackend = { - scalarTags: () => Promise.resolve(r2t), - scalar: () => null, - }; - testElement = fixture('testElementFixture'); - testElement.dataType = 'scalar'; - testElement.backend = fakeBackend; - setTimeout(() => { - chai.assert.equal(testElement.run2tag, r2t); - done(); - }); - }); - - it('doesn\'t mutate props if backend returns same data', function(done) { - var r2t_1 = {foo: ['1', '2'], bar: ['3', '4']}; - var r2t_2 = {foo: ['1', '2'], bar: ['3', '4']}; - var fakeBackend = { - scalarTags: () => Promise.resolve(r2t_1), - scalar: () => null, - }; - testElement.backend = fakeBackend; - testElement.reload().then(() => { - fakeBackend.scalarTags = () => Promise.resolve(r2t_2); - var tags = testElement.tags; - testElement.reload().then(() => { - // shallow equality ensures it wasn't recomputed - chai.assert.equal(tags, testElement.tags, 'tags was not recomputed'); - done(); - }); - }); - }); - - // TODO(dandelion): Fix this test. - it('reload calls frontendReload', function(done) { - testElement.frontendReload = function() { - done(); - }; - testElement.reload(); - }); - -}); diff --git a/tensorflow/tensorboard/components/tf_backend/test/data/audio_run_run1_tag_audio1.json b/tensorflow/tensorboard/components/tf_backend/test/data/audio_run_run1_tag_audio1.json deleted file mode 100644 index 21a00f198d6..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/test/data/audio_run_run1_tag_audio1.json +++ /dev/null @@ -1 +0,0 @@ -[{"wall_time": 0, "step": 0, "query": "index=0&tag=audio1&run=run1", "content_type": "audio/wav"}] diff --git a/tensorflow/tensorboard/components/tf_backend/test/data/compressedHistograms_run_run1_tag_histo1.json b/tensorflow/tensorboard/components/tf_backend/test/data/compressedHistograms_run_run1_tag_histo1.json deleted file mode 100644 index 8b4c088392d..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/test/data/compressedHistograms_run_run1_tag_histo1.json +++ /dev/null @@ -1 +0,0 @@ -[[0, 0, [[0, -2.3150592308536755], [668, -2.0967547155036605], [1587, -1.4326244423655616], [3085, -0.8871306575801902], [5000, -0.09312398815580714], [6915, 0.2584093405812282], [8413, 0.8895470642005087], [9332, 1.3198979614453679], [10000, 1.6793308878855118]]], [100.0, 10, [[0, -1.3417572789138936], [668, -1.183563374619141], [1587, -0.48920418783271574], [3085, 0.29326906896076954], [5000, 0.56953784145381], [6915, 0.8684655583499333], [8413, 1.4133127368907181], [9332, 1.906140650457873], [10000, 2.135771998171255]]], [200.0, 20, [[0, -1.5066917525035333], [668, -1.3910909571770793], [1587, -0.902737218885874], [3085, -0.3807791904765027], [5000, 0.38900200905253046], [6915, 0.8209734209339482], [8413, 1.302385856695965], [9332, 1.9324626053521639], [10000, 2.957505317875451]]], [300.0, 30, [[0, -0.5430457051469562], [668, -0.4626161834245273], [1587, 0.21573949543027715], [3085, 0.37353741100174215], [5000, 0.6891407881591103], [6915, 1.0927156232630852], [8413, 1.2745337159550916], [9332, 1.4321116832891605], [10000, 2.1913774993059034]]], [400.0, 40, [[0, -0.3584790755077172], [668, -0.33301611509753215], [1587, -0.1089466072951948], [3085, 0.5792199847585249], [5000, 1.220854943811942], [6915, 1.759829438421432], [8413, 2.3072559906741614], [9332, 2.753036118353921], [10000, 3.0267252195784047]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_backend/test/data/example.json b/tensorflow/tensorboard/components/tf_backend/test/data/example.json deleted file mode 100644 index 8adc6fb896a..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/test/data/example.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "foo": 3, - "bar": "zoidberg" -} diff --git a/tensorflow/tensorboard/components/tf_backend/test/data/histograms_run_run1_tag_histo1.json b/tensorflow/tensorboard/components/tf_backend/test/data/histograms_run_run1_tag_histo1.json deleted file mode 100644 index a5600a356e8..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/test/data/histograms_run_run1_tag_histo1.json +++ /dev/null @@ -1 +0,0 @@ -[[400.0, 40, [-0.3584790755077172, 3.0267252195784047, 20.0, 24.012225532303315, 48.29045006426564, [-0.35363819004775493, -0.29226296698161564, -0.19961953895336082, 0.3214892636797772, 0.5177616740489182, 0.56953784145381, 0.6264916255991911, 0.7580548669750213, 0.8338603536725235, 1.220854943811942, 1.3429404381931362, 1.47723448201245, 1.624957930213695, 1.7874537232350647, 1.9661990955585713, 2.379100905625872, 2.6170109961884593, 3.1665833053880363], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 2.0]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_backend/test/data/images_run_run1_tag_im1.json b/tensorflow/tensorboard/components/tf_backend/test/data/images_run_run1_tag_im1.json deleted file mode 100644 index fd2a96b62fe..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/test/data/images_run_run1_tag_im1.json +++ /dev/null @@ -1 +0,0 @@ -[{"wall_time": 0, "step": 0, "query": "index=0&tag=im1&run=run1", "width": 1, "height": 1}] diff --git a/tensorflow/tensorboard/components/tf_backend/test/data/individualImage_index_0_tag_im1_run_run1.png b/tensorflow/tensorboard/components/tf_backend/test/data/individualImage_index_0_tag_im1_run_run1.png deleted file mode 100644 index f191b280ce9..00000000000 Binary files a/tensorflow/tensorboard/components/tf_backend/test/data/individualImage_index_0_tag_im1_run_run1.png and /dev/null differ diff --git a/tensorflow/tensorboard/components/tf_backend/test/data/run_metadata_run_step99_tag_train.pbtxt b/tensorflow/tensorboard/components/tf_backend/test/data/run_metadata_run_step99_tag_train.pbtxt deleted file mode 100644 index 07ce4fad539..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/test/data/run_metadata_run_step99_tag_train.pbtxt +++ /dev/null @@ -1,17 +0,0 @@ -step_stats { - dev_stats { - device: "/job:localhost/replica:0/task:0/cpu:0" - node_stats { - node_name: "_SOURCE" - all_start_micros: 1459365298611334 - op_start_rel_micros: 29 - op_end_rel_micros: 30 - all_end_rel_micros: 52 - memory { - allocator_name: "cpu" - } - timeline_label: "_SOURCE = NoOp()" - scheduled_micros: 1459365298611291 - } - } -} diff --git a/tensorflow/tensorboard/components/tf_backend/test/data/runs.json b/tensorflow/tensorboard/components/tf_backend/test/data/runs.json deleted file mode 100644 index 413ddb9ab34..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/test/data/runs.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "run1": { - "images": [ - "im1" - ], - "audio": [ - "audio1" - ], - "scalars": [ - "cross_entropy (1)" - ], - "histograms": [ - "histo1" - ], - "compressedHistograms": [ - "histo1" - ], - "run_metadata": [ - "step99" - ], - "graph": false - }, - "fake_run_no_data": { - "images": ["im1", "im2"], - "audio": ["audio1", "audio2"], - "scalars": ["scalar2"], - "histograms": ["histo1"], - "compressedHistograms": ["histo1"], - "run_metadata": ["step99"], - "graph": true - } -} diff --git a/tensorflow/tensorboard/components/tf_backend/test/data/scalars.json b/tensorflow/tensorboard/components/tf_backend/test/data/scalars.json deleted file mode 100644 index bc9d3353d5f..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/test/data/scalars.json +++ /dev/null @@ -1 +0,0 @@ -{"run1": {"cross_entropy (1)": [[0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]]}} \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_backend/test/data/scalars_run_run1_tag_cross_entropy__281_29.json b/tensorflow/tensorboard/components/tf_backend/test/data/scalars_run_run1_tag_cross_entropy__281_29.json deleted file mode 100644 index 97b0062f0f0..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/test/data/scalars_run_run1_tag_cross_entropy__281_29.json +++ /dev/null @@ -1 +0,0 @@ -[[0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_backend/test/requestManagerTests.ts b/tensorflow/tensorboard/components/tf_backend/test/requestManagerTests.ts deleted file mode 100644 index 3800e6e4021..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/test/requestManagerTests.ts +++ /dev/null @@ -1,294 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {RequestManager, RequestNetworkError} from '../requestManager'; - -interface MockRequest { - resolve: Function; - reject: Function; - id: number; - url: string; -} - -class MockedRequestManager extends RequestManager { - private resolvers: Function[]; - private rejectors: Function[]; - public requestsDispatched: number; - constructor(maxRequests = 10, maxRetries = 3) { - super(maxRequests, maxRetries); - this.resolvers = []; - this.rejectors = []; - this.requestsDispatched = 0; - } - protected _promiseFromUrl(url) { - return new Promise((resolve, reject) => { - const mockJSON = { - ok: true, - json() { - return url; - }, - url, - status: 200, - }; - const mockFailedRequest: any = { - ok: false, - url, - status: 502, - }; - const mockFailure = new RequestNetworkError(mockFailedRequest, url); - this.resolvers.push(() => { - resolve(mockJSON); - }); - this.rejectors.push(() => { - reject(mockFailure); - }); - this.requestsDispatched++; - }); - } - public resolveFakeRequest() { - this.resolvers.pop()(); - } - public rejectFakeRequest() { - this.rejectors.pop()(); - } - public dispatchAndResolve() { - // Wait for at least one request to be dispatched, then resolve it. - this.waitForDispatch(1).then(() => this.resolveFakeRequest()); - } - public waitForDispatch(num) { - return waitForCondition(() => { - return this.requestsDispatched >= num; - }); - } -} - -/** Create a promise that returns when *check* returns true. - * May cause a test timeout if check never becomes true. - */ - -function waitForCondition(check: () => boolean): Promise { - return new Promise((resolve, reject) => { - const go = () => { - if (check()) { - resolve(); - } - setTimeout(go, 2); - }; - go(); - }); -} - -describe('backend', () => { - describe('request manager', () => { - it('request loads JSON properly', (done) => { - const rm = new RequestManager(); - const promise = rm.request('data/example.json'); - promise.then( - (response) => { - chai.assert.deepEqual(response, {foo: 3, bar: 'zoidberg'}); - done(); - }, - (reject) => { - throw new Error(reject); - }); - }); - - it('rejects on bad url', (done) => { - const rm = new RequestManager(5, 0); - const badUrl = '_bad_url_which_doesnt_exist.json'; - const promise = rm.request(badUrl); - promise.then( - (success) => { - done(new Error('the promise should have rejected')); - }, - (reject: RequestNetworkError) => { - chai.assert.include(reject.message, '404'); - chai.assert.include(reject.message, badUrl); - chai.assert.equal(reject.req.status, 404); - done(); - }); - }); - - it('can retry if requests fail', (done) => { - const rm = new MockedRequestManager(3, 5); - const r = rm.request('foo'); - rm.waitForDispatch(1) - .then(() => { - rm.rejectFakeRequest(); - return rm.waitForDispatch(2); - }) - .then(() => rm.resolveFakeRequest()); - r.then((success) => done()); - }); - - it('retries at most maxRetries times', (done) => { - const MAX_RETRIES = 2; - const rm = new MockedRequestManager(3, MAX_RETRIES); - const r = rm.request('foo'); - rm.waitForDispatch(1) - .then(() => { - rm.rejectFakeRequest(); - return rm.waitForDispatch(2); - }) - .then(() => { - rm.rejectFakeRequest(); - return rm.waitForDispatch(3); - }) - .then(() => { - rm.rejectFakeRequest(); - }); - - r.then( - (success) => done(new Error('The request should have failed')), - (failure) => done()); - }); - - it('requestManager only sends maxRequests requests at a time', (done) => { - const rm = new MockedRequestManager(3); - const r0 = rm.request('1'); - const r1 = rm.request('2'); - const r2 = rm.request('3'); - const r3 = rm.request('4'); - chai.assert.equal(rm.activeRequests(), 3, 'three requests are active'); - chai.assert.equal( - rm.outstandingRequests(), 4, 'four requests are pending'); - rm.waitForDispatch(3) - .then(() => { - chai.assert.equal( - rm.activeRequests(), 3, 'three requests are still active (1)'); - chai.assert.equal( - rm.requestsDispatched, 3, 'three requests were dispatched'); - rm.resolveFakeRequest(); - return rm.waitForDispatch(4); - }) - .then(() => { - chai.assert.equal( - rm.activeRequests(), 3, 'three requests are still active (2)'); - chai.assert.equal( - rm.requestsDispatched, 4, 'four requests were dispatched'); - chai.assert.equal( - rm.outstandingRequests(), 3, 'three requests are pending'); - rm.resolveFakeRequest(); - rm.resolveFakeRequest(); - rm.resolveFakeRequest(); - return r3; - }) - .then(() => { - chai.assert.equal(rm.activeRequests(), 0, 'all requests finished'); - chai.assert.equal( - rm.outstandingRequests(), 0, 'no requests pending'); - done(); - }); - }); - - it('queue continues after failures', (done) => { - const rm = new MockedRequestManager(1, 0); - const r0 = rm.request('1'); - const r1 = rm.request('2'); - rm.waitForDispatch(1).then(() => { - rm.rejectFakeRequest(); - }); - - r0.then( - (success) => done(new Error('r0 should have failed')), - (failure) => 'unused_argument') - .then(() => rm.resolveFakeRequest()); - - // When the first request rejects, it should decrement nActiveRequests - // and then launch remaining requests in queue (i.e. this one) - r1.then((success) => done(), (failure) => done(new Error(failure))); - }); - - it('queue is LIFO', (done) => { - /* This test is a bit tricky. - * We want to verify that the RequestManager queue has LIFO semantics. - * So we construct three requests off the bat: A, B, C. - * So LIFO semantics ensure these will resolve in order A, C, B. - * (Because the A request launches immediately when we create it, it's - * not in queue) - * Then after resolving A, C moves out of queue, and we create X. - * So expected final order is A, C, X, B. - * We verify this with an external var that counts how many requests were - * resolved. - */ - const rm = new MockedRequestManager(1); - let nResolved = 0; - function assertResolutionOrder(expectedSpotInSequence) { - return () => { - nResolved++; - chai.assert.equal(expectedSpotInSequence, nResolved); - }; - } - - function launchThirdRequest() { - rm.request('started late but goes third') - .then(assertResolutionOrder(3)) - .then(() => rm.dispatchAndResolve()); - } - - rm.request('first') - .then( - assertResolutionOrder(1)) // Assert that this one resolved first - .then(launchThirdRequest) - .then(() => rm.dispatchAndResolve()); // then trigger the next one - - rm.request('this one goes fourth') // created second, will go last - .then(assertResolutionOrder( - 4)) // assert it was the fourth to get resolved - .then(done); // finish the test - - rm.request('second') - .then(assertResolutionOrder(2)) - .then(() => rm.dispatchAndResolve()); - - rm.dispatchAndResolve(); - }); - - it('requestManager can clear queue', (done) => { - const rm = new MockedRequestManager(1); - let requestsResolved = 0; - let requestsRejected = 0; - const success = () => requestsResolved++; - const failure = (err) => { - chai.assert.equal(err.name, 'RequestCancellationError'); - requestsRejected++; - }; - const finishTheTest = () => { - chai.assert.equal(rm.activeRequests(), 0, 'no requests still active'); - chai.assert.equal( - rm.requestsDispatched, 1, 'only one req was ever dispatched'); - chai.assert.equal(rm.outstandingRequests(), 0, 'no pending requests'); - chai.assert.equal(requestsResolved, 1, 'one request got resolved'); - chai.assert.equal( - requestsRejected, 4, 'four were cancelled and threw errors'); - done(); - }; - rm.request('0').then(success, failure).then(finishTheTest); - rm.request('1').then(success, failure); - rm.request('2').then(success, failure); - rm.request('3').then(success, failure); - rm.request('4').then(success, failure); - chai.assert.equal(rm.activeRequests(), 1, 'one req is active'); - rm.waitForDispatch(1).then(() => { - chai.assert.equal(rm.activeRequests(), 1, 'one req is active'); - chai.assert.equal(rm.requestsDispatched, 1, 'one req was dispatched'); - chai.assert.equal(rm.outstandingRequests(), 5, 'five reqs outstanding'); - rm.clearQueue(); - rm.resolveFakeRequest(); - // resolving the first request triggers finishTheTest - }); - }); - }); -}); diff --git a/tensorflow/tensorboard/components/tf_backend/test/tests.html b/tensorflow/tensorboard/components/tf_backend/test/tests.html deleted file mode 100644 index 58cb89a30b6..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/test/tests.html +++ /dev/null @@ -1,37 +0,0 @@ - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_backend/tf-backend.html b/tensorflow/tensorboard/components/tf_backend/tf-backend.html deleted file mode 100644 index c2a44b3b63f..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/tf-backend.html +++ /dev/null @@ -1,28 +0,0 @@ - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_backend/urlPathHelpers.ts b/tensorflow/tensorboard/components/tf_backend/urlPathHelpers.ts deleted file mode 100644 index 62519dac5ca..00000000000 --- a/tensorflow/tensorboard/components/tf_backend/urlPathHelpers.ts +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -export const BAD_CHARACTERS = '#%&{}\\/<>*? $!\'":@+`|=() '; -/** Cleanup a url so that it can be loaded from a filesystem. */ -export function demoify(s) { - // for consistency with python's urllib.urlencode - s = s.replace(new RegExp('%20', 'g'), '+'); - for (let i = 0; i < BAD_CHARACTERS.length; i++) { - const c = BAD_CHARACTERS[i]; - s = s.replace(new RegExp('\\' + c, 'g'), '_'); - } - return s; -} - -export function queryEncoder(params?: any): string { - // It's important that the keys be sorted, so we always grab the right file - // if we are talking to the backend generated by serialze_tensorboard.py - if (params == null) { - return ''; - } - const components = _.keys(params) - .sort() - .filter((k) => params[k] !== undefined) - .map((k) => k + '=' + encodeURIComponent(params[k])); - const result = components.length ? '?' + components.join('&') : ''; - // Replace parens for consistency with urllib.urlencode - return result.replace(/\(/g, '%28').replace(/\)/g, '%29'); -} diff --git a/tensorflow/tensorboard/components/tf_color_scale/BUILD b/tensorflow/tensorboard/components/tf_color_scale/BUILD deleted file mode 100644 index 730ab37d6f7..00000000000 --- a/tensorflow/tensorboard/components/tf_color_scale/BUILD +++ /dev/null @@ -1,39 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_color_scale", - srcs = [ - "colorScale.ts", - "palettes.ts", - "tf-color-scale.html", - ], - path = "/tf-color-scale", - deps = [ - "//tensorflow/tensorboard/components/tf_imports:d3", - "//tensorflow/tensorboard/components/tf_imports:polymer", - ], -) - -ts_web_library( - name = "demo", - srcs = ["index.html"], - path = "/tf-color-scale", - deps = [ - ":tf_color_scale", - "//tensorflow/tensorboard/components/tf_imports:d3", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - "@org_polymer_iron_demo_helpers", - "@org_polymer_paper_button", - "@org_polymer_paper_styles", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_color_scale/colorScale.ts b/tensorflow/tensorboard/components/tf_color_scale/colorScale.ts deleted file mode 100644 index e20a65cdd84..00000000000 --- a/tensorflow/tensorboard/components/tf_color_scale/colorScale.ts +++ /dev/null @@ -1,89 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// Example usage: -// runs = ["train", "test", "test1", "test2"] -// ccs = new ColorScale(); -// ccs.domain(runs); -// ccs.getColor("train"); -// ccs.getColor("test1"); - -import {palettes} from './palettes'; - -export class ColorScale { - private identifiers = d3.map(); - - /** - * Creates a color scale with optional custom palette. - * @param {Array} [palette=palettes.googleColorBlind] - The color - * palette you want as an Array of hex strings. - */ - constructor( - private readonly palette: string[] = palettes.googleColorBlindAssist) {} - - /** - * Set the domain of strings. - * @param {Array} strings - An array of possible strings to use as the - * domain for your scale. - */ - public domain(strings: string[]): this { - this.identifiers = d3.map(); - - // TODO(wchargin): Remove this call to `sort` once we have only a - // singleton ColorScale, linked directly to the RunsStore, which - // will always give sorted output. - strings = strings.slice(); - strings.sort(); - - strings.forEach((s, i) => { - this.identifiers.set(s, this.palette[i % this.palette.length]); - }); - return this; - } - - /** - * Use the color scale to transform an element in the domain into a color. - * @param {string} The input string to map to a color. - * @return {string} The color corresponding to that input string. - * @throws Will error if input string is not in the scale's domain. - */ - public scale(s: string): string { - if (!this.identifiers.has(s)) { - throw new Error('String was not in the domain.'); - } - return this.identifiers.get(s) as string; - } -} - -Polymer({ - is: 'tf-color-scale', - properties: { - runs: { - type: Array, - }, - outColorScale: { - type: Object, - readOnly: true, - notify: true, - value() { - return new ColorScale(); - }, - }, - }, - observers: ['updateColorScale(runs.*)'], - updateColorScale(runsChange) { - this.outColorScale.domain(this.runs); - }, -}); diff --git a/tensorflow/tensorboard/components/tf_color_scale/index.html b/tensorflow/tensorboard/components/tf_color_scale/index.html deleted file mode 100644 index 81dfab098c6..00000000000 --- a/tensorflow/tensorboard/components/tf_color_scale/index.html +++ /dev/null @@ -1,94 +0,0 @@ - - - - - -tf-color-scale demo - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_color_scale/palettes.ts b/tensorflow/tensorboard/components/tf_color_scale/palettes.ts deleted file mode 100644 index ce42a115458..00000000000 --- a/tensorflow/tensorboard/components/tf_color_scale/palettes.ts +++ /dev/null @@ -1,76 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -export const palettes = { - googleStandard: [ - '#db4437', // google red 500 - '#ff7043', // deep orange 400 - '#f4b400', // google yellow 500 - '#0f9d58', // google green 500 - '#00796b', // teal 700 - '#00acc1', // cyan 600 - '#4285f4', // google blue 500 - '#5c6bc0', // indigo 400 - '#ab47bc' // purple 400 - ], - googleCool: [ - '#9e9d24', // lime 800 - '#0f9d58', // google green 500 - '#00796b', // teal 700 - '#00acc1', // cyan 600 - '#4285f4', // google blue 500 - '#5c6bc0', // indigo 400 - '#607d8b' // blue gray 500 - ], - googleWarm: [ - '#795548', // brown 500 - '#ab47bc', // purple 400 - '#f06292', // pink 300 - '#c2185b', // pink 700 - '#db4437', // google red 500 - '#ff7043', // deep orange 400 - '#f4b400' // google yellow 700 - ], - googleColorBlindAssist: [ - '#ff7043', // orange - '#00ACC1', // dark cyan - '#AB47BC', // bright purple - '#2A56C6', // dark blue - '#0b8043', // green - '#F7CB4D', // yellow - '#c0ca33', // lime - '#5e35b1', // purple - '#A52714', // red - ], - // These palettes try to be better for color differentiation. - // https://personal.sron.nl/~pault/ - colorBlindAssist1: - ['#4477aa', '#44aaaa', '#aaaa44', '#aa7744', '#aa4455', '#aa4488'], - colorBlindAssist2: [ - '#88ccee', '#44aa99', '#117733', '#999933', '#ddcc77', '#cc6677', '#882255', - '#aa4499' - ], - colorBlindAssist3: [ - '#332288', '#6699cc', '#88ccee', '#44aa99', '#117733', '#999933', '#ddcc77', - '#cc6677', '#aa4466', '#882255', '#661100', '#aa4499' - ], - // based on this palette: http://mkweb.bcgsc.ca/biovis2012/ - colorBlindAssist4: [ - '#FF6DB6', '#920000', '#924900', '#DBD100', '#24FF24', '#006DDB', '#490092' - ], - mldash: [ - '#E47EAD', '#F4640D', '#FAA300', '#F5E636', '#00A077', '#0077B8', '#00B7ED' - ] -}; diff --git a/tensorflow/tensorboard/components/tf_color_scale/test/BUILD b/tensorflow/tensorboard/components/tf_color_scale/test/BUILD deleted file mode 100644 index 331783f3c76..00000000000 --- a/tensorflow/tensorboard/components/tf_color_scale/test/BUILD +++ /dev/null @@ -1,30 +0,0 @@ -package( - default_testonly = True, - default_visibility = ["//tensorflow/tensorboard:internal"], -) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "test", - srcs = [ - "colorScaleTests.ts", - "tests.html", - ], - path = "/tf-color-scale/test", - deps = [ - "//tensorflow/tensorboard/components/tf_color_scale", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "//tensorflow/tensorboard/components/tf_imports:web_component_tester", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - ], -) - -filegroup( - name = "all_files", - testonly = 0, - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_color_scale/test/colorScaleTests.ts b/tensorflow/tensorboard/components/tf_color_scale/test/colorScaleTests.ts deleted file mode 100644 index 78824a772c3..00000000000 --- a/tensorflow/tensorboard/components/tf_color_scale/test/colorScaleTests.ts +++ /dev/null @@ -1,48 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -let assert = chai.assert; - -import {ColorScale} from '../colorScale' - -describe('ColorScale', function() { - let ccs: ColorScale; - - beforeEach(function() { - ccs = new ColorScale(); - }); - - it('Returns consistent colors', function() { - ccs.domain(['train', 'eval', 'test']); - let trainColor = ccs.scale('train'); - let trainColor2 = ccs.scale('train'); - assert.equal(trainColor, trainColor2); - }); - - it('Returns consistent colors after new domain', function() { - ccs.domain(['train', 'eval']); - let trainColor = ccs.scale('train'); - ccs.domain(['train', 'eval', 'test']); - let trainColor2 = ccs.scale('train'); - assert.equal(trainColor, trainColor2); - }); - - it('Throws an error if string is not in the domain', function() { - ccs.domain(['red', 'yellow', 'green']); - assert.throws(function() { - ccs.scale('not in domain'); - }, 'String was not in the domain.'); - }); -}); diff --git a/tensorflow/tensorboard/components/tf_color_scale/test/tests.html b/tensorflow/tensorboard/components/tf_color_scale/test/tests.html deleted file mode 100644 index 59c802d02bf..00000000000 --- a/tensorflow/tensorboard/components/tf_color_scale/test/tests.html +++ /dev/null @@ -1,24 +0,0 @@ - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/BUILD b/tensorflow/tensorboard/components/tf_dashboard_common/BUILD deleted file mode 100644 index 7471da3144a..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/BUILD +++ /dev/null @@ -1,107 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_dashboard_common", - srcs = [ - "dashboard-behavior.ts", - "dashboard-style.html", - "reload-behavior.ts", - "run-color-style.html", - "scrollbar-style.html", - "tensorboard-color.html", - "tf-categorizer.html", - "tf-categorizer.ts", - "tf-chart-scaffold.html", - "tf-collapsable-pane.html", - "tf-dashboard.html", - "tf-dashboard-layout.html", - "tf-downloader.html", - "tf-multi-checkbox.html", - "tf-multi-checkbox.ts", - "tf-no-data-warning.html", - "tf-option-selector.html", - "tf-panes-helper.html", - "tf-regex-group.html", - "tf-regex-group.ts", - "tf-run-selector.html", - "tf-sidebar-helper.html", - ], - path = "/tf-dashboard-common", - deps = [ - "//tensorflow/tensorboard/components/tf_imports:d3", - "//tensorflow/tensorboard/components/tf_imports:lodash", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "//tensorflow/tensorboard/components/tf_storage", - "//tensorflow/tensorboard/components/vz_sorting", - "@org_polymer_iron_ajax", - "@org_polymer_iron_collapse", - "@org_polymer_iron_icons", - "@org_polymer_paper_button", - "@org_polymer_paper_checkbox", - "@org_polymer_paper_dialog", - "@org_polymer_paper_dropdown_menu", - "@org_polymer_paper_icon_button", - "@org_polymer_paper_input", - "@org_polymer_paper_item", - "@org_polymer_paper_menu", - "@org_polymer_paper_slider", - "@org_polymer_paper_spinner", - "@org_polymer_paper_styles", - "@org_polymer_paper_toggle_button", - ], -) - -ts_web_library( - name = "demo", - srcs = [ - "tf-categorizer-demo.html", - "tf-collapsable-pane-demo.html", - "tf-multi-checkbox-demo.html", - "tf-regex-group-demo.html", - ], - path = "/tf-dashboard-common", - deps = [ - ":tf_dashboard_common", - "//tensorflow/tensorboard/components/tf_color_scale", - "@org_polymer_iron_demo_helpers", - "@org_polymer_paper_styles", - ], -) - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [":tf_dashboard_common"], - destdir = "tf-dashboard-common", - deps = [ - "//tensorflow/tensorboard/components/tf_imports_google:lib", - "//tensorflow/tensorboard/components/tf_storage:legacy", - "//tensorflow/tensorboard/components/vz_sorting:legacy", - "//third_party/javascript/polymer/v1/iron-ajax:lib", - "//third_party/javascript/polymer/v1/iron-collapse:lib", - "//third_party/javascript/polymer/v1/iron-icons:lib", - "//third_party/javascript/polymer/v1/paper-button:lib", - "//third_party/javascript/polymer/v1/paper-checkbox:lib", - "//third_party/javascript/polymer/v1/paper-dialog:lib", - "//third_party/javascript/polymer/v1/paper-dropdown-menu:lib", - "//third_party/javascript/polymer/v1/paper-icon-button:lib", - "//third_party/javascript/polymer/v1/paper-input:lib", - "//third_party/javascript/polymer/v1/paper-item:lib", - "//third_party/javascript/polymer/v1/paper-menu:lib", - "//third_party/javascript/polymer/v1/paper-slider:lib", - "//third_party/javascript/polymer/v1/paper-spinner:lib", - "//third_party/javascript/polymer/v1/paper-styles:lib", - "//third_party/javascript/polymer/v1/paper-toggle-button:lib", - "//third_party/javascript/polymer/v1/polymer:lib", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/dashboard-behavior.ts b/tensorflow/tensorboard/components/tf_dashboard_common/dashboard-behavior.ts deleted file mode 100644 index aa063c74220..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/dashboard-behavior.ts +++ /dev/null @@ -1,40 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/** - * A behavior that TensorBoard dashboards must implement. This behavior serves - * the purpose of an interface. - * - * @polymerBehavior - */ -export function DashboardBehavior(dashboardName) { - return { - properties: { - name: { - type: String, - value: dashboardName, - readOnly: true, - }, - }, - // This method is called when the dashboard reloads, either when the - // dashboard is first visited, periodically reloaded, or manually reloaded - // via the user clicking the button. Note that dashboard custom elements - // that use TF.Dashboard.ReloadBehavior already implement a reload method. - reload() { - throw Error( - 'The ' + dashboardName + ' dashboard does not implement reload.'); - }, - }; -} diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/dashboard-style.html b/tensorflow/tensorboard/components/tf_dashboard_common/dashboard-style.html deleted file mode 100644 index 6629e5bfc22..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/dashboard-style.html +++ /dev/null @@ -1,53 +0,0 @@ - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/reload-behavior.ts b/tensorflow/tensorboard/components/tf_dashboard_common/reload-behavior.ts deleted file mode 100644 index 61fe0c07812..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/reload-behavior.ts +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/** - * ReloadBehavior: A simple behavior for dashboards where the - * frontendReload() function should find every child element with a - * given tag name (e.g. "tf-line-chart" or "tf-image-loader") - * and call a `reload` method on that child. - * May later extend it so it has more sophisticated logic, e.g. reloading - * only tags that are in view. - * - * @polymerBehavior - */ -export function ReloadBehavior(tagName) { - return { - properties: { - reloadTag: { - type: String, - value: tagName, - }, - }, - frontendReload: function() { - var elements = this.getElementsByTagName(this.reloadTag); - Array.prototype.forEach.call(elements, function(x) { - x.reload(); - }); - }, - }; -} diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/run-color-style.html b/tensorflow/tensorboard/components/tf_dashboard_common/run-color-style.html deleted file mode 100644 index b15861694f5..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/run-color-style.html +++ /dev/null @@ -1,79 +0,0 @@ - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/scrollbar-style.html b/tensorflow/tensorboard/components/tf_dashboard_common/scrollbar-style.html deleted file mode 100644 index bfd61f66191..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/scrollbar-style.html +++ /dev/null @@ -1,46 +0,0 @@ - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/test/BUILD b/tensorflow/tensorboard/components/tf_dashboard_common/test/BUILD deleted file mode 100644 index ef7a1562c65..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/test/BUILD +++ /dev/null @@ -1,30 +0,0 @@ -package( - default_testonly = True, - default_visibility = ["//tensorflow/tensorboard:internal"], -) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "test", - srcs = [ - "tests.html", - "tf-categorizer-tests.ts", - ], - path = "/tf-dashboard-common/test", - deps = [ - "//tensorflow/tensorboard/components/tf_dashboard_common", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "//tensorflow/tensorboard/components/tf_imports:web_component_tester", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - ], -) - -filegroup( - name = "all_files", - testonly = 0, - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/test/tests.html b/tensorflow/tensorboard/components/tf_dashboard_common/test/tests.html deleted file mode 100644 index c9ad14730f0..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/test/tests.html +++ /dev/null @@ -1,24 +0,0 @@ - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/test/tf-categorizer-tests.ts b/tensorflow/tensorboard/components/tf_dashboard_common/test/tf-categorizer-tests.ts deleted file mode 100644 index a786f39b4fb..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/test/tf-categorizer-tests.ts +++ /dev/null @@ -1,144 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import * as cat from '../tf-categorizer'; - -let assert = chai.assert; - -describe('categorizer', () => { - describe('topLevelNamespaceCategorizer', () => { - it('returns empty array on empty tags', () => { - assert.lengthOf(cat.topLevelNamespaceCategorizer([]), 0); - }); - - it('handles a simple case', () => { - let simple = [ - 'foo1/bar', 'foo1/zod', 'foo2/bar', 'foo2/zod', 'gosh/lod/mar', - 'gosh/lod/ned' - ]; - let expected = [ - {name: 'foo1', tags: ['foo1/bar', 'foo1/zod']}, - {name: 'foo2', tags: ['foo2/bar', 'foo2/zod']}, - {name: 'gosh', tags: ['gosh/lod/mar', 'gosh/lod/ned']}, - ]; - assert.deepEqual(cat.topLevelNamespaceCategorizer(simple), expected); - }); - - it('orders the categories', () => { - let test = ['e', 'f', 'g', 'a', 'b', 'c']; - let expected = [ - {name: 'a', tags: ['a']}, - {name: 'b', tags: ['b']}, - {name: 'c', tags: ['c']}, - {name: 'e', tags: ['e']}, - {name: 'f', tags: ['f']}, - {name: 'g', tags: ['g']}, - ]; - assert.deepEqual(cat.topLevelNamespaceCategorizer(test), expected); - }); - - it('handles cases where category names overlap node names', () => { - let test = ['a', 'a/a', 'a/b', 'a/c', 'b', 'b/a']; - const actual = cat.topLevelNamespaceCategorizer(test); - let expected = [ - {name: 'a', tags: ['a', 'a/a', 'a/b', 'a/c']}, - {name: 'b', tags: ['b', 'b/a']}, - ]; - assert.deepEqual(actual, expected); - }); - - it('handles singleton case', () => { - assert.deepEqual( - cat.topLevelNamespaceCategorizer(['a']), [{name: 'a', tags: ['a']}]); - }); - }); - - describe('customCategorizer', () => { - function noFallbackCategorizer(tags: string[]): cat.Category[] { - return []; - } - - function testCategorizer( - defs: string[], fallback: cat.Categorizer, - tags: string[]): cat.Category[] { - const catDefs = defs.map(cat.defineCategory); - return cat._categorizer(catDefs, fallback)(tags); - } - - it('categorizes by regular expression', () => { - let defs = ['foo..', 'bar..']; - let tags = ['fooab', 'fooxa', 'barts', 'barms']; - const actual = testCategorizer(defs, noFallbackCategorizer, tags); - let expected = [ - {name: 'foo..', tags: ['fooab', 'fooxa']}, - {name: 'bar..', tags: ['barms', 'barts']}, - ]; - assert.deepEqual(actual, expected); - }); - - it('matches non-exclusively', () => { - let tags = ['abc', 'bar', 'zod']; - const actual = - testCategorizer(['...', 'bar'], noFallbackCategorizer, tags); - let expected = [ - {name: '...', tags: ['abc', 'bar', 'zod']}, - {name: 'bar', tags: ['bar']}, - ]; - assert.deepEqual(actual, expected); - }); - - it('creates categories for unmatched rules', () => { - const actual = - testCategorizer(['a', 'b', 'c'], noFallbackCategorizer, []); - let expected = [ - {name: 'a', tags: []}, - {name: 'b', tags: []}, - {name: 'c', tags: []}, - ]; - assert.deepEqual(actual, expected); - }); - - it('category regexs work with special characters', () => { - let defs = ['^\\w+$', '^\\d+$', '^\\/..$']; - let tags = ['foo', '3243', '/xa']; - const actual = testCategorizer(defs, noFallbackCategorizer, tags); - let expected = [ - {name: '^\\w+$', tags: ['3243', 'foo']}, - {name: '^\\d+$', tags: ['3243']}, - {name: '^\\/..$', tags: ['/xa']}, - ]; - assert.deepEqual(actual, expected); - }); - - it('category tags are sorted', () => { - let tags = ['a', 'z', 'c', 'd', 'e', 'x', 'f', 'y', 'g']; - let sorted = tags.slice().sort(); - let expected = [{name: '.*', tags: sorted}]; - const actual = testCategorizer(['.*'], noFallbackCategorizer, tags); - assert.deepEqual(actual, expected); - }); - - it('if nonexclusive: all tags passed to fallback', () => { - let passedToDefault = null; - function defaultCategorizer(tags: string[]): cat.Category[] { - passedToDefault = tags; - return []; - } - let tags = ['foo', 'bar', 'foo123']; - testCategorizer(['foo'], defaultCategorizer, tags); - assert.deepEqual(passedToDefault, tags); - }); - }); -}); diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer-demo.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer-demo.html deleted file mode 100644 index 23babaaecc4..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer-demo.html +++ /dev/null @@ -1,106 +0,0 @@ - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.html deleted file mode 100644 index f09eb03582d..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.html +++ /dev/null @@ -1,63 +0,0 @@ - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.ts b/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.ts deleted file mode 100644 index 0eaf852ff13..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-categorizer.ts +++ /dev/null @@ -1,189 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {compareTagNames} from '../vz-sorting/sorting'; - -/** - * This module contains methods that allow sorting tags into 'categories'. - * A category contains a name and a list of tags. - * The sorting strategy is defined by a 'CustomCategorization', which contains - * 'categoryDefinitions' which are regex rules used to construct a category. - * E.g. the regex rule 'xent' will create a category called 'xent' that - * contains values whose tags match the regex. - * - * After custom categories are evaluated, the tags are sorted by a hardcoded - * fallback categorizer, which may, for example, group tags into categories - * based on their top namespace. - */ - -export interface Category { - // Categories that data is sorted into - name: string; - tags: string[]; -} - -export interface CustomCategorization { - // Defines a categorization strategy - categoryDefinitions: string[]; - fallbackCategorizer: string; - /* {'TopLevelNamespaceCategorizer', - 'LegacyUnderscoreCategorizer'} */ -} - -export interface Categorizer { - // Function that generates categories - (tags: string[]): Category[]; -} - -/* Canonical TensorFlow ops are namespaced using forward slashes. - * This fallback categorizer categorizes by the top-level namespace. - */ -export var topLevelNamespaceCategorizer: Categorizer = splitCategorizer(/\//); - -export function fallbackCategorizer(s: string): Categorizer { - switch (s) { - case 'TopLevelNamespaceCategorizer': - return topLevelNamespaceCategorizer; - default: - throw new Error('Unrecognized categorization strategy: ' + s); - } -} - -/* An 'extractor' is a function that takes a tag name, and 'extracts' a - * category name. - * This function takes an extractor, and produces a categorizer. - * Currently, it is just used for the fallbackCategorizer, but we may want to - * refactor the general categorization logic to use the concept of extractors. - */ -function extractorToCategorizer(extractor: (s: string) => string): Categorizer { - return (tags: string[]): Category[] => { - if (tags.length === 0) { - return []; - } - - // Maps between top-level name and category. We use the mapping to avoid - // duplicating categories per run. - const categoryMapping: {[key: string]: Category} = {}; - - tags.forEach((t: string) => { - const topLevel = extractor(t); - if (!categoryMapping[topLevel]) { - const newCategory = { - name: topLevel, - tags: [], - }; - categoryMapping[topLevel] = newCategory; - } - - categoryMapping[topLevel].tags.push(t); - }); - - // Sort categories into alphabetical order. - const categories = - _.map(_.keys(categoryMapping).sort(), key => categoryMapping[key]); - _.forEach(categories, (category) => { - // Sort the tags within each category. - category.tags.sort(compareTagNames); - }); - return categories; - }; -} - -function splitCategorizer(r: RegExp): Categorizer { - let extractor = (t: string) => { - return t.split(r)[0]; - }; - return extractorToCategorizer(extractor); -} - -export interface CategoryDefinition { - name: string; - matches: (t: string) => boolean; -} - -export function defineCategory(ruledef: string): CategoryDefinition { - let r = new RegExp(ruledef); - let f = function(tag: string): boolean { - return r.test(tag); - }; - return {name: ruledef, matches: f}; -} - -export function _categorizer( - rules: CategoryDefinition[], fallback: Categorizer) { - return function(tags: string[]): Category[] { - let remaining: d3.Set = d3.set(tags); - let userSpecified = rules.map((def: CategoryDefinition) => { - let tags: string[] = []; - remaining.each((t: string) => { - if (def.matches(t)) { - tags.push(t); - } - }); - let cat = {name: def.name, tags: tags.sort(compareTagNames)}; - return cat; - }); - let defaultCategories = fallback(remaining.values()); - return userSpecified.concat(defaultCategories); - }; -} - -export function categorizer(s: CustomCategorization): Categorizer { - let rules = s.categoryDefinitions.map(defineCategory); - let fallback = fallbackCategorizer(s.fallbackCategorizer); - return _categorizer(rules, fallback); -}; - -Polymer({ - is: 'tf-categorizer', - properties: { - regexes: {type: Array}, - tags: {type: Array}, - categoriesAreExclusive: {type: Boolean, value: true}, - fallbackCategorizer: { - type: String, - value: 'TopLevelNamespaceCategorizer', - }, - categorizer: { - type: Object, - computed: - 'computeCategorization(regexes.*, categoriesAreExclusive, fallbackCategorizer)', - }, - categories: { - type: Array, - value: function() { - return []; - }, - notify: true, - readOnly: true - }, - }, - observers: ['recategorize(tags.*, categorizer)'], - computeCategorization: function( - regexes, categoriesAreExclusive, fallbackCategorizer) { - var categorizationStrategy = { - categoryDefinitions: regexes.base, - categoriesAreExclusive: categoriesAreExclusive, - fallbackCategorizer: fallbackCategorizer, - }; - return categorizer(categorizationStrategy); - }, - recategorize: function() { - this.debounce('tf-categorizer-recategorize', function() { - var categories = this.categorizer(this.tags); - this._setCategories(categories); - }) - }, -}); diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-chart-scaffold.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-chart-scaffold.html deleted file mode 100644 index a39fb9462ba..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-chart-scaffold.html +++ /dev/null @@ -1,152 +0,0 @@ - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-collapsable-pane-demo.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-collapsable-pane-demo.html deleted file mode 100644 index efa990b11cf..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-collapsable-pane-demo.html +++ /dev/null @@ -1,31 +0,0 @@ - - - - - - - - - - - -

This is content inside the pane.

-
- - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-collapsable-pane.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-collapsable-pane.html deleted file mode 100644 index e82540127fa..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-collapsable-pane.html +++ /dev/null @@ -1,109 +0,0 @@ - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-dashboard-layout.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-dashboard-layout.html deleted file mode 100644 index e0e8a2b52c3..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-dashboard-layout.html +++ /dev/null @@ -1,67 +0,0 @@ - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-dashboard.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-dashboard.html deleted file mode 100644 index 9e2f6b9589b..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-dashboard.html +++ /dev/null @@ -1,26 +0,0 @@ - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-downloader.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-downloader.html deleted file mode 100644 index 71914259598..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-downloader.html +++ /dev/null @@ -1,99 +0,0 @@ - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox-demo.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox-demo.html deleted file mode 100644 index d0f5aa6f27d..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox-demo.html +++ /dev/null @@ -1,176 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.html deleted file mode 100644 index fad4642963f..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.html +++ /dev/null @@ -1,160 +0,0 @@ - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.ts b/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.ts deleted file mode 100644 index 4b38d82b14e..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-multi-checkbox.ts +++ /dev/null @@ -1,205 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import * as storage from '../tf-storage/storage'; - -Polymer({ - is: 'tf-multi-checkbox', - properties: { - names: { - type: Array, - value: function() { - return []; - }, - }, // All the runs in consideration - regexInput: { - type: String, - value: storage.getStringInitializer('regexInput', ''), - observer: '_regexInputObserver', - }, // Regex for filtering the runs - regex: {type: Object, computed: '_makeRegex(regexInput)'}, - namesMatchingRegex: { - type: Array, - computed: 'computeNamesMatchingRegex(names.*, regex)' - }, // Runs that match the regex - runSelectionState: { - // if a run is explicitly enabled, True, if explicitly disabled, False. - // if undefined, default value (enable for first k runs, disable after). - type: Object, - value: storage.getObjectInitializer('runSelectionState', {}), - observer: '_storeRunToIsCheckedMapping', - }, - // (Allows state to persist across regex filtering) - outSelected: { - type: Array, - notify: true, - computed: 'computeOutSelected(namesMatchingRegex.*, runSelectionState.*)' - }, - colorScale: { - type: Object, - observer: 'synchronizeColors', - }, // map from run name to css class - maxRunsToEnableByDefault: { - // When TB first loads, if it has k or fewer runs, they are all enabled - // by default. If there are more, then they are all disabled. - type: Number, - value: 40, - }, - _debouncedRegexChange: { - type: Object, - // Updating the regex can be slow, because it involves updating styles - // on a large number of Polymer paper-checkboxes. We don't want to do - // this while the user is typing, as it may make a bad, laggy UI. - // So we debounce the updates that come from user typing. - value: function() { - const _this = this; - var debounced = _.debounce(function(r) { - _this.regexInput = r; - }, 150, {leading: false}); - return function() { - var r = this.$$('#runs-regex').value; - if (r == '') { - // If the user cleared the field, they may be done typing, so - // update more quickly. - this.async(function() { - _this.regexInput = r; - }, 30); - } else { - debounced(r); - }; - }; - }, - }, - }, - listeners: { - 'dom-change': 'synchronizeColors', - }, - observers: [ - '_setIsolatorIcon(runSelectionState, names)', - ], - _storeRunToIsCheckedMapping: - storage.getObjectObserver('runSelectionState', {}), - _makeRegex: function(regex) { - try { - return new RegExp(regex) - } catch (e) { - return null; - } - }, - _setIsolatorIcon: function() { - var runMap = this.runSelectionState; - var numChecked = _.filter(_.values(runMap)).length; - var buttons = - Array.prototype.slice.call(this.querySelectorAll('.isolator')); - - buttons.forEach(function(b) { - if (numChecked === 1 && runMap[b.name]) { - b.icon = 'radio-button-checked'; - } else { - b.icon = 'radio-button-unchecked'; - } - }); - }, - computeNamesMatchingRegex: function(__, ___) { - var regex = this.regex; - return this.names.filter(function(n) { - return regex == null || regex.test(n); - }); - }, - computeOutSelected: function(__, ___) { - var runSelectionState = this.runSelectionState; - var num = this.maxRunsToEnableByDefault; - var allEnabled = this.namesMatchingRegex.length <= num; - return this.namesMatchingRegex.filter(function(n, i) { - return runSelectionState[n] == null ? allEnabled : runSelectionState[n]; - }); - }, - synchronizeColors: function(e) { - if (!this.colorScale) return; - - this._setIsolatorIcon(); - - var checkboxes = - Array.prototype.slice.call(this.querySelectorAll('paper-checkbox')); - var scale = this.colorScale; - checkboxes.forEach(function(p) { - var color = scale.scale(p.name); - p.customStyle['--paper-checkbox-checked-color'] = color; - p.customStyle['--paper-checkbox-checked-ink-color'] = color; - p.customStyle['--paper-checkbox-unchecked-color'] = color; - p.customStyle['--paper-checkbox-unchecked-ink-color'] = color; - }); - var buttons = - Array.prototype.slice.call(this.querySelectorAll('.isolator')); - buttons.forEach(function(p) { - var color = scale.scale(p.name); - p.style['color'] = color; - }); - // The updateStyles call fails silently if the browser doesn't have focus, - // e.g. if TensorBoard was opened into a new tab that isn't visible. - // So we wait for requestAnimationFrame. - var _this = this; - window.requestAnimationFrame(function() { - _this.updateStyles(); - }); - }, - _isolateRun: function(e) { - // If user clicks on the label for one run, enable it and disable all other - // runs. - - var name = (Polymer.dom(e) as any).localTarget.name; - var selectionState = {}; - this.names.forEach(function(n) { - selectionState[n] = n == name; - }); - this.runSelectionState = selectionState; - }, - _checkboxChange: function(e) { - var target = (Polymer.dom(e) as any).localTarget; - this.runSelectionState[target.name] = target.checked; - // n.b. notifyPath won't work because run names may have periods. - this.runSelectionState = _.clone(this.runSelectionState); - }, - _isChecked: function(item, outSelectedChange) { - return this.outSelected.indexOf(item) != -1; - }, - _regexInputObserver: storage.getStringObserver('regexInput', ''), - toggleAll: function() { - var _this = this; - var anyToggledOn = this.namesMatchingRegex.some(function(n) { - return _this.runSelectionState[n] - }); - - - var runSelectionStateIsDefault = - Object.keys(this.runSelectionState).length == 0; - - var defaultOff = - this.namesMatchingRegex.length > this.maxRunsToEnableByDefault; - // We have runs toggled either if some were explicitly toggled on, or if - // we are in the default state, and there are few enough that we default - // to toggling on. - anyToggledOn = anyToggledOn || runSelectionStateIsDefault && !defaultOff; - - // If any are toggled on, we turn everything off. Or, if none are toggled - // on, we turn everything on. - - var newRunsDisabled = {}; - this.names.forEach(function(n) { - newRunsDisabled[n] = !anyToggledOn; - }); - this.runSelectionState = newRunsDisabled; - }, -}); diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-no-data-warning.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-no-data-warning.html deleted file mode 100644 index c90efac1d6b..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-no-data-warning.html +++ /dev/null @@ -1,129 +0,0 @@ - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-option-selector.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-option-selector.html deleted file mode 100644 index 547a558ad0b..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-option-selector.html +++ /dev/null @@ -1,94 +0,0 @@ - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-panes-helper.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-panes-helper.html deleted file mode 100644 index 155259d3294..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-panes-helper.html +++ /dev/null @@ -1,352 +0,0 @@ - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-regex-group-demo.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-regex-group-demo.html deleted file mode 100644 index 3565fec1791..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-regex-group-demo.html +++ /dev/null @@ -1,45 +0,0 @@ - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-regex-group.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-regex-group.html deleted file mode 100644 index c1d3cf06aea..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-regex-group.html +++ /dev/null @@ -1,99 +0,0 @@ - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-regex-group.ts b/tensorflow/tensorboard/components/tf_dashboard_common/tf-regex-group.ts deleted file mode 100644 index 92a0eb6a0b9..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-regex-group.ts +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import * as storage from '../tf-storage/storage'; - -Polymer({ - is: 'tf-regex-group', - properties: { - rawRegexes: { - type: Array, - value: storage.getObjectInitializer( - 'rawRegexes', [{regex: '', valid: true}]), - }, - regexes: - {type: Array, computed: 'usableRegexes(rawRegexes.*)', notify: true}, - }, - observers: [ - 'addNewRegexIfNeeded(rawRegexes.*)', - 'checkValidity(rawRegexes.*)', - '_uriStoreRegexes(rawRegexes.*)', - ], - _uriStoreRegexes: - storage.getObjectObserver('rawRegexes', [{regex: '', valid: true}]), - checkValidity: function(x) { - var match = x.path.match(/rawRegexes\.(\d+)\.regex/); - if (match) { - var idx = match[1]; - this.set('rawRegexes.' + idx + '.valid', this.isValid(x.value)); - } - }, - isValid: function(s) { - try { - new RegExp(s); - return true; - } catch (e) { - return false; - } - }, - usableRegexes: function(regexes) { - var isValid = this.isValid; - return regexes.base - .filter(function(r) { - // Checking validity here (rather than using the data property) - // is necessary because otherwise we might send invalid regexes due - // to the fact that this function can call before the observer does - return r.regex !== '' && isValid(r.regex); - }) - .map(function(r) { - return r.regex; - }); - }, - addNewRegexIfNeeded: function() { - var last = this.rawRegexes[this.rawRegexes.length - 1]; - if (last.regex !== '') { - this.push('rawRegexes', {regex: '', valid: true}); - } - }, - deleteRegex: function(e) { - if (this.rawRegexes.length > 1) { - this.splice('rawRegexes', e.model.index, 1); - } - }, - moveFocus: function(e) { - if (e.keyCode === 13) { - var idx = e.model.index; - var inputs = Polymer.dom(this.root).querySelectorAll('.regex-input'); - if (idx < this.rawRegexes.length - 1) { - (inputs[idx + 1] as any).$.input.focus(); - } else { - (document.activeElement as HTMLElement).blur(); - } - } - } -}); diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-run-selector.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-run-selector.html deleted file mode 100644 index e3d8a91fd0c..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-run-selector.html +++ /dev/null @@ -1,188 +0,0 @@ - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_dashboard_common/tf-sidebar-helper.html b/tensorflow/tensorboard/components/tf_dashboard_common/tf-sidebar-helper.html deleted file mode 100644 index 5eb8537040c..00000000000 --- a/tensorflow/tensorboard/components/tf_dashboard_common/tf-sidebar-helper.html +++ /dev/null @@ -1,165 +0,0 @@ - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_distribution_dashboard/BUILD b/tensorflow/tensorboard/components/tf_distribution_dashboard/BUILD deleted file mode 100644 index 5ddd6ba5bb9..00000000000 --- a/tensorflow/tensorboard/components/tf_distribution_dashboard/BUILD +++ /dev/null @@ -1,41 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_distribution_dashboard", - srcs = ["tf-distribution-dashboard.html"], - path = "/tf-distribution-dashboard", - deps = [ - "//tensorflow/tensorboard/components/tf_backend", - "//tensorflow/tensorboard/components/tf_color_scale", - "//tensorflow/tensorboard/components/tf_dashboard_common", - "//tensorflow/tensorboard/components/tf_imports:lodash", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "//tensorflow/tensorboard/components/vz_distribution_chart", - "@org_polymer_iron_collapse", - "@org_polymer_paper_icon_button", - "@org_polymer_paper_styles", - ], -) - -ts_web_library( - name = "demo", - srcs = ["index.html"] + glob(["data/**"]), - path = "/tf-distribution-dashboard", - deps = [ - ":tf_distribution_dashboard", - "//tensorflow/tensorboard/components/tf_imports:d3", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - "@org_polymer_iron_demo_helpers", - "@org_polymer_paper_styles", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_distribution_dashboard/data/compressedHistograms_run_run1_tag_histo1.json b/tensorflow/tensorboard/components/tf_distribution_dashboard/data/compressedHistograms_run_run1_tag_histo1.json deleted file mode 100644 index a6765285b14..00000000000 --- a/tensorflow/tensorboard/components/tf_distribution_dashboard/data/compressedHistograms_run_run1_tag_histo1.json +++ /dev/null @@ -1,212 +0,0 @@ -[ - [ - 0.0, - 0, - [ - [ - 0, - -2.3150592308536755 - ], - [ - 668, - -2.0967547155036605 - ], - [ - 1587, - -1.4326244423655616 - ], - [ - 3085, - -0.8871306575801902 - ], - [ - 5000, - -0.09312398815580714 - ], - [ - 6915, - 0.2584093405812282 - ], - [ - 8413, - 0.8895470642005087 - ], - [ - 9332, - 1.3198979614453679 - ], - [ - 10000, - 1.6793308878855118 - ] - ] - ], - [ - 100.0, - 10, - [ - [ - 0, - -1.3417572789138936 - ], - [ - 668, - -1.183563374619141 - ], - [ - 1587, - -0.48920418783271574 - ], - [ - 3085, - 0.29326906896076954 - ], - [ - 5000, - 0.56953784145381 - ], - [ - 6915, - 0.8684655583499333 - ], - [ - 8413, - 1.4133127368907181 - ], - [ - 9332, - 1.906140650457873 - ], - [ - 10000, - 2.135771998171255 - ] - ] - ], - [ - 200.0, - 20, - [ - [ - 0, - -1.5066917525035333 - ], - [ - 668, - -1.3910909571770793 - ], - [ - 1587, - -0.902737218885874 - ], - [ - 3085, - -0.3807791904765027 - ], - [ - 5000, - 0.38900200905253046 - ], - [ - 6915, - 0.8209734209339482 - ], - [ - 8413, - 1.302385856695965 - ], - [ - 9332, - 1.9324626053521639 - ], - [ - 10000, - 2.957505317875451 - ] - ] - ], - [ - 300.0, - 30, - [ - [ - 0, - -0.5430457051469562 - ], - [ - 668, - -0.4626161834245273 - ], - [ - 1587, - 0.21573949543027715 - ], - [ - 3085, - 0.37353741100174215 - ], - [ - 5000, - 0.6891407881591103 - ], - [ - 6915, - 1.0927156232630852 - ], - [ - 8413, - 1.2745337159550916 - ], - [ - 9332, - 1.4321116832891605 - ], - [ - 10000, - 2.1913774993059034 - ] - ] - ], - [ - 400.0, - 40, - [ - [ - 0, - -0.3584790755077172 - ], - [ - 668, - -0.33301611509753215 - ], - [ - 1587, - -0.1089466072951948 - ], - [ - 3085, - 0.5792199847585249 - ], - [ - 5000, - 1.220854943811942 - ], - [ - 6915, - 1.759829438421432 - ], - [ - 8413, - 2.3072559906741614 - ], - [ - 9332, - 2.753036118353921 - ], - [ - 10000, - 3.0267252195784047 - ] - ] - ] -] diff --git a/tensorflow/tensorboard/components/tf_distribution_dashboard/data/compressedHistograms_run_run2_tag_histo1.json b/tensorflow/tensorboard/components/tf_distribution_dashboard/data/compressedHistograms_run_run2_tag_histo1.json deleted file mode 100644 index 9e8a55b3f20..00000000000 --- a/tensorflow/tensorboard/components/tf_distribution_dashboard/data/compressedHistograms_run_run2_tag_histo1.json +++ /dev/null @@ -1,212 +0,0 @@ -[ - [ - 0.0, - 0, - [ - [ - 0, - -3.6801669545044846 - ], - [ - 668, - -3.192188140974744 - ], - [ - 1587, - -2.3414678549368806 - ], - [ - 3085, - -0.9632173471995873 - ], - [ - 5000, - -0.3214892636797772 - ], - [ - 6915, - 0.11870794142185205 - ], - [ - 8413, - 0.8895470642005087 - ], - [ - 9332, - 1.183563374619141 - ], - [ - 10000, - 2.665663810418372 - ] - ] - ], - [ - 100.0, - 10, - [ - [ - 0, - -3.564793583751807 - ], - [ - 668, - -3.376844436865802 - ], - [ - 1587, - -1.0366615731293798 - ], - [ - 3085, - -0.27318696312672563 - ], - [ - 5000, - 0.9718642422053263 - ], - [ - 6915, - 2.5765662807928194 - ], - [ - 8413, - 3.1415385101545126 - ], - [ - 9332, - 4.085981768607621 - ], - [ - 10000, - 4.623079406808927 - ] - ] - ], - [ - 200.0, - 20, - [ - [ - 0, - -2.235172510433281 - ], - [ - 668, - -2.004569042815611 - ], - [ - 1587, - -1.2015432383370985 - ], - [ - 3085, - 0.11835464933202625 - ], - [ - 5000, - 0.56953784145381 - ], - [ - 6915, - 1.202844810963146 - ], - [ - 8413, - 2.689066032283515 - ], - [ - 9332, - 2.8494015726499944 - ], - [ - 10000, - 3.481377676013788 - ] - ] - ], - [ - 300.0, - 30, - [ - [ - 0, - -3.360113978269659 - ], - [ - 668, - -2.8293185004961043 - ], - [ - 1587, - -1.5992540502266783 - ], - [ - 3085, - 0.14393860259807117 - ], - [ - 5000, - 1.47723448201245 - ], - [ - 6915, - 1.9510057389110733 - ], - [ - 8413, - 2.833176104473626 - ], - [ - 9332, - 4.142405216576347 - ], - [ - 10000, - 4.706937777668589 - ] - ] - ], - [ - 400.0, - 40, - [ - [ - 0, - -2.599286228987632 - ], - [ - 668, - -2.240365897443259 - ], - [ - 1587, - -1.5992540502266783 - ], - [ - 3085, - -0.9101893288861387 - ], - [ - 5000, - 0.7580548669750213 - ], - [ - 6915, - 1.6009864433919474 - ], - [ - 8413, - 2.3504002974280036 - ], - [ - 9332, - 2.7907805263353733 - ], - [ - 10000, - 3.5098048900144323 - ] - ] - ] -] diff --git a/tensorflow/tensorboard/components/tf_distribution_dashboard/data/compressedHistograms_run_run2_tag_histo2.json b/tensorflow/tensorboard/components/tf_distribution_dashboard/data/compressedHistograms_run_run2_tag_histo2.json deleted file mode 100644 index 7c8836f6246..00000000000 --- a/tensorflow/tensorboard/components/tf_distribution_dashboard/data/compressedHistograms_run_run2_tag_histo2.json +++ /dev/null @@ -1,212 +0,0 @@ -[ - [ - 0.0, - 0, - [ - [ - 0, - -1.9291158122759586 - ], - [ - 668, - -1.5970765333488954 - ], - [ - 1587, - -1.0923120348519078 - ], - [ - 3085, - -0.6688082872192093 - ], - [ - 5000, - 0.09312398815580714 - ], - [ - 6915, - 0.44532789251701854 - ], - [ - 8413, - 0.8238009655877649 - ], - [ - 9332, - 1.0357232383581656 - ], - [ - 10000, - 1.2741043689144438 - ] - ] - ], - [ - 100.0, - 10, - [ - [ - 0, - -0.7780725642449806 - ], - [ - 668, - -0.7138496178727424 - ], - [ - 1587, - -0.5448932415735014 - ], - [ - 3085, - -0.24370397454796228 - ], - [ - 5000, - 0.42790220995778355 - ], - [ - 6915, - 0.6191730643365096 - ], - [ - 8413, - 0.752059342118037 - ], - [ - 9332, - 1.0451472255274825 - ], - [ - 10000, - 2.5559479569222825 - ] - ] - ], - [ - 200.0, - 20, - [ - [ - 0, - -1.3876904425996377 - ], - [ - 668, - -1.1464188862638496 - ], - [ - 1587, - -0.4049955219067526 - ], - [ - 3085, - 0.04721394862139682 - ], - [ - 5000, - 0.56953784145381 - ], - [ - 6915, - 1.3221859041483333 - ], - [ - 8413, - 1.6188495656305735 - ], - [ - 9332, - 1.7613953069723651 - ], - [ - 10000, - 2.3257482385477384 - ] - ] - ], - [ - 300.0, - 30, - [ - [ - 0, - -1.600772629982185 - ], - [ - 668, - -1.1548516185367033 - ], - [ - 1587, - -0.260387173785447 - ], - [ - 3085, - 0.17416570914366614 - ], - [ - 5000, - 0.47069243095356195 - ], - [ - 6915, - 1.1559276581637614 - ], - [ - 8413, - 2.0474031182051404 - ], - [ - 9332, - 2.18821711651116 - ], - [ - 10000, - 2.2393193406467518 - ] - ] - ], - [ - 400.0, - 40, - [ - [ - 0, - -0.8286852465281818 - ], - [ - 668, - -0.7815041529866706 - ], - [ - 1587, - -0.3334896444053469 - ], - [ - 3085, - 0.21085213041026643 - ], - [ - 5000, - 0.5177616740489182 - ], - [ - 6915, - 1.077122434649409 - ], - [ - 8413, - 1.5898009703967424 - ], - [ - 9332, - 1.8859097291499742 - ], - [ - 10000, - 2.0954239138728523 - ] - ] - ] -] diff --git a/tensorflow/tensorboard/components/tf_distribution_dashboard/data/logdir b/tensorflow/tensorboard/components/tf_distribution_dashboard/data/logdir deleted file mode 100644 index b6362b45d77..00000000000 --- a/tensorflow/tensorboard/components/tf_distribution_dashboard/data/logdir +++ /dev/null @@ -1 +0,0 @@ -{"logdir": "/foo/some/fake/logdir"} \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_distribution_dashboard/data/runs.json b/tensorflow/tensorboard/components/tf_distribution_dashboard/data/runs.json deleted file mode 100644 index 739262a9fb6..00000000000 --- a/tensorflow/tensorboard/components/tf_distribution_dashboard/data/runs.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "run1": {"compressedHistograms": ["histo1"]}, - "run2": {"compressedHistograms": ["histo2", "histo1"]} -} diff --git a/tensorflow/tensorboard/components/tf_distribution_dashboard/index.html b/tensorflow/tensorboard/components/tf_distribution_dashboard/index.html deleted file mode 100644 index fe899a0ba8c..00000000000 --- a/tensorflow/tensorboard/components/tf_distribution_dashboard/index.html +++ /dev/null @@ -1,69 +0,0 @@ - - - - - - - - -Distribution Dashboard Demo - - - - diff --git a/tensorflow/tensorboard/components/tf_distribution_dashboard/tf-distribution-dashboard.html b/tensorflow/tensorboard/components/tf_distribution_dashboard/tf-distribution-dashboard.html deleted file mode 100644 index 76de74273f2..00000000000 --- a/tensorflow/tensorboard/components/tf_distribution_dashboard/tf-distribution-dashboard.html +++ /dev/null @@ -1,131 +0,0 @@ - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_globals/BUILD b/tensorflow/tensorboard/components/tf_globals/BUILD deleted file mode 100644 index c5b0cfbaa55..00000000000 --- a/tensorflow/tensorboard/components/tf_globals/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_globals", - srcs = [ - "globals.ts", - "tf-globals.html", - ], - path = "/tf-globals", -) - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [":tf_globals"], - destdir = "tf-globals", -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_globals/globals.ts b/tensorflow/tensorboard/components/tf_globals/globals.ts deleted file mode 100644 index fb6bb83b97f..00000000000 --- a/tensorflow/tensorboard/components/tf_globals/globals.ts +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// The names of TensorBoard tabs. -export const TABS = [ - 'scalars', 'images', 'audio', 'graphs', 'distributions', 'histograms', - 'embeddings', 'text' -]; - -// If true, TensorBoard stores its hash in the URI state. -// If false, tab switching in TensorBoard will not update location hash, -// because hash updates interfere with wct_tests. -let _useHash = false; - -export function setUseHash(shouldUseHash: boolean): void { - _useHash = shouldUseHash; -} - -export function useHash(): boolean { - return _useHash; -} - -let _fakeHash = ''; - -export function setFakeHash(h: string) { - _fakeHash = h; -} - -export function getFakeHash() { - return _fakeHash; -} diff --git a/tensorflow/tensorboard/components/tf_globals/tf-globals.html b/tensorflow/tensorboard/components/tf_globals/tf-globals.html deleted file mode 100644 index efb8e92e080..00000000000 --- a/tensorflow/tensorboard/components/tf_globals/tf-globals.html +++ /dev/null @@ -1,19 +0,0 @@ - - - - diff --git a/tensorflow/tensorboard/components/tf_graph/BUILD b/tensorflow/tensorboard/components/tf_graph/BUILD deleted file mode 100644 index 4c0894f1925..00000000000 --- a/tensorflow/tensorboard/components/tf_graph/BUILD +++ /dev/null @@ -1,56 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_graph", - srcs = [ - "tf-graph.html", - "tf-graph-minimap.html", - "tf-graph-scene.html", - ], - path = "/tf-graph", - deps = [ - "//tensorflow/tensorboard/components/tf_dashboard_common", - "//tensorflow/tensorboard/components/tf_graph_common", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_icons", - "@org_polymer_paper_button", - "@org_polymer_paper_dropdown_menu", - "@org_polymer_paper_input", - "@org_polymer_paper_menu", - "@org_polymer_paper_radio_group", - "@org_polymer_paper_toggle_button", - "@org_polymer_paper_tooltip", - ], -) - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [":tf_graph"], - destdir = "tf-graph", - deps = [ - "//tensorflow/tensorboard/components/tf_dashboard_common:legacy", - "//tensorflow/tensorboard/components/tf_graph_common:legacy", - "//third_party/javascript/polymer/v1/iron-flex-layout:lib", - "//third_party/javascript/polymer/v1/iron-icons:lib", - "//third_party/javascript/polymer/v1/paper-button:lib", - "//third_party/javascript/polymer/v1/paper-dropdown-menu:lib", - "//third_party/javascript/polymer/v1/paper-input:lib", - "//third_party/javascript/polymer/v1/paper-menu:lib", - "//third_party/javascript/polymer/v1/paper-radio-group:lib", - "//third_party/javascript/polymer/v1/paper-toggle-button:lib", - "//third_party/javascript/polymer/v1/paper-tooltip:lib", - "//third_party/javascript/polymer/v1/polymer:lib", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_graph/demo/BUILD b/tensorflow/tensorboard/components/tf_graph/demo/BUILD deleted file mode 100644 index 02f3bf64bbc..00000000000 --- a/tensorflow/tensorboard/components/tf_graph/demo/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -# bazel run //third_party/tensorflow/tensorboard/components/tf_graph/demo -ts_web_library( - name = "demo", - srcs = ["index.html"] + glob(["data/**"]), - path = "/tf-graph/demo", - deps = [ - "//tensorflow/tensorboard/components/tf_graph", - "//tensorflow/tensorboard/components/tf_graph_common", - "//tensorflow/tensorboard/components/tf_graph_loader", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - "@org_polymer_iron_demo_helpers", - "@org_polymer_paper_styles", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_graph/demo/data/graph.pbtxt b/tensorflow/tensorboard/components/tf_graph/demo/data/graph.pbtxt deleted file mode 100644 index 30b20645346..00000000000 --- a/tensorflow/tensorboard/components/tf_graph/demo/data/graph.pbtxt +++ /dev/null @@ -1,4606 +0,0 @@ -node { - name: "GradientDescent/learning_rate" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_3" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.1 - } - } - } -} -node { - name: "gradients/add_grad/Shape_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 100 - } - } - } -} -node { - name: "gradients/add_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000d\000\000\000" - } - } - } -} -node { - name: "gradients/add_grad/BroadcastGradientArgs" - op: "BroadcastGradientArgs" - input: "gradients/add_grad/Shape" - input: "gradients/add_grad/Shape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - } - } - } -} -node { - name: "gradients/add_1_grad/Shape_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 10 - } - } - } -} -node { - name: "gradients/add_1_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "gradients/add_1_grad/BroadcastGradientArgs" - op: "BroadcastGradientArgs" - input: "gradients/add_1_grad/Shape" - input: "gradients/add_1_grad/Shape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - } - } - } -} -node { - name: "gradients/Reshape_1_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: -1 - } - } - } -} -node { - name: "gradients/Reshape_3_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 200 - } - } - } -} -node { - name: "gradients/Mean_grad/Maximum/y" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "gradients/Mean_grad/Const_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "gradients/Mean_grad/Const" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "gradients/Mean_grad/Shape_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } -} -node { - name: "gradients/Mean_grad/Prod_1" - op: "Prod" - input: "gradients/Mean_grad/Shape_1" - input: "gradients/Mean_grad/Const_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/Mean_grad/Maximum" - op: "Maximum" - input: "gradients/Mean_grad/Prod_1" - input: "gradients/Mean_grad/Maximum/y" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "gradients/Mean_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 200 - } - } - } -} -node { - name: "gradients/Mean_grad/Prod" - op: "Prod" - input: "gradients/Mean_grad/Shape" - input: "gradients/Mean_grad/Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/Mean_grad/floordiv" - op: "FloorDiv" - input: "gradients/Mean_grad/Prod" - input: "gradients/Mean_grad/Maximum" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "gradients/Mean_grad/Cast" - op: "Cast" - input: "gradients/Mean_grad/floordiv" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "gradients/Mean_grad/Tile/multiples" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 200 - } - } - } -} -node { - name: "gradients/Mean_grad/Reshape/shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "gradients/Const" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1 - } - } - } -} -node { - name: "gradients/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } -} -node { - name: "gradients/Fill" - op: "Fill" - input: "gradients/Shape" - input: "gradients/Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "gradients/Mean_grad/Reshape" - op: "Reshape" - input: "gradients/Fill" - input: "gradients/Mean_grad/Reshape/shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "gradients/Mean_grad/Tile" - op: "Tile" - input: "gradients/Mean_grad/Reshape" - input: "gradients/Mean_grad/Tile/multiples" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tmultiples" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - } - } - } -} -node { - name: "gradients/Mean_grad/truediv" - op: "RealDiv" - input: "gradients/Mean_grad/Tile" - input: "gradients/Mean_grad/Cast" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - } - } - } -} -node { - name: "gradients/Reshape_3_grad/Reshape" - op: "Reshape" - input: "gradients/Mean_grad/truediv" - input: "gradients/Reshape_3_grad/Shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - } - } - } -} -node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" - op: "ExpandDims" - input: "gradients/Reshape_3_grad/Reshape" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "Const" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "Slice_2/begin" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "Sub_2/y" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "concat_1/axis" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "concat_1/values_0" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: -1 - } - } - } -} -node { - name: "Slice_1/size" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "Sub_1/y" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "Shape_2" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "Rank_2" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "Sub_1" - op: "Sub" - input: "Rank_2" - input: "Sub_1/y" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Slice_1/begin" - op: "Pack" - input: "Sub_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "Slice_1" - op: "Slice" - input: "Shape_2" - input: "Slice_1/begin" - input: "Slice_1/size" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "concat_1" - op: "ConcatV2" - input: "concat_1/values_0" - input: "Slice_1" - input: "concat_1/axis" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } -} -node { - name: "concat/axis" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "concat/values_0" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: -1 - } - } - } -} -node { - name: "Slice/size" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "Sub/y" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "Shape_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "Rank_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "Sub" - op: "Sub" - input: "Rank_1" - input: "Sub/y" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Slice/begin" - op: "Pack" - input: "Sub" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "Slice" - op: "Slice" - input: "Shape_1" - input: "Slice/begin" - input: "Slice/size" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "concat" - op: "ConcatV2" - input: "concat/values_0" - input: "Slice" - input: "concat/axis" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } -} -node { - name: "Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "Rank" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "Sub_2" - op: "Sub" - input: "Rank" - input: "Sub_2/y" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Slice_2/size" - op: "Pack" - input: "Sub_2" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "Slice_2" - op: "Slice" - input: "Shape" - input: "Slice_2/begin" - input: "Slice_2/size" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "logits_biases" - op: "VariableV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_class" - value { - list { - s: "loc:@logits_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 10 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "logits_biases/read" - op: "Identity" - input: "logits_biases" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@logits_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } -} -node { - name: "logits_weights" - op: "VariableV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_class" - value { - list { - s: "loc:@logits_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "logits_weights/read" - op: "Identity" - input: "logits_weights" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@logits_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "hidden_biases" - op: "VariableV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_class" - value { - list { - s: "loc:@hidden_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 100 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "hidden_biases/read" - op: "Identity" - input: "hidden_biases" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@hidden_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } -} -node { - name: "hidden_weights" - op: "VariableV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_class" - value { - list { - s: "loc:@hidden_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 784 - } - dim { - size: 100 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "hidden_weights/read" - op: "Identity" - input: "hidden_weights" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@hidden_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "Reshape/shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\377\377\377\377" - } - } - } -} -node { - name: "mnist_dataset_train_2/one_hot/depth" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 10 - } - } - } -} -node { - name: "mnist_dataset_train_2/one_hot/off_value" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0 - } - } - } -} -node { - name: "mnist_dataset_train_2/one_hot/on_value" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1 - } - } - } -} -node { - name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 200 - } - } - } -} -node { - name: "mnist_dataset_train_1/random_shuffle_queue" - op: "RandomShuffleQueueV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "capacity" - value { - i: 20000 - } - } - attr { - key: "component_types" - value { - list { - type: DT_FLOAT - type: DT_INT64 - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "min_after_dequeue" - value { - i: 4000 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } - attr { - key: "shapes" - value { - list { - shape { - dim { - size: 28 - } - dim { - size: 28 - } - dim { - size: 1 - } - } - shape { - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" - op: "QueueDequeueManyV2" - input: "mnist_dataset_train_1/random_shuffle_queue" - input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - } - } - } - attr { - key: "component_types" - value { - list { - type: DT_FLOAT - type: DT_INT64 - } - } - } - attr { - key: "timeout_ms" - value { - i: -1 - } - } -} -node { - name: "Reshape" - op: "Reshape" - input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" - input: "Reshape/shape" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: -1 - } - } - } - } - } -} -node { - name: "MatMul" - op: "MatMul" - input: "Reshape" - input: "hidden_weights/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: false - } - } -} -node { - name: "add" - op: "Add" - input: "MatMul" - input: "hidden_biases/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "Relu" - op: "Relu" - input: "add" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "MatMul_1" - op: "MatMul" - input: "Relu" - input: "logits_weights/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: false - } - } -} -node { - name: "add_1" - op: "Add" - input: "MatMul_1" - input: "logits_biases/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "Reshape_1" - op: "Reshape" - input: "add_1" - input: "concat" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "mnist_dataset_train_2/one_hot" - op: "OneHot" - input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany:1" - input: "mnist_dataset_train_2/one_hot/depth" - input: "mnist_dataset_train_2/one_hot/on_value" - input: "mnist_dataset_train_2/one_hot/off_value" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "TI" - value { - type: DT_INT64 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "axis" - value { - i: -1 - } - } -} -node { - name: "Reshape_2" - op: "Reshape" - input: "mnist_dataset_train_2/one_hot" - input: "concat_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "SoftmaxCrossEntropyWithLogits" - op: "SoftmaxCrossEntropyWithLogits" - input: "Reshape_1" - input: "Reshape_2" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" - op: "PreventGradient" - input: "SoftmaxCrossEntropyWithLogits:1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "message" - value { - s: "Currently there is no way to take the second derivative of softmax_cross_entropy_with_logits due to the fused implementation\'s interaction with tf.gradients()" - } - } -} -node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" - op: "Mul" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/Reshape_1_grad/Reshape" - op: "Reshape" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" - input: "gradients/Reshape_1_grad/Shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/add_1_grad/Sum_1" - op: "Sum" - input: "gradients/Reshape_1_grad/Reshape" - input: "gradients/add_1_grad/BroadcastGradientArgs:1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/add_1_grad/Reshape_1" - op: "Reshape" - input: "gradients/add_1_grad/Sum_1" - input: "gradients/add_1_grad/Shape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/add_1_grad/Sum" - op: "Sum" - input: "gradients/Reshape_1_grad/Reshape" - input: "gradients/add_1_grad/BroadcastGradientArgs" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/add_1_grad/Reshape" - op: "Reshape" - input: "gradients/add_1_grad/Sum" - input: "gradients/add_1_grad/Shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/add_1_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/add_1_grad/Reshape" - input: "^gradients/add_1_grad/Reshape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "gradients/add_1_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/add_1_grad/Reshape_1" - input: "^gradients/add_1_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_1_grad/Reshape_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } -} -node { - name: "GradientDescent/update_logits_biases/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "logits_biases" - input: "GradientDescent/learning_rate" - input: "gradients/add_1_grad/tuple/control_dependency_1" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@logits_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } -} -node { - name: "gradients/add_1_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/add_1_grad/Reshape" - input: "^gradients/add_1_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_1_grad/Reshape" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/MatMul_1_grad/MatMul_1" - op: "MatMul" - input: "Relu" - input: "gradients/add_1_grad/tuple/control_dependency" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: true - } - } - attr { - key: "transpose_b" - value { - b: false - } - } -} -node { - name: "gradients/MatMul_1_grad/MatMul" - op: "MatMul" - input: "gradients/add_1_grad/tuple/control_dependency" - input: "logits_weights/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "gradients/MatMul_1_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/MatMul_1_grad/MatMul" - input: "^gradients/MatMul_1_grad/MatMul_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "gradients/MatMul_1_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/MatMul_1_grad/MatMul_1" - input: "^gradients/MatMul_1_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/MatMul_1_grad/MatMul_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "GradientDescent/update_logits_weights/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "logits_weights" - input: "GradientDescent/learning_rate" - input: "gradients/MatMul_1_grad/tuple/control_dependency_1" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@logits_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } -} -node { - name: "gradients/MatMul_1_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/MatMul_1_grad/MatMul" - input: "^gradients/MatMul_1_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/MatMul_1_grad/MatMul" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/Relu_grad/ReluGrad" - op: "ReluGrad" - input: "gradients/MatMul_1_grad/tuple/control_dependency" - input: "Relu" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/add_grad/Sum_1" - op: "Sum" - input: "gradients/Relu_grad/ReluGrad" - input: "gradients/add_grad/BroadcastGradientArgs:1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/add_grad/Reshape_1" - op: "Reshape" - input: "gradients/add_grad/Sum_1" - input: "gradients/add_grad/Shape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/add_grad/Sum" - op: "Sum" - input: "gradients/Relu_grad/ReluGrad" - input: "gradients/add_grad/BroadcastGradientArgs" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/add_grad/Reshape" - op: "Reshape" - input: "gradients/add_grad/Sum" - input: "gradients/add_grad/Shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/add_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/add_grad/Reshape" - input: "^gradients/add_grad/Reshape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "gradients/add_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/add_grad/Reshape_1" - input: "^gradients/add_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_grad/Reshape_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } -} -node { - name: "GradientDescent/update_hidden_biases/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "hidden_biases" - input: "GradientDescent/learning_rate" - input: "gradients/add_grad/tuple/control_dependency_1" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@hidden_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } -} -node { - name: "gradients/add_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/add_grad/Reshape" - input: "^gradients/add_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_grad/Reshape" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/MatMul_grad/MatMul_1" - op: "MatMul" - input: "Reshape" - input: "gradients/add_grad/tuple/control_dependency" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: true - } - } - attr { - key: "transpose_b" - value { - b: false - } - } -} -node { - name: "gradients/MatMul_grad/MatMul" - op: "MatMul" - input: "gradients/add_grad/tuple/control_dependency" - input: "hidden_weights/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 784 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "gradients/MatMul_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/MatMul_grad/MatMul" - input: "^gradients/MatMul_grad/MatMul_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "gradients/MatMul_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/MatMul_grad/MatMul_1" - input: "^gradients/MatMul_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/MatMul_grad/MatMul_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "GradientDescent/update_hidden_weights/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "hidden_weights" - input: "GradientDescent/learning_rate" - input: "gradients/MatMul_grad/tuple/control_dependency_1" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@hidden_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } -} -node { - name: "GradientDescent" - op: "NoOp" - input: "^GradientDescent/update_hidden_weights/ApplyGradientDescent" - input: "^GradientDescent/update_hidden_biases/ApplyGradientDescent" - input: "^GradientDescent/update_logits_weights/ApplyGradientDescent" - input: "^GradientDescent/update_logits_biases/ApplyGradientDescent" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_2" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "Reshape_3" - op: "Reshape" - input: "SoftmaxCrossEntropyWithLogits" - input: "Slice_2" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - } - } - } -} -node { - name: "Mean" - op: "Mean" - input: "Reshape_3" - input: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "_send_Mean_0" - op: "_Send" - input: "Mean" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "client_terminated" - value { - b: true - } - } - attr { - key: "recv_device" - value { - s: "/job:localhost/replica:0/task:0/cpu:0" - } - } - attr { - key: "send_device" - value { - s: "/job:localhost/replica:0/task:0/cpu:0" - } - } - attr { - key: "send_device_incarnation" - value { - i: -5924635994370253548 - } - } - attr { - key: "tensor_name" - value { - s: "Mean:0" - } - } -} -library { -} -versions { - producer: 21 -} diff --git a/tensorflow/tensorboard/components/tf_graph/demo/index.html b/tensorflow/tensorboard/components/tf_graph/demo/index.html deleted file mode 100644 index 52e2f0b9340..00000000000 --- a/tensorflow/tensorboard/components/tf_graph/demo/index.html +++ /dev/null @@ -1,92 +0,0 @@ - - - - - - - -TF Graph Demo - - - - diff --git a/tensorflow/tensorboard/components/tf_graph/tf-graph-minimap.html b/tensorflow/tensorboard/components/tf_graph/tf-graph-minimap.html deleted file mode 100644 index 5fc16c05207..00000000000 --- a/tensorflow/tensorboard/components/tf_graph/tf-graph-minimap.html +++ /dev/null @@ -1,88 +0,0 @@ - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_graph/tf-graph-scene.html b/tensorflow/tensorboard/components/tf_graph/tf-graph-scene.html deleted file mode 100644 index fb2bc13f9a1..00000000000 --- a/tensorflow/tensorboard/components/tf_graph/tf-graph-scene.html +++ /dev/null @@ -1,1081 +0,0 @@ - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_graph/tf-graph.html b/tensorflow/tensorboard/components/tf_graph/tf-graph.html deleted file mode 100644 index efbf065a40a..00000000000 --- a/tensorflow/tensorboard/components/tf_graph/tf-graph.html +++ /dev/null @@ -1,316 +0,0 @@ - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_graph_app/BUILD b/tensorflow/tensorboard/components/tf_graph_app/BUILD deleted file mode 100644 index d0b6d79640d..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_app/BUILD +++ /dev/null @@ -1,47 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_graph_app", - srcs = [ - "index.html", - "tf-graph-app.html", - ], - path = "/tf-graph-app", - deps = [ - "//tensorflow/tensorboard/components/tf_graph_board", - "//tensorflow/tensorboard/components/tf_graph_controls", - "//tensorflow/tensorboard/components/tf_graph_loader", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - "@org_polymer_iron_component_page", - ], -) - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [":tf_graph_app"], - visibility = [ - "//learning/brain/python/client/colab:__pkg__", - "//learning/vis/vz_elements/catalog:__pkg__", - ], - destdir = "tf-graph-app", - deps = [ - "//tensorflow/tensorboard/components/tf_graph_board:legacy", - "//tensorflow/tensorboard/components/tf_graph_controls:legacy", - "//tensorflow/tensorboard/components/tf_graph_loader:legacy", - "//third_party/javascript/polymer/v1/iron-component-page:lib", - "//third_party/javascript/polymer/v1/polymer:lib", - "//third_party/javascript/polymer/v1/webcomponentsjs:lib", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_graph_app/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_app/demo/BUILD deleted file mode 100644 index 0205e2fd92c..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_app/demo/BUILD +++ /dev/null @@ -1,23 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -# bazel run //third_party/tensorflow/tensorboard/components/tf_graph_app/demo -ts_web_library( - name = "demo", - srcs = ["index.html"] + glob(["data/**"]), - path = "/tf-graph-app/demo", - deps = [ - "//tensorflow/tensorboard/components/tf_graph_app", - "@org_polymer_iron_demo_helpers", - "@org_polymer_paper_styles", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_graph_app/demo/data/graph.pbtxt b/tensorflow/tensorboard/components/tf_graph_app/demo/data/graph.pbtxt deleted file mode 100644 index 8b95b258df4..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_app/demo/data/graph.pbtxt +++ /dev/null @@ -1,90 +0,0 @@ -node { - name: "life" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "universe" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 40 - } - } - } -} -node { - name: "everything" - op: "Const" - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "Add" - op: "Add" - input: "life" - input: "universe" - attr { - key: "T" - value { - type: DT_INT32 - } - } -} -node { - name: "answer" - op: "Add" - input: "Add" - input: "everything" - attr { - key: "T" - value { - type: DT_INT32 - } - } -} -versions { - producer: 10 -} diff --git a/tensorflow/tensorboard/components/tf_graph_app/demo/index.html b/tensorflow/tensorboard/components/tf_graph_app/demo/index.html deleted file mode 100644 index f71feea390a..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_app/demo/index.html +++ /dev/null @@ -1,34 +0,0 @@ - - - - - - - -

Answer to the Ultimate Question of Life, the Universe, and Everything

- - - diff --git a/tensorflow/tensorboard/components/tf_graph_app/index.html b/tensorflow/tensorboard/components/tf_graph_app/index.html deleted file mode 100644 index c80fbf4f632..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_app/index.html +++ /dev/null @@ -1,30 +0,0 @@ - - - - - - vz-vega - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_graph_app/tf-graph-app.html b/tensorflow/tensorboard/components/tf_graph_app/tf-graph-app.html deleted file mode 100644 index 915b54a06a9..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_app/tf-graph-app.html +++ /dev/null @@ -1,152 +0,0 @@ - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_graph_board/BUILD b/tensorflow/tensorboard/components/tf_graph_board/BUILD deleted file mode 100644 index 866112e0212..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_board/BUILD +++ /dev/null @@ -1,38 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_graph_board", - srcs = ["tf-graph-board.html"], - path = "/tf-graph-board", - deps = [ - "//tensorflow/tensorboard/components/tf_graph", - "//tensorflow/tensorboard/components/tf_graph_common", - "//tensorflow/tensorboard/components/tf_graph_info", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "@org_polymer_paper_progress", - ], -) - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [":tf_graph_board"], - destdir = "tf-graph-board", - deps = [ - "//tensorflow/tensorboard/components/tf_graph:legacy", - "//tensorflow/tensorboard/components/tf_graph_common:legacy", - "//tensorflow/tensorboard/components/tf_graph_info:legacy", - "//third_party/javascript/polymer/v1/paper-progress:lib", - "//third_party/javascript/polymer/v1/polymer:lib", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_graph_board/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_board/demo/BUILD deleted file mode 100644 index 07e8d43dbee..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_board/demo/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -# bazel run //third_party/tensorflow/tensorboard/components/tf_graph_board/demo -ts_web_library( - name = "demo", - srcs = ["index.html"] + glob(["data/**"]), - path = "/tf-graph-board/demo", - deps = [ - "//tensorflow/tensorboard/components/tf_graph_board", - "//tensorflow/tensorboard/components/tf_graph_common", - "//tensorflow/tensorboard/components/tf_graph_loader", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - "@org_polymer_iron_demo_helpers", - "@org_polymer_paper_styles", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_graph_board/demo/data/graph.pbtxt b/tensorflow/tensorboard/components/tf_graph_board/demo/data/graph.pbtxt deleted file mode 100644 index 30b20645346..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_board/demo/data/graph.pbtxt +++ /dev/null @@ -1,4606 +0,0 @@ -node { - name: "GradientDescent/learning_rate" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_3" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.1 - } - } - } -} -node { - name: "gradients/add_grad/Shape_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 100 - } - } - } -} -node { - name: "gradients/add_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000d\000\000\000" - } - } - } -} -node { - name: "gradients/add_grad/BroadcastGradientArgs" - op: "BroadcastGradientArgs" - input: "gradients/add_grad/Shape" - input: "gradients/add_grad/Shape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - } - } - } -} -node { - name: "gradients/add_1_grad/Shape_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 10 - } - } - } -} -node { - name: "gradients/add_1_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "gradients/add_1_grad/BroadcastGradientArgs" - op: "BroadcastGradientArgs" - input: "gradients/add_1_grad/Shape" - input: "gradients/add_1_grad/Shape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - } - } - } -} -node { - name: "gradients/Reshape_1_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: -1 - } - } - } -} -node { - name: "gradients/Reshape_3_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 200 - } - } - } -} -node { - name: "gradients/Mean_grad/Maximum/y" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "gradients/Mean_grad/Const_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "gradients/Mean_grad/Const" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "gradients/Mean_grad/Shape_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } -} -node { - name: "gradients/Mean_grad/Prod_1" - op: "Prod" - input: "gradients/Mean_grad/Shape_1" - input: "gradients/Mean_grad/Const_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/Mean_grad/Maximum" - op: "Maximum" - input: "gradients/Mean_grad/Prod_1" - input: "gradients/Mean_grad/Maximum/y" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "gradients/Mean_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 200 - } - } - } -} -node { - name: "gradients/Mean_grad/Prod" - op: "Prod" - input: "gradients/Mean_grad/Shape" - input: "gradients/Mean_grad/Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/Mean_grad/floordiv" - op: "FloorDiv" - input: "gradients/Mean_grad/Prod" - input: "gradients/Mean_grad/Maximum" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "gradients/Mean_grad/Cast" - op: "Cast" - input: "gradients/Mean_grad/floordiv" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "gradients/Mean_grad/Tile/multiples" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 200 - } - } - } -} -node { - name: "gradients/Mean_grad/Reshape/shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "gradients/Const" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1 - } - } - } -} -node { - name: "gradients/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } -} -node { - name: "gradients/Fill" - op: "Fill" - input: "gradients/Shape" - input: "gradients/Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "gradients/Mean_grad/Reshape" - op: "Reshape" - input: "gradients/Fill" - input: "gradients/Mean_grad/Reshape/shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "gradients/Mean_grad/Tile" - op: "Tile" - input: "gradients/Mean_grad/Reshape" - input: "gradients/Mean_grad/Tile/multiples" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tmultiples" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - } - } - } -} -node { - name: "gradients/Mean_grad/truediv" - op: "RealDiv" - input: "gradients/Mean_grad/Tile" - input: "gradients/Mean_grad/Cast" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - } - } - } -} -node { - name: "gradients/Reshape_3_grad/Reshape" - op: "Reshape" - input: "gradients/Mean_grad/truediv" - input: "gradients/Reshape_3_grad/Shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - } - } - } -} -node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" - op: "ExpandDims" - input: "gradients/Reshape_3_grad/Reshape" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "Const" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "Slice_2/begin" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "Sub_2/y" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "concat_1/axis" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "concat_1/values_0" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: -1 - } - } - } -} -node { - name: "Slice_1/size" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "Sub_1/y" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "Shape_2" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "Rank_2" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "Sub_1" - op: "Sub" - input: "Rank_2" - input: "Sub_1/y" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Slice_1/begin" - op: "Pack" - input: "Sub_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "Slice_1" - op: "Slice" - input: "Shape_2" - input: "Slice_1/begin" - input: "Slice_1/size" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "concat_1" - op: "ConcatV2" - input: "concat_1/values_0" - input: "Slice_1" - input: "concat_1/axis" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } -} -node { - name: "concat/axis" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "concat/values_0" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: -1 - } - } - } -} -node { - name: "Slice/size" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "Sub/y" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "Shape_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "Rank_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "Sub" - op: "Sub" - input: "Rank_1" - input: "Sub/y" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Slice/begin" - op: "Pack" - input: "Sub" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "Slice" - op: "Slice" - input: "Shape_1" - input: "Slice/begin" - input: "Slice/size" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "concat" - op: "ConcatV2" - input: "concat/values_0" - input: "Slice" - input: "concat/axis" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } -} -node { - name: "Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "Rank" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "Sub_2" - op: "Sub" - input: "Rank" - input: "Sub_2/y" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Slice_2/size" - op: "Pack" - input: "Sub_2" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "Slice_2" - op: "Slice" - input: "Shape" - input: "Slice_2/begin" - input: "Slice_2/size" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "logits_biases" - op: "VariableV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_class" - value { - list { - s: "loc:@logits_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 10 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "logits_biases/read" - op: "Identity" - input: "logits_biases" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@logits_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } -} -node { - name: "logits_weights" - op: "VariableV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_class" - value { - list { - s: "loc:@logits_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "logits_weights/read" - op: "Identity" - input: "logits_weights" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@logits_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "hidden_biases" - op: "VariableV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_class" - value { - list { - s: "loc:@hidden_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 100 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "hidden_biases/read" - op: "Identity" - input: "hidden_biases" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@hidden_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } -} -node { - name: "hidden_weights" - op: "VariableV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_class" - value { - list { - s: "loc:@hidden_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 784 - } - dim { - size: 100 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "hidden_weights/read" - op: "Identity" - input: "hidden_weights" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@hidden_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "Reshape/shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\377\377\377\377" - } - } - } -} -node { - name: "mnist_dataset_train_2/one_hot/depth" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 10 - } - } - } -} -node { - name: "mnist_dataset_train_2/one_hot/off_value" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0 - } - } - } -} -node { - name: "mnist_dataset_train_2/one_hot/on_value" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1 - } - } - } -} -node { - name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 200 - } - } - } -} -node { - name: "mnist_dataset_train_1/random_shuffle_queue" - op: "RandomShuffleQueueV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "capacity" - value { - i: 20000 - } - } - attr { - key: "component_types" - value { - list { - type: DT_FLOAT - type: DT_INT64 - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "min_after_dequeue" - value { - i: 4000 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } - attr { - key: "shapes" - value { - list { - shape { - dim { - size: 28 - } - dim { - size: 28 - } - dim { - size: 1 - } - } - shape { - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" - op: "QueueDequeueManyV2" - input: "mnist_dataset_train_1/random_shuffle_queue" - input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - } - } - } - attr { - key: "component_types" - value { - list { - type: DT_FLOAT - type: DT_INT64 - } - } - } - attr { - key: "timeout_ms" - value { - i: -1 - } - } -} -node { - name: "Reshape" - op: "Reshape" - input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" - input: "Reshape/shape" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: -1 - } - } - } - } - } -} -node { - name: "MatMul" - op: "MatMul" - input: "Reshape" - input: "hidden_weights/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: false - } - } -} -node { - name: "add" - op: "Add" - input: "MatMul" - input: "hidden_biases/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "Relu" - op: "Relu" - input: "add" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "MatMul_1" - op: "MatMul" - input: "Relu" - input: "logits_weights/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: false - } - } -} -node { - name: "add_1" - op: "Add" - input: "MatMul_1" - input: "logits_biases/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "Reshape_1" - op: "Reshape" - input: "add_1" - input: "concat" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "mnist_dataset_train_2/one_hot" - op: "OneHot" - input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany:1" - input: "mnist_dataset_train_2/one_hot/depth" - input: "mnist_dataset_train_2/one_hot/on_value" - input: "mnist_dataset_train_2/one_hot/off_value" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "TI" - value { - type: DT_INT64 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "axis" - value { - i: -1 - } - } -} -node { - name: "Reshape_2" - op: "Reshape" - input: "mnist_dataset_train_2/one_hot" - input: "concat_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "SoftmaxCrossEntropyWithLogits" - op: "SoftmaxCrossEntropyWithLogits" - input: "Reshape_1" - input: "Reshape_2" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" - op: "PreventGradient" - input: "SoftmaxCrossEntropyWithLogits:1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "message" - value { - s: "Currently there is no way to take the second derivative of softmax_cross_entropy_with_logits due to the fused implementation\'s interaction with tf.gradients()" - } - } -} -node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" - op: "Mul" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/Reshape_1_grad/Reshape" - op: "Reshape" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" - input: "gradients/Reshape_1_grad/Shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/add_1_grad/Sum_1" - op: "Sum" - input: "gradients/Reshape_1_grad/Reshape" - input: "gradients/add_1_grad/BroadcastGradientArgs:1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/add_1_grad/Reshape_1" - op: "Reshape" - input: "gradients/add_1_grad/Sum_1" - input: "gradients/add_1_grad/Shape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/add_1_grad/Sum" - op: "Sum" - input: "gradients/Reshape_1_grad/Reshape" - input: "gradients/add_1_grad/BroadcastGradientArgs" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/add_1_grad/Reshape" - op: "Reshape" - input: "gradients/add_1_grad/Sum" - input: "gradients/add_1_grad/Shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/add_1_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/add_1_grad/Reshape" - input: "^gradients/add_1_grad/Reshape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "gradients/add_1_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/add_1_grad/Reshape_1" - input: "^gradients/add_1_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_1_grad/Reshape_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } -} -node { - name: "GradientDescent/update_logits_biases/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "logits_biases" - input: "GradientDescent/learning_rate" - input: "gradients/add_1_grad/tuple/control_dependency_1" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@logits_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } -} -node { - name: "gradients/add_1_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/add_1_grad/Reshape" - input: "^gradients/add_1_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_1_grad/Reshape" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/MatMul_1_grad/MatMul_1" - op: "MatMul" - input: "Relu" - input: "gradients/add_1_grad/tuple/control_dependency" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: true - } - } - attr { - key: "transpose_b" - value { - b: false - } - } -} -node { - name: "gradients/MatMul_1_grad/MatMul" - op: "MatMul" - input: "gradients/add_1_grad/tuple/control_dependency" - input: "logits_weights/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "gradients/MatMul_1_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/MatMul_1_grad/MatMul" - input: "^gradients/MatMul_1_grad/MatMul_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "gradients/MatMul_1_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/MatMul_1_grad/MatMul_1" - input: "^gradients/MatMul_1_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/MatMul_1_grad/MatMul_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "GradientDescent/update_logits_weights/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "logits_weights" - input: "GradientDescent/learning_rate" - input: "gradients/MatMul_1_grad/tuple/control_dependency_1" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@logits_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } -} -node { - name: "gradients/MatMul_1_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/MatMul_1_grad/MatMul" - input: "^gradients/MatMul_1_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/MatMul_1_grad/MatMul" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/Relu_grad/ReluGrad" - op: "ReluGrad" - input: "gradients/MatMul_1_grad/tuple/control_dependency" - input: "Relu" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/add_grad/Sum_1" - op: "Sum" - input: "gradients/Relu_grad/ReluGrad" - input: "gradients/add_grad/BroadcastGradientArgs:1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/add_grad/Reshape_1" - op: "Reshape" - input: "gradients/add_grad/Sum_1" - input: "gradients/add_grad/Shape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/add_grad/Sum" - op: "Sum" - input: "gradients/Relu_grad/ReluGrad" - input: "gradients/add_grad/BroadcastGradientArgs" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/add_grad/Reshape" - op: "Reshape" - input: "gradients/add_grad/Sum" - input: "gradients/add_grad/Shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/add_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/add_grad/Reshape" - input: "^gradients/add_grad/Reshape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "gradients/add_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/add_grad/Reshape_1" - input: "^gradients/add_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_grad/Reshape_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } -} -node { - name: "GradientDescent/update_hidden_biases/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "hidden_biases" - input: "GradientDescent/learning_rate" - input: "gradients/add_grad/tuple/control_dependency_1" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@hidden_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } -} -node { - name: "gradients/add_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/add_grad/Reshape" - input: "^gradients/add_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_grad/Reshape" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/MatMul_grad/MatMul_1" - op: "MatMul" - input: "Reshape" - input: "gradients/add_grad/tuple/control_dependency" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: true - } - } - attr { - key: "transpose_b" - value { - b: false - } - } -} -node { - name: "gradients/MatMul_grad/MatMul" - op: "MatMul" - input: "gradients/add_grad/tuple/control_dependency" - input: "hidden_weights/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 784 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "gradients/MatMul_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/MatMul_grad/MatMul" - input: "^gradients/MatMul_grad/MatMul_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "gradients/MatMul_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/MatMul_grad/MatMul_1" - input: "^gradients/MatMul_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/MatMul_grad/MatMul_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "GradientDescent/update_hidden_weights/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "hidden_weights" - input: "GradientDescent/learning_rate" - input: "gradients/MatMul_grad/tuple/control_dependency_1" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@hidden_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } -} -node { - name: "GradientDescent" - op: "NoOp" - input: "^GradientDescent/update_hidden_weights/ApplyGradientDescent" - input: "^GradientDescent/update_hidden_biases/ApplyGradientDescent" - input: "^GradientDescent/update_logits_weights/ApplyGradientDescent" - input: "^GradientDescent/update_logits_biases/ApplyGradientDescent" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_2" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "Reshape_3" - op: "Reshape" - input: "SoftmaxCrossEntropyWithLogits" - input: "Slice_2" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - } - } - } -} -node { - name: "Mean" - op: "Mean" - input: "Reshape_3" - input: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "_send_Mean_0" - op: "_Send" - input: "Mean" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "client_terminated" - value { - b: true - } - } - attr { - key: "recv_device" - value { - s: "/job:localhost/replica:0/task:0/cpu:0" - } - } - attr { - key: "send_device" - value { - s: "/job:localhost/replica:0/task:0/cpu:0" - } - } - attr { - key: "send_device_incarnation" - value { - i: -5924635994370253548 - } - } - attr { - key: "tensor_name" - value { - s: "Mean:0" - } - } -} -library { -} -versions { - producer: 21 -} diff --git a/tensorflow/tensorboard/components/tf_graph_board/demo/index.html b/tensorflow/tensorboard/components/tf_graph_board/demo/index.html deleted file mode 100644 index 2563e1595e9..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_board/demo/index.html +++ /dev/null @@ -1,98 +0,0 @@ - - - - - - - -TF Graph Board Demo - - - - diff --git a/tensorflow/tensorboard/components/tf_graph_board/tf-graph-board.html b/tensorflow/tensorboard/components/tf_graph_board/tf-graph-board.html deleted file mode 100644 index 742bb63e045..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_board/tf-graph-board.html +++ /dev/null @@ -1,264 +0,0 @@ - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_graph_common/BUILD b/tensorflow/tensorboard/components/tf_graph_common/BUILD deleted file mode 100644 index e4e57149f3c..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/BUILD +++ /dev/null @@ -1,54 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_graph_common", - srcs = [ - "annotation.ts", - "colors.ts", - "common.ts", - "contextmenu.ts", - "edge.ts", - "externs.ts", - "graph.ts", - "hierarchy.ts", - "layout.ts", - "minimap.ts", - "node.ts", - "parser.ts", - "proto.ts", - "render.ts", - "scene.ts", - "template.ts", - "tf-graph-common.html", - "util.ts", - ], - path = "/tf-graph-common", - deps = [ - "//tensorflow/tensorboard/components/tf_imports:d3", - "//tensorflow/tensorboard/components/tf_imports:dagre", - "//tensorflow/tensorboard/components/tf_imports:graphlib", - "//tensorflow/tensorboard/components/tf_imports:lodash", - "//tensorflow/tensorboard/components/tf_imports:polymer", - ], -) - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [":tf_graph_common"], - destdir = "tf-graph-common", - deps = [ - "//tensorflow/tensorboard/components/tf_imports_google:lib", - "//third_party/javascript/polymer/v1/polymer:lib", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_graph_common/annotation.ts b/tensorflow/tensorboard/components/tf_graph_common/annotation.ts deleted file mode 100644 index bde38297785..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/annotation.ts +++ /dev/null @@ -1,235 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -module tf.graph.scene.annotation { - /** - * Populate a given annotation container group - * - * - * - * with annotation group of the following structure: - * - * - * - * - * - * - * - * @param container selection of the container. - * @param annotationData node.{in|out}Annotations - * @param d node to build group for. - * @param sceneElement polymer element. - * @return selection of appended objects - */ - export function buildGroup( - container, annotationData: render.AnnotationList, - d: render.RenderNodeInfo, sceneElement) { - // Select all children and join with data. - let annotationGroups = - container - .selectAll(function() { - // using d3's selector function - // See https://github.com/mbostock/d3/releases/tag/v2.0.0 - // (It's not listed in the d3 wiki.) - return this.childNodes; - }) - .data(annotationData.list, d => { return d.node.name; }); - - annotationGroups.enter() - .append('g') - .attr('data-name', a => { return a.node.name; }) - .each(function(a) { - let aGroup = d3.select(this); - - // Add annotation to the index in the scene - sceneElement.addAnnotationGroup(a, d, aGroup); - // Append annotation edge - let edgeType = Class.Annotation.EDGE; - let metaedge = a.renderMetaedgeInfo && a.renderMetaedgeInfo.metaedge; - if (metaedge && !metaedge.numRegularEdges) { - edgeType += ' ' + Class.Annotation.CONTROL_EDGE; - } - // If any edges are reference edges, add the reference edge class. - if (metaedge && metaedge.numRefEdges) { - edgeType += ' ' + Class.Edge.REF_LINE; - } - edge.appendEdge(aGroup, a, sceneElement, edgeType); - - if (a.annotationType !== render.AnnotationType.ELLIPSIS) { - addAnnotationLabelFromNode(aGroup, a); - buildShape(aGroup, a); - } else { - addAnnotationLabel( - aGroup, a.node.name, a, Class.Annotation.ELLIPSIS); - } - }).merge(annotationGroups) - .attr( - 'class', - a => { - return Class.Annotation.GROUP + ' ' + - annotationToClassName(a.annotationType) + ' ' + - node.nodeClass(a); - }) - .each(function(a) { - let aGroup = d3.select(this); - update(aGroup, d, a, sceneElement); - if (a.annotationType !== render.AnnotationType.ELLIPSIS) { - addInteraction(aGroup, d, a, sceneElement); - } - }); - - annotationGroups.exit() - .each(function(a) { - let aGroup = d3.select(this); - - // Remove annotation from the index in the scene - sceneElement.removeAnnotationGroup(a, d, aGroup); - }) - .remove(); - return annotationGroups; -}; - -/** - * Maps an annotation enum to a class name used in css rules. - */ -function annotationToClassName(annotationType: render.AnnotationType) { - return (render.AnnotationType[annotationType] || '').toLowerCase() || null; -} - -function buildShape(aGroup, a: render.Annotation) { - if (a.annotationType === render.AnnotationType.SUMMARY) { - let summary = selectOrCreateChild(aGroup, 'use'); - summary - .attr('class', 'summary') - .attr('xlink:href', '#summary-icon') - .attr('cursor', 'pointer'); - } else { - let shape = node.buildShape(aGroup, a, Class.Annotation.NODE); - // add title tag to get native tooltips - selectOrCreateChild(shape, 'title').text(a.node.name); - } -} - -function addAnnotationLabelFromNode(aGroup, a: render.Annotation) { - let namePath = a.node.name.split('/'); - let text = namePath[namePath.length - 1]; - return addAnnotationLabel(aGroup, text, a, null); -} - -function addAnnotationLabel( - aGroup, label: string, a: render.Annotation, additionalClassNames) { - let classNames = Class.Annotation.LABEL; - if (additionalClassNames) { - classNames += ' ' + additionalClassNames; - } - let txtElement = aGroup.append('text') - .attr('class', classNames) - .attr('dy', '.35em') - .attr('text-anchor', a.isIn ? 'end' : 'start') - .text(label); - - return tf.graph.scene.node.enforceLabelWidth(txtElement, -1); -} - -function addInteraction(selection, d: render.RenderNodeInfo, - annotation: render.Annotation, sceneElement) { - selection - .on('mouseover', - a => { - sceneElement.fire( - 'annotation-highlight', - {name: a.node.name, hostName: d.node.name}); - }) - .on('mouseout', - a => { - sceneElement.fire( - 'annotation-unhighlight', - {name: a.node.name, hostName: d.node.name}); - }) - .on('click', a => { - // Stop this event's propagation so that it isn't also considered a - // graph-select. - (d3.event).stopPropagation(); - sceneElement.fire( - 'annotation-select', {name: a.node.name, hostName: d.node.name}); - }); - if (annotation.annotationType !== render.AnnotationType.SUMMARY && - annotation.annotationType !== render.AnnotationType.CONSTANT) { - selection.on( - 'contextmenu', contextmenu.getMenu( - node.getContextMenu(annotation.node, sceneElement))); - } -}; - -/** - * Adjust annotation's position. - * - * @param aGroup selection of a 'g.annotation' element. - * @param d Host node data. - * @param a annotation node data. - * @param sceneElement polymer element. - */ -function update(aGroup, d: render.RenderNodeInfo, a: render.Annotation, - sceneElement) { - let cx = layout.computeCXPositionOfNodeShape(d); - // Annotations that point to embedded nodes (constants,summary) - // don't have a render information attached so we don't stylize these. - // Also we don't stylize ellipsis annotations (the string '... and X more'). - if (a.renderNodeInfo && - a.annotationType !== render.AnnotationType.ELLIPSIS) { - node.stylize(aGroup, a.renderNodeInfo, sceneElement, - Class.Annotation.NODE); - } - - if (a.annotationType === render.AnnotationType.SUMMARY) { - // Update the width of the annotation to give space for the image. - a.width += 10; - } - - // label position - aGroup.select('text.' + Class.Annotation.LABEL).transition() - .attr('x', cx + a.dx + (a.isIn ? -1 : 1) * (a.width / 2 + a.labelOffset)) - .attr('y', d.y + a.dy); - - // Some annotations (such as summary) are represented using a 12x12 image tag. - // Purposely omitted units (e.g. pixels) since the images are vector graphics. - // If there is an image, we adjust the location of the image to be vertically - // centered with the node and horizontally centered between the arrow and the - // text label. - aGroup.select('use.summary').transition() - .attr('x', cx + a.dx - 3) - .attr('y', d.y + a.dy - 6); - - // Node position (only one of the shape selection will be non-empty.) - positionEllipse( - aGroup.select('.' + Class.Annotation.NODE + ' ellipse'), cx + a.dx, - d.y + a.dy, a.width, a.height); - positionRect( - aGroup.select('.' + Class.Annotation.NODE + ' rect'), cx + a.dx, - d.y + a.dy, a.width, a.height); - positionRect( - aGroup.select('.' + Class.Annotation.NODE + ' use'), cx + a.dx, - d.y + a.dy, a.width, a.height); - - // Edge position - aGroup.select('path.' + Class.Annotation.EDGE).transition().attr('d', a => { - // map relative position to absolute position - let points = a.points.map(p => { return {x: p.dx + cx, y: p.dy + d.y}; }); - return edge.interpolate(points); - }); -}; - -} // close module diff --git a/tensorflow/tensorboard/components/tf_graph_common/colors.ts b/tensorflow/tensorboard/components/tf_graph_common/colors.ts deleted file mode 100644 index 40f91f7d2db..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/colors.ts +++ /dev/null @@ -1,130 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -module tf { - /** - * Mapping from color palette name to color palette, which contains - * exact colors for multiple states of a single color palette. - */ - export let COLORS = [ - { - 'name': 'Google Blue', - 'color': '#4184f3', - 'active': '#3a53c5', - 'disabled': '#cad8fc' - }, - { - 'name': 'Google Red', - 'color': '#db4437', - 'active': '#8f2a0c', - 'disabled': '#e8c6c1' - }, - { - 'name': 'Google Yellow', - 'color': '#f4b400', - 'active': '#db9200', - 'disabled': '#f7e8b0' - }, - { - 'name': 'Google Green', - 'color': '#0f9d58', - 'active': '#488046', - 'disabled': '#c2e1cc' - }, - { - 'name': 'Purple', - 'color': '#aa46bb', - 'active': '#5c1398', - 'disabled': '#d7bce6' - }, - { - 'name': 'Teal', - 'color': '#00abc0', - 'active': '#47828e', - 'disabled': '#c2eaf2' - }, - { - 'name': 'Deep Orange', - 'color': '#ff6f42', - 'active': '#ca4a06', - 'disabled': '#f2cbba' - }, - { - 'name': 'Lime', - 'color': '#9d9c23', - 'active': '#7f771d', - 'disabled': '#f1f4c2' - }, - { - 'name': 'Indigo', - 'color': '#5b6abf', - 'active': '#3e47a9', - 'disabled': '#c5c8e8' - }, - { - 'name': 'Pink', - 'color': '#ef6191', - 'active': '#ca1c60', - 'disabled': '#e9b9ce' - }, - { - 'name': 'Deep Teal', - 'color': '#00786a', - 'active': '#2b4f43', - 'disabled': '#bededa' - }, - { - 'name': 'Deep Pink', - 'color': '#c1175a', - 'active': '#75084f', - 'disabled': '#de8cae' - }, - { - 'name': 'Gray', - 'color': '#9E9E9E', // 500 - 'active': '#424242', // 800 - 'disabled': 'F5F5F5' // 100 - } - ].reduce((m, c) => { - m[c.name] = c; - return m; - }, {}); - - /** - * Mapping from op category to color palette name - * e.g., OP_GROUP_COLORS['state_ops'] = 'Google Blue'; - */ - export let OP_GROUP_COLORS = [ - { - color: 'Google Red', - groups: [ - 'gen_legacy_ops', 'legacy_ops', 'legacy_flogs_input', - 'legacy_image_input', 'legacy_input_example_input', - 'legacy_sequence_input', 'legacy_seti_input_input' - ] - }, - {color: 'Deep Orange', groups: ['constant_ops']}, - {color: 'Indigo', groups: ['state_ops']}, - {color: 'Purple', groups: ['nn_ops', 'nn']}, - {color: 'Google Green', groups: ['math_ops']}, - {color: 'Lime', groups: ['array_ops']}, - {color: 'Teal', groups: ['control_flow_ops', 'data_flow_ops']}, - {color: 'Pink', groups: ['summary_ops']}, - {color: 'Deep Pink', groups: ['io_ops']} - ].reduce((m, c) => { - c.groups.forEach(function(group) { m[group] = c.color; }); - return m; - }, {}); -} diff --git a/tensorflow/tensorboard/components/tf_graph_common/common.ts b/tensorflow/tensorboard/components/tf_graph_common/common.ts deleted file mode 100644 index e7eac54e58f..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/common.ts +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/** - * @fileoverview Common interfaces for the tensorflow graph visualizer. - */ - -module tf { - /** - * Tracks task progress. Each task being passed a progress tracker needs - * to call the below-defined methods to notify the caller about the gradual - * progress of the task. - */ - export interface ProgressTracker { - updateProgress(incrementValue: number): void; - setMessage(msg: string): void; - reportError(msg: string, err: Error): void; - } -} // close module tf diff --git a/tensorflow/tensorboard/components/tf_graph_common/contextmenu.ts b/tensorflow/tensorboard/components/tf_graph_common/contextmenu.ts deleted file mode 100644 index 8121cf9f6da..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/contextmenu.ts +++ /dev/null @@ -1,75 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -module tf.graph.scene.contextmenu { - -/** Function that converts data to a title string. */ -export interface TitleFunction { - (data: any): string; -} - -/** Function that takes action based on item clicked in the context menu. */ -export interface ActionFunction { - (elem: any, d: any, i: number): void; -} - -/** - * The interface for an item in the context menu - */ -export interface ContextMenuItem { - title: TitleFunction; - action: ActionFunction; -} - -/** - * Returns the event listener, which can be used as an argument for the d3 - * selection.on function. Renders the context menu that is to be displayed - * in response to the event. - */ -export function getMenu(menu: ContextMenuItem[]) { - let menuSelection = d3.select('.context-menu'); - // Close the menu when anything else is clicked. - d3.select('body').on( - 'click.context', function() { menuSelection.style('display', 'none'); }); - - // Function called to populate the context menu. - return function(data, index: number): void { - // Position and display the menu. - let event = d3.event; - menuSelection - .style('display', 'block') - .style('left', (event.layerX + 1) + 'px') - .style('top', (event.layerY + 1) + 'px'); - - // Stop the event from propagating further. - event.preventDefault(); - event.stopPropagation(); - - // Add provided items to the context menu. - menuSelection.html(''); - let list = menuSelection.append('ul'); - list.selectAll('li') - .data(menu) - .enter() - .append('li') - .html(function(d) { return d.title(data); }) - .on('click', (d, i) => { - d.action(this, data, index); - menuSelection.style('display', 'none'); - }); - }; -}; - -} // close module diff --git a/tensorflow/tensorboard/components/tf_graph_common/edge.ts b/tensorflow/tensorboard/components/tf_graph_common/edge.ts deleted file mode 100644 index 4a1182bb9fb..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/edge.ts +++ /dev/null @@ -1,359 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -module tf.graph.scene.edge { - -/** Delimiter between dimensions when showing sizes of tensors. */ -const TENSOR_SHAPE_DELIM = '×'; - -/** The minimum stroke width of an edge. */ -export const MIN_EDGE_WIDTH = 0.75; - -/** The maximum stroke width of an edge. */ -export const MAX_EDGE_WIDTH = 12; - -/** The exponent used in the power scale for edge thickness. */ -const EDGE_WIDTH_SCALE_EXPONENT = 0.3; - -/** The domain (min and max value) for the edge width. */ -const DOMAIN_EDGE_WIDTH_SCALE = [1, 5E6]; - -export const EDGE_WIDTH_SCALE: d3.ScalePower = d3.scalePow() - .exponent(EDGE_WIDTH_SCALE_EXPONENT) - .domain(DOMAIN_EDGE_WIDTH_SCALE) - .range([MIN_EDGE_WIDTH, MAX_EDGE_WIDTH]) - .clamp(true); - -let arrowheadMap = - d3.scaleQuantize().domain([MIN_EDGE_WIDTH, MAX_EDGE_WIDTH]).range([ - 'small', 'medium', 'large', 'xlarge' - ]); - -/** Minimum stroke width to put edge labels in the middle of edges */ -const CENTER_EDGE_LABEL_MIN_STROKE_WIDTH = 2.5; - -export type EdgeData = {v: string, w: string, label: render.RenderMetaedgeInfo}; - -export function getEdgeKey(edgeObj: EdgeData) { - return edgeObj.v + EDGE_KEY_DELIM + edgeObj.w; -} - -/** - * Select or Create a 'g.edges' group to a given sceneGroup - * and builds a number of 'g.edge' groups inside the group. - * - * Structure Pattern: - * - * - * - * - * - * ... - * - * - * - * @param sceneGroup container - * @param graph - * @param sceneElement polymer element. - * @return selection of the created nodeGroups - */ -export function buildGroup(sceneGroup, - graph: graphlib.Graph, - sceneElement) { - let edges: EdgeData[] = []; - edges = _.reduce(graph.edges(), (edges, edgeObj) => { - let edgeLabel = graph.edge(edgeObj); - edges.push({ - v: edgeObj.v, - w: edgeObj.w, - label: edgeLabel - }); - return edges; - }, edges); - - let container = - scene.selectOrCreateChild(sceneGroup, 'g', Class.Edge.CONTAINER); - - // Select all children and join with data. - // (Note that all children of g.edges are g.edge) - let edgeGroups = (container as any).selectAll(function() {return this.childNodes;}).data(edges, getEdgeKey); - - // Make edges a group to support rendering multiple lines for metaedge - edgeGroups.enter() - .append('g') - .attr('class', Class.Edge.GROUP) - .attr('data-edge', getEdgeKey) - .each(function(d: EdgeData) { - let edgeGroup = d3.select(this); - d.label.edgeGroup = edgeGroup; - // index node group for quick highlighting - sceneElement._edgeGroupIndex[getEdgeKey(d)] = edgeGroup; - - // Add line during enter because we're assuming that type of line - // normally does not change. - appendEdge(edgeGroup, d, sceneElement); - }) - .merge(edgeGroups) - .each(position) - .each(function(d) { - stylize(d3.select(this), d, sceneElement); - }); - - edgeGroups.exit() - .each(d => { - delete sceneElement._edgeGroupIndex[getEdgeKey(d)]; - }) - .remove(); - return edgeGroups; -}; - -/** - * Returns the label for the given base edge. - * The label is the shape of the underlying tensor. - */ -export function getLabelForBaseEdge( - baseEdge: BaseEdge, renderInfo: render.RenderGraphInfo): string { - let node = renderInfo.getNodeByName(baseEdge.v); - if (node.outputShapes == null || node.outputShapes.length === 0) { - return null; - } - let shape = node.outputShapes[baseEdge.outputTensorIndex]; - if (shape == null) { - return null; - } - if (shape.length === 0) { - return 'scalar'; - } - return shape.map(size => { return size === -1 ? '?' : size; }) - .join(TENSOR_SHAPE_DELIM); -} - -/** - * Creates the label for the given metaedge. If the metaedge consists - * of only 1 tensor, and it's shape is known, the label will contain that - * shape. Otherwise, the label will say the number of tensors in the metaedge. - */ -export function getLabelForEdge(metaedge: Metaedge, - renderInfo: render.RenderGraphInfo): string { - let isMultiEdge = metaedge.baseEdgeList.length > 1; - return isMultiEdge ? - metaedge.baseEdgeList.length + ' tensors' : - getLabelForBaseEdge(metaedge.baseEdgeList[0], renderInfo); -} - -/** - * Shortens the path enought such that the tip of the start/end marker will - * point to the start/end of the path. The marker can be of arbitrary size. - * - * @param points Array of path control points. - * @param marker D3 selection of the svg element. - * @param isStart Is the marker a `start-marker`. If false, the marker is - * an `end-marker`. - * @return The new array of control points. - */ -function adjustPathPointsForMarker(points: render.Point[], - marker: d3.Selection, isStart: boolean): render.Point[] { - let lineFunc = d3.line() - .x(d => d.x) - .y(d => d.y); - let path = - d3.select(document.createElementNS('http://www.w3.org/2000/svg', 'path')) - .attr('d', lineFunc(points)); - let markerWidth = +marker.attr('markerWidth'); - let viewBox = marker.attr('viewBox').split(' ').map(Number); - let viewBoxWidth = viewBox[2] - viewBox[0]; - let refX = +marker.attr('refX'); - let pathNode = path.node(); - if (isStart) { - // The edge flows downwards. Do not make the edge go the whole way, lest we - // clobber the arrowhead. - const fractionStickingOut = 1 - refX / viewBoxWidth; - const length = markerWidth * fractionStickingOut; - const point = pathNode.getPointAtLength(length); - // Figure out how many segments of the path we need to remove in order - // to shorten the path. - const segIndex = pathNode.getPathSegAtLength(length); - // Update the very first segment. - points[segIndex - 1] = {x: point.x, y: point.y}; - // Ignore every point before segIndex - 1. - return points.slice(segIndex - 1); - } else { - // The edge flows upwards. Do not make the edge go the whole way, lest we - // clobber the arrowhead. - const fractionStickingOut = 1 - refX / viewBoxWidth; - const length = - pathNode.getTotalLength() - markerWidth * fractionStickingOut; - const point = pathNode.getPointAtLength(length); - // Figure out how many segments of the path we need to remove in order - // to shorten the path. - const segIndex = pathNode.getPathSegAtLength(length); - // Update the very last segment. - points[segIndex] = {x: point.x, y: point.y}; - // Ignore every point after segIndex. - return points.slice(0, segIndex + 1); - } -} - -/** - * For a given d3 selection and data object, create a path to represent the - * edge described in d.label. - * - * If d.label is defined, it will be a RenderMetaedgeInfo instance. It - * will sometimes be undefined, for example for some Annotation edges for which - * there is no underlying Metaedge in the hierarchical graph. - */ -export function appendEdge(edgeGroup, d: EdgeData, - sceneElement: {renderHierarchy: render.RenderGraphInfo}, - edgeClass?: string) { - let size = 1; - if (d.label != null && d.label.metaedge != null) { - // There is an underlying Metaedge. - size = d.label.metaedge.totalSize; - } - edgeClass = edgeClass || Class.Edge.LINE; // set default type - - if (d.label && d.label.structural) { - edgeClass += ' ' + Class.Edge.STRUCTURAL; - } - if (d.label && d.label.metaedge && d.label.metaedge.numRefEdges) { - edgeClass += ' ' + Class.Edge.REFERENCE_EDGE; - } - // Give the path a unique id, which will be used to link - // the textPath (edge label) to this path. - let pathId = 'path_' + getEdgeKey(d); - let strokeWidth = sceneElement.renderHierarchy.edgeWidthScale(size); - - let path = edgeGroup.append('path') - .attr('id', pathId) - .attr('class', edgeClass) - .style('stroke-width', strokeWidth + 'px'); - - // Check if there is a reference edge and add an arrowhead of the right size. - if (d.label && d.label.metaedge) { - if (d.label.metaedge.numRefEdges) { - // We have a reference edge. - const markerId = `reference-arrowhead-${arrowheadMap(strokeWidth)}`; - path.style('marker-start', `url(#${markerId})`); - d.label.startMarkerId = markerId; - } else { - // We have a dataflow edge. - const markerId = `dataflow-arrowhead-${arrowheadMap(strokeWidth)}`; - path.style('marker-end', `url(#${markerId})`); - d.label.endMarkerId = markerId; - } - } - - if (d.label == null || d.label.metaedge == null) { - // There is no associated metaedge, thus no text. - // This happens for annotation edges. - return; - } - let labelForEdge = getLabelForEdge(d.label.metaedge, - sceneElement.renderHierarchy); - if (labelForEdge == null) { - // We have no information to show on this edge. - return; - } - - // Put edge label in the middle of edge only if the edge is thick enough. - let baseline = strokeWidth > CENTER_EDGE_LABEL_MIN_STROKE_WIDTH ? - 'central' : - 'text-after-edge'; - - edgeGroup.append('text') - .append('textPath') - .attr('xlink:href', '#' + pathId) - .attr('startOffset', '50%') - .attr('text-anchor', 'middle') - .attr('dominant-baseline', 'central') - .text(labelForEdge); -}; - -export let interpolate: d3.Line<{x: number, y: number}> = d3.line<{x: number, y: number}>() - .curve(d3.curveBasis) - .x((d) => { return d.x;}) - .y((d) => { return d.y;}); - -/** - * Returns a tween interpolator for the endpoint of an edge path. - */ -function getEdgePathInterpolator(d: EdgeData, i: number, a: string) { - let renderMetaedgeInfo = d.label; - let adjoiningMetaedge = renderMetaedgeInfo.adjoiningMetaedge; - let points = renderMetaedgeInfo.points; - - // Adjust the path so that start/end markers point to the end - // of the path. - if (d.label.startMarkerId) { - points = adjustPathPointsForMarker( - points, d3.select('#' + d.label.startMarkerId), true); - } - if (d.label.endMarkerId) { - points = adjustPathPointsForMarker( - points, d3.select('#' + d.label.endMarkerId), false); - } - - if (!adjoiningMetaedge) { - return d3.interpolate(a, interpolate(points)); - } - - let renderPath = this; - - // Get the adjoining path that matches the adjoining metaedge. - let adjoiningPath = - ((adjoiningMetaedge.edgeGroup.node()) - .firstChild); - - // Find the desired SVGPoint along the adjoining path, then convert those - // coordinates into the space of the renderPath using its Current - // Transformation Matrix (CTM). - let inbound = renderMetaedgeInfo.metaedge.inbound; - - return function(t) { - let adjoiningPoint = adjoiningPath - .getPointAtLength(inbound ? adjoiningPath.getTotalLength() : 0) - .matrixTransform(adjoiningPath.getCTM()) - .matrixTransform(renderPath.getCTM().inverse()); - - // Update the relevant point in the renderMetaedgeInfo's points list, then - // re-interpolate the path. - let index = inbound ? 0 : points.length - 1; - points[index].x = adjoiningPoint.x; - points[index].y = adjoiningPoint.y; - let dPath = interpolate(points); - return dPath; - }; -} - -function position(d) { - d3.select(this) - .select('path.' + Class.Edge.LINE) - .transition() - .attrTween('d', getEdgePathInterpolator as any); -}; - -/** - * For a given d3 selection and data object, mark the edge as a control - * dependency if it contains only control edges. - * - * d's label property will be a RenderMetaedgeInfo object. - */ -function stylize(edgeGroup, d: EdgeData, stylize) { - edgeGroup.classed('faded', d.label.isFadedOut); - let metaedge = d.label.metaedge; - edgeGroup.select('path.' + Class.Edge.LINE) - .classed('control-dep', metaedge && !metaedge.numRegularEdges); -}; - -} // close module diff --git a/tensorflow/tensorboard/components/tf_graph_common/externs.ts b/tensorflow/tensorboard/components/tf_graph_common/externs.ts deleted file mode 100644 index 7c0d168a429..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/externs.ts +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/** - * @fileoverview Extern declarations for tensorflow graph visualizer. - * This file contains compiler stubs for external dependencies whos - * implementations are defined at runtime. - */ - -declare module graphlib { - interface GraphOptions { - name?: string; - /** - * Direction for rank nodes. Can be TB, BT, LR, or RL, where T = top, - * B = bottom, L = left, and R = right. - */ - rankdir?: string; - type?: string|number; - /** Number of pixels between each rank in the layout. */ - ranksep?: number; - /** Number of pixels that separate nodes horizontally in the layout. */ - nodesep?: number; - /** Number of pixels that separate edges horizontally in the layout */ - edgesep?: number; - } - - export interface EdgeObject { - v: string; - w: string; - name?: string; - } - - export class Graph { - constructor(opt?: Object); - setNode(name: string, value?: N): void; - hasNode(name: string): boolean; - setEdge(fromName: string, toName: string, value?: E): void; - hasEdge(fromName: string, toName: string): boolean; - edge(fromName: string, toName: string): E; - edge(edgeObject: EdgeObject): E; - removeEdge(v: string, w: string): void; - nodes(): string[]; - node(name: string): N; - removeNode(name: string): void; - setGraph(graphOptions: GraphOptions): void; - graph(): GraphOptions; - nodeCount(): number; - neighbors(name: string): string[]; - successors(name: string): string[]; - predecessors(name: string): string[]; - edges(): EdgeObject[]; - outEdges(name: string): E[]; - inEdges(name: string): E[]; - /** - * Returns those nodes in the graph that have no in-edges. - * Takes O(|V|) time. - */ - sources(): string[]; - /** - * Remove the node with the id v in the graph or do nothing if - * the node is not in the graph. If the node was removed this - * function also removes any incident edges. Returns the graph, - * allowing this to be chained with other functions. Takes O(|E|) time. - */ - removeNode(name: string): Graph; - setParent(name: string, parentName: string): void; - } -} - -/** - * Declaring dagre var used for dagre layout. - */ -declare var dagre: {layout(graph: graphlib.Graph): void;}; diff --git a/tensorflow/tensorboard/components/tf_graph_common/graph.ts b/tensorflow/tensorboard/components/tf_graph_common/graph.ts deleted file mode 100644 index cbd7b14539a..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/graph.ts +++ /dev/null @@ -1,1257 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -module tf.graph { - -/** Delimiter used in node names to denote namespaces. */ -export const NAMESPACE_DELIM = '/'; -export const ROOT_NAME = '__root__'; - -/** Attribute key used for storing attributes that are too large. */ -export const LARGE_ATTRS_KEY = '_too_large_attrs'; -/** - * Maximum allowed size in bytes, before the attribute is considered large - * and filtered out of the graph. - */ -export const LIMIT_ATTR_SIZE = 1024; - -// Separator between the source and the destination name of the edge. -export const EDGE_KEY_DELIM = '--'; - -export enum GraphType {FULL, EMBEDDED, META, SERIES, CORE, SHADOW, BRIDGE, - EDGE}; -export enum NodeType {META, OP, SERIES, BRIDGE, ELLIPSIS}; - -/** Indicates if a node is to be included in the main graph when rendered. */ -export enum InclusionType {INCLUDE, EXCLUDE, UNSPECIFIED}; - -/** Indicates if a series is to be grouped in the graph when rendered. */ -export enum SeriesGroupingType {GROUP, UNGROUP}; - -/** Attribute key reserved for the shapes of the output tensors. */ -const OUTPUT_SHAPES_KEY = '_output_shapes'; - -/** Attribute key reserved for the XLA cluster that an op runs on. */ -const _XLA_CLUSTER_KEY = '_XlaCluster'; - -/** - * A BaseEdge is the label object (in the graphlib sense) for an edge in the - * original, full graph produced after parsing. Subsequent graphs, like those - * which belong to Metanodes, should not use BaseEdge objects, but instead - * contain Metaedges (which in turn may contain any number of BaseEdges). - */ -export interface BaseEdge extends graphlib.EdgeObject { - isControlDependency: boolean; - isReferenceEdge: boolean; - /** The index of the output tensor of the source node. */ - outputTensorIndex: number; -} - -/** - * A SlimGraph is inspired by graphlib.Graph, but having only the functionality - * that we need. - */ -export class SlimGraph { - nodes: { [nodeName: string]: OpNode }; - edges: BaseEdge[]; - - constructor() { - this.nodes = {}; - this.edges = []; - } -} - -export interface NormalizedInput { - name: string; - /** The index of the output tensor of the source node. */ - outputTensorIndex: number; - isControlDependency: boolean; -} - -export interface BuildParams { - enableEmbedding: boolean; - inEmbeddingTypes: string[]; - outEmbeddingTypes: string[]; - refEdges: { [inputEdge: string]: boolean }; -} - -/** - * The most basic information about a node in the hierarchical graph. - */ -export interface Node { - /** The name of the node, used frequently to look up nodes by name. */ - name: string; - /** Which type of node this is. */ - type: NodeType; - /** - * Whether this node is a type that may contain other nodes. Those types - * should extend from GroupNode. - * - * For an OpNode, isGroupNode will be false, even though it may have - * embeddings. These embedding Nodes will have their parentNode set to the - * OpNode. However, embeddings are later rendered as annotations, not as - * children to be made visible on expansion (like a Metanode or SeriesNode). - */ - isGroupNode: boolean; - /** - * The number of nodes this node represents. For OpNodes, this will be 1, and - * for GroupNodes it will be a count of the total number of descendents it - * contains. - */ - cardinality: number; - /** - * The Node which is this Node's parent. This is of type Node and not - * GroupNode because of embeddings, which will have a parent OpNode. - */ - parentNode: Node; - /** Runtime execution stats for this node, if available */ - stats: NodeStats; - /** If the node is to be included or excluded from the main graph when - * rendered. Defaults to UNSPECIFIED, which means that the rendering - * algorithm determines if it will be included or not. Then can be set to - * INCLUDE or EXCLUDE manually by the user. - */ - include: InclusionType; - /** - * Node attributes specify customizable visual aspects of a node and - * application-specific metadata associated with a node. The name - * 'nodeAttributes' is meant to avoid naming-conflicts with the 'attr' in - * subclasses of Node. - */ - nodeAttributes: {[key: string]: any;}; -} - -export type TensorShape = number[]; - -export interface OpNode extends Node { - op: string; - // The device on which the op ran. Null if it is unknown. - device: string; - attr: {key: string, value: any}[]; - inputs: NormalizedInput[]; - inEmbeddings: OpNode[]; - outEmbeddings: OpNode[]; - // The name of the SeriesNode that can contain this node in its series. - // If there is no such node, then this is null. - owningSeries: string; - /** - * Array of tensor shapes. Null if the number of output tensors is unknown, - * otherwise the length will equal the number of output tensors. - * - * Each tensor shape is an array of numbers, or null. Details: - * - null means unknown rank, and therefore entire shape is unknown. - * - [4, 2, 1] means rank-3 tensor of size 4x2x1. - * - [] means a scalar (rank-0 tensor). - * - [1] means rank-1 tensor of size 1 (not the same as scalar). - * - [5, -1, 3] means rank-3 tensor of shape is 5x?x3. The size - * of the middle dimension is unknown (encoded as -1). - */ - outputShapes: TensorShape[]; - // The XLA Cluster on which the op ran. Null if it is unknown. - xlaCluster: string; -} - -export interface BridgeNode extends Node { - /** - * Whether this bridge node represents edges coming into its parent node. - */ - inbound: boolean; -} - -/** - * A node that is used when there are more than the maximum number of allowed - * annotations hanging off of a node. This node represents an ellipsis - * annotation, indicating a number of additional annotations. - */ -export interface EllipsisNode extends Node { - /** - * The number of nodes this ellipsis represents. - */ - numMoreNodes: number; - - /** - * Sets the number of nodes this ellipsis represents and changes the node - * name accordingly. - */ - setNumMoreNodes(numNodes: number); -} - -export interface GroupNode extends Node { - /** - * The metagraph contains nodes and metaedges between the immediate children - * of this group. The node label objects may be other GroupNodes (like - * SeriesNodes and Metanodes) or individual OpNodes. All edge label objects - * are Metaedges, each of which contains references to the original - * BaseEdge(s) from which it was created. - */ - metagraph: graphlib.Graph; - - /** - * The bridgegraph contains only edges which link immediate children of this - * group with nodes outside of the metagraph. As in the metagraph, all edge - * label objects are Metaedges which contain references to the original - * BaseEdge(s) that contribute to it. - * - * For a Metaedge in the bridgegraph, its external endpoint will be the same - * as the metagraph edge from which it came. This is most easily explained - * by example. - * - * Consider an original graph that contains a BaseEdge A/B/C->Z/Y/X. - * - * +-------+ (BaseEdge) +-------+ - * | A/B/C |>----------------->| Z/Y/X | - * +-------+ +-------+ - * - * When we construct the Root's metagraph, it will contain nodes for A and Z, - * and a Metaedge A->Z. The A->Z Metaedge will contain the original BaseEdge - * A/B/C->Z/Y/X in its baseEdgeGraph. The Root's bridgegraph will always be - * empty. - * - * +---+ (Root.metagraph edge) +---+ - * | A |>--------------------------->| Z | - * +---+ +---+ - * - * Now consider the Metanode A. Its metagraph will contain a Metanode for A/B - * and no edges. A's bridgegraph will have one Metaedge from A/B->Z, which - * was derived from the Root's Metaedge A->Z. That Metaedge will contain the - * original BaseEdge in its baseEdgeGraph. - * - * +---------+ - * | A | - * | +---+ | (A.bridgegraph edge) +---+ - * | | B |>---------------------------->| Z | - * | +---+ | +---+ - * +---------+ - * - * Finally, consider the Metanode A/B. Its metagraph will contain a Metanode - * for A/B/C and again no edges. A/B's bridgegraph will have one Metaedge - * from A/B/C->Z, which was derived from A's bridgegraph Metaedge A/B->Z. - * As before, the A/B/C->Z Metaedge will contain the original BaseEdge in its - * baseEdgeGraph. - * - * +---------------+ - * | A | - * | +---------+ | - * | | B | | - * | | +---+ | | (A/B.bridgegraph edge) +---+ - * | | | C |>----------------------------------->| Z | - * | | +---+ | | +---+ - * | +---------+ | - * +---------------+ - * - * Likewise, under the Metanode Z and Z/Y, to compute the bridgegraph, we'll - * end up with Metaedges A->Z/Y and A->Z/Y/X respectively. So the original - * BaseEdge A/B/C->Z/Y/X becomes four different Metaedges in four different - * bridgegraphs: - * - * + A/B->Z in GroupNode A's bridgegraph, - * + A/B/C->Z in GroupNode A/B's bridgegraph, - * + A->Z/Y in GroupNode Z's bridgegraph, and - * + A->Z/Y/X in GroupNode Z/Y's bridgegraph. - * - * Considering any BaseEdge then, if N is the number of path segments in the - * source and M is the number of path segments in the destination, then the - * total number of bridgegraph edges you could create would be (N-1)(M-1). - * - * For this reason, it is computationally expensive to generate all the - * bridgegraphs for all the Metanodes, and instead they should be computed - * on demand as needed. - */ - bridgegraph: graphlib.Graph; - - /** - * Stores how many times each device name appears in its children - * op nodes. Used to color group nodes by devices. - */ - deviceHistogram: {[device: string]: number}; - - /** - * Flag indicating whether this GroupNode's metagraph contains any edges that - * are not control edges. Used to quickly determine how to draw a collapsed - * series (vertically or horizontally). - */ - hasNonControlEdges: boolean; -} - -export interface Metanode extends GroupNode { - depth: number; - templateId: string; - opHistogram: {[op: string]: number}; - getFirstChild(): GroupNode|OpNode; - getRootOp(): OpNode; - /** Return name of all leaves inside a metanode. */ - leaves(): string[]; -} - -export interface SeriesNode extends GroupNode { - hasLoop: boolean; - prefix: string; - suffix: string; - clusterId: number; - ids: number[]; - parent: string; -} - -export class EllipsisNodeImpl implements EllipsisNode { - name: string; - numMoreNodes: number; - stats: NodeStats; - type: NodeType; - isGroupNode: boolean; - cardinality: number; - parentNode: Node; - include: InclusionType; - nodeAttributes: {[key: string]: any;}; - /** - * Constructs a new ellipsis annotation node. - * - * @param numNodes The number of additional annotations this node represents. - */ - constructor(numNodes: number) { - this.type = NodeType.ELLIPSIS; - this.isGroupNode = false; - this.cardinality = 1; - this.parentNode = null; - this.stats = null; - this.setNumMoreNodes(numNodes); - this.include = InclusionType.UNSPECIFIED; - } - - setNumMoreNodes(numNodes: number) { - this.numMoreNodes = numNodes; - this.name = '... ' + numNodes + ' more'; - } -}; - -/** - * A label object for nodes in the full graph and leaf nodes in the render - * graph. - */ -export class OpNodeImpl implements OpNode { - name: string; - op: string; - device: string; - stats: NodeStats; - attr: {key: string, value: any}[]; - inputs: NormalizedInput[]; - type: NodeType; - isGroupNode: boolean; - cardinality: number; - inEmbeddings: OpNode[]; - outEmbeddings: OpNode[]; - parentNode: Node; - include: InclusionType; - owningSeries: string; - outputShapes: TensorShape[]; - nodeAttributes: {[key: string]: any;}; - xlaCluster: string; - - /** - * Constructs a new Op node. - * - * @param rawNode The raw node. - */ - constructor(rawNode: tf.graph.proto.NodeDef) { - this.op = rawNode.op; - this.name = rawNode.name; - this.device = rawNode.device; - this.attr = rawNode.attr; - // An array of normalized inputs that denote the incoming edges to - // the current node. Each input contains the normalized name of the - // source node, whether it has a number part and whether it is a - // control dependency. - this.inputs = normalizeInputs(rawNode.input); - this.outputShapes = extractOutputShapes(rawNode.attr); - this.xlaCluster = extractXlaCluster(rawNode.attr); - // additional properties - this.type = NodeType.OP; - this.isGroupNode = false; - this.cardinality = 1; - this.inEmbeddings = []; - this.outEmbeddings = []; - this.parentNode = null; - this.include = InclusionType.UNSPECIFIED; - this.owningSeries = null; - } -}; - -export function createMetanode(name: string, opt = {}): Metanode { - return new MetanodeImpl(name, opt); -} - -/** - * Joins the information from the stats file (memory, compute time) with the - * graph information. - */ -export function joinStatsInfoWithGraph( - graph: SlimGraph, stats: tf.graph.proto.StepStats, - devicesForStats?: {[device: string]: boolean}): void { - // Reset stats for each node. - _.each(graph.nodes, node => { node.stats = null; }); - - _.each(stats.dev_stats, devStats => { - // Ignore devices that are not selected. - if (devicesForStats && !devicesForStats[devStats.device]) { - return; - } - _.each(devStats.node_stats, nodeStats => { - // Lookup the node in the graph by its original name, e.g. A. If not - // found, lookup by the rewritten name A/(A) in case the name is both - // a namespace and a node name. - let nodeName = nodeStats.node_name in graph.nodes ? nodeStats.node_name : - nodeStats.node_name + - NAMESPACE_DELIM + '(' + nodeStats.node_name + ')'; - - // Couldn't find a matching node. - if (!(nodeName in graph.nodes)) { - return; - } - - // Compute the total bytes used. - let totalBytes = 0; - if (nodeStats.memory) { - _.each(nodeStats.memory, alloc => { - if (alloc.total_bytes) { - if (alloc.total_bytes > 0) { - totalBytes += Number(alloc.total_bytes); - } else { - /* tslint:disable */ - console.log( - 'ignoring negative memory allocation for ' + nodeName); - /* tslint:enable */ - } - } - }); - } - let outputSize: number[][] = null; - if (nodeStats.output) { - outputSize = _.map(nodeStats.output, output => { - return _.map(output.tensor_description.shape.dim, - dim => Number(dim.size)); - }); - } - graph.nodes[nodeName].device = devStats.device; - if (graph.nodes[nodeName].stats == null) { - graph.nodes[nodeName].stats = new NodeStats(outputSize); - } - graph.nodes[nodeName].stats.addBytesAllocation(totalBytes); - if (nodeStats.all_end_rel_micros) { - if (nodeStats.all_end_rel_micros > 0) { - graph.nodes[nodeName].stats.addExecutionTime( - nodeStats.all_start_micros, - nodeStats.all_start_micros + nodeStats.all_end_rel_micros); - } else { - /* tslint:disable */ - console.log('ignoring negative runtime for ' + nodeName); - /* tslint:enable */ - } - } - }); - }); -} - -/** - * Execution stats for the node. - */ -export class NodeStats { - constructor(outputSize: number[][]) { this.outputSize = outputSize; } - - /** - * Add the start and end time for a particular kernel execution of this op. - * Ops can have multiple kernel executions within the same session run. - */ - addExecutionTime(startTime: number, endTime: number) { - if (this.startTime != null) { - this.startTime = Math.min(this.startTime, startTime); - } else { - this.startTime = startTime; - } - if (this.endTime != null) { - this.endTime = Math.max(this.endTime, endTime); - } else { - this.endTime = endTime; - } - } - - /** - * Add the bytes allocated for a particular kernel execution of this op. - * Ops can have multiple kernel executions within the same session run. - */ - addBytesAllocation(totalBytes: number) { - if (this.totalBytes != null) { - this.totalBytes = Math.max(this.totalBytes, totalBytes); - } else { - this.totalBytes = totalBytes; - } - } - - /** - * Absolute start time for the very first kernel execution of this op. - */ - startTime: number; - /** - * Absolute end time for the very last kernel execution of this op. - */ - endTime: number; - /** - * Total number of bytes used for the node. Sum of all children - * if it is a Group node. - */ - totalBytes = 0; - - /** - * The shape of each output tensors, if there are any. - * Empty if it is a Group node. - */ - outputSize: number[][]; - - /** - * Combines the specified stats with the current stats. - * Modifies the current object. This method is used to - * compute aggregate stats for group nodes. - */ - combine(stats: NodeStats): void { - if (stats.totalBytes != null) { - this.totalBytes += stats.totalBytes; - } - if (stats.getTotalMicros() != null) { - this.addExecutionTime(stats.startTime, stats.endTime); - } - } - - /** - * Total number of compute time in microseconds used for the node. - * Sum of all children if it is a Group node. Null if it is unknown. - * This method can not be scaffolded under a getter attribute because - * ECMAScript 5 does not support getter attributes. - */ - getTotalMicros(): number { - if (this.startTime == null || this.endTime == null) { - return null; - } - return this.endTime - this.startTime; - } -} - -export class MetanodeImpl implements Metanode { - name: string; - stats: NodeStats; - type: NodeType; - depth: number; - isGroupNode: boolean; - cardinality: number; - metagraph: graphlib.Graph; - bridgegraph: graphlib.Graph; - templateId: string; - opHistogram: {[op: string]: number}; - deviceHistogram: {[op: string]: number}; - parentNode: Node; - hasNonControlEdges: boolean; - include: InclusionType; - nodeAttributes: {[key: string]: any;}; - - /** A label object for meta-nodes in the graph hierarchy */ - constructor(name: string, opt = {}) { - this.name = name; - this.type = NodeType.META; - /** number of levels under this group */ - this.depth = 1; - this.isGroupNode = true; - /** # of leaf nodes (including embedded ones) */ - this.cardinality = 0; - /** graph contains metanodes, nodes, edges - * and metaedges for main items within this metanode - */ - this.metagraph = - createGraph(name, GraphType.META, opt); - /** bridgegraph must be constructed lazily-see hierarchy.getBridgegraph() */ - this.bridgegraph = null; - /** - * A dictionary that count ops type of nodes in this metanode - * (op type => count). - */ - this.opHistogram = {}; - this.deviceHistogram = {}; - /** unique id for a metanode of similar subgraph */ - this.templateId = null; - /** Metanode which contains this node, if any */ - this.parentNode = null; - this.hasNonControlEdges = false; - this.include = InclusionType.UNSPECIFIED; - } - - getFirstChild(): GroupNode|OpNode { - return this.metagraph.node(this.metagraph.nodes()[0]); - } - - /** - * Returns the op node associated with the metanode. - * For example, if the metanode is 'sgd', the associated - * op node is sgd/(sgd). - */ - getRootOp(): OpNode { - let nameSplit = this.name.split('/'); - let rootOpName = this.name + '/(' + nameSplit[nameSplit.length - 1] + ')'; - return this.metagraph.node(rootOpName); - } - - /** - * Return an array of the names of all the leaves (non-GroupNodes) inside - * this metanode. This performs a breadth-first search of the tree, so - * immediate child leaves will appear earlier in the output array than - * descendant leaves. - */ - leaves(): string[] { - let leaves = []; - let queue = [ this]; - let metagraph; // Defined here due to a limitation of ES6->5 compilation. - while (queue.length) { - let node = queue.shift(); - if (node.isGroupNode) { - metagraph = ( node).metagraph; - _.each(metagraph.nodes(), name => queue.push(metagraph.node(name))); - } else { - leaves.push(node.name); - } - } - return leaves; - } -}; - -export interface Metaedge extends graphlib.EdgeObject { - - /** - * Stores the original BaseEdges represented by this Metaedge. - */ - baseEdgeList: BaseEdge[]; - - /** - * Whether this edge represents a relationship that is inbound (or outbound) - * to the object which contains this information. For example, in a Metanode's - * bridgegraph, each edge connects an immediate child to something outside - * the Metanode. If the destination of the edge is inside the Metanode, then - * its inbound property should be true. If the destination is outside the - * Metanode, then its inbound property should be false. - * - * The property is optional because not all edges can be described as - * inbound/outbound. For example, in a Metanode's metagraph, all of the edges - * connect immediate children of the Metanode. None should have an inbound - * property, or they should be null/undefined. - */ - inbound?: boolean; - - /** - * Number of regular edges (not control dependency edges). - */ - numRegularEdges: number; - - /** - * Number of control dependency edges. - */ - numControlEdges: number; - - /** - * Number of reference edges, which is an edge to an operation - * that takes a reference to its input and changes its value. - */ - numRefEdges: number; - - /** - * Total size (number of units) of all the tensors flowing through this edge. - */ - totalSize: number; - - addBaseEdge(edge: BaseEdge, h: hierarchy.Hierarchy): void; -} - -export function createMetaedge(v: string, w: string): Metaedge { - return new MetaedgeImpl(v, w); -} - -/** - * A label object for edges between metanodes of subgraphs in the render graph. - */ -export class MetaedgeImpl implements Metaedge { - v: string; - w: string; - baseEdgeList: BaseEdge[]; - inbound: boolean; - numRegularEdges: number; - numControlEdges: number; - numRefEdges: number; - totalSize: number; - - constructor(v: string, w: string) { - this.v = v; - this.w = w; - this.baseEdgeList = []; - this.inbound = null; - this.numRegularEdges = 0; - this.numControlEdges = 0; - this.numRefEdges = 0; - this.totalSize = 0; - } - - addBaseEdge(edge: BaseEdge, h: hierarchy.Hierarchy): void { - this.baseEdgeList.push(edge); - if (edge.isControlDependency) { - this.numControlEdges += 1; - } else { - this.numRegularEdges += 1; - } - if (edge.isReferenceEdge) { - this.numRefEdges += 1; - } - // Compute the size of the tensor flowing through this - // base edge. - this.totalSize += MetaedgeImpl.computeSizeOfEdge(edge, h); - h.maxMetaEdgeSize = Math.max(h.maxMetaEdgeSize, this.totalSize); - } - - private static computeSizeOfEdge(edge: BaseEdge, h: hierarchy.Hierarchy): - number { - let opNode = h.node(edge.v); - if (opNode.outputShapes == null) { - // No shape information. Assume a single number. This gives - // a lower bound for the total size. - return 1; - } - h.hasShapeInfo = true; - // Sum the sizes of all output tensors. - return _(opNode.outputShapes).map(shape => { - // If the shape is unknown, treat it as 1 when computing - // total size. This gives a lower bound for the total size. - if (shape == null) { - return 1; - } - // Multiply all shapes to get the total size of the tensor. - // E.g. The total size of [4, 2, 1] is 4 * 2 * 1. - return _(shape).reduce((accumulated, currSize) => { - // If this particular dimension is unknown, treat - // it as 1 when computing total size. This gives a lower bound - // for the total size. - if (currSize === -1) { - currSize = 1; - } - return accumulated * currSize; - }, 1); - }).sum(); - } -} - -export function createSeriesNode(prefix: string, suffix: string, - parent: string, clusterId: number, name: string): SeriesNode { - return new SeriesNodeImpl(prefix, suffix, parent, clusterId, name); -} - -export function getSeriesNodeName(prefix: string, suffix: string, - parent: string, startId?: number, endId?: number): string { - let numRepresentation = - (typeof startId !== 'undefined' && typeof endId !== 'undefined') ? - '[' + startId + '-' + endId + ']' : - '#'; - let pattern = prefix + numRepresentation + suffix; - return (parent ? parent + '/' : '') + pattern; -} - -class SeriesNodeImpl implements SeriesNode { - name: string; - type: NodeType; - stats: NodeStats; - hasLoop: boolean; - prefix: string; - suffix: string; - clusterId: number; - ids: number[]; - parent: string; - isGroupNode: boolean; - cardinality: number; - metagraph: graphlib.Graph; - bridgegraph: graphlib.Graph; - parentNode: Node; - deviceHistogram: {[op: string]: number}; - hasNonControlEdges: boolean; - include: InclusionType; - nodeAttributes: {[key: string]: any;}; - - constructor(prefix: string, suffix: string, parent: string, - clusterId: number, name: string) { - this.name = name || getSeriesNodeName(prefix, suffix, parent); - this.type = NodeType.SERIES; - this.hasLoop = false; - this.prefix = prefix; - this.suffix = suffix; - this.clusterId = clusterId; - this.ids = []; - this.parent = parent; - this.isGroupNode = true; - this.cardinality = 0; - this.metagraph = createGraph(name, GraphType.SERIES); - // bridgegraph must be constructed lazily-see hierarchy.getBridgegraph() - this.bridgegraph = null; - this.parentNode = null; - this.deviceHistogram = {}; - this.hasNonControlEdges = false; - this.include = InclusionType.UNSPECIFIED; - } -} - -/** - * Extracts the shapes of the output tensors from the attr property in the - * node proto. - */ -// tslint:disable-next-line:no-any -function extractOutputShapes(attr: Array<{key: string, value: any}>): - TensorShape[] { - let result = null; - // We don't know anything about the output tensors. - if (!attr) { - return null; - } - for (let i = 0; i < attr.length; i++) { - let {key, value} = attr[i]; - if (key === OUTPUT_SHAPES_KEY) { - if (!value.list.shape) { - // The OUTPUT_SHAPES_KEY lacks a value. We know nothing about the shape. - return null; - } - - // Map all output tensors into array of numbers denoting their shape. - let result = value.list.shape.map(shape => { - if (shape.unknown_rank) { - // This output tensor is of unknown rank. We don't know if it is a - // scalar, or a tensor, or of what shape it is. - return null; - } - if (shape.dim == null || - (shape.dim.length === 1 && shape.dim[0].size == null)) { - // This output tensor is a scalar. - return []; - } - // This output tensor has a known rank. Map each dimension size - // into a number. - return shape.dim.map(dim => { - // Size can be -1 if this particular dimension is unknown. - return dim.size; - }); - }); - // Since we already processed it, remove the entry from the attribute - // list (saves memory). - attr.splice(i, 1); - return result; - } - } - // We didn't find OUTPUT_SHAPES_KEY in attributes, so we don't know anything - // about the output tensors. - return null; -} - -/** - * Extracts the XLA Cluster that an op runs on from the attrs of the OpNode. - * @param attr The attr property. - * @return A string that is the name of the cluster. Or null if it could not be - * determined. - */ -// tslint:disable-next-line:no-any -function extractXlaCluster(attr: Array<{key: string, value: any}>): string| - null { - if (!attr) { - return null; - } - - // Find the attribute for XLA cluster if there is one. - for (let i = 0; i < attr.length; i++) { - if (attr[i].key === _XLA_CLUSTER_KEY) { - return attr[i].value['s'] || null; - } - } - return null; -} - -/** - * Normalizes the inputs and extracts associated metadata: - * 1) Inputs can contain a colon followed by a number at the end - * (e.g. inputName:1) and we remove this from the input name, and take note - * that the input was numbered. - * 2) Control dependency inputs contain caret at the beginning and we - * remove this and annotate the edge as a control dependency. - * @param inputs Array of unnormalized names of input nodes. - */ -function normalizeInputs(inputs: string[]): NormalizedInput[] { - let normalizedInputs: NormalizedInput[] = []; - _.each(inputs, inputName => { - let start = inputName[0] === '^'; - let colon = inputName.lastIndexOf(':'); - let end = colon !== -1 && - inputName.length - colon > 1 && - !(/\D/).test(inputName.substring(colon + 1)) ? - colon : inputName.length; - let name = inputName.substring(start ? 1 : 0, end); - if (normalizedInputs.length === 0 || - name !== normalizedInputs[normalizedInputs.length - 1].name) { - normalizedInputs.push({ - name: name, - outputTensorIndex: - end === inputName.length ? 0 : Number(inputName.slice(colon + 1)), - isControlDependency: start - }); - } - }); - return normalizedInputs; -} - -function addEdgeToGraph( - graph: SlimGraph, inputName: string, outputNode: OpNode, - input: NormalizedInput, params: BuildParams, index: number) { - // Don't allow loops in the graph. - if (inputName === outputNode.name) { - return; - } - // Check if this op type and input number corresponds to a - // reference edge using the refEdges dictionary in the params. - let isRefEdge = params.refEdges[outputNode.op + ' ' + index] === true; - graph.edges.push({ - v: inputName, - w: outputNode.name, - outputTensorIndex: input.outputTensorIndex, - isControlDependency: input.isControlDependency, - isReferenceEdge: isRefEdge - }); -} - -export function build( - rawNodes: tf.graph.proto.NodeDef[], params: BuildParams, - tracker: ProgressTracker): Promise { - /** - * A dictionary that maps each in-embedding node name to the node - * object. - */ - let inEmbedding: {[nodeName: string]: OpNode} = {}; - /** - * A dictionary that maps each out-embedding node name to the node - * object. - */ - let outEmbedding: {[nodeName: string]: OpNode} = {}; - /** - * A dictionary that maps each node name to an array of the node's - * out-embedding node label objects. - */ - let outEmbeddings: {[inputName: string]: OpNode[]} = {}; - let isInEmbeddedPred = getEmbedPredicate(params.inEmbeddingTypes); - let isOutEmbeddedPred = getEmbedPredicate(params.outEmbeddingTypes); - let embeddingNodeNames: string[] = []; - /** - * A list of all the non-embedding node names which appear in the processed - * list of raw nodes. Here we pre-allocate enough room for all the rawNodes, - * even though there will some number of embeddings. The excess array length - * is spliced off later. - * - * Experimentation shows that around 30% of the array will go unused, and - * even for very large networks that amounts to less than 10k spaces. - */ - let nodeNames = new Array(rawNodes.length); - - return tf.graph.util - .runAsyncTask( - 'Normalizing names', 30, - () => { - let opNodes = new Array(rawNodes.length); - let index = 0; - _.each(rawNodes, rawNode => { - let opNode = new OpNodeImpl(rawNode); - if (isInEmbeddedPred(opNode)) { - embeddingNodeNames.push(opNode.name); - inEmbedding[opNode.name] = opNode; - return; - } - - if (isOutEmbeddedPred(opNode)) { - embeddingNodeNames.push(opNode.name); - outEmbedding[opNode.name] = opNode; - _.each(opNode.inputs, input => { - let inputName = input.name; - outEmbeddings[inputName] = outEmbeddings[inputName] || []; - outEmbeddings[inputName].push(opNode); - }); - return; - } - // The node is not an embedding, so add it to the names and nodes - // lists. - opNodes[index] = opNode; - nodeNames[index] = opNode.name; - index++; - }); - opNodes.splice(index); - nodeNames.splice(index); - return opNodes; - }, - tracker) - .then((opNodes) => { - // Create the graph data structure from the graphlib library. - return tf.graph.util.runAsyncTask( - 'Building the data structure', 70, () => { - let normalizedNameDict = - mapStrictHierarchy(nodeNames, embeddingNodeNames); - let graph = new SlimGraph; - - // Add the nodes to the graph. - _.each(opNodes, opNode => { - let normalizedName = - normalizedNameDict[opNode.name] || opNode.name; - graph.nodes[normalizedName] = opNode; - // Check if the node has out-embeddings. If yes, add them to the - // node. - if (opNode.name in outEmbeddings) { - opNode.outEmbeddings = outEmbeddings[opNode.name]; - // Normalize the names of the out-embeddings. - _.each(opNode.outEmbeddings, node => { - node.name = normalizedNameDict[node.name] || node.name; - }); - } - // Update the name of the node. - opNode.name = normalizedName; - }); - - // Visit each node's inputs to add the edges to the graph. If the - // input - // is an in-embedding, then add it to the node's in-embeddings - // instead. - _.each(opNodes, opNode => { - _.each(opNode.inputs, (input, i) => { - let inputName = input.name; - if (inputName in inEmbedding) { - let inEmbedNode = inEmbedding[inputName]; - opNode.inEmbeddings.push(inEmbedNode); - // Move the inputs of the in-embedding node into incoming - // edges of - // the main node. E.g. the control dependency of a constant - // node - // should be moved to the op node where the constant is - // embedded. - for (let embedInput of inEmbedNode.inputs) { - addEdgeToGraph( - graph, normalizedNameDict[embedInput.name] || - embedInput.name, - opNode, embedInput, params, i); - } - } else if (inputName in outEmbedding) { - // Move the inputs of the out-embedding node into inputs of - // the main node where the out-embedding points to. - let outEmbedNode = outEmbedding[inputName]; - for (let embedInput of outEmbedNode.inputs) { - addEdgeToGraph( - graph, normalizedNameDict[embedInput.name] || - embedInput.name, - opNode, input, params, i); - } - } else { - addEdgeToGraph( - graph, normalizedNameDict[inputName] || inputName, - opNode, input, params, i); - } - }); - }); - - // Normalize the names of in-embeddings. - _.each(inEmbedding, (node, name) => { - node.name = normalizedNameDict[node.name] || node.name; - }); - - return graph; - }, tracker); - }); -}; - -/** - * Create a new graphlib.Graph() instance with default parameters - */ -export function createGraph(name: string, type, opt = {}): - graphlib.Graph { - let graph = new graphlib.Graph(opt); - graph.setGraph({ - name: name, - rankdir: 'BT', // BT,TB,LR,RL - type: type - }); - return graph; -}; - -/** - * Create a predicate for checking whether a node should be embedded based on - * the specified types. - */ -function getEmbedPredicate(types: string[]) { - return function(node: OpNode) { - // check types - for (let i = 0; i < types.length; i++) { - let regExp = new RegExp(types[i]); - if (node.op.match(regExp)) { return true; } - } - return false; - }; -}; - -/** - * Returns a strict node name (name => name/(name)) to avoid conflicts - * where the node name is also a namespace. - */ -export function getStrictName(name: string): string { - let parts = name.split(NAMESPACE_DELIM); - return name + NAMESPACE_DELIM + '(' + parts[parts.length - 1] + ')'; -} - -/** - * For each op node (embedding or non-embedding), rename it if there is a - * non-embedding node under its namespace. For example, assume node name 'A'. - * If there is a non-embedding node under its namespace (e.g. 'A/B'), 'A' will - * be renamed to 'A/(A)'. Then the namespace 'A' will contain 2 nodes: '(A)' - * and 'B'. If all the nodes under 'A' are embedding nodes (e.g. constant and - * summary), keep 'A' as an Op node and don't create a namespace. - * - * @param nodeNames An array of regular (non-embedding) node names. - * @param embeddingNodeNames An array of embedding node names. - * @return Dictionary object mapping names that need to be renamed to - * new names. - */ -function mapStrictHierarchy(nodeNames: string[], - embeddingNodeNames: string[]): {[oldName: string]: string} { - /** Dictionary that maps the old new to the new name */ - let newNameDictionary: {[oldName: string]: string} = {}; - /** Set used to store all namespaces. */ - let namespaceSet: {[namespace: string]: boolean} = {}; - // sort the nodes to make prefix check faster - nodeNames.sort(); - // look for nodes with a prefix a,a/b -> a/(a),a/b - for (let i = 0; i < nodeNames.length - 1; ++i) { - let a = nodeNames[i]; - // Get all the parent namespaces of the current node - // and add them in the namespace set. - _.each(getHierarchicalPath(a).slice(0, -1), ns => { - namespaceSet[ns] = true; - }); - for (let j = i + 1; j < nodeNames.length; ++j) { - let b = nodeNames[j]; - if (_.startsWith(b, a)) { - if (b.length > a.length && b.charAt(a.length) === NAMESPACE_DELIM) { - newNameDictionary[a] = getStrictName(a); - break; - } - } else { - break; - } - } - } - // Go through all the embedding node names and rename them in case they - // collide with namespaces. - _.each(embeddingNodeNames, embeddingName => { - if (embeddingName in namespaceSet) { - // Rename to follow strict hierarchy. - newNameDictionary[embeddingName] = getStrictName(embeddingName); - } - }); - return newNameDictionary; -}; - -/** - * Returns a list of the degrees of each node in the graph. - */ -function degreeSequence(graph: graphlib.Graph): number[] { - let degrees = graph.nodes().map(function(name) { - return graph.neighbors(name).length; - }); - degrees.sort(); - return degrees; -}; - -/** - * Returns if the degree sequence of the two graphs is the same. - */ -export function hasSimilarDegreeSequence(graph1: graphlib.Graph, - graph2: graphlib.Graph): boolean { - let dg1 = degreeSequence(graph1); - let dg2 = degreeSequence(graph2); - - for (let i = 0; i < dg1.length; i++) { - if (dg1[i] !== dg2[i]) { - return false; - } - } - return true; -}; - -/** - * Returns the hierarchical path of the current node, based on the node's name. - * For example, if the name is 'a/b/c', the returned path is - * ['a', 'a/b', 'a/b/c']. - */ -export function getHierarchicalPath(name: string, - seriesNames?: { [name: string]: string }): string[] { - let path: string[] = []; - let i = name.indexOf(NAMESPACE_DELIM); - // Push all parent portions of the path. - while (i >= 0) { - path.push(name.substring(0, i)); - i = name.indexOf(NAMESPACE_DELIM, i + 1); - } - // If the node's path is under a series, then add the series node name to the - // hierarchical path as the parent of the leaf. - if (seriesNames) { - let seriesName = seriesNames[name]; - if (seriesName) { - path.push(seriesName); - } - } - // Push the leaf of the path. - path.push(name); - return path; -}; - -/** - * Returns the string for the node inclusion toggle button, dependant - * on the provided current InclusionType. - */ -export function getIncludeNodeButtonString(include: InclusionType) { - if (include === tf.graph.InclusionType.EXCLUDE) { - return 'Add to main graph'; - } else { - return 'Remove from main graph'; - } -}; - -/** - * Returns the string for the series node grouping toggle button, dependant - * on the provided current SeriesGroupingType. - */ -export function getGroupSeriesNodeButtonString(group: SeriesGroupingType) { - if (group === tf.graph.SeriesGroupingType.GROUP) { - return 'Ungroup this series of nodes'; - } else { - return 'Group this series of nodes'; - } -}; - -/** - * Toggle the node series grouping option in the provided map, setting it - * to ungroup if the series is not already in the map. - */ -export function toggleNodeSeriesGroup( - map: { [name: string]: tf.graph.SeriesGroupingType }, name: string) { - if (!(name in map) || map[name] === tf.graph.SeriesGroupingType.GROUP) { - map[name] = tf.graph.SeriesGroupingType.UNGROUP; - } else { - map[name] = tf.graph.SeriesGroupingType.GROUP; - } -}; - -} // close module tf.graph diff --git a/tensorflow/tensorboard/components/tf_graph_common/hierarchy.ts b/tensorflow/tensorboard/components/tf_graph_common/hierarchy.ts deleted file mode 100644 index 889607ac500..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/hierarchy.ts +++ /dev/null @@ -1,807 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -/** - * Package for the Graph Hierarchy for TensorFlow graph. - */ -module tf.graph.hierarchy { - -/** - * Class used as output for getPredecessors and getSuccessors methods - */ -export interface Edges { - control: Metaedge[]; - regular: Metaedge[]; -} - -export interface Hierarchy { - root: Metanode; - templates: {[templateId: string]: string[]}; - /** List of all device names */ - devices: string[]; - /** List of all XLA cluster names */ - xlaClusters: string[]; - /** True if at least one tensor in the graph has shape information */ - hasShapeInfo: boolean; - /** The maximum size across all meta edges. Used for scaling thickness. */ - maxMetaEdgeSize: number; - getNodeMap(): {[nodeName: string]: GroupNode|OpNode}; - node(name: string): GroupNode|OpNode; - setNode(name: string, node: GroupNode|OpNode): void; - getBridgegraph(nodeName: string): graphlib.Graph; - getPredecessors(nodeName: string): Edges; - getSuccessors(nodeName: string): Edges; - getTopologicalOrdering(nodeName: string): { [childName: string]: number }; - getTemplateIndex(): (string) => number; -} - -/** - * Class for the Graph Hierarchy for TensorFlow graph. - */ -class HierarchyImpl implements Hierarchy { - root: Metanode; - templates: {[templateId: string]: string[]}; - private index: {[nodeName: string]: GroupNode|OpNode}; - devices: string[]; - xlaClusters: string[]; - hasShapeInfo = false; - maxMetaEdgeSize = 1; - orderings: { [nodeName: string]: { [childName: string]: number } }; - - constructor() { - this.root = createMetanode(ROOT_NAME, {compound: true}); - this.templates = null; - this.devices = null; - /** - * @type {Object} Dictionary object that maps node name to the node - * (could be op-node, metanode, or series-node) - */ - this.index = {}; - this.index[ROOT_NAME] = this.root; - this.orderings = {}; - } - - getNodeMap(): {[nodeName: string]: GroupNode|OpNode} { - return this.index; - } - - node(name: string): GroupNode|OpNode { - return this.index[name]; - } - - setNode(name: string, node: GroupNode|OpNode): void { - this.index[name] = node; - } - - /** - * Given the name of a node in this hierarchy, get its bridgegraph, creating - * it on the fly if necessary. If the node is not a GroupNode, then this - * method returns null. If the provided name does not map to a node in the - * hierarchy, an error will be thrown. - */ - getBridgegraph(nodeName: string): graphlib.Graph { - let node = this.index[nodeName]; - if (!node) { - throw Error('Could not find node in hierarchy: ' + nodeName); - } - if (!('metagraph' in node)) { - return null; - } - let groupNode = node; - if (groupNode.bridgegraph) { - return groupNode.bridgegraph; - } - let bridgegraph = groupNode.bridgegraph = - createGraph( - 'BRIDGEGRAPH', GraphType.BRIDGE); - if (!node.parentNode || !('metagraph' in node.parentNode)) { - return bridgegraph; - } - - let parentNode = node.parentNode; - let parentMetagraph = parentNode.metagraph; - let parentBridgegraph = this.getBridgegraph(parentNode.name); - - // For each of the parent node's two Metaedge containing graphs, process - // each Metaedge involving this node. - _.each([parentMetagraph, parentBridgegraph], parentGraph => { - _(parentGraph.edges()) - .filter(e => e.v === nodeName || e.w === nodeName) - .each(parentEdgeObj => { - - let inbound = parentEdgeObj.w === nodeName; - let parentMetaedge = parentGraph.edge(parentEdgeObj); - - // The parent's Metaedge represents some number of underlying - // BaseEdges from the original full graph. For each of those, we need - // to determine which immediate child is involved and make sure - // there's a Metaedge in the bridgegraph that covers it. - _.each(parentMetaedge.baseEdgeList, baseEdge => { - - // Based on the direction, figure out which is the descendant node - // and which is the 'other' node (sibling of parent or ancestor). - let [descendantName, otherName] = - inbound ? - [baseEdge.w, parentEdgeObj.v] : - [baseEdge.v, parentEdgeObj.w]; - - // Determine the immediate child containing this descendant node. - let childName = this.getChildName(nodeName, descendantName); - - // Look for an existing Metaedge in the bridgegraph (or create a - // new one) that covers the relationship between child and other. - let bridgeEdgeObj = { - v: inbound ? otherName : childName, - w: inbound ? childName : otherName, - }; - let bridgeMetaedge = bridgegraph.edge(bridgeEdgeObj); - if (!bridgeMetaedge) { - bridgeMetaedge = createMetaedge(bridgeEdgeObj.v, bridgeEdgeObj.w); - bridgeMetaedge.inbound = inbound; - bridgegraph.setEdge(bridgeEdgeObj.v, bridgeEdgeObj.w, - bridgeMetaedge); - } - - // Copy the BaseEdge from the parent's Metaedge into this - // bridgegraph Metaedge. - bridgeMetaedge.addBaseEdge(baseEdge, this); - }); - }) - .value(); // force lodash chain execution. - }); - - return bridgegraph; - } - - /** - * Utility function for determining the name of the immediate child under a - * node for a given descendant path. If the descendant corresponds to no - * immediate child, an error is thrown. - */ - getChildName(nodeName: string, descendantName: string): string { - // Walk up the hierarchy from the descendant to find the child. - let currentNode: Node = this.index[descendantName]; - while (currentNode) { - if (currentNode.parentNode && currentNode.parentNode.name === nodeName) { - return currentNode.name; - } - currentNode = currentNode.parentNode; - } - throw Error( - 'Could not find immediate child for descendant: ' + descendantName); - }; - - /** Given the name of a node, return its incoming metaedges. */ - getPredecessors(nodeName: string): Edges { - let node = this.index[nodeName]; - if (!node) { - throw Error('Could not find node with name: ' + nodeName); - } - - let predecessors = this.getOneWayEdges(node, true); - // Add embedded predecessors, such as constants. - if (!node.isGroupNode) { - _.each((node).inEmbeddings, embeddedNode => { - _.each((node).inputs, input => { - if (input.name === embeddedNode.name) { - // Make a new metaedge holding the edge between the - // node and the in-embedding. - let metaedge = new MetaedgeImpl(embeddedNode.name, nodeName); - metaedge.addBaseEdge( - { - isControlDependency: input.isControlDependency, - outputTensorIndex: input.outputTensorIndex, - isReferenceEdge: false, - v: embeddedNode.name, - w: nodeName - }, - this); - predecessors.regular.push(metaedge); - } - }); - }); - } - return predecessors; - } - - /** - * Given the name of a node, return its outgoing metaedges. - * - * This is the inverse of getPredecessors(). See that method's documentation - * for an in-depth example. - */ - getSuccessors(nodeName: string): Edges { - let node = this.index[nodeName]; - if (!node) { - throw Error('Could not find node with name: ' + nodeName); - } - - let successors = this.getOneWayEdges(node, false); - - // Add embedded successors, such as summaries. - if (!node.isGroupNode) { - _.each((node).outEmbeddings, embeddedNode => { - _.each(embeddedNode.inputs, input => { - if (input.name === nodeName) { - // Make a new metaedge holding the edge between the - // node and the out-embedding. - let metaedge = new MetaedgeImpl(nodeName, embeddedNode.name); - metaedge.addBaseEdge( - { - isControlDependency: input.isControlDependency, - outputTensorIndex: input.outputTensorIndex, - isReferenceEdge: false, - v: nodeName, - w: embeddedNode.name - }, - this); - successors.regular.push(metaedge); - } - }); - }); - } - return successors; - } - - /** Helper method for getPredecessors and getSuccessors */ - getOneWayEdges(node: GroupNode|OpNode, inEdges: boolean) { - let edges: Edges = {control: [], regular: []}; - // A node with no parent cannot have any edges. - if (!node.parentNode || !node.parentNode.isGroupNode) { - return edges; - } - let parentNode = node.parentNode; - let metagraph = parentNode.metagraph; - let bridgegraph = this.getBridgegraph(parentNode.name); - findEdgeTargetsInGraph(metagraph, node, inEdges, edges); - findEdgeTargetsInGraph(bridgegraph, node, inEdges, edges); - return edges; - } - - /** - * For a given GroupNode, get or calculate an object which describes a - * topological ordering of child nodes within that GroupNode's metagraph. - * - * This ordering is used when rendering bridge control edges which are - * sometimes backwards relative to the dataflow. - * - * For example, say we have a graph with two edges A->B and A->C, and we're - * interested in the ordering under ROOT. In this case, any of the following - * would be legitimate return values: - * - * - { 'A': 0, 'B': 1, 'C': 2 } -- most likely - * - { 'A': 0, 'B': 2, 'C': 1 } -- less likely - * - { 'A': 12, 'B': 100, 'C': 99 } -- unlikely, but still OK - * - * The algorithm does not guarantee that all numbers from 0-N (where N is - * the number of nodes) appear exactly once. Rather it guarantees that if - * there is a path between two nodes, the earlier one will have a lower - * number in the ordering hash. - * - * When generating the ordering, we ignore control Metaedges (those which - * represent only BaseEdges that have isControlDependency set to true). - * - * If there is no node with the specified name, an error is thrown. If the - * node with the specified name is not a group node, null is returned. - */ - getTopologicalOrdering(nodeName: string): { [childName: string]: number } { - let node = this.index[nodeName]; - if (!node) { - throw Error('Could not find node with name: ' + nodeName); - } - if (!node.isGroupNode) { - return null; - } - if (nodeName in this.orderings) { - return this.orderings[nodeName]; - } - - // Mapping of a child node names to lists of their successors. - let successors: { [childName: string]: string[] } = {}; - - // Set of node names which have appeared as a destination. - let destinations: { [childName: string]: boolean } = {}; - - let metagraph = ( node).metagraph; - _.each(metagraph.edges(), (e: graphlib.EdgeObject) => { - if (!metagraph.edge(e).numRegularEdges) { - return; // Skip control edges. - } - - // Keep track of successors and destinations. - if (!(e.v in successors)) { - successors[e.v] = []; - } - successors[e.v].push(e.w); - destinations[e.w] = true; - }); - - // Seed the queue with true sources (those that are not destinations). - let queue: string[] = - _.difference(_.keys(successors), _.keys(destinations)); - - // Produce an ordering by traversing the graph breadth first. - let ordering = this.orderings[nodeName] = {}; - let index = 0; - while (queue.length) { - let childName = queue.shift(); - ordering[childName] = index++; - _.each(successors[childName], succName => queue.push(succName)); - delete successors[childName]; // Prevent cycles from infinite looping. - } - return ordering; - } - - /** - * Returns a d3 Ordinal function that can be used to look up the index of - * a node based on its template id. - */ - getTemplateIndex(): (string) => number { - let templateNames = d3.keys(this.templates); - let templateIndex = d3.scaleOrdinal() - .domain(templateNames) - .range(d3.range(0, templateNames.length)); - return (templateId: string) => templateIndex(templateId); - } -} - -/** - * Internal utility function - given a graph (should be either a metagraph or a - * bridgegraph) and a node which is known to be in that graph, determine - * the other ends of edges that involve that node in the direction specified - * by whether it's inbound. - * - * For example if you wanted to find the predecessors of a node, you'd call - * this method for the parent's metagraph and bridgegraph, specifying inbound - * as true (look at the source of inbound edges to the specified node). - * - * Discovered target names are appended to the targets array. - */ -function findEdgeTargetsInGraph( - graph: graphlib.Graph, - node: Node, inbound: boolean, targets: Edges): void { - let edges = inbound ? graph.inEdges(node.name) : graph.outEdges(node.name); - _.each(edges, e => { - let metaedge = graph.edge(e); - let targetList = - metaedge.numRegularEdges ? targets.regular : targets.control; - targetList.push(metaedge); - }); -} - -export interface HierarchyParams { - verifyTemplate: boolean; - seriesNodeMinSize: number; - seriesMap: { [name: string]: tf.graph.SeriesGroupingType }; -} - -/** - * @param graph The raw graph. - * @param params Parameters used when building a hierarchy. - */ -export function build(graph: tf.graph.SlimGraph, params: HierarchyParams, - tracker: ProgressTracker): Promise { - let h = new HierarchyImpl(); - let seriesNames: { [name: string]: string } = {}; - return tf.graph.util - .runAsyncTask( - 'Adding nodes', 20, - () => { - // Get all the possible device and XLA cluster names. - let deviceNames = {}; - let xlaClusterNames = {}; - _.each(graph.nodes, (node, nodeName) => { - if (node.device) { - deviceNames[node.device] = true; - } - - if (node.xlaCluster) { - xlaClusterNames[node.xlaCluster] = true; - } - }); - - h.devices = _.keys(deviceNames); - h.xlaClusters = _.keys(xlaClusterNames); - - addNodes(h, graph); - }, - tracker) - .then(() => { - return tf.graph.util.runAsyncTask('Detect series', 20, () => { - if (params.seriesNodeMinSize > 0) { - groupSeries( - h.root, h, seriesNames, params.seriesNodeMinSize, - params.seriesMap); - } - }, tracker); - }) - .then(() => { - return tf.graph.util.runAsyncTask('Adding edges', 30, () => { - addEdges(h, graph, seriesNames); - }, tracker); - }) - .then(() => { - return tf.graph.util.runAsyncTask( - 'Finding similar subgraphs', 30, () => { - h.templates = template.detect(h, params.verifyTemplate); - }, tracker); - }) - .then(() => { - return h; - }); -}; - -export function joinAndAggregateStats( - h: Hierarchy, stats: tf.graph.proto.StepStats) { - // Get all the possible device names. - let deviceNames = {}; - _.each(h.root.leaves(), nodeName => { - let leaf = h.node(nodeName); - if (leaf.device != null) { - deviceNames[leaf.device] = true; - } - }); - h.devices = _.keys(deviceNames); - - // Reset stats for each group node. - _.each(h.getNodeMap(), (node, nodeName) => { - if (node.isGroupNode) { - node.stats = new NodeStats(null); - (node).deviceHistogram = {}; - } - }); - - // Bubble-up the stats and device distribution from leaves to parents. - _.each(h.root.leaves(), nodeName => { - let leaf = h.node(nodeName); - let node = leaf; - while (node.parentNode != null) { - if (leaf.device != null) { - let deviceHistogram = (node.parentNode).deviceHistogram; - deviceHistogram[leaf.device] = (deviceHistogram[leaf.device] || 0) + 1; - } - if (leaf.stats != null) { - node.parentNode.stats.combine(leaf.stats); - } - node = node.parentNode; - } - }); -} - -/** - * Creates the metanodes in the hierarchical graph and assigns parent-child - * relationship between them. - */ -function addNodes(h: Hierarchy, graph: SlimGraph) { - _.each(graph.nodes, (node, nodeName) => { - let path = getHierarchicalPath(node.name); - let parent: Metanode = h.root; - - parent.depth = Math.max(path.length, parent.depth); - - // Create parent metanodes for each depth. For example if the node name - // is 'a/b/c', then create metanodes 'a' and 'a/b', where 'a/b' is a child - // of a. - for (let i = 0; i < path.length; i++) { - parent.depth = Math.max(parent.depth, path.length - i); - parent.cardinality += node.cardinality; - parent.opHistogram[node.op] = (parent.opHistogram[node.op] || 0) + 1; - if (node.device != null) { - parent.deviceHistogram[node.device] = - (parent.deviceHistogram[node.device] || 0) + 1; - } - if (i === path.length - 1) { break; } - let name = path[i]; - let child = h.node(name); - if (!child) { - child = createMetanode(name); - child.parentNode = parent; - h.setNode(name, child); - parent.metagraph.setNode(name, child); - } - parent = child; - } - // Assuming node name is 'a/b/c', assign the OpNode as a child of the - // metanode 'a/b'. - h.setNode(node.name, node); - node.parentNode = parent; - parent.metagraph.setNode(node.name, node); - - // Add each of the in-embeddings and out-embeddings in the hierarchy. - _.each(node.inEmbeddings, function(embedding) { - h.setNode(embedding.name, embedding); - embedding.parentNode = node; - }); - _.each(node.outEmbeddings, function(embedding) { - h.setNode(embedding.name, embedding); - embedding.parentNode = node; - }); - }); -}; - -/** - * For each metanode in the hierarchical graph, this method adds: - * the edges in the metagraph. These are edges between nodes - * that share the same parent. - */ -function addEdges(h: Hierarchy, graph: SlimGraph, - seriesNames: { [name: string]: string }) { - - let nodeIndex = h.getNodeMap(); - - // Ancestor paths for the source and destination nodes of an edge. These are - // reused for each edge rather than allocating new ones. It's about 10% faster - // than allocating new ones on each pass through the loop. - let sourcePath: string[] = []; - let destPath: string[] = []; - - // Insert the ancestor path for a node into the provided array, including the - // node itself. Return the index of the last node inserted (always ROOT). - let getPath = (node: Node, path: string[]): number => { - let i = 0; - while (node) { - path[i++] = node.name; - node = node.parentNode; - } - return i - 1; - }; - - _.each(graph.edges, baseEdge => { - - // Get the hierarchical paths for the source and destination of the edge. - let sourceAncestorIndex = getPath(graph.nodes[baseEdge.v], sourcePath); - let destAncestorIndex = getPath(graph.nodes[baseEdge.w], destPath); - - // If the hierarchical path cannot be found for either endpoint, then we - // cannot create the edge. This happens for example when a node has a - // control dependency on a summary node, which are embedded. - if (sourceAncestorIndex === -1 || destAncestorIndex === -1) { - return; - } - - // Find the lowest shared ancestor between source and dest by looking for - // the highest nodes that differ between their ancestor paths. - while (sourcePath[sourceAncestorIndex] === destPath[destAncestorIndex]) { - sourceAncestorIndex--; - destAncestorIndex--; - if (sourceAncestorIndex < 0 || destAncestorIndex < 0) { - // This would only occur if the two nodes were the same (a cycle in the - // graph), or if one endpoint was a strict ancestor of the other. The - // latter shouldn't happen because we rename nodes which are both - // metanodes and op nodes. E.g. 'A/B' becomes 'A/B/(B)'. - throw Error('No difference found between ancestor paths.'); - } - } - - let sharedAncestorNode = - nodeIndex[sourcePath[sourceAncestorIndex + 1]]; - let sourceAncestorName = sourcePath[sourceAncestorIndex]; - let destAncestorName = destPath[destAncestorIndex]; - - // Find or create the Metaedge which should contain this BaseEdge inside - // the shared ancestor. - let metaedge = - sharedAncestorNode.metagraph.edge(sourceAncestorName, destAncestorName); - if (!metaedge) { - metaedge = createMetaedge(sourceAncestorName, destAncestorName); - sharedAncestorNode.metagraph - .setEdge(sourceAncestorName, destAncestorName, metaedge); - } - if (!sharedAncestorNode.hasNonControlEdges && - !baseEdge.isControlDependency) { - sharedAncestorNode.hasNonControlEdges = true; - } - metaedge.addBaseEdge(baseEdge, h); - }); -}; - -/** - * Using the hierarchy template information, detect series in the provided - * metanode. For each detected series, create a new SeriesNode - * and remove series members from the metanode's metagraph and move them to - * the new series node's metagraph. - * - * @param metanode - * @param hierarchy - * @param seriesNames Map of node names to their series they are contained in. - * This should be provided empty and is populated by this method. - * @param threshold If the series has this many nodes or more, then group them - * into a series. - * @param map Map of series names to their series grouping type, if one has - * been set. - * @return A dictionary from node name to series node name that contains the - * node. - */ -function groupSeries(metanode: Metanode, hierarchy: Hierarchy, - seriesNames: { [name: string]: string }, threshold: number, - map: { [name: string]: tf.graph.SeriesGroupingType }) { - let metagraph = metanode.metagraph; - _.each(metagraph.nodes(), n => { - let child = metagraph.node(n); - if (child.type === tf.graph.NodeType.META) { - groupSeries(child, hierarchy, seriesNames, threshold, map); - } - }); - - let clusters = clusterNodes(metagraph); - let seriesDict = detectSeries(clusters, metagraph); - - // Add each series node to the graph and add its grouped children to its own - // metagraph. - _.each(seriesDict, function(seriesNode: SeriesNode, seriesName: string) { - let nodeMemberNames = seriesNode.metagraph.nodes(); - _.each(nodeMemberNames, n => { - let child = metagraph.node(n); - if (!child.owningSeries) { - child.owningSeries = seriesName; - } - }); - // If the series contains less than the threshold number of nodes and - // this series has not been adding to the series map, then set this - // series to be shown ungrouped in the map. - if (nodeMemberNames.length < threshold && !(seriesNode.name in map)) { - map[seriesNode.name] = tf.graph.SeriesGroupingType.UNGROUP; - } - // If the series is in the map as ungrouped then do not group the series. - if (seriesNode.name in map - && map[seriesNode.name] === tf.graph.SeriesGroupingType.UNGROUP) { - return; - } - hierarchy.setNode(seriesName, seriesNode); // add to the index - metagraph.setNode(seriesName, seriesNode); - _.each(nodeMemberNames, n => { - let child = metagraph.node(n); - seriesNode.metagraph.setNode(n, child); - seriesNode.parentNode = child.parentNode; - seriesNode.cardinality++; - if (child.device != null) { - seriesNode.deviceHistogram[child.device] = - (seriesNode.deviceHistogram[child.device] || 0) + 1; - } - child.parentNode = seriesNode; - seriesNames[n] = seriesName; - // Remove now-grouped node from its original parent's metagraph. - metagraph.removeNode(n); - }); - }); -}; - -/** cluster op-nodes with similar op */ -function clusterNodes(metagraph: graphlib.Graph): - {[clusterId: string]: string[]} { - let result: {[clusterId: string]: string[]} = {}; - return _.reduce(metagraph.nodes(), - (clusters: {[clusterId: string]: string[]}, n: string) => { - let child = metagraph.node(n); - if (child.type === NodeType.META) { - // skip metanodes - return clusters; - } - let template = (child).op; - if (template) { - clusters[template] = clusters[template] || []; - clusters[template].push(child.name); - } - return clusters; - }, result); -} - -/** - * For each cluster of op-nodes based op type, try to detect groupings. - * Infer series name using by trying to find pattern '' in the node - * name. - * - * @param clusters Dictionary output from clusterNodes(). - * @param metagraph - * @return A dictionary from series name => seriesNode - */ -function detectSeries(clusters: {[clusterId: string]: string[]}, - metagraph: graphlib.Graph): - {[seriesName: string]: SeriesNode} { - let seriesDict: {[seriesName: string]: SeriesNode} = {}; - _.each(clusters, function(members, clusterId: string) { - if (members.length <= 1) { return; } // isolated clusters can't make series - - /** @type {Object} A dictionary mapping seriesName to seriesInfoArray, - * which is an array that contains objects with name, id, prefix, suffix, - * and parent properties. - */ - let candidatesDict: {[seriesName: string]: SeriesNode[]} = {}; - - // Group all nodes that have the same name, with the exception of a - // number at the end of the name after an underscore, which is allowed to - // vary. - _.each(members, function(name: string) { - let isGroup = name.charAt(name.length - 1) === '*'; - let namepath = name.split('/'); - let leaf = namepath[namepath.length - 1]; - let parent = namepath.slice(0, namepath.length - 1).join('/'); - let matches = leaf.match(/^(\D*)_(\d+)$/); - - let prefix; - let id; - let suffix = ''; - if (matches) { // if found '' in the name, assign id. - prefix = matches[1]; // the front non-numeric characters - id = matches[2]; // the digits - } else { // for node without '_', make them zero-th items. - prefix = isGroup ? leaf.substr(0, leaf.length - 1) : leaf; - id = 0; - suffix = isGroup ? '*' : ''; - } - let seriesName = getSeriesNodeName(prefix, suffix, parent); - candidatesDict[seriesName] = candidatesDict[seriesName] || []; - let seriesNode = createSeriesNode(prefix, suffix, parent, +id, name); - candidatesDict[seriesName].push(seriesNode); - }); - - // In each group of nodes, group nodes in bunches that have monotonically - // increasing numbers in their names. Each of these bunches is a series. - _.each(candidatesDict, function(seriesInfoArray: SeriesNode[], seriesName) { - if (seriesInfoArray.length < 2) { - return; - } - seriesInfoArray.sort(function(a, b) { - return (+a.clusterId) - (+b.clusterId); - }); - - // Loop through the nodes sorted by its detected series number, grouping - // all nodes with monotonically-increasing series numbers. - let seriesNodes = [seriesInfoArray[0]]; - for (let index = 1; index < seriesInfoArray.length; index++) { - let nextNode = seriesInfoArray[index]; - if (nextNode.clusterId === seriesNodes[seriesNodes.length - 1].clusterId - + 1) { - seriesNodes.push(nextNode); - continue; - } - addSeriesToDict(seriesNodes, seriesDict, +clusterId, metagraph); - seriesNodes = [nextNode]; - } - addSeriesToDict(seriesNodes, seriesDict, +clusterId, metagraph); - }); - }); - return seriesDict; -} - -/** - * Add a series to the provided dictionary mapping series names to series. - * - * @param seriesNodes the nodes in the series. Contains - * name, id, prefix, suffix and parent properties of the node. - * @param seriesDict the dictionary of series - * @param clusterId ID of the template of the nodes of the series - * @param metagraph - */ -function addSeriesToDict(seriesNodes: SeriesNode[], - seriesDict: {[seriesName: string]: SeriesNode}, - clusterId: number, - metagraph: graphlib.Graph) { - if (seriesNodes.length > 1) { - let curSeriesName = getSeriesNodeName( - seriesNodes[0].prefix, seriesNodes[0].suffix, - seriesNodes[0].parent, seriesNodes[0].clusterId, - seriesNodes[seriesNodes.length - 1].clusterId); - let curSeriesNode = createSeriesNode(seriesNodes[0].prefix, - seriesNodes[0].suffix, seriesNodes[0].parent, clusterId, - curSeriesName); - _.each(seriesNodes, function(node) { - curSeriesNode.ids.push(node.clusterId); - curSeriesNode.metagraph.setNode(node.name, metagraph.node(node.name)); - }); - seriesDict[curSeriesName] = curSeriesNode; - } -} - -} // close module tf.graph.hierarchy diff --git a/tensorflow/tensorboard/components/tf_graph_common/layout.ts b/tensorflow/tensorboard/components/tf_graph_common/layout.ts deleted file mode 100644 index 1019e4f2694..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/layout.ts +++ /dev/null @@ -1,760 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -module tf.graph.layout { - -/** Set of parameters that define the look and feel of the graph. */ -export const PARAMS = { - animation: { - /** Default duration for graph animations in ms. */ - duration: 250 - }, - graph: { - /** Graph parameter for metanode. */ - meta: { - /** - * Dagre's nodesep param - number of pixels that - * separate nodes horizontally in the layout. - * - * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout - */ - nodeSep: 5, - /** - * Dagre's ranksep param - number of pixels - * between each rank in the layout. - * - * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout - */ - rankSep: 25, - /** - * Dagre's edgesep param - number of pixels that separate - * edges horizontally in the layout. - */ - edgeSep: 5, - }, - /** Graph parameter for metanode. */ - series: { - /** - * Dagre's nodesep param - number of pixels that - * separate nodes horizontally in the layout. - * - * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout - */ - nodeSep: 5, - /** - * Dagre's ranksep param - number of pixels - * between each rank in the layout. - * - * See https://github.com/cpettitt/dagre/wiki#configuring-the-layout - */ - rankSep: 25, - /** - * Dagre's edgesep param - number of pixels that separate - * edges horizontally in the layout. - */ - edgeSep: 5 - }, - /** - * Padding is used to correctly position the graph SVG inside of its parent - * element. The padding amounts are applied using an SVG transform of X and - * Y coordinates. - */ - padding: {paddingTop: 40, paddingLeft: 20} - }, - subscene: { - meta: { - paddingTop: 10, - paddingBottom: 10, - paddingLeft: 10, - paddingRight: 10, - /** - * Used to leave room for the label on top of the highest node in - * the core graph. - */ - labelHeight: 20, - /** X-space between each extracted node and the core graph. */ - extractXOffset: 15, - /** Y-space between each extracted node. */ - extractYOffset: 20 - }, - series: { - paddingTop: 10, - paddingBottom: 10, - paddingLeft: 10, - paddingRight: 10, - labelHeight: 10 - } - }, - nodeSize: { - /** Size of meta nodes. */ - meta: { - radius: 5, - width: 60, - maxLabelWidth: 52, - /** A scale for the node's height based on number of nodes inside */ - // Hack - set this as an any type to avoid issues in exporting a type - // from an external module. - height: (d3 as any).scaleLinear().domain([1, 200]).range([15, 60]).clamp(true), - /** The radius of the circle denoting the expand button. */ - expandButtonRadius: 3 - }, - /** Size of op nodes. */ - op: { - width: 15, - height: 6, - radius: 3, // for making annotation touching ellipse - labelOffset: -8, - maxLabelWidth: 30 - }, - /** Size of series nodes. */ - series: { - expanded: { - // For expanded series nodes, width and height will be - // computed to account for the subscene. - radius: 10, - labelOffset: 0, - }, - vertical: { - // When unexpanded, series whose underlying metagraphs contain - // one or more non-control edges will show as a vertical stack - // of ellipses. - width: 16, - height: 13, - labelOffset: -13, - }, - horizontal: { - // When unexpanded, series whose underlying metagraphs contain - // no non-control edges will show as a horizontal stack of - // ellipses. - width: 24, - height: 8, - radius: 10, // Forces annotations to center line. - labelOffset: -10, - }, - }, - /** Size of bridge nodes. */ - bridge: { - // NOTE: bridge nodes will normally be invisible, but they must - // take up some space so that the layout step leaves room for - // their edges. - width: 20, - height: 20, - radius: 2, - labelOffset: 0 - } - }, - shortcutSize: { - /** Size of shortcuts for op nodes */ - op: {width: 10, height: 4}, - /** Size of shortcuts for meta nodes */ - meta: {width: 12, height: 4, radius: 1}, - /** Size of shortcuts for series nodes */ - series: { - width: 14, - height: 4, - } - }, - annotations: { - /** Maximum possible width of the bounding box for in annotations */ - inboxWidth: 50, - /** Maximum possible width of the bounding box for out annotations */ - outboxWidth: 50, - /** X-space between the shape and each annotation-node. */ - xOffset: 10, - /** Y-space between each annotation-node. */ - yOffset: 3, - /** X-space between each annotation-node and its label. */ - labelOffset: 2, - /** Defines the max width for annotation label */ - maxLabelWidth: 120 - }, - constant: {size: {width: 4, height: 4}}, - series: { - /** Maximum number of repeated item for unexpanded series node. */ - maxStackCount: 3, - /** - * Positioning offset ratio for collapsed stack - * of parallel series (series without edges between its members). - */ - parallelStackOffsetRatio: 0.2, - /** - * Positioning offset ratio for collapsed stack - * of tower series (series with edges between its members). - */ - towerStackOffsetRatio: 0.5 - }, - minimap: { - /** The maximum width/height the minimap can have. */ - size: 150 - } -}; - -/** Calculate layout for a scene of a group node. */ -export function layoutScene(renderNodeInfo: render.RenderGroupNodeInfo): void { - // Update layout, size, and annotations of its children nodes and edges. - if (renderNodeInfo.node.isGroupNode) { - layoutChildren(renderNodeInfo); - } - - // Update position of its children nodes and edges - if (renderNodeInfo.node.type === NodeType.META) { - layoutMetanode(renderNodeInfo); - } else if (renderNodeInfo.node.type === NodeType.SERIES) { - layoutSeriesNode(renderNodeInfo); - } -}; - -/** - * Updates the total width of an unexpanded node which includes the size of its - * in and out annotations. - */ -function updateTotalWidthOfNode(renderInfo: render.RenderNodeInfo): void { - renderInfo.inboxWidth = renderInfo.inAnnotations.list.length > 0 ? - PARAMS.annotations.inboxWidth : 0; - renderInfo.outboxWidth = renderInfo.outAnnotations.list.length > 0 ? - PARAMS.annotations.outboxWidth : 0; - // Assign the width of the core box (the main shape of the node). - renderInfo.coreBox.width = renderInfo.width; - renderInfo.coreBox.height = renderInfo.height; - // TODO(jimbo): Account for font width rather than using a magic number. - let labelLength = renderInfo.node.name.length - - renderInfo.node.name.lastIndexOf(NAMESPACE_DELIM) - 1; - let charWidth = 3; // 3 pixels per character. - // Compute the total width of the node. - renderInfo.width = Math.max(renderInfo.coreBox.width + - renderInfo.inboxWidth + renderInfo.outboxWidth, - labelLength * charWidth); - -} - -/** - * Update layout, size, and annotations of its children nodes and edges. - */ -function layoutChildren(renderNodeInfo: render.RenderGroupNodeInfo): void { - let children = renderNodeInfo.coreGraph.nodes().map(n => { - return renderNodeInfo.coreGraph.node(n); - }).concat(renderNodeInfo.isolatedInExtract, - renderNodeInfo.isolatedOutExtract); - - _.each(children, childNodeInfo => { - // Set size of each child - switch (childNodeInfo.node.type) { - case NodeType.OP: - _.extend(childNodeInfo, PARAMS.nodeSize.op); - break; - case NodeType.BRIDGE: - _.extend(childNodeInfo, PARAMS.nodeSize.bridge); - break; - case NodeType.META: - if (!childNodeInfo.expanded) { - // Set fixed width and scalable height based on cardinality - _.extend(childNodeInfo, PARAMS.nodeSize.meta); - childNodeInfo.height = - PARAMS.nodeSize.meta.height(childNodeInfo.node.cardinality); - } else { - let childGroupNodeInfo = - childNodeInfo; - layoutScene(childGroupNodeInfo); // Recursively layout its subscene. - } - break; - case NodeType.SERIES: - if (childNodeInfo.expanded) { - _.extend(childNodeInfo, PARAMS.nodeSize.series.expanded); - let childGroupNodeInfo = - childNodeInfo; - layoutScene(childGroupNodeInfo); // Recursively layout its subscene. - } else { - let childGroupNodeInfo = - childNodeInfo; - let seriesParams = - childGroupNodeInfo.node.hasNonControlEdges ? - PARAMS.nodeSize.series.vertical : - PARAMS.nodeSize.series.horizontal; - _.extend(childNodeInfo, seriesParams); - } - break; - default: - throw Error('Unrecognized node type: ' + childNodeInfo.node.type); - } - // Compute total width of un-expanded nodes. Width of expanded nodes - // has already been computed. - if (!childNodeInfo.expanded) { - updateTotalWidthOfNode(childNodeInfo); - } - // Layout each child's annotations - layoutAnnotation(childNodeInfo); - }); -} - -/** - * Calculate layout for a graph using dagre - * @param graph the graph to be laid out - * @param params layout parameters - * @return width and height of the core graph - */ -function dagreLayout( - graph: graphlib.Graph, - params): {height: number, width: number} { - _.extend(graph.graph(), { - nodesep: params.nodeSep, - ranksep: params.rankSep, - edgesep: params.edgeSep - }); - let bridgeNodeNames = []; - let nonBridgeNodeNames = []; - - // Split out nodes into bridge and non-bridge nodes, and calculate the total - // width we should use for bridge nodes. - _.each(graph.nodes(), nodeName => { - let nodeInfo = graph.node(nodeName); - if (nodeInfo.node.type === NodeType.BRIDGE) { - bridgeNodeNames.push(nodeName); - } else { - nonBridgeNodeNames.push(nodeName); - } - }); - - // If there are no non-bridge nodes, then the graph has zero size. - if (!nonBridgeNodeNames.length) { - return { - width: 0, - height: 0, - }; - } - dagre.layout(graph); - - // Calculate the true bounding box of the graph by iterating over nodes and - // edges rather than accepting dagre's word for it. In particular, we should - // ignore the extra-wide bridge nodes and bridge edges, and allow for - // annotation boxes and labels. - let minX = Infinity; - let minY = Infinity; - let maxX = -Infinity; - let maxY = -Infinity; - _.each(nonBridgeNodeNames, nodeName => { - let nodeInfo = graph.node(nodeName); - let w = 0.5 * nodeInfo.width; - let x1 = nodeInfo.x - w; - let x2 = nodeInfo.x + w; - minX = x1 < minX ? x1 : minX; - maxX = x2 > maxX ? x2 : maxX; - // TODO(jimbo): Account for the height of labels above op nodes here. - let h = 0.5 * nodeInfo.height; - let y1 = nodeInfo.y - h; - let y2 = nodeInfo.y + h; - minY = y1 < minY ? y1 : minY; - maxY = y2 > maxY ? y2 : maxY; - }); - _.each(graph.edges(), edgeObj => { - let edgeInfo = graph.edge(edgeObj); - if (edgeInfo.structural) { - return; // Skip structural edges from min/max calculations. - } - - // Since the node size passed to dagre includes the in and out - // annotations, the endpoints of the edge produced by dagre may not - // point to the actual node shape (rectangle, ellipse). We correct the - // end-points by finding the intersection of a line between the - // next-to-last (next-to-first) point and the destination (source) - // rectangle. - let sourceNode = graph.node(edgeInfo.metaedge.v); - let destNode = graph.node(edgeInfo.metaedge.w); - - // Straight 3-points edges are special case, since they are curved after - // our default correction. To keep them straight, we remove the mid point - // and correct the first and the last point to be the center of the - // source and destination node respectively. - if (edgeInfo.points.length === 3 && isStraightLine(edgeInfo.points)) { - if (sourceNode != null) { - let cxSource = sourceNode.expanded ? - sourceNode.x : computeCXPositionOfNodeShape(sourceNode); - edgeInfo.points[0].x = cxSource; - } - if (destNode != null) { - let cxDest = destNode.expanded ? - destNode.x : computeCXPositionOfNodeShape(destNode); - edgeInfo.points[2].x = cxDest; - } - // Remove the middle point so the edge doesn't curve. - edgeInfo.points = [edgeInfo.points[0], edgeInfo.points[1]]; - } - // Correct the destination endpoint of the edge. - let nextToLastPoint = edgeInfo.points[edgeInfo.points.length - 2]; - // The destination node might be null if this is a bridge edge. - if (destNode != null) { - edgeInfo.points[edgeInfo.points.length - 1] = - intersectPointAndNode(nextToLastPoint, destNode); - } - // Correct the source endpoint of the edge. - let secondPoint = edgeInfo.points[1]; - // The source might be null if this is a bridge edge. - if (sourceNode != null) { - edgeInfo.points[0] = intersectPointAndNode(secondPoint, sourceNode); - } - - _.each(edgeInfo.points, (point: render.Point) => { - minX = point.x < minX ? point.x : minX; - maxX = point.x > maxX ? point.x : maxX; - minY = point.y < minY ? point.y : minY; - maxY = point.y > maxY ? point.y : maxY; - }); - }); - - // Shift all nodes and edge points to account for the left-padding amount, - // and the invisible bridge nodes. - _.each(graph.nodes(), nodeName => { - let nodeInfo = graph.node(nodeName); - nodeInfo.x -= minX; - nodeInfo.y -= minY; - }); - _.each(graph.edges(), edgeObj => { - _.each(graph.edge(edgeObj).points, (point: render.Point) => { - point.x -= minX; - point.y -= minY; - }); - }); - - return { - width: maxX - minX, - height: maxY - minY - }; -} - -/** Layout a metanode. Only called for an expanded node. */ -function layoutMetanode(renderNodeInfo: render.RenderGroupNodeInfo): void { - // First, copy params specific to meta nodes onto this render info object. - let params = PARAMS.subscene.meta; - _.extend(renderNodeInfo, params); - // Invoke dagre.layout() on the core graph and record the bounding box - // dimensions. - _.extend(renderNodeInfo.coreBox, - dagreLayout(renderNodeInfo.coreGraph, PARAMS.graph.meta)); - - // Calculate the position of nodes in isolatedInExtract relative to the - // top-left corner of inExtractBox (the bounding box for all inExtract nodes) - // and calculate the size of the inExtractBox. - let maxInExtractWidth = _.max(renderNodeInfo.isolatedInExtract, - renderNode => renderNode.width).width; - renderNodeInfo.inExtractBox.width = maxInExtractWidth != null ? - maxInExtractWidth : 0; - - renderNodeInfo.inExtractBox.height = - _.reduce(renderNodeInfo.isolatedInExtract, (height, child, i) => { - let yOffset = i > 0 ? params.extractYOffset : 0; - // use width/height here to avoid overlaps between extracts - child.x = 0; - child.y = height + yOffset + child.height / 2; - return height + yOffset + child.height; - }, 0); - - // Calculate the position of nodes in isolatedOutExtract relative to the - // top-left corner of outExtractBox (the bounding box for all outExtract - // nodes) and calculate the size of the outExtractBox. - let maxOutExtractWidth = _.max(renderNodeInfo.isolatedOutExtract, - renderNode => renderNode.width).width; - renderNodeInfo.outExtractBox.width = maxOutExtractWidth != null ? - maxOutExtractWidth : 0; - - renderNodeInfo.outExtractBox.height = - _.reduce(renderNodeInfo.isolatedOutExtract, (height, child, i) => { - let yOffset = i > 0 ? params.extractYOffset : 0; - // use width/height here to avoid overlaps between extracts - child.x = 0; - child.y = height + yOffset + child.height / 2; - return height + yOffset + child.height; - }, 0); - - // Compute the total padding between the core graph, in-extract and - // out-extract boxes. - let numParts = 0; - if (renderNodeInfo.isolatedInExtract.length > 0) { - numParts++; - } - if (renderNodeInfo.isolatedOutExtract.length > 0) { - numParts++; - } - if (renderNodeInfo.coreGraph.nodeCount() > 0) { - numParts++; - } - let offset = PARAMS.subscene.meta.extractXOffset; - let padding = numParts <= 1 ? 0 : (numParts <= 2 ? offset : 2 * offset); - - // Add the in-extract and out-extract width to the core box width. - renderNodeInfo.coreBox.width += renderNodeInfo.inExtractBox.width + - renderNodeInfo.outExtractBox.width + padding; - renderNodeInfo.coreBox.height = - params.labelHeight + - Math.max( - renderNodeInfo.inExtractBox.height, - renderNodeInfo.coreBox.height, - renderNodeInfo.outExtractBox.height - ); - // Determine the whole metanode's width (from left to right). - renderNodeInfo.width = renderNodeInfo.coreBox.width + - params.paddingLeft + params.paddingRight; - - // Determine the whole metanode's height (from top to bottom). - renderNodeInfo.height = - renderNodeInfo.paddingTop + - renderNodeInfo.coreBox.height + - renderNodeInfo.paddingBottom; -} - -/** - * Calculate layout for series node's core graph. Only called for an expanded - * series. - */ -function layoutSeriesNode(node: render.RenderGroupNodeInfo): void { - let graph = node.coreGraph; - - let params = PARAMS.subscene.series; - _.extend(node, params); - - // Layout the core. - _.extend(node.coreBox, dagreLayout(node.coreGraph, PARAMS.graph.series)); - - _.each(graph.nodes(), nodeName => { - graph.node(nodeName).excluded = false; - }); - - // Series do not have in/outExtractBox so no need to include them here. - node.width = node.coreBox.width + params.paddingLeft + params.paddingRight; - node.height = node.coreBox.height + params.paddingTop + params.paddingBottom; -} - -/** - * Calculate layout for annotations of a given node. - * This will modify positions of the given node and its annotations. - * - * @see tf.graph.render.Node and tf.graph.render.Annotation - * for description of each property of each render node. - * - */ -function layoutAnnotation(renderNodeInfo: render.RenderNodeInfo): void { - // If the render node is an expanded metanode, then its annotations will not - // be visible and we should skip the annotation calculations. - if (renderNodeInfo.expanded) { - return; - } - - let inAnnotations = renderNodeInfo.inAnnotations.list; - let outAnnotations = renderNodeInfo.outAnnotations.list; - - // Calculate size for in-annotations - _.each(inAnnotations, a => sizeAnnotation(a)); - - // Calculate size for out-annotations - _.each(outAnnotations, a => sizeAnnotation(a)); - - let params = PARAMS.annotations; - - // Calculate annotation node position (a.dx, a.dy) - // and total height for in-annotations - // After this chunk of code: - // inboxHeight = sum of annotation heights+ (annotation.length - 1 * yOffset) - let inboxHeight = _.reduce(inAnnotations, - (height, a, i) => { - let yOffset = i > 0 ? params.yOffset : 0; - a.dx = -(renderNodeInfo.coreBox.width + a.width) / 2 - params.xOffset; - a.dy = height + yOffset + a.height / 2; - return height + yOffset + a.height; - }, 0); - - _.each(inAnnotations, a => { - a.dy -= inboxHeight / 2; - - a.labelOffset = params.labelOffset; - }); - - // Calculate annotation node position (a.dx, a.dy) - // and total height for out-annotations - // After this chunk of code: - // outboxHeight = sum of annotation heights + - // (annotation.length - 1 * yOffset) - let outboxHeight = _.reduce(outAnnotations, - (height, a, i) => { - let yOffset = i > 0 ? params.yOffset : 0; - a.dx = (renderNodeInfo.coreBox.width + a.width) / 2 + params.xOffset; - a.dy = height + yOffset + a.height / 2; - return height + yOffset + a.height; - }, 0); - - _.each(outAnnotations, a => { - // adjust by (half of ) the total height - // so dy is relative to the host node's center. - a.dy -= outboxHeight / 2; - - a.labelOffset = params.labelOffset; - }); - - // Creating scales for touch point between the in-annotation edges - // and their hosts. - - let inTouchHeight = - Math.min(renderNodeInfo.height / 2 - renderNodeInfo.radius, - inboxHeight / 2); - inTouchHeight = inTouchHeight < 0 ? 0 : inTouchHeight; - - let inY = d3.scaleLinear() - .domain([0, inAnnotations.length - 1]) - .range([-inTouchHeight, inTouchHeight]); - - // Calculate annotation edge position - _.each(inAnnotations, (a, i) => { - a.points = [ - // The annotation node end - { - dx: a.dx + a.width / 2, - dy: a.dy - }, - - // The host node end - { - dx: - renderNodeInfo.coreBox.width / 2, - // only use scale if there are more than one, - // otherwise center it vertically - dy: inAnnotations.length > 1 ? inY(i) : 0 - } - ]; - }); - - // Creating scales for touch point between the out-annotation edges - // and their hosts. - let outTouchHeight = - Math.min(renderNodeInfo.height / 2 - renderNodeInfo.radius, - outboxHeight / 2); - outTouchHeight = outTouchHeight < 0 ? 0 : outTouchHeight; - let outY = d3.scaleLinear() - .domain([0, outAnnotations.length - 1]) - .range([-outTouchHeight, outTouchHeight]); - - _.each(outAnnotations, (a, i) => { - // Add point from the border of the annotation node - a.points = [ - // The host node end - { - dx: renderNodeInfo.coreBox.width / 2, - // only use scale if there are more than one, - // otherwise center it vertically - dy: outAnnotations.length > 1 ? outY(i) : 0 - }, - // The annotation node end - { - dx: a.dx - a.width / 2, - dy: a.dy - } - ]; - }); - - renderNodeInfo.height = - Math.max(renderNodeInfo.height, inboxHeight, outboxHeight); -} - -/** - * Set size of an annotation node. - */ -function sizeAnnotation(a: render.Annotation): void { - switch (a.annotationType) { - case render.AnnotationType.CONSTANT: - _.extend(a, PARAMS.constant.size); - break; - case render.AnnotationType.SHORTCUT: - if (a.node.type === NodeType.OP) { - _.extend(a, PARAMS.shortcutSize.op); - } else if (a.node.type === NodeType.META) { - _.extend(a, PARAMS.shortcutSize.meta); - } else if (a.node.type === NodeType.SERIES) { - _.extend(a, PARAMS.shortcutSize.series); - } else { - throw Error('Invalid node type: ' + a.node.type); - } - break; - case render.AnnotationType.SUMMARY: - _.extend(a, PARAMS.constant.size); - break; - } -} - -/** - * Determines the center position of the node's shape. The position depends - * on if the node has in and out-annotations. - */ -export function computeCXPositionOfNodeShape(renderInfo: render.RenderNodeInfo): - number { - if (renderInfo.expanded) { - return renderInfo.x; - } - let dx = renderInfo.inAnnotations.list.length ? renderInfo.inboxWidth : 0; - return renderInfo.x - renderInfo.width / 2 + dx + - renderInfo.coreBox.width / 2; -} - -/** Returns the angle (in degrees) between two points. */ -function angleBetweenTwoPoints(a: render.Point, b: render.Point): number { - let dx = b.x - a.x; - let dy = b.y - a.y; - return 180 * Math.atan(dy / dx) / Math.PI; -} - -/** - * Returns if a line going through the specified points is a straight line. - */ -function isStraightLine(points: render.Point[]) { - let angle = angleBetweenTwoPoints(points[0], points[1]); - for (let i = 1; i < points.length - 1; i++) { - let newAngle = angleBetweenTwoPoints(points[i], points[i + 1]); - // Have a tolerance of 1 degree. - if (Math.abs(newAngle - angle) > 1) { - return false; - } - angle = newAngle; - } - return true; -} - -/** - * Returns the intersection of a line between the provided point - * and the provided rectangle. - */ -function intersectPointAndNode( - point: render.Point, node: render.RenderNodeInfo): render.Point { - // cx and cy are the center of the rectangle. - let cx = node.expanded ? - node.x : computeCXPositionOfNodeShape(node); - let cy = node.y; - // Calculate the slope - let dx = point.x - cx; - let dy = point.y - cy; - let w = node.expanded ? node.width : node.coreBox.width; - let h = node.expanded ? node.height : node.coreBox.height; - let deltaX, deltaY; - if (Math.abs(dy) * w / 2 > Math.abs(dx) * h / 2) { - // The intersection is above or below the rectangle. - if (dy < 0) { - h = -h; - } - deltaX = dy === 0 ? 0 : h / 2 * dx / dy; - deltaY = h / 2; - } else { - // The intersection is left or right of the rectangle. - if (dx < 0) { - w = -w; - } - deltaX = w / 2; - deltaY = dx === 0 ? 0 : w / 2 * dy / dx; - } - return {x: cx + deltaX, y: cy + deltaY}; -} - -} // close module diff --git a/tensorflow/tensorboard/components/tf_graph_common/minimap.ts b/tensorflow/tensorboard/components/tf_graph_common/minimap.ts deleted file mode 100644 index 8129df3a426..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/minimap.ts +++ /dev/null @@ -1,328 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -module tf.scene { - -/** Show minimap when the viewpoint area is less than X% of the whole area. */ -const FRAC_VIEWPOINT_AREA: number = 0.8; - -export class Minimap { - /** The minimap container. */ - private minimap: HTMLElement; - /** The canvas used for drawing the mini version of the svg. */ - private canvas: HTMLCanvasElement; - /** A buffer canvas used for temporary drawing to avoid flickering. */ - private canvasBuffer: HTMLCanvasElement; - private download: HTMLLinkElement; - private downloadCanvas: HTMLCanvasElement; - - /** The minimap svg used for holding the viewpoint rectangle. */ - private minimapSvg: SVGSVGElement; - /** The rectangle showing the current viewpoint. */ - private viewpoint: SVGRectElement; - /** - * The scale factor for the minimap. The factor is determined automatically - * so that the minimap doesn't violate the maximum width/height specified - * in the constructor. The minimap maintains the same aspect ratio as the - * original svg. - */ - private scaleMinimap: number; - /** The main svg element. */ - private svg: SVGSVGElement; - /** The svg group used for panning and zooming the main svg. */ - private zoomG: SVGGElement; - /** The zoom behavior of the main svg. */ - private mainZoom: d3.ZoomBehavior; - /** The maximum width and height for the minimap. */ - private maxWandH: number; - /** The last translation vector used in the main svg. */ - private translate: [number, number]; - /** The last scaling factor used in the main svg. */ - private scaleMain: number; - /** The coordinates of the viewpoint rectangle. */ - private viewpointCoord: {x: number, y: number}; - /** The current size of the minimap */ - private minimapSize: {width: number, height: number}; - /** Padding (px) due to the main labels of the graph. */ - private labelPadding: number; - /** - * Constructs a new minimap. - * - * @param svg The main svg element. - * @param zoomG The svg group used for panning and zooming the main svg. - * @param mainZoom The main zoom behavior. - * @param minimap The minimap container. - * @param maxWandH The maximum width/height for the minimap. - * @param labelPadding Padding in pixels due to the main graph labels. - */ - constructor(svg: SVGSVGElement, zoomG: SVGGElement, - mainZoom: d3.ZoomBehavior, minimap: HTMLElement, - maxWandH: number, labelPadding: number) { - this.svg = svg; - this.labelPadding = labelPadding; - this.zoomG = zoomG; - this.mainZoom = mainZoom; - this.maxWandH = maxWandH; - let $minimap = d3.select(minimap); - // The minimap will have 2 main components: the canvas showing the content - // and an svg showing a rectangle of the currently zoomed/panned viewpoint. - let $minimapSvg = $minimap.select('svg'); - - // Make the viewpoint rectangle draggable. - let $viewpoint = $minimapSvg.select('rect'); - let dragmove = (d) => { - this.viewpointCoord.x = (d3.event).x; - this.viewpointCoord.y = (d3.event).y; - this.updateViewpoint(); - }; - this.viewpointCoord = {x: 0, y: 0}; - let drag = d3.drag().subject(Object).on('drag', dragmove); - $viewpoint.datum(this.viewpointCoord as any).call(drag); - - // Make the minimap clickable. - $minimapSvg.on('click', () => { - if ((d3.event).defaultPrevented) { - // This click was part of a drag event, so suppress it. - return; - } - // Update the coordinates of the viewpoint. - let width = Number($viewpoint.attr('width')); - let height = Number($viewpoint.attr('height')); - let clickCoords = d3.mouse($minimapSvg.node() as any); - this.viewpointCoord.x = clickCoords[0] - width / 2; - this.viewpointCoord.y = clickCoords[1] - height / 2; - this.updateViewpoint(); - }); - this.viewpoint = $viewpoint.node(); - this.minimapSvg = $minimapSvg.node(); - this.minimap = minimap; - this.canvas = $minimap.select('canvas.first').node(); - this.canvasBuffer = - $minimap.select('canvas.second').node(); - this.downloadCanvas = - $minimap.select('canvas.download').node(); - d3.select(this.downloadCanvas).style('display', 'none'); - this.update(); - } - - /** - * Updates the position and the size of the viewpoint rectangle. - * It also notifies the main svg about the new panned position. - */ - private updateViewpoint(): void { - // Update the coordinates of the viewpoint rectangle. - d3.select(this.viewpoint) - .attr('x', this.viewpointCoord.x) - .attr('y', this.viewpointCoord.y); - // Update the translation vector of the main svg to reflect the - // new viewpoint. - let mainX = - this.viewpointCoord.x * this.scaleMain / this.scaleMinimap; - let mainY = - this.viewpointCoord.y * this.scaleMain / this.scaleMinimap; - d3.select(this.svg).call( - this.mainZoom.transform, - d3.zoomIdentity.translate(mainX, mainY).scale(this.scaleMain)); - } - - /** - * Redraws the minimap. Should be called whenever the main svg - * was updated (e.g. when a node was expanded). - */ - update(): void { - let sceneSize = null; - try { - // Get the size of the entire scene. - sceneSize = this.zoomG.getBBox(); - if (sceneSize.width === 0) { - // There is no scene anymore. We have been detached from the dom. - return; - } - } catch (e) { - // Firefox produced NS_ERROR_FAILURE if we have been - // detached from the dom. - return; - } - let $download = d3.select('#graphdownload'); - this.download = $download.node(); - $download.on('click', d => { - this.download.href = this.downloadCanvas.toDataURL('image/png'); - }); - - let $svg = d3.select(this.svg); - // Read all the style rules in the document and embed them into the svg. - // The svg needs to be self contained, i.e. all the style rules need to be - // embedded so the canvas output matches the origin. - let stylesText = ''; - for (let k = 0; k < document.styleSheets.length; k++) { - try { - let cssRules = (document.styleSheets[k]).cssRules || - (document.styleSheets[k]).rules; - if (cssRules == null) { - continue; - } - for (let i = 0; i < cssRules.length; i++) { - // Remove tf-* selectors from the styles. - stylesText += - cssRules[i].cssText.replace(/ ?tf-[\w-]+ ?/g, '') + '\n'; - } - } catch (e) { - if (e.name !== 'SecurityError') { - throw e; - } - } - } - - // Temporarily add the css rules to the main svg. - let svgStyle = $svg.append('style'); - svgStyle.text(stylesText); - - // Temporarily remove the zoom/pan transform from the main svg since we - // want the minimap to show a zoomed-out and centered view. - let $zoomG = d3.select(this.zoomG); - let zoomTransform = $zoomG.attr('transform'); - $zoomG.attr('transform', null); - - // Since we add padding, account for that here. - sceneSize.height += this.labelPadding * 2; - sceneSize.width += this.labelPadding * 2; - - // Temporarily assign an explicit width/height to the main svg, since - // it doesn't have one (uses flex-box), but we need it for the canvas - // to work. - $svg - .attr('width', sceneSize.width) - .attr('height', sceneSize.height); - - // Since the content inside the svg changed (e.g. a node was expanded), - // the aspect ratio have also changed. Thus, we need to update the scale - // factor of the minimap. The scale factor is determined such that both - // the width and height of the minimap are <= maximum specified w/h. - this.scaleMinimap = - this.maxWandH / Math.max(sceneSize.width, sceneSize.height); - - this.minimapSize = { - width: sceneSize.width * this.scaleMinimap, - height: sceneSize.height * this.scaleMinimap - }; - - // Update the size of the minimap's svg, the buffer canvas and the - // viewpoint rect. - d3.select(this.minimapSvg).attr(this.minimapSize); - d3.select(this.canvasBuffer).attr(this.minimapSize); - - // Download canvas width and height are multiples of the style width and - // height in order to increase pixel density of the PNG for clarity. - d3.select(this.downloadCanvas).style( - { width: sceneSize.width, height: sceneSize.height }); - d3.select(this.downloadCanvas).attr( - { width: sceneSize.width * 3, height: sceneSize.height * 3 }); - - if (this.translate != null && this.zoom != null) { - // Update the viewpoint rectangle shape since the aspect ratio of the - // map has changed. - requestAnimationFrame(() => this.zoom()); - } - - // Serialize the main svg to a string which will be used as the rendering - // content for the canvas. - let svgXml = (new XMLSerializer()).serializeToString(this.svg); - - // Now that the svg is serialized for rendering, remove the temporarily - // assigned styles, explicit width and height and bring back the pan/zoom - // transform. - svgStyle.remove(); - $svg.attr('width', null).attr('height', null); - - $zoomG.attr('transform', zoomTransform); - let image = new Image(); - image.onload = () => { - // Draw the svg content onto the buffer canvas. - let context = this.canvasBuffer.getContext('2d'); - context.clearRect(0, 0, this.canvasBuffer.width, - this.canvasBuffer.height); - context.drawImage(image, 0, 0, - this.minimapSize.width, this.minimapSize.height); - requestAnimationFrame(() => { - // Hide the old canvas and show the new buffer canvas. - d3.select(this.canvasBuffer).style('display', null); - d3.select(this.canvas).style('display', 'none'); - // Swap the two canvases. - [this.canvas, this.canvasBuffer] = [this.canvasBuffer, this.canvas]; - }); - let downloadContext = this.downloadCanvas.getContext('2d'); - downloadContext.clearRect(0, 0, this.downloadCanvas.width, - this.downloadCanvas.height); - downloadContext.drawImage(image, 0, 0, - this.downloadCanvas.width, this.downloadCanvas.height); - }; - image.onerror = () => { - let blob = new Blob([svgXml], {type: 'image/svg+xml;charset=utf-8'}); - image.src = URL.createObjectURL(blob); - }; - image.src = - 'data:image/svg+xml;charset=utf-8,' + encodeURIComponent(svgXml); - } - - /** - * Handles changes in zooming/panning. Should be called from the main svg - * to notify that a zoom/pan was performed and this minimap will update it's - * viewpoint rectangle. - * - * @param translate The translate vector, or none to use the last used one. - * @param scale The scaling factor, or none to use the last used one. - */ - zoom(transform?: d3.ZoomTransform): void { - if (this.scaleMinimap == null) { - // Scene is not ready yet. - return; - } - // Update the new translate and scale params, only if specified. - if (transform) { - this.translate = [transform.x, transform.y]; - this.scaleMain = transform.k; - } - - // Update the location of the viewpoint rectangle. - let svgRect = this.svg.getBoundingClientRect(); - let $viewpoint = d3.select(this.viewpoint); - this.viewpointCoord.x = -this.translate[0] * this.scaleMinimap / - this.scaleMain; - this.viewpointCoord.y = -this.translate[1] * this.scaleMinimap / - this.scaleMain; - let viewpointWidth = svgRect.width * this.scaleMinimap / this.scaleMain; - let viewpointHeight = svgRect.height * this.scaleMinimap / this.scaleMain; - $viewpoint - .attr('x', this.viewpointCoord.x) - .attr('y', this.viewpointCoord.y) - .attr('width', viewpointWidth) - .attr('height', viewpointHeight); - // Show/hide the minimap depending on the viewpoint area as fraction of the - // whole minimap. - let mapWidth = this.minimapSize.width; - let mapHeight = this.minimapSize.height; - let x = this.viewpointCoord.x; - let y = this.viewpointCoord.y; - let w = Math.min(Math.max(0, x + viewpointWidth), mapWidth) - - Math.min(Math.max(0, x), mapWidth); - let h = Math.min(Math.max(0, y + viewpointHeight), mapHeight) - - Math.min(Math.max(0, y), mapHeight); - let fracIntersect = (w * h) / (mapWidth * mapHeight); - if (fracIntersect < FRAC_VIEWPOINT_AREA) { - this.minimap.classList.remove('hidden'); - } else { - this.minimap.classList.add('hidden'); - } - } -} - -} // close module tf.scene diff --git a/tensorflow/tensorboard/components/tf_graph_common/node.ts b/tensorflow/tensorboard/components/tf_graph_common/node.ts deleted file mode 100644 index f090a51fc4e..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/node.ts +++ /dev/null @@ -1,1072 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -module tf.graph.scene.node { - import RenderNodeInfo = tf.graph.render.RenderNodeInfo; - /** - * Select or Create a 'g.nodes' group to a given sceneGroup - * and builds a number of 'g.node' groups inside the group. - * - * Structure Pattern: - * - * - * - * - * ... - * - * - * ... - * - * - * - * - * node name - * - * - * - * - * ... - * - * - * - * @param sceneGroup selection of the container - * @param nodeData array of render node information to map - * @param sceneElement polymer element - * @return selection of the created nodeGroups - */ - export function buildGroup( - sceneGroup, nodeData: render.RenderNodeInfo[], sceneElement) { - let container = - scene.selectOrCreateChild(sceneGroup, 'g', Class.Node.CONTAINER); - // Select all children and join with data. - // (Note that all children of g.nodes are g.node) - let nodeGroups = - (container as any).selectAll(function() {return this.childNodes;}) - .data(nodeData, (d) => { - // make sure that we don't have to swap shape type - return d.node.name + ':' + d.node.type; - }); - - // ENTER - nodeGroups.enter() - .append('g') - .attr('data-name', d => { return d.node.name; }) - .each(function(d) { - let nodeGroup = d3.select(this); - // index node group for quick stylizing - sceneElement.addNodeGroup(d.node.name, nodeGroup); - }) - .merge(nodeGroups) - // ENTER + UPDATE - .attr('class', d => { return Class.Node.GROUP + ' ' + nodeClass(d); }) - .each(function(d) { - let nodeGroup = d3.select(this); - // Add g.in-annotations (always add -- to keep layer order - // consistent.) - let inAnnotationBox = - scene.selectOrCreateChild(nodeGroup, 'g', Class.Annotation.INBOX); - annotation.buildGroup( - inAnnotationBox, d.inAnnotations, d, sceneElement); - - // Add g.out-annotations (always add -- to keep layer order - // consistent.) - let outAnnotationBox = scene.selectOrCreateChild( - nodeGroup, 'g', Class.Annotation.OUTBOX); - annotation.buildGroup( - outAnnotationBox, d.outAnnotations, d, sceneElement); - - // Build .shape first (background of the node). - let shape = buildShape(nodeGroup, d, Class.Node.SHAPE); - if (d.node.isGroupNode) { - addButton(shape, d, sceneElement); - } - addInteraction(shape, d, sceneElement); - - // Build subscene on the top. - subsceneBuild(nodeGroup, d, sceneElement); - - // Build label last. Should be on top of everything else. - let label = labelBuild(nodeGroup, d, sceneElement); - // Do not add interaction to metanode labels as they live inside the - // metanode shape which already has the same interactions. - addInteraction(label, d, sceneElement, d.node.type === NodeType.META); - - stylize(nodeGroup, d, sceneElement); - position(nodeGroup, d); - }); - - // EXIT - nodeGroups.exit() - .each(function(d) { - // remove all indices on remove - sceneElement.removeNodeGroup(d.node.name); - - let nodeGroup = d3.select(this); - if (d.inAnnotations.list.length > 0) { - nodeGroup.select('.' + Class.Annotation.INBOX) - .selectAll('.' + Class.Annotation.GROUP) - .each(a => { sceneElement.removeAnnotationGroup(a, d); }); - } - if (d.outAnnotations.list.length > 0) { - nodeGroup.select('.' + Class.Annotation.OUTBOX) - .selectAll('.' + Class.Annotation.GROUP) - .each(a => { sceneElement.removeAnnotationGroup(a, d); }); - } - }) - .remove(); - return nodeGroups; -}; - -/** - * Update or remove the subscene of a render group node depending on whether it - * is a expanded. If the node is not a group node, this method has no effect. - * - * @param nodeGroup selection of the container - * @param renderNodeInfo the render information for the node. - * @param sceneElement polymer element. - * @return Selection of the subscene group, or null if node group does not have - * a subscene. Op nodes, bridge nodes and unexpanded group nodes will - * not have a subscene. - */ -function subsceneBuild(nodeGroup, - renderNodeInfo: render.RenderGroupNodeInfo, sceneElement) { - if (renderNodeInfo.node.isGroupNode) { - if (renderNodeInfo.expanded) { - // Recursively build the subscene. - return scene.buildGroup(nodeGroup, renderNodeInfo, sceneElement, - Class.Subscene.GROUP); - } - // Clean out existing subscene if the node is not expanded. - scene.selectChild(nodeGroup, 'g', Class.Subscene.GROUP).remove(); - } - return null; -}; - -/** - * Translate the subscene of the given node group - */ -function subscenePosition(nodeGroup, d: render.RenderNodeInfo) { - let x0 = d.x - d.width / 2.0 + d.paddingLeft; - let y0 = d.y - d.height / 2.0 + d.paddingTop; - - let subscene = scene.selectChild(nodeGroup, 'g', Class.Subscene.GROUP); - scene.translate(subscene, x0, y0); -}; - -/** - * Add an expand/collapse button to a group node - * - * @param selection The group node selection. - * @param d Info about the node being rendered. - * @param sceneElement polymer element. - */ -function addButton(selection, d: render.RenderNodeInfo, sceneElement) { - let group = - scene.selectOrCreateChild(selection, 'g', Class.Node.BUTTON_CONTAINER); - scene.selectOrCreateChild(group, 'circle', Class.Node.BUTTON_CIRCLE); - scene.selectOrCreateChild(group, 'path', Class.Node.EXPAND_BUTTON) - .attr('d', 'M0,-2.2 V2.2 M-2.2,0 H2.2'); - scene.selectOrCreateChild(group, 'path', Class.Node.COLLAPSE_BUTTON) - .attr('d', 'M-2.2,0 H2.2'); - (group as any).on('click', (d: any) => { - // Stop this event's propagation so that it isn't also considered a - // node-select. - (d3.event).stopPropagation(); - sceneElement.fire('node-toggle-expand', {name: d.node.name}); - }); - scene.positionButton(group, d); -}; - -/** - * Fire node-* events when the selection is interacted. - * - * @param disableInteraction When true, have the provided selection - * ignore all pointer events. Used for text labels inside of metanodes, which - * don't need interaction as their surrounding shape has interaction, and if - * given interaction would cause conflicts with the expand/collapse button. - */ -function addInteraction(selection, d: render.RenderNodeInfo, - sceneElement, disableInteraction?: boolean) { - if (disableInteraction) { - selection.attr('pointer-events', 'none'); - return; - } - - let contextMenuFunction = contextmenu.getMenu( - getContextMenu(d.node, sceneElement)); - selection - .on('dblclick', - d => { - sceneElement.fire('node-toggle-expand', {name: d.node.name}); - }) - .on('mouseover', - d => { - // don't send mouseover over expanded group, - // otherwise it is causing too much glitches - if (sceneElement.isNodeExpanded(d)) { - return; - } - - sceneElement.fire('node-highlight', {name: d.node.name}); - }) - .on('mouseout', - d => { - // don't send mouseover over expanded group, - // otherwise it is causing too much glitches - if (sceneElement.isNodeExpanded(d)) { - return; - } - - sceneElement.fire('node-unhighlight', {name: d.node.name}); - }) - .on('click', - d => { - // Stop this event's propagation so that it isn't also considered - // a graph-select. - (d3.event).stopPropagation(); - sceneElement.fire('node-select', {name: d.node.name}); - }) - .on('contextmenu', (d, i) => { - sceneElement.fire('node-select', {name: d.node.name}); - contextMenuFunction.call(d, i); - }); -}; - -/** - * Returns the d3 context menu specification for the provided node. - */ -export function getContextMenu(node: Node, sceneElement) { - let menu = [{ - title: (d): string => { - return getIncludeNodeButtonString(node.include); - }, - action: (elm, d, i) => { - sceneElement.fire('node-toggle-extract', {name: node.name}); - } - }]; - if (canBeInSeries(node)) { - menu.push({ - title: d => { return getGroupSettingLabel(node); }, - action: (elm, d, i) => { - sceneElement.fire( - 'node-toggle-seriesgroup', {name: getSeriesName(node)}); - } - }); - } - return menu; -} - -/** Returns if a node can be part of a grouped series */ -export function canBeInSeries(node: Node) { - return getSeriesName(node) !== null; -} - -/** - * Returns the name of the possible grouped series containing this node. - * Returns null if the node cannot be part of a grouped series of nodes. - */ -export function getSeriesName(node: Node) { - if (!node) { - return null; - } - if (node.type === NodeType.SERIES) { - return node.name; - } - if (node.type === NodeType.OP) { - let op = node; - return op.owningSeries; - } - return null; -} - -/** - * Returns the SeriesNode that represents the series that the provided node - * is contained in (or itself if the provided node is itself a SeriesNode). - * Returns null if the node is not rendered as part of a series. - */ -function getContainingSeries(node: Node) { - let s: SeriesNode = null; - if (!node) { - return null; - } else if (node.type === NodeType.SERIES) { - s = node; - } else if (node.parentNode && node.parentNode.type === NodeType.SERIES) { - s = node.parentNode; - } - return s; -} - -/** - * Returns the label for a button to toggle the group setting of the provided - * node. - */ -export function getGroupSettingLabel(node: Node) { - return tf.graph.getGroupSeriesNodeButtonString( - getContainingSeries(node) !== null ? tf.graph.SeriesGroupingType.GROUP : - tf.graph.SeriesGroupingType.UNGROUP); -} - -/** - * Append svg text for label and assign data. - * @param nodeGroup - * @param renderNodeInfo The render node information for the label. - * @param sceneElement polymer element. - */ -function labelBuild(nodeGroup, renderNodeInfo: render.RenderNodeInfo, - sceneElement) { - let namePath = renderNodeInfo.node.name.split('/'); - let text = namePath[namePath.length - 1]; - - // Truncate long labels for unexpanded Metanodes. - let useFontScale = renderNodeInfo.node.type === NodeType.META && - !renderNodeInfo.expanded; - - let label = scene.selectOrCreateChild(nodeGroup, 'text', Class.Node.LABEL); - - // Make sure the label is visually on top among its siblings. - let labelNode = label.node(); - labelNode.parentNode.appendChild(labelNode); - - label.attr('dy', '.35em').attr('text-anchor', 'middle'); - if (useFontScale) { - if (text.length > sceneElement.maxMetanodeLabelLength) { - text = text.substr(0, sceneElement.maxMetanodeLabelLength - 2) + '...'; - } - let scale = getLabelFontScale(sceneElement); - label.attr('font-size', scale(text.length) + 'px'); - } - - let txtElement = >label.text(text); - enforceLabelWidth(txtElement, renderNodeInfo.node.type, renderNodeInfo); - return label; -} -/** - * This function shortens text which would exceed the maximum pixel width of - * a label. - * - * @param txtElementSelection The text element containing the label's text as d3 - * selection. - * @param nodeType The type of the node the label belongs to. If the node is - * an annotation, the value is -1. Label widths are defined in - * layout.PARAMS.nodeSize.{meta|op|...}.maxLabelWidth for nodes and - * layout.PARAMS.annotations.labelWidth for annotations. - * @param renderNodeInfo The render information about the node, required to - * determine whether META nodes are collapsed or expanded. - */ -export function enforceLabelWidth( - txtElementSelection: d3.Selection, nodeType: NodeType | number, - renderNodeInfo?: render.RenderNodeInfo): any { - // Get text element itself and its on-screen width. - let txtNode = txtElementSelection.node(); - let computedTxtLength = txtNode.getComputedTextLength(); - let labelContent = txtNode.textContent; - - // Get maximum length from settings. - let maxLength = null; - switch (nodeType) { - case NodeType.META: - if (renderNodeInfo && !renderNodeInfo.expanded) { // Only trim text if - // node expanded. - maxLength = layout.PARAMS.nodeSize.meta.maxLabelWidth; - } - break; - - case NodeType.OP: - maxLength = layout.PARAMS.nodeSize.op.maxLabelWidth; - break; - - case -1: - maxLength = layout.PARAMS.annotations.maxLabelWidth; - break; - - default: - break; - } - - // Return if no max length provided for node type, or current label length is - // less than or equal to the provided length limit. - if (maxLength === null || computedTxtLength <= maxLength) { - return; - } - - // Find the index of the character which exceeds the width. - // getSubStringLength performs far better than getComputedTextLength, and - // results in a 3x speed-up on average. - let index = 1; - while (txtNode.getSubStringLength(0, index) < maxLength) { - index++; - } - - // Shorten the label starting at the string length known to be one - // character above max pixel length. - // When shortened the original label's substring is concatenated with - // '...', baseText contains the substring not including the '...'. - let baseText = txtNode.textContent.substr(0, index); - do { - baseText = baseText.substr(0, baseText.length - 1); - - // Recompute text length. - txtNode.textContent = baseText + '...'; - computedTxtLength = txtNode.getComputedTextLength(); - } while (computedTxtLength > maxLength && baseText.length > 0); - - // Add tooltip with full name and return. - return txtElementSelection.append('title').text(labelContent); -} - -/** - * d3 scale used for sizing font of labels, used by labelBuild, - * initialized once by getLabelFontScale. - */ -let fontScale = null; -function getLabelFontScale(sceneElement) { - if (!fontScale) { - fontScale = d3.scaleLinear() - .domain([sceneElement.maxMetanodeLabelLengthLargeFont, - sceneElement.maxMetanodeLabelLength]) - .range([sceneElement.maxMetanodeLabelLengthFontSize, - sceneElement.minMetanodeLabelLengthFontSize]).clamp(true); - } - return fontScale; -} - -/** - * Set label position of a given node group - */ -function labelPosition(nodeGroup, cx: number, cy: number, - yOffset: number) { - scene.selectChild(nodeGroup, 'text', Class.Node.LABEL) - .transition() - .attr('x', cx) - .attr('y', cy + yOffset); -}; - -/** - * Select or append/insert shape for a node and assign renderNode - * as the shape's data. - * - * @param nodeGroup - * @param d Render node information. - * @param nodeClass class for the element. - * @return Selection of the shape. - */ -export function buildShape(nodeGroup, d, nodeClass: string): d3.Selection { - // Create a group to house the underlying visual elements. - let shapeGroup = scene.selectOrCreateChild(nodeGroup, 'g', nodeClass); - // TODO(jimbo): DOM structure should be templated in HTML somewhere, not JS. - switch (d.node.type) { - case NodeType.OP: - scene.selectOrCreateChild(shapeGroup, 'ellipse', Class.Node.COLOR_TARGET); - break; - case NodeType.SERIES: - // Choose the correct stamp to use to represent this series. - let stampType = 'annotation'; - let groupNodeInfo = d; - if (groupNodeInfo.coreGraph) { - stampType = - groupNodeInfo.node.hasNonControlEdges ? 'vertical' : 'horizontal'; - } - let classList = [Class.Node.COLOR_TARGET]; - if (groupNodeInfo.isFadedOut) { - classList.push('faded-ellipse'); - } - scene.selectOrCreateChild(shapeGroup, 'use', classList) - .attr('xlink:href', '#op-series-' + stampType + '-stamp'); - scene.selectOrCreateChild(shapeGroup, 'rect', Class.Node.COLOR_TARGET) - .attr('rx', d.radius).attr('ry', d.radius); - break; - case NodeType.BRIDGE: - scene.selectOrCreateChild(shapeGroup, 'rect', Class.Node.COLOR_TARGET) - .attr('rx', d.radius).attr('ry', d.radius); - break; - case NodeType.META: - scene.selectOrCreateChild(shapeGroup, 'rect', Class.Node.COLOR_TARGET) - .attr('rx', d.radius).attr('ry', d.radius); - break; - default: - throw Error('Unrecognized node type: ' + d.node.type); - } - return shapeGroup; -}; - -export function nodeClass(d: render.RenderNodeInfo) { - switch (d.node.type) { - case NodeType.OP: - return Class.OPNODE; - case NodeType.META: - return Class.METANODE; - case NodeType.SERIES: - return Class.SERIESNODE; - case NodeType.BRIDGE: - return Class.BRIDGENODE; - case NodeType.ELLIPSIS: - return Class.ELLIPSISNODE; - }; - throw Error('Unrecognized node type: ' + d.node.type); -}; - -/** Modify node and its subscene and its label's positional attributes */ -function position(nodeGroup, d: render.RenderNodeInfo) { - let shapeGroup = scene.selectChild(nodeGroup, 'g', Class.Node.SHAPE); - let cx = layout.computeCXPositionOfNodeShape(d); - switch (d.node.type) { - case NodeType.OP: { - // position shape - let shape = scene.selectChild(shapeGroup, 'ellipse'); - scene.positionEllipse(shape, cx, d.y, d.coreBox.width, d.coreBox.height); - labelPosition(nodeGroup, cx, d.y, d.labelOffset); - break; - } - case NodeType.META: { - // position shape - let shape = scene.selectChild(shapeGroup, 'rect'); - if (d.expanded) { - scene.positionRect(shape, d.x, d.y, d.width, d.height); - subscenePosition(nodeGroup, d); - // put label on top - labelPosition(nodeGroup, cx, d.y, - - d.height / 2 + d.labelHeight / 2); - } else { - scene.positionRect(shape, cx, d.y, d.coreBox.width, d.coreBox.height); - labelPosition(nodeGroup, cx, d.y, 0); - } - break; - } - case NodeType.SERIES: { - let shape = scene.selectChild(shapeGroup, 'use'); - if (d.expanded) { - scene.positionRect(shape, d.x, d.y, d.width, d.height); - subscenePosition(nodeGroup, d); - // put label on top - labelPosition(nodeGroup, cx, d.y, - - d.height / 2 + d.labelHeight / 2); - } else { - scene.positionRect(shape, cx, d.y, d.coreBox.width, d.coreBox.height); - labelPosition(nodeGroup, cx, d.y, d.labelOffset); - } - break; - } - case NodeType.BRIDGE: { - // position shape - // NOTE: In reality, these will not be visible, but it helps to put them - // in the correct position for debugging purposes. - let shape = scene.selectChild(shapeGroup, 'rect'); - scene.positionRect(shape, d.x, d.y, d.width, d.height); - break; - } - default: { throw Error('Unrecognized node type: ' + d.node.type); } - } -}; - -/** Enum specifying the options to color nodes by */ -export enum ColorBy {STRUCTURE, DEVICE, XLA_CLUSTER, COMPUTE_TIME, MEMORY} -; - -/** - * Returns the fill color for the node given its state and the 'color by' - * option. - */ -export function getFillForNode(templateIndex, colorBy, - renderInfo: render.RenderNodeInfo, isExpanded: boolean): string { - let colorParams = render.MetanodeColors; - switch (colorBy) { - case ColorBy.STRUCTURE: - if (renderInfo.node.type === NodeType.META) { - let tid = (renderInfo.node).templateId; - return tid === null ? - colorParams.UNKNOWN : - colorParams.STRUCTURE_PALETTE(templateIndex(tid), isExpanded); - } else if (renderInfo.node.type === NodeType.SERIES) { - // If expanded, we're showing the background rect, which we want to - // appear gray. Otherwise we're showing a stack of ellipses which we - // want to show white. - return isExpanded ? colorParams.EXPANDED_COLOR : 'white'; - } else if (renderInfo.node.type === NodeType.BRIDGE) { - return renderInfo.structural ? - '#f0e' : - (renderInfo.node).inbound ? '#0ef' : '#fe0'; - } else { - // Op nodes are white. - return 'white'; - } - case ColorBy.DEVICE: - if (renderInfo.deviceColors == null) { - // Return the hue for unknown device. - return colorParams.UNKNOWN; - } - let id = renderInfo.node.name; - let escapedId = tf.graph.util.escapeQuerySelector(id); - let gradientDefs = d3.select('svg#svg defs #linearGradients'); - let linearGradient = gradientDefs.select('linearGradient#' + escapedId); - // If the linear gradient is not there yet, create it. - if (linearGradient.size() === 0) { - linearGradient = gradientDefs.append('linearGradient').attr('id', id); - // Re-create the stops of the linear gradient. - linearGradient.selectAll('*').remove(); - let cumulativeProportion = 0; - // For each device, create a stop using the proportion of that device. - _.each(renderInfo.deviceColors, d => { - let color = d.color; - linearGradient.append('stop') - .attr('offset', cumulativeProportion) - .attr('stop-color', color); - linearGradient.append('stop') - .attr('offset', cumulativeProportion + d.proportion) - .attr('stop-color', color); - cumulativeProportion += d.proportion; - }); - } - return isExpanded ? colorParams.EXPANDED_COLOR : `url(#${escapedId})`; - case ColorBy.XLA_CLUSTER: - return isExpanded ? colorParams.EXPANDED_COLOR : - renderInfo.xlaClusterColor || colorParams.UNKNOWN; - case ColorBy.COMPUTE_TIME: - return isExpanded ? - colorParams.EXPANDED_COLOR : renderInfo.computeTimeColor || - colorParams.UNKNOWN; - case ColorBy.MEMORY: - return isExpanded ? - colorParams.EXPANDED_COLOR : renderInfo.memoryColor || - colorParams.UNKNOWN; - default: - throw new Error('Unknown case to color nodes by'); - } -} - -/** - * Modify node style by toggling class and assign attributes (only for things - * that can't be done in css). - */ -export function stylize(nodeGroup, renderInfo: render.RenderNodeInfo, - sceneElement, nodeClass?) { - nodeClass = nodeClass || Class.Node.SHAPE; - let isHighlighted = sceneElement.isNodeHighlighted(renderInfo.node.name); - let isSelected = sceneElement.isNodeSelected(renderInfo.node.name); - let isExtract = renderInfo.isInExtract || renderInfo.isOutExtract; - let isExpanded = renderInfo.expanded; - let isFadedOut = renderInfo.isFadedOut; - nodeGroup.classed('highlighted', isHighlighted); - nodeGroup.classed('selected', isSelected); - nodeGroup.classed('extract', isExtract); - nodeGroup.classed('expanded', isExpanded); - nodeGroup.classed('faded', isFadedOut); - - // Main node always exists here and it will be reached before subscene, - // so d3 selection is fine here. - let node = nodeGroup.select('.' + nodeClass + ' .' + Class.Node.COLOR_TARGET); - let fillColor = getFillForNode(sceneElement.templateIndex, - ColorBy[sceneElement.colorBy.toUpperCase()], - renderInfo, isExpanded); - node.style('fill', fillColor); - - // Choose outline to be darker version of node color if the node is a single - // color and is not selected. - node.style('stroke', isSelected ? null : getStrokeForFill(fillColor)); -}; - -/** - * Given a node's fill color/gradient, determine the stroke for the node. - */ -export function getStrokeForFill(fill: string) { - // If node is colored by a gradient, then use a dark gray outline. - return fill.substring(0, 3) === 'url' ? - render.MetanodeColors.GRADIENT_OUTLINE : - d3.rgb(fill).darker().toString(); -} - -/** - * Finds selected node and highlights all nodes which are providing direct - * or indirect input to the node and all edges connecting these nodes - * together and to the selected node. - * - * @param renderGraphInfo Information on the rendered state of the graph. - */ -export function traceInputs(renderGraphInfo: tf.graph.render.RenderGraphInfo) { - // Reset all styling. - d3.selectAll('.input-highlight').classed('input-highlight', false); - d3.selectAll('.non-input').classed('non-input', false); - d3.selectAll('.input-parent').classed('input-parent', false); - d3.selectAll('.input-child').classed('input-child', false); - d3.selectAll('.input-edge-highlight').classed('input-edge-highlight', false); - d3.selectAll('.non-input-edge-highlight') - .classed('non-input-edge-highlight', false); - d3.selectAll('.input-highlight-selected') - .classed('input-highlight-selected', false); - - // Extract currently selected node. Return if input tracing disabled or no - // node is selected. - const selectedNodeSelectorString = 'g.node.selected,g.op.selected'; - const nodeSelection = d3.select(selectedNodeSelectorString); - let currentNode = undefined; - if (renderGraphInfo && renderGraphInfo.traceInputs && - nodeSelection.nodes().length) { - currentNode = nodeSelection.nodes()[0]; - } else { - return; - } - let nodeName = currentNode.getAttribute('data-name'); - let opNodes = _getAllContainedOpNodes(nodeName, renderGraphInfo); - let allTracedNodes = {}; - _.each(opNodes, function(nodeInstance) { - allTracedNodes = - traceAllInputsOfOpNode(renderGraphInfo, nodeInstance, allTracedNodes); - }); - - d3.selectAll(selectedNodeSelectorString) - // Remove the input-highlight from the selected node. - .classed('input-highlight', false) - // Add input-highlight-selected class to selected node, which allows - // treating the selected not as a special case of an input node. - .classed('input-highlight-selected', true); - - // Highlight all parent nodes of each OpNode as input parent to allow - // specific highlighting. - let highlightedNodes = Object.keys(allTracedNodes); - let visibleNodes = - _findVisibleParentsFromOpNodes(renderGraphInfo, highlightedNodes); - _markParentsOfNodes(visibleNodes); - - // Attach class to all non-input nodes and edges for styling. - d3.selectAll( - 'g.node:not(.selected):not(.input-highlight)' + - ':not(.input-parent):not(.input-children)') - .classed('non-input', true) - .each(function(d: RenderNodeInfo) { - // Mark all nodes with the specified name as non-inputs. This - // results in Annotation nodes which are attached to inputs to be - // tagged as well. - let nodeName = d.node.name; - d3.selectAll(`[data-name="${nodeName}"]`).classed('non-input', true); - }); - d3.selectAll('g.edge:not(.input-edge-highlight)') - .classed('non-input-edge-highlight', true); -} - -/** - * Recursively find all op nodes contained by the node identified by the - * provided name. - * @param nodeName The meta or op node of which the OpNode instances are - * required. - * @param renderGraphInfo The rendered graph information object. - * @returns {Array} An array of OpNodeImpl instances. - */ -export function _getAllContainedOpNodes( - nodeName: string, renderGraphInfo: tf.graph.render.RenderGraphInfo) { - let opNodes = []; - - // Get current node. - let node = renderGraphInfo.getNodeByName(nodeName) as tf.graph.GroupNode | - tf.graph.OpNode; - - // If node is already OpNode then return the node plus its input embeddings. - if (node instanceof tf.graph.OpNodeImpl) { - return [node].concat(node.inEmbeddings); - } - - // Otherwise, make recursive call for each node contained by the GroupNode. - let childNodeNames = (node as tf.graph.GroupNode).metagraph.nodes(); - _.each(childNodeNames, function(childNodeName) { - opNodes = - opNodes.concat(_getAllContainedOpNodes(childNodeName, renderGraphInfo)); - }); - - return opNodes; -} - -/** - * When resolving inputs of a node the visible parent node of each input - * node (i.e. the first parent which is rendered to the screen) needs to be - * found, and since such a node may contain several input OpNodes a map - * of the visible parent to all the input OpNodes it contains is provided by - * opNodes. - */ -interface VisibleParent { - visibleParent: Node; - opNodes: OpNode[]; -} - -export function traceAllInputsOfOpNode( - renderGraphInfo: tf.graph.render.RenderGraphInfo, startNode: OpNode, - allTracedNodes: Object) { - // To prevent infinite loops due to cyclical relationships and improving - // performance by tracing OpNode which is input to 2+ nodes only once. - if (allTracedNodes[startNode.name]) { - return allTracedNodes; - } else { - allTracedNodes[startNode.name] = true; - } - // Extract the inputs. - let inputs = startNode.inputs; - // Get visible parent. - let currentVisibleParent = getVisibleParent(renderGraphInfo, startNode); - // Mark as input node. - d3.select(`.node[data-name="${currentVisibleParent.name}"]`) - .classed('input-highlight', true); - - // Find the visible parent of each input. - let visibleInputs = {}; - _.each(inputs, function(nodeInstance) { - let resolvedNode = renderGraphInfo.getNodeByName(nodeInstance.name); - if (resolvedNode === undefined) { - // Node could not be found in rendered Hierarchy, which happens when - // tracing inputs of a SummaryNode. - return; - } - // Ensure node is resolved to OpNode if name collision with Metanode exists. - if (resolvedNode instanceof MetanodeImpl) { - let resolvedNodeName = tf.graph.getStrictName(resolvedNode.name); - resolvedNode = renderGraphInfo.getNodeByName(resolvedNodeName) as OpNode; - } - - let visibleParent = getVisibleParent(renderGraphInfo, resolvedNode); - - // Append OpNode to visible parent entry. - let visibleInputsEntry = visibleInputs[visibleParent.name]; - if (visibleInputsEntry) { - visibleInputsEntry.opNodes.push(resolvedNode); - } else { // Create new entry. - visibleInputs[visibleParent.name] = { - visibleParent: visibleParent, - opNodes: [resolvedNode] - } as VisibleParent; - } - }); - - // Find all parents of the start node. - let startNodeParents = {}; - let indexedStartNodeParents = [currentVisibleParent]; - startNodeParents[currentVisibleParent.name] = { - traced: false, - index: 0, - connectionEndpoints: [] - }; - - let currentNode = currentVisibleParent as Node; - for (let index = 1; currentNode.name !== tf.graph.ROOT_NAME; index++) { - currentNode = currentNode.parentNode; - startNodeParents[currentNode.name] = { - traced: false, - index: index, - connectionEndpoints: [] - }; - indexedStartNodeParents[index] = currentNode; - } - - // Find first mutual parent of each input node and highlight connection. - _.forOwn(visibleInputs, function(visibleParentInfo: VisibleParent, key) { - let nodeInstance = visibleParentInfo.visibleParent; - // Make recursive call for each input-OpNode contained by the visible - // parent. - _.each(visibleParentInfo.opNodes, function(opNode: OpNode) { - allTracedNodes = - traceAllInputsOfOpNode(renderGraphInfo, opNode, allTracedNodes); - }); - - if (nodeInstance.name !== currentVisibleParent.name) { - _createVisibleTrace( - nodeInstance, startNodeParents, indexedStartNodeParents); - } - }); - - return allTracedNodes; -} - -/** - * Colors the edges to connect the passed node to the start node. This is - * done by: - * - * a) Finding the first (visible) common parent in the rendered - * hierarchy. - * NB: There are 2 types of connections: - * 1) Direct connections between node A - * and B, marked below as II, - * 2) Connections from any node A to its parent, A'. Marked below as I and III. - * For type 2 connection you need to know the inner-nested node, the - * direct parent, and the ultimate destination of the connection. - * - * A_parent B_parent - * +--------+ +---------+ - * | | | | - * | +--+ I| II |III+--+ | - * | |A +---------->+B | | - * | +--+ | | +--+ | - * | | | | - * +--------+ +---------+ - * - * - * b) Highlighting the direct connection between the parents of A and B, - * called A_parent and B_parent, s.t. A_parent and B_parent are children of the - * mutual parent of A and B found in a), marked above as II. - * - * c) Highlighting the connection from A to A_parent and B to B_parent - * (through all layers of parents between A and A_parent and B and B_parent, - * respectively). Marked above as I and III. - * - * @param nodeInstance The instance of the node to use as destination node, B. - * @param startNodeParents Map of startNodeParent names to information objects - * about the parent. - * @param indexedStartNodeParents An array of all parents of the start node. - * This is required to find the child of the mutual parent which is a parent - * of the start node. - * @private - */ -function _createVisibleTrace( - nodeInstance: Node, startNodeParents, indexedStartNodeParents: Node[]) { - let currentNode = nodeInstance; - let previousNode = nodeInstance; - - // Ascend through parents until a mutual parent is found with the start - // node. - let destinationParentPairs = []; - while (!startNodeParents[currentNode.name]) { - if (previousNode.name !== currentNode.name) { - destinationParentPairs.push([previousNode, currentNode]); - } - previousNode = currentNode; - currentNode = currentNode.parentNode; - } - - // Connection between nodes is drawn between the parents of each - // respective node, both of which share the mutual parent. - let startNodeIndex = startNodeParents[currentNode.name].index; - let startNodeName = - indexedStartNodeParents[Math.max(startNodeIndex - 1, 0)].name; - - let startNodeTopParentName = startNodeName; - let targetNodeTopParentName = previousNode.name; - - let endNodeName = previousNode.name; - d3.selectAll(`[data-edge="${endNodeName}--${startNodeName}"]`) - .classed('input-edge-highlight', true); - - // Trace up the parents of the input. - _.each(destinationParentPairs, function(value) { - let inner = value[0]; - let outer = value[1]; - let edgeSelector = `[data-edge="${inner.name}--${startNodeTopParentName}` + - `~~${outer.name}~~OUT"]`; - d3.selectAll(edgeSelector).classed('input-edge-highlight', true); - }); - - // Trace up the parents of the start node. - for (let index = 1; index < startNodeIndex; index++) { - let inner = indexedStartNodeParents[index - 1]; - let outer = indexedStartNodeParents[index]; - let edgeSelector = `[data-edge="${targetNodeTopParentName}~~${outer.name}` + - `~~IN--${inner.name}"]`; - d3.selectAll(edgeSelector).classed('input-edge-highlight', true); - } -} - -/** - * Creates map { [name: string] -> Node } of all visible / rendered parents - * of the nodes identified by the node names passed in. - * - * @param renderGraphInfo The information on the rendered graph. - * @param nodeNames String array of node names. - * @returns {[nodeName: string]: Node} - * @private - */ -function _findVisibleParentsFromOpNodes(renderGraphInfo, nodeNames: string[]) { - let visibleParents: {[nodeName: string]: Node} = {}; - _.each(nodeNames, function(nodeName) { - let currentNode = renderGraphInfo.getNodeByName(nodeName); - let visibleParent = getVisibleParent(renderGraphInfo, currentNode); - visibleParents[visibleParent.name] = visibleParent; - }); - - return visibleParents; -} - -/** - * Traverse through the parents of all nodes in the list and mark each - * encountered node as input-parent. - * @param visibleNodes Map of input nodes, have to be visible/rendered when - * called. - * @private - */ -function _markParentsOfNodes(visibleNodes: {[nodeName: string]: Node}) { - _.forOwn(visibleNodes, function(nodeInstance: Node) { - // Mark all parents of the node as input-parents. - let currentNode = nodeInstance; - - while (currentNode.name !== tf.graph.ROOT_NAME) { - const renderedElementSelection = - d3.select(`.node[data-name="${currentNode.name}"]`); - // Only mark the element as a parent node to an input if it is not - // marked as input node itself. - if (renderedElementSelection.nodes().length && - !renderedElementSelection.classed('input-highlight') && - !renderedElementSelection.classed('selected') && - // OpNode only parent if start node is embedded node, in which case - // the OpNode should be faded as well. - !renderedElementSelection.classed('op')) { - renderedElementSelection.classed('input-parent', true); - } - currentNode = currentNode.parentNode; - } - }); -} - -/** - * Find the parent of the passed in op node which is expanded. This is done - * by going through all parents until the parent's parent is expanded, thus - * finding the first unexpanded parent which is rendered on the screen. - * @param renderGraphInfo The graph info object used to gain access to the - * render info of the parents. - * @param currentNode The node whose parent is to be found. - * @returns Node - */ -export function getVisibleParent( - renderGraphInfo: tf.graph.render.RenderGraphInfo, - currentNode: tf.graph.Node) { - let found = false; - let currentParent = currentNode; - - while (!found) { - // Get parent element, to extract name. - currentNode = currentParent; - currentParent = currentNode.parentNode; - - if (currentParent === undefined) { - found = true; - } else { - let renderNode = renderGraphInfo.getRenderNodeByName(currentParent.name); - // Found if node is rendered on the screen (renderNode truthy), and - // the parent is either expanded (i.e. it is a metanode or seriesnode) - // or the parent is an OpNode in which case currentNode is an embedded - // node which has another OpNode as parent. - if (renderNode && - (renderNode.expanded || currentParent instanceof graph.OpNodeImpl)) { - found = true; - } - } - } // Close while loop. - return currentNode; -} -} // Close module. diff --git a/tensorflow/tensorboard/components/tf_graph_common/parser.ts b/tensorflow/tensorboard/components/tf_graph_common/parser.ts deleted file mode 100644 index 04d879ef910..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/parser.ts +++ /dev/null @@ -1,284 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -module tf.graph.parser { - -/** - * Parses a native js value, which can be either a string, boolean or number. - * - * @param value The value to be parsed. - */ -function parseValue(value: string): string|number|boolean { - if (value === 'true') { - return true; - } - if (value === 'false') { - return false; - } - let firstChar = value[0]; - if (firstChar === '"') { - return value.substring(1, value.length - 1); - } - let num = parseFloat(value); - return isNaN(num) ? value : num; -} - -/** - * Fetches a text file and returns a promise of the result. - */ -export function fetchPbTxt(filepath: string): Promise { - return new Promise(function(resolve, reject) { - const request = new XMLHttpRequest(); - request.open('GET', filepath); - request.responseType = 'arraybuffer'; - - request.onerror = () => reject(request.status); - request.onload = () => resolve(request.response); - - request.send(null); - }); -} - -/** - * Fetches the metadata file, parses it and returns a promise of the result. - */ -export function fetchAndParseMetadata(path: string, tracker: ProgressTracker) { - return tf.graph.util - .runTask( - 'Reading metadata pbtxt', 40, - () => { - if (path == null) { - return Promise.resolve(null); - } - return fetchPbTxt(path); - }, - tracker) - .then((arrayBuffer: ArrayBuffer) => { - return tf.graph.util.runAsyncPromiseTask( - 'Parsing metadata.pbtxt', 60, () => { - return arrayBuffer != null ? parseStatsPbTxt(arrayBuffer) : - Promise.resolve(null); - }, tracker); - }); -} - -/** - * Fetches the graph file, parses it and returns a promise of the result. The - * result will be undefined if the graph is empty. - */ -export function fetchAndParseGraphData(path: string, pbTxtFile: Blob, - tracker: ProgressTracker) { - return tf.graph.util - .runTask( - 'Reading graph pbtxt', 40, - () => { - if (pbTxtFile) { - return new Promise(function(resolve, reject) { - let fileReader = new FileReader(); - fileReader.onload = () => resolve(fileReader.result); - fileReader.onerror = () => reject(fileReader.error); - fileReader.readAsArrayBuffer(pbTxtFile); - }); - } else { - return fetchPbTxt(path); - } - }, - tracker) - .then((arrayBuffer: ArrayBuffer) => { - return tf.graph.util.runTask('Parsing graph.pbtxt', 60, () => { - return parseGraphPbTxt(arrayBuffer); - }, tracker); - }); -} - -/** - * Parse a file object in a streaming fashion line by line (or custom delim). - * Can handle very large files. - * @param input The file object as an array buffer. - * @param callback The callback called on each line - * @param chunkSize The size of each read chunk. (optional) - * @param delim The delimiter used to split a line. (optional) - * @returns A promise for when it is finished. - */ -export function streamParse( - arrayBuffer: ArrayBuffer, callback: (string) => void, - chunkSize: number = 1000000, delim: string = '\n'): Promise { - return new Promise(function(resolve, reject) { - let offset = 0; - let bufferSize = arrayBuffer.byteLength - 1; - let data = ''; - - function readHandler(str) { - offset += chunkSize; - let parts = str.split(delim); - let first = data + parts[0]; - if (parts.length === 1) { - data = first; - readChunk(offset, chunkSize); - return; - } - data = parts[parts.length - 1]; - callback(first); - for (let i = 1; i < parts.length - 1; i++) { - callback(parts[i]); - } - if (offset >= bufferSize) { - if (data) { - callback(data); - } - resolve(true); - return; - } - readChunk(offset, chunkSize); - } - - function readChunk(offset: number, size: number) { - const arrayBufferChunk = arrayBuffer.slice(offset, offset + size); - - const blob = new Blob([arrayBufferChunk]); - const file = new FileReader(); - file.onload = (e: any) => readHandler(e.target.result); - file.readAsText(blob); - } - - readChunk(offset, chunkSize); - }); -} - -/** - * Since proto-txt doesn't explicitly say whether an attribute is repeated - * (an array) or not, we keep a hard-coded list of attributes that are known - * to be repeated. This list is used in parsing time to convert repeated - * attributes into arrays even when the attribute only shows up once in the - * object. - */ -const GRAPH_REPEATED_FIELDS: {[attrPath: string]: boolean} = { - 'node': true, - 'node.input': true, - 'node.attr': true, - 'node.attr.value.list.type': true, - 'node.attr.value.shape.dim': true, - 'node.attr.value.tensor.string_val': true, - 'node.attr.value.tensor.tensor_shape.dim': true, - 'node.attr.value.list.shape': true, - 'node.attr.value.list.shape.dim': true, - 'node.attr.value.list.s': true -}; - -const METADATA_REPEATED_FIELDS: {[attrPath: string]: boolean} = { - 'step_stats.dev_stats': true, - 'step_stats.dev_stats.node_stats': true, - 'step_stats.dev_stats.node_stats.output': true, - 'step_stats.dev_stats.node_stats.memory': true, - 'step_stats.dev_stats.node_stats.output.tensor_description.shape.dim': true -}; - -/** - * Parses an ArrayBuffer of a proto txt file into a raw Graph object. - */ -export function parseGraphPbTxt(input: ArrayBuffer): - Promise { - return parsePbtxtFile(input, GRAPH_REPEATED_FIELDS).then(obj => obj['node']); -} - -/** - * Parses an ArrayBuffer of a proto txt file into a StepStats object. - */ -export function parseStatsPbTxt(input: ArrayBuffer): - Promise { - return parsePbtxtFile(input, METADATA_REPEATED_FIELDS) - .then(obj => obj['step_stats']); -} - -/** - * Parses a ArrayBuffer of a proto txt file into javascript object. - * - * @param input The ArrayBuffer or file object implementing slice. - * @param repeatedFields Map (Set) of all the repeated fields, since you can't - * tell directly from the pbtxt if a field is repeated or not. - * @returns The parsed object. - */ -function parsePbtxtFile( - input: ArrayBuffer, - repeatedFields: {[attrPath: string]: boolean}): Promise { - let output: { [name: string]: any; } = {}; - let stack = []; - let path: string[] = []; - let current: { [name: string]: any; } = output; - - function splitNameAndValueInAttribute(line: string) { - let colonIndex = line.indexOf(':'); - let name = line.substring(0, colonIndex).trim(); - let value = parseValue(line.substring(colonIndex + 2).trim()); - return { - name: name, - value: value - }; - } - - /** - * Adds a value, given the attribute name and the host object. If the - * attribute already exists, but is not an array, it will convert it to an - * array of values. - * - * @param obj The host object that holds the attribute. - * @param name The attribute name (key). - * @param value The attribute value. - * @param path A path that identifies the attribute. Used to check if - * an attribute is an array or not. - */ - function addAttribute(obj: Object, name: string, - value: Object|string|number|boolean, path: string[]): void { - // We treat 'node' specially since it is done so often. - let existingValue = obj[name]; - if (existingValue == null) { - obj[name] = path.join('.') in repeatedFields ? [value] : value; - } else if (Array.isArray(existingValue)) { - existingValue.push(value); - } else { - obj[name] = [existingValue, value]; - } - } - - // Run through the file a line at a time. - return streamParse(input, function(line: string) { - if (!line) { - return; - } - line = line.trim(); - - switch (line[line.length - 1]) { - case '{': // create new object - let name = line.substring(0, line.length - 2).trim(); - let newValue: { [name: string]: any; } = {}; - stack.push(current); - path.push(name); - addAttribute(current, name, newValue, path); - current = newValue; - break; - case '}': - current = stack.pop(); - path.pop(); - break; - default: - let x = splitNameAndValueInAttribute(line); - addAttribute(current, x.name, x.value, path.concat(x.name)); - break; - } - }).then(function() { - return output; - }); -} - -} // Close module tf.graph.parser. diff --git a/tensorflow/tensorboard/components/tf_graph_common/proto.ts b/tensorflow/tensorboard/components/tf_graph_common/proto.ts deleted file mode 100644 index eda73e45c3b..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/proto.ts +++ /dev/null @@ -1,143 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/** - * @fileoverview Interfaces that parallel proto definitions in - * third_party/tensorflow/core/framework/... - * graph.proto - * step_stats.proto - * These should stay in sync. - */ -module tf.graph.proto { - /** - * TensorFlow node definition as defined in the graph.proto file. - */ - export interface NodeDef { - /** Name of the node */ - name: string; - /** List of nodes that are inputs for this node. */ - input: string[]; - /** The name of the device where the computation will run. */ - device: string; - /** The name of the operation associated with this node. */ - op: string; - /** List of attributes that describe/modify the operation. */ - attr: {key: string, value: Object}[]; - } - - /** - * Generic graph as defined in the graph_explorer.proto file. - */ - export interface GenericGraph { - /** List of nodes in the graph */ - node: GenericNode[]; - /** List of nodes in the graph */ - edge: GenericEdge[]; - /** List of attributes that describe/modify the operation. */ - attr: Array<{[key: string]: any}>; - } - - /** - * GenericEdge corresponds to the Edge message in graph_explorer.proto. - */ - export interface GenericEdge { - /** Name of the source node. */ - source: string; - /** Name of the target node. */ - target: string; - /** Attributes of the edge. */ - edge_attr: Array<{[key: string]: any}>; - } - - /** - * GenericNode corresponds to the Node message in graph_explorer.proto. - */ - export interface GenericNode { - /** Name of the node */ - name: string; - /** Attributes of a leaf node or leaf nodes within a metanode. */ - node_attr: Array<{[key: string]: any}>; - /** Attributes of a metanode. */ - metanode_attr: Array<{[key: string]: any}>; - } - - /** - * TensorFlow stats file definition as defined in the stats proto file. - */ - export interface StepStats { - dev_stats: {device: string, node_stats: NodeExecStats[]}[]; - } - - /** - * TensorFlow stats for a node as defined in the step_stats proto file. - */ - export interface NodeExecStats { - node_name: string; - // The next 4 properties are currently stored as string in json - // and must be parsed. - all_start_micros: number; - op_start_rel_micros: number; - op_end_rel_micros: number; - all_end_rel_micros: number; - memory: { - allocator_name: string; - total_bytes: number; // Stored as string in json and should be parsed. - peak_bytes: number; // Stored as string in json and should be parsed. - }[]; - /** Output sizes recorded for a single execution of a graph node */ - output: NodeOutput[]; - timeline_label: string; - scheduled_micros: string; - thread_id: string; - } - - /** - * Description for the output tensor(s) of an operation in the graph as - * defined in the step_stats.proto file. - */ - export interface NodeOutput { - slot: number; // Stored as string in json and should be parsed. - tensor_description: { - /** Data type of tensor elements */ - dtype: string; - /** Shape of the tensor */ - shape: { - /** - * Dimensions of the tensor, such as [{name: 'input', size: 30}, - * {name: 'output', size: 40}] for a 30 x 40 2D tensor. The names - * are optional. The order of entries in 'dim' matters: It indicates - * the layout of the values in the tensor in-memory representation. - */ - dim: { - /** Size of the tensor in that dimension */ - size: number, // Stored as string in json and should be parsed. - /** Optional name of the tensor dimension */ - name?: string - }[]; - }; - /** Information about the size and allocator used for the data */ - allocation_description: { - // The next 2 properties are stored as string in json and - // should be parsed. - /** Total number of bytes requested */ - requested_bytes: number; - /** Total number of bytes allocated, if known */ - allocated_bytes?: number; - /** Name of the allocator used */ - allocator_name: string; - }; - }; - } -} diff --git a/tensorflow/tensorboard/components/tf_graph_common/render.ts b/tensorflow/tensorboard/components/tf_graph_common/render.ts deleted file mode 100644 index 4f28af481d4..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/render.ts +++ /dev/null @@ -1,1673 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -/** - * Package for the Render Hierarchy for TensorFlow graph. - */ -module tf.graph.render { - -export type Point = {x: number, y: number}; - -/** - * Color parameters for op nodes. - */ -export let OpNodeColors = {DEFAULT_FILL: 'white', DEFAULT_STROKE: '#b2b2b2'}; - -/** - * Color parameters for node encoding. - * @type {Object} - */ -export let MetanodeColors = { - /** - * Default fill and stroke to use when no other information is available. - */ - DEFAULT_FILL: '#d9d9d9', - DEFAULT_STROKE: '#a6a6a6', - SATURATION: 0.6, - LIGHTNESS: 0.85, - /** - * Neutral color to use when the node is expanded (used when coloring by - * compute time, memory and device). - */ - EXPANDED_COLOR: '#f0f0f0', - /** - * Standard hue values for node color palette. - */ - HUES: [220, 100, 180, 40, 20, 340, 260, 300, 140, 60], - STRUCTURE_PALETTE(id: number, lightened?: boolean) { - // The code below is a flexible way to computationally create a set - // of colors that go well together. - let hues = MetanodeColors.HUES; - let n = hues.length; - let hue = hues[id % n]; - let m = Math.sin(hue * Math.PI / 360); - let sat = lightened ? 30 : 90 - 60 * m; - let light = lightened ? 95 : 80; - return d3.hsl(hue, .01 * sat, .01 * light).toString(); - }, - DEVICE_PALETTE(index: number): string { - return MetanodeColors.STRUCTURE_PALETTE(index); - }, - XLA_CLUSTER_PALETTE(index: number): string { - return MetanodeColors.STRUCTURE_PALETTE(index); - }, - UNKNOWN: '#eee', - GRADIENT_OUTLINE: '#888' -}; - -/** - * Color parameters for op nodes. - */ -export let SeriesNodeColors = { - DEFAULT_FILL: 'white', - DEFAULT_STROKE: '#b2b2b2' -}; - -/** - * Parameters that affect how the graph is rendered on the screen. - */ -const PARAMS = { - /** - * Whether to extract high degree nodes from the core part of the graph. - */ - enableExtraction: true, - /** - * The minimum number of nodes for a graph to have in order for high in and - * out degree nodes to be extracted in auxiliary. The aim here is to prevent - * nodes from being extracted from small graphs. - */ - minNodeCountForExtraction: 15, - /** - * The minimum in or out degree a node must have in order to be possibly - * extracted. - */ - minDegreeForExtraction: 5, - /** - * Maximum number of control edges a node can have before they aren't - * displayed. - */ - maxControlDegree: 4, - /** - * Maximum in (for outbound bridge paths) or out (for inbound bridge paths) - * degree of a node allowed for a bridge path to be rendered to it from a - * subhierarchy of nodes. Having a max prevents having too many nodes emanate - * from a subhierarchy and crowding up. - */ - maxBridgePathDegree: 4, - /** - * Types patterns for predefined out-extract nodes, which are - * sink-like nodes that will be extracted from the main graph. - */ - outExtractTypes: [ - 'NoOp' // NoOps are sink-like used for managing control dependencies. - ], - - /** - * Types patterns for predefined in-extract nodes, which are - * source-like nodes that will be extracted from the main graph. - */ - inExtractTypes: [], - - /** - * When removing edges from a high degree node, remove all of its edges if - * detachAllEdgesForHighDegree is true. Otherwise remove all in-edges if - * the node has high in-degree, or all out-edges if the node has high - * out-degree. - */ - detachAllEdgesForHighDegree: true, - - /** - * After extracting high in/out degree nodes and predefined - * source-like/sink-like, extract isolated nodes to the side - * if this extractIsolatedNodesWithAnnotationsOnOneSide is true. - */ - extractIsolatedNodesWithAnnotationsOnOneSide: true, - - /** - * Whether to add bridge nodes and edges to the core when building the - * subhierarchy of an expanded metanode. See buildSubhierarchy(). - */ - enableBridgegraph: true, - - /** - * 2 colors, for the minimum and maximum value respectively, whenever we - * have a gradient scale. - */ - minMaxColors: ['#fff5f0', '#fb6a4a'], - - /** - * Maximum number of annotations to be displayed on a node before an - * ellipsis is used. - */ - maxAnnotations: 5 -}; - -/** - * Stores the rendering information, such as x and y coordinates, - * for each node in the graph. - */ -export class RenderGraphInfo { - hierarchy: hierarchy.Hierarchy; - private displayingStats: boolean; - private index: {[nodeName: string]: RenderNodeInfo}; - private renderedOpNames: string[]; - private deviceColorMap: d3.ScaleOrdinal; - private xlaClusterColorMap: d3.ScaleOrdinal; - private memoryUsageScale: d3.ScaleLinear; - private computeTimeScale: d3.ScaleLinear; - /** Scale for the thickness of edges when there is no shape information. */ - edgeWidthScale: - d3.ScaleLinear | d3.ScalePower; - // Since the rendering information for each node is constructed lazily, - // upon node's expansion by the user, we keep a map between the node's name - // and whether the rendering information was already constructed for that - // node. - private hasSubhierarchy: {[nodeName: string]: boolean}; - root: RenderGroupNodeInfo; - traceInputs: Boolean; - - constructor(hierarchy: hierarchy.Hierarchy, displayingStats: boolean) { - this.hierarchy = hierarchy; - this.displayingStats = displayingStats; - this.index = {}; - this.renderedOpNames = []; - - this.computeScales(); - // Maps node name to whether the rendering hierarchy was already - // constructed. - this.hasSubhierarchy = {}; - this.root = new RenderGroupNodeInfo(hierarchy.root); - this.index[hierarchy.root.name] = this.root; - this.renderedOpNames.push(hierarchy.root.name); - this.buildSubhierarchy(hierarchy.root.name); - this.root.expanded = true; - this.traceInputs = false; - } - - computeScales() { - this.deviceColorMap = d3.scaleOrdinal() - .domain(this.hierarchy.devices) - .range(_.map(d3.range(this.hierarchy.devices.length), - MetanodeColors.DEVICE_PALETTE)); - - this.xlaClusterColorMap = - d3.scaleOrdinal() - .domain(this.hierarchy.xlaClusters) - .range(_.map( - d3.range(this.hierarchy.xlaClusters.length), - MetanodeColors.XLA_CLUSTER_PALETTE)); - - let topLevelGraph = this.hierarchy.root.metagraph; - // Find the maximum and minimum memory usage. - let memoryExtent = d3.extent(topLevelGraph.nodes(), - (nodeName, index) => { - let node = topLevelGraph.node(nodeName); - // Some ops don't have stats at all. - if (node.stats != null) { - return node.stats.totalBytes; - } - }); - this.memoryUsageScale = d3.scaleLinear() - .domain(memoryExtent) - .range(PARAMS.minMaxColors); - - // Find also the minimum and maximum compute time. - let computeTimeExtent = d3.extent(topLevelGraph.nodes(), - (nodeName, index) => { - let node = topLevelGraph.node(nodeName); - // Some ops don't have stats at all. - if (node.stats != null) { - return node.stats.getTotalMicros(); - } - }); - this.computeTimeScale = d3.scaleLinear() - .domain(computeTimeExtent) - .range(PARAMS.minMaxColors); - - this.edgeWidthScale = this.hierarchy.hasShapeInfo ? - scene.edge.EDGE_WIDTH_SCALE : - d3.scaleLinear() - .domain([1, this.hierarchy.maxMetaEdgeSize]) - .range([scene.edge.MIN_EDGE_WIDTH, scene.edge.MAX_EDGE_WIDTH]); - } - - /** - * Get a previously created RenderNodeInfo by its node name. - */ - getRenderNodeByName(nodeName: string): RenderNodeInfo { - return this.index[nodeName]; - } - - /** - * Get the underlying node in the hierarchical graph by its name. - */ - getNodeByName(nodeName: string): Node { - return this.hierarchy.node(nodeName); - } - - /** - * Get a previously created RenderNodeInfo for the specified node name, - * or create one if it hasn't been created yet. - */ - getOrCreateRenderNodeByName(nodeName: string): RenderNodeInfo { - // Polymer may invoke this with null. - if (!nodeName) { - return null; - } - - if (nodeName in this.index) { - return this.index[nodeName]; - } - - let node = this.hierarchy.node(nodeName); - // Exit early if the node does not exist in the hierarchy. This can happen - // when a graph is reloaded while the infocard points to a node not visible - // at the top-level. - if (!node) { - return null; - } - let renderInfo = node.isGroupNode ? - new RenderGroupNodeInfo(node) : - new RenderNodeInfo(node); - this.index[nodeName] = renderInfo; - this.renderedOpNames.push(nodeName); - - if (node.stats) { - renderInfo.memoryColor = this.memoryUsageScale(node.stats.totalBytes); - renderInfo.computeTimeColor = - this.computeTimeScale(node.stats.getTotalMicros()); - } - - if (!node.isGroupNode) { - let clusterName = (node as OpNode).xlaCluster; - if (clusterName) { - renderInfo.xlaClusterColor = this.xlaClusterColorMap(clusterName); - } - } - - // We only fade nodes when we're displaying stats. - renderInfo.isFadedOut = this.displayingStats && - !tf.graph.util.hasDisplayableNodeStats(node.stats); - - if (node.isGroupNode) { - // Make a list of tuples (device, proportion), where proportion - // is the fraction of op nodes that have that device. - let pairs = _.pairs((node).deviceHistogram); - if (pairs.length > 0) { - // Compute the total # of devices. - let numDevices = _.sum(pairs, _.last); - renderInfo.deviceColors = _.map(pairs, pair => ({ - color: this.deviceColorMap(pair[0]), - // Normalize to a proportion of total # of devices. - proportion: pair[1] / numDevices - })); - } - } else { - let device = (renderInfo.node).device; - if (device) { - renderInfo.deviceColors = [{ - color: this.deviceColorMap(device), - proportion: 1.0 - }]; - } - } - - return this.index[nodeName]; - } - - /** - * Return the nearest ancestor node, including itself, that is visible - * in the visualization. This method is used so that we can select - * (highlight) a node that isn't drawn yet, by selecting (highlighting) - * its nearest ancestor that has been drawn. - */ - getNearestVisibleAncestor(name: string): string { - let path = getHierarchicalPath(name); - for (let i = 0; i < path.length; i++) { - let nodeName = path[i]; - // Op nodes have expanded set to false by default. - if (!this.getRenderNodeByName(nodeName).expanded) { - return nodeName; - } - } - // Fallthrough. If everything was expanded return the node. - return name; - } - - // TODO(jimbo): Delete this an any code it touches (all deprecated). - setDepth(depth: number): void { - setGroupNodeDepth(this.root, +depth); - } - - /** - * Returns true if the renderNode is an isolated node within its parent node. - */ - isNodeAuxiliary(renderNode: RenderNodeInfo): boolean { - let parentNode = this.getRenderNodeByName( - renderNode.node.parentNode.name); - let found = _.find(parentNode.isolatedInExtract, node => { - return node.node.name === renderNode.node.name; - }); - if (found) { - return true; - } - found = _.find(parentNode.isolatedOutExtract, node => { - return node.node.name === renderNode.node.name; - }); - return !!found; - } - - /** - * Returns a list of ops that have been rendered so far for this graph. More - * ops may later be rendered if the user expands nodes for instance. The list - * returned here can only stay the same size or grow on successive calls. - */ - getNamesOfRenderedOps(): string[] { - return this.renderedOpNames; - } - - buildSubhierarchy(nodeName: string): void { - // Terminate if the rendering hierarchy was already constructed - // for this node. - if (nodeName in this.hasSubhierarchy) { - return; - } - - let renderNodeInfo = this.index[nodeName]; - - // If it is not a meta node or a series node, don't do anything. - if (renderNodeInfo.node.type !== NodeType.META && - renderNodeInfo.node.type !== NodeType.SERIES) { - return; - } - - // At this point we know the rendering information is about a group node. - let renderGroupNodeInfo = renderNodeInfo; - let metagraph = renderGroupNodeInfo.node.metagraph; - let coreGraph = renderGroupNodeInfo.coreGraph; - - // Create render nodes to represent each child from the metagraph. Although - // these will initially be added to the coreGraph, they may later be - // extracted. Also, due to extraction, the coreGraph may contain disjoint - // groups between which there is no visible path (other than annotations). - _.each(metagraph.nodes(), childName => { - - let childRenderInfo = this.getOrCreateRenderNodeByName(childName); - let childNode = childRenderInfo.node; - - coreGraph.setNode(childName, childRenderInfo); - - if (!childNode.isGroupNode) { - _.each((childNode).inEmbeddings, embedding => { - let renderMetaedgeInfo = new RenderMetaedgeInfo(null); - addInAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo, - AnnotationType.CONSTANT); - this.index[embedding.name] = new RenderNodeInfo(embedding); - }); - _.each((childNode).outEmbeddings, embedding => { - let renderMetaedgeInfo = new RenderMetaedgeInfo(null); - addOutAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo, - AnnotationType.SUMMARY); - this.index[embedding.name] = new RenderNodeInfo(embedding); - }); - } - - }); - - // Add render metaedge info for edges in the metagraph. - _.each(metagraph.edges(), edgeObj => { - let metaedge = metagraph.edge(edgeObj); - let renderMetaedgeInfo = new RenderMetaedgeInfo(metaedge); - renderMetaedgeInfo.isFadedOut = - this.index[edgeObj.v].isFadedOut || this.index[edgeObj.w].isFadedOut; - coreGraph.setEdge(edgeObj.v, edgeObj.w, renderMetaedgeInfo); - }); - - if (PARAMS.enableExtraction && - renderGroupNodeInfo.node.type === NodeType.META) { - extractHighDegrees(renderGroupNodeInfo); - } - - // Record that we constructed the rendering hierarchy for this node, so we - // don't construct it another time. - this.hasSubhierarchy[nodeName] = true; - - // Look up the parent node's render information and short circuit if none. - let parentNode = renderGroupNodeInfo.node.parentNode; - if (!parentNode) { - return; - } - let parentNodeInfo = - this.index[parentNode.name]; - - // Utility function for computing the name of a bridge node. - let getBridgeNodeName = (inbound, ...rest) => - rest.concat([inbound ? 'IN' : 'OUT']).join('~~'); - - // Build out the bridgegraph. - let bridgegraph = this.hierarchy.getBridgegraph(nodeName); - - // Look for popular nodes so we can make annotations instead of paths. - let otherCounts = { - // Counts of edges coming INTO other nodes by name (outgoing from self). - in: <{[nodeName: string]: number}> {}, - // Counts of edges going OUT from other nodes by name (coming into self). - out: <{[nodeName: string]: number}> {}, - // Counts of all control edges involving other nodes by name. - control: <{[nodeName: string]: number}> {}, - }; - _.each(bridgegraph.edges(), e => { - // An edge is inbound if its destination node is in the metagraph. - let inbound = !!metagraph.node(e.w); - let otherName = inbound ? e.v : e.w; - let metaedge = bridgegraph.edge(e); - if (!metaedge.numRegularEdges) { - otherCounts.control[otherName] = - (otherCounts.control[otherName] || 0) + 1; - } else if (inbound) { - otherCounts.out[otherName] = (otherCounts.out[otherName] || 0) + 1; - } else { - otherCounts.in[otherName] = (otherCounts.in[otherName] || 0) + 1; - } - }); - - // Add annotations and edges for bridgegraph relationships. - let hierarchyNodeMap = this.hierarchy.getNodeMap(); - _.each(bridgegraph.edges(), bridgeEdgeObj => { - let bridgeMetaedge = bridgegraph.edge(bridgeEdgeObj); - - // Determine whether this bridge edge is incoming by checking the - // metagraph for a node that matches the destination end. - let inbound = !!metagraph.node(bridgeEdgeObj.w); - - // Based on the direction of the edge, one endpoint will be an immediate - // child of this renderNodeInfo, and the other endpoint will be a sibling - // of the parent (or an ancestor further up). - let [childName, otherName] = - inbound ? - [bridgeEdgeObj.w, bridgeEdgeObj.v] : - [bridgeEdgeObj.v, bridgeEdgeObj.w]; - - let childRenderInfo = this.index[childName]; - let otherRenderInfo = this.index[otherName]; - let otherNode = - otherRenderInfo ? - otherRenderInfo.node : - hierarchyNodeMap[otherName]; - - // Determine whether this edge is a control edge between nodes where - // either node is high-degree with respect to control edges. This will - // be a signal to show it as an annotation instead of a bridge edge. - let isHighDegreeControlEdge = !bridgeMetaedge.numRegularEdges && - otherCounts.control[otherName] > PARAMS.maxControlDegree; - - let [, childAnnotations] = - inbound ? - [renderNodeInfo.inAnnotations, childRenderInfo.inAnnotations] : - [renderNodeInfo.outAnnotations, childRenderInfo.outAnnotations]; - - // Don't render a bridge path if the other node has in or out degree above - // a threshold, lest bridge paths emanating out of a metagraph crowd up, - // as was the case for the Fatcat LSTM lstm_1 > lstm_1 metagraph. - let otherDegreeCount = - (inbound ? otherCounts.out : otherCounts.in)[otherName]; - let isOtherHighDegree = otherDegreeCount > PARAMS.maxBridgePathDegree; - - // The adjoining render metaedge info from the parent's coreGraph, if any. - // It will either be a Metaedge involving this node directly, if it - // previously came from a metagraph, or it'll be a Metaedge involving - // a previously created bridge node standing in for the other node. - let adjoiningMetaedge = null; - - // We can only hope to render a bridge path if: - // - bridgegraph paths are enabled, - // - the other node is not too high-degree, - // - the child is in the core (not extracted for being high-degree), and - // - there's a path (in the traversal sense) between child and other. - let canDrawBridgePath = false; - if (PARAMS.enableBridgegraph && - !isOtherHighDegree && - !isHighDegreeControlEdge && - childRenderInfo.isInCore()) { - - // Utility function for finding an adjoining metaedge. - let findAdjoiningMetaedge = targetName => { - let adjoiningEdgeObj: graphlib.EdgeObject = - inbound ? - { v: targetName, w: nodeName } : - { v: nodeName, w: targetName }; - return - parentNodeInfo.coreGraph.edge(adjoiningEdgeObj); - }; - - adjoiningMetaedge = findAdjoiningMetaedge(otherName); - if (!adjoiningMetaedge) { - adjoiningMetaedge = findAdjoiningMetaedge( - getBridgeNodeName(inbound, otherName, parentNode.name)); - } - - canDrawBridgePath = !!adjoiningMetaedge; - } - - // Although dataflow edges are acyclic, control dependency edges may - // actually point 'backwards' in the graph. If this bridgeMetaedge is - // a control dependency, we need to determine whether it's backwards - // pointing so that we render it appropriately. - // - // For instance, say we're rendering a graph with nodes named A/B and Z/Y, - // and we're currently rendering the bridgegraph for A. Further, let's say - // that there was an original BaseEdge from A/B->Z/Y and a CONTROL EDGE - // from Z/Y=>A/B. - // - // +----------------+ - // | A | - // | +-----+ | +------+ - // | | B |>----->|>------->| Z | - // | | | | | | - // | | | * | | | - // | | |<=====<|<=======<| | - // | +-----+ | +------+ - // +----------------+ - // - // When we render the subhierarchy for Metanode A, we'll come across a - // control-only Metaedge in the bridgegraph from Z=>A/B (*). The question - // is whether this edge is backwards. - // - // To answer that question, we follow the chain of adjoining metaedges - // until we reach the topmost one. In this case, that's the control-only - // Metaedge Z=>A in the ROOT's metagraph. We determine that this edge - // is backwards by looking at the topological ordering of ROOT's metagraph - // (which ignores control edges) and seeing that Z comes AFTER A. - // - // The property of being backwards is independent of whether the edge - // is inbound or outbound. In the preceding example, if we were building - // the subhierarchy for Z, we'd find bridge edge Z/Y=>A, walk to its - // topmost adjoining metaedge Z=>A and discover that it's backwards. - let backwards = false; - if (adjoiningMetaedge && !bridgeMetaedge.numRegularEdges) { - // Find the top-most adjoining render metaedge information, and the - // GroupNode whose metagraph must contain the associated metaedge. - let topAdjoiningMetaedge = adjoiningMetaedge; - let topGroupNode = parentNodeInfo.node; - while (topAdjoiningMetaedge.adjoiningMetaedge) { - topAdjoiningMetaedge = topAdjoiningMetaedge.adjoiningMetaedge; - topGroupNode = topGroupNode.parentNode; - } - - // Check against the topological ordering for the top node. The current - // bridge metaedge we're evaluating is backwards if its source comes - // after its destination. - let ordering = this.hierarchy.getTopologicalOrdering(topGroupNode.name); - let e = topAdjoiningMetaedge.metaedge; - backwards = ordering[e.v] > ordering[e.w]; - } - - // Render backwards control edges as annotations. - canDrawBridgePath = canDrawBridgePath && !backwards; - - // If we can't make a bridge path for any reason, then we add an - // annotation instead. - if (!canDrawBridgePath) { - childAnnotations.push(new Annotation( - otherNode, - otherRenderInfo, - new RenderMetaedgeInfo(bridgeMetaedge), - AnnotationType.SHORTCUT, - inbound)); - return; - } - - // At this point, all conditions have been met for drawing a bridge path. - - // Find or create the IN/OUT node representing otherNode. - let bridgeContainerName = getBridgeNodeName(inbound, nodeName); - let bridgeNodeName = getBridgeNodeName(inbound, otherName, nodeName); - let bridgeNodeRenderInfo = coreGraph.node(bridgeNodeName); - if (!bridgeNodeRenderInfo) { - - // Find or create the directional container for the bridge node. - let bridgeContainerInfo = coreGraph.node(bridgeContainerName); - if (!bridgeContainerInfo) { - let bridgeContainerNode: BridgeNode = { - // Important node properties. - name: bridgeContainerName, - type: NodeType.BRIDGE, - // Unused node properties. - isGroupNode: false, - cardinality: 0, - parentNode: null, - stats: null, - include: InclusionType.UNSPECIFIED, - // BridgeNode properties. - inbound: inbound, - nodeAttributes: {}, - }; - bridgeContainerInfo = - new RenderNodeInfo(bridgeContainerNode); - this.index[bridgeContainerName] = bridgeContainerInfo; - coreGraph.setNode(bridgeContainerName, bridgeContainerInfo); - } - - let bridgeNode: BridgeNode = { - // Important node properties. - name: bridgeNodeName, - type: NodeType.BRIDGE, - // Unimportant node properties. - isGroupNode: false, - cardinality: 1, - parentNode: null, - stats: null, - include: InclusionType.UNSPECIFIED, - // BridgeNode properties. - inbound: inbound, - nodeAttributes: {}, - }; - bridgeNodeRenderInfo = new RenderNodeInfo(bridgeNode); - this.index[bridgeNodeName] = bridgeNodeRenderInfo; - coreGraph.setNode(bridgeNodeName, bridgeNodeRenderInfo); - - // Set bridgeNode to be a graphlib child of the container node. - coreGraph.setParent(bridgeNodeName, bridgeContainerName); - bridgeContainerInfo.node.cardinality++; - } - - // Create and add a bridge render metaedge. - let bridgeRenderMetaedge = - new RenderMetaedgeInfo(bridgeMetaedge); - bridgeRenderMetaedge.adjoiningMetaedge = adjoiningMetaedge; - inbound ? - coreGraph.setEdge(bridgeNodeName, childName, bridgeRenderMetaedge) : - coreGraph.setEdge(childName, bridgeNodeName, bridgeRenderMetaedge); - - }); // End _.each(bridgegraph.edges). - - // For each bridge container (IN and/or OUT), add structural edges between - // terminal nodes and that container. A terminal node is one which has no - // non-bridge edges in the direction of the container. - // - // For example, consider a Metanode A which contains two child nodes A/B - // and A/C. Let's say it has one edge in the metagraph from A/B->A/C, and - // one edge in the bridgegraph from Z->A/C. - // - // At this point, we've added a container bridge node IN to house all - // incoming bridge nodes. We've also added a bridge node Z' (with parent IN) - // to A, and a bridge edge from Z'->C. - // - // +----------------------+ - // | A +---+ | - // | +------>| C | | - // | | +---+ | - // | | ^ | - // | | | | - // | | +----|----+ | - // | | | IN | | | - // | +---+ | +---+ | | - // | | B | | | Z'| | | - // | +---+ | +---+ | | - // | +---------+ | - // +----------------------+ - // - // With no other help, dagre would lay out B and Z' on the same level, - // because both of them have no incoming edges. In other words, B is a - // terminal node in the INCOMING direction. - // - // But we want to force dagre to lay out Z' (and everything in IN) lower - // than all non-bridge nodes, so that there's enough room for the bridge - // edges after they've been adjusted to meet up with paths coming in from - // outside. - // - // To force Z' (and all other bridge nodes) to be lowest in the graph, we - // identify terminal nodes like B and give them structural edges to - // a new structural bridge node S which we add to IN. - // - // +----------------------+ - // | A +---+ | - // | +--->| C | | - // | | +---+ | - // | +---+ ^ | - // | | B | | | - // | +---+ | | - // | ^ | | - // | | | | - // | +----|------|----+ | - // | |IN | | | | - // | | +---+ +---+ | | - // | | | S | | Z'| | | - // | | +---+ +---+ | | - // | +----------------+ | - // +----------------------+ - // - // This ensures that dagre will lay out the bridge containers strictly at - // the ends of the graph. The structural edges will never be seen in the - // visualization except as a debugging aid. - _.each([true, false], inbound => { - let bridgeContainerName = getBridgeNodeName(inbound, nodeName); - let bridgeContainerInfo = coreGraph.node(bridgeContainerName); - if (!bridgeContainerInfo) { - return; - } - _.each(coreGraph.nodes(), childName => { - // Short-circuit if this child is a bridge node or it's not a terminal - // node in the direction we're interested in. - let childNodeInfo = coreGraph.node(childName); - if (childNodeInfo.node.type === NodeType.BRIDGE) { - return; - } - let isTerminal = inbound ? - !coreGraph.predecessors(childName).length : - !coreGraph.successors(childName).length; - if (!isTerminal) { - return; - } - - // Find or create a bridge node in the container for all structural - // metaedges. It would have been nice to skip this step and simply - // set a metaedge between the terminal node and the container node, but - // in that case, something about the graph upsets dagre.layout()'s - // longestPath algorithm (was getting errors due to an undefined). - let structuralNodeName = - getBridgeNodeName(inbound, nodeName, 'STRUCTURAL_TARGET'); - let structuralRenderInfo = coreGraph.node(structuralNodeName); - if (!structuralRenderInfo) { - let bridgeNode: BridgeNode = { - // Important Node properties. - name: structuralNodeName, - type: NodeType.BRIDGE, - // Unimportant Node properties. - isGroupNode: false, - cardinality: 1, - parentNode: null, - stats: null, - include: InclusionType.UNSPECIFIED, - // BridgeNode properties. - inbound: inbound, - nodeAttributes: {}, - }; - structuralRenderInfo = new RenderNodeInfo(bridgeNode); - structuralRenderInfo.structural = true; - this.index[structuralNodeName] = structuralRenderInfo; - coreGraph.setNode(structuralNodeName, structuralRenderInfo); - bridgeContainerInfo.node.cardinality++; - coreGraph.setParent(structuralNodeName, bridgeContainerName); - } - - // Create the structural Metaedge and insert it. - let structuralMetaedgeInfo = new RenderMetaedgeInfo(null); - structuralMetaedgeInfo.structural = true; - structuralMetaedgeInfo.weight--; // Reduce weight for dagre layout. - inbound ? - coreGraph.setEdge( - structuralNodeName, childName, structuralMetaedgeInfo) : - coreGraph.setEdge( - childName, structuralNodeName, structuralMetaedgeInfo); - }); - }); - } -} - -/** - * A class for rendering annotation object which contains label - * about the node embedded as annotation, type of annotation and the location - * of both the annotation's node and edge. - * - * Annotation objects include embedded constants, embedded summary, and - * edge shortcuts. - */ -export class Annotation { - node: Node; - renderNodeInfo: RenderNodeInfo; - renderMetaedgeInfo: RenderMetaedgeInfo; - annotationType: AnnotationType; - /** - * Center position of annotation relative to the host - * node's center x. - */ - dx: number; - /** - * Center position of annotation relative to the host - * node's center y. - */ - dy: number; - width: number; - height: number; - /** - * The names of nodes on either side of this edge. - */ - v: string; - w: string; - /** - * A flag whether it is an in-annotation (if true) or - * out-annotation (if false). - */ - isIn: boolean; - /** Label horizontal offset from the end of the node shape */ - labelOffset: number; - /** - * Array of points for edges from the annotation to its host - * node. Each point contains the point location, relative to - * the host node's center. - */ - points: {dx: number, dy: number}[]; - - /** - * Creates a new Annotation. - * - * @param node The underlying node this annotation points to. - * @param renderNodeInfo The render information for the underlying node - * this annotation points to. This can be null if the annotation - * denotes an embedding (constant, summary), in which case we - * use the node property. - * @param renderMetaedgeInfo The render information for the edge associated - * with the annotation. - * @param type The type of the annotation. - * @param isIn True if it is an in-annotation. False if it is an - * out-annotation. - */ - constructor(node: Node, renderNodeInfo: RenderNodeInfo, - renderMetaedgeInfo: RenderMetaedgeInfo, type: AnnotationType, - isIn: boolean) { - this.node = node; - this.renderNodeInfo = renderNodeInfo; - this.renderMetaedgeInfo = renderMetaedgeInfo; - this.annotationType = type; - // Properties specified by layout - this.dx = 0; - this.dy = 0; - this.width = 0; - this.height = 0; - // Properties needed for generating an ID for the edge's path element if - // this annotation is associated with a metaedge. - if (renderMetaedgeInfo && renderMetaedgeInfo.metaedge) { - this.v = renderMetaedgeInfo.metaedge.v; - this.w = renderMetaedgeInfo.metaedge.w; - } - - this.isIn = isIn; - this.points = []; - } -}; - -export enum AnnotationType {SHORTCUT, CONSTANT, SUMMARY, ELLIPSIS}; - -/** - * Manages a list of annotations. Two will be used for each - * RenderNodeInfo, one for in annotations and one for out annotations. - */ -export class AnnotationList { - /** - * List of visually drawable annotations, may include an ellipses annotation - * if the number added exceeds the number specified by maxAnnotations. - */ - list: Annotation[]; - - /** - * Set of nodes which have been added as annotations to this list, so we can - * prevent duplicates. - */ - nodeNames: { [nodeName: string]: boolean }; - - constructor() { - this.list = []; - this.nodeNames = {}; - } - - /** - * Append an annotation to the list, or a stand-in ellipsis annotation instead - * if this would make it too many. - */ - push(annotation: Annotation): void { - if (annotation.node.name in this.nodeNames) { - return; // Skip duplicate annotation. - } - this.nodeNames[annotation.node.name] = true; - - if (this.list.length < PARAMS.maxAnnotations) { - this.list.push(annotation); - return; - } - - let lastAnnotation = this.list[this.list.length - 1]; - if (lastAnnotation.annotationType === AnnotationType.ELLIPSIS) { - let ellipsisNode = lastAnnotation.node; - ellipsisNode.setNumMoreNodes(++ellipsisNode.numMoreNodes); - return; - } - - let ellipsisNode = new tf.graph.EllipsisNodeImpl(1); - this.list.push(new Annotation(ellipsisNode, - new RenderNodeInfo(ellipsisNode), null, - AnnotationType.ELLIPSIS, annotation.isIn)); - } -} - -/** - * Contains rendering information about a node in the hierarchical graph. - */ -export class RenderNodeInfo { - /** Reference to the original underlying Node from the hierarchical graph. */ - node: Node; - /** Whether the node is expanded or not. */ - expanded: boolean; - /** - * List of rendering information about in-annotations like constants and - * shortcuts to high-degree nodes. - */ - inAnnotations: AnnotationList; - /** - * List of rendering information about out-annotations (e.g. summary nodes) - */ - outAnnotations: AnnotationList; - - // --- Params specified by layout --- // - - /** Center x position */ - x: number; - /** Center y position */ - y: number; - /** - * Total width of the node's shape, including in- and out-annotations. This - * property is used by dagre to layout the graph. - */ - width: number; - /** - * Total height of the node's shape, including in- and out-annotations. This - * property is used by dagre to layout the graph. - */ - height: number; - /** - * Size of the main box of the node, excluding in- and out-annotations. This - * property is used to draw the rectangle/ellipse shape denoting the node. - */ - coreBox: { - width: number, - height: number, - }; - - /** Width of the bounding box for all in-annotations. */ - inboxWidth: number; - /** Width of the bounding box for all out-annotations. */ - outboxWidth: number; - /** - * Whether the node should be excluded from the scene. - * This is only used when there are too many items in a series so we only - * want to include top N ones. - */ - // TODO(jimbo): Now that series rendering is non-recursive, remove this and - // all its uses from the code base. - excluded: boolean; - - // --- Params used in drawing the bridge paths --- // - - /** - * All bridge nodes are meant to be invisible, but whereas most represent a - * relationship from the underlying graph hierarchy, some exist solely for - * layout reasons. Specifically, those bridge nodes which have only structural - * rendering metaedges. - */ - structural: boolean; - - // --- Params for the size of the node box --- // - - /** Label vertical offset from the center of node shape */ - labelOffset: number; - /** Rectangle radius (for making rounded rectangle) */ - radius: number; - - // --- Params for expanded node --- // - - /** Label height for expanded node. */ - labelHeight: number; - // Paddings between inner subscene and the border of the expanded node. - paddingTop: number; - paddingLeft: number; - paddingRight: number; - paddingBottom: number; - - /** - * Whether a node is extracted as source-like (having high out-degree or - * matching predefined in-extract pattern.) - */ - isInExtract: boolean; - /** - * Whether a node is extracted as sink-like (having high in-degree or matching - * predefined out-extract pattern.) - */ - isOutExtract: boolean; - - /** - * List of (color, proportion) tuples based on the proportion of devices of - * its children. If this node is an op node, this list will have only one - * color with proportion 1.0. - */ - deviceColors: Array<{color: string, proportion: number}>; - - /** - * Color according to the XLA cluster of this node. - */ - xlaClusterColor: string; - - /** - * Color according to the memory usage of this node. - */ - memoryColor: string; - - /** - * Color according to the compute time of this node. - */ - computeTimeColor: string; - - /** - * Whether this node is faded out. Used when displaying stats. - */ - isFadedOut: boolean; - - constructor(node: Node) { - this.node = node; - this.expanded = false; - this.inAnnotations = new AnnotationList(); - this.outAnnotations = new AnnotationList(); - // Params specified by layout - this.x = 0; - this.y = 0; - this.width = 0; - this.height = 0; - this.inboxWidth = 0; - this.outboxWidth = 0; - - this.excluded = false; - - // Params for bridge paths. - this.structural = false; - - // Params for node box. - this.labelOffset = 0; - this.radius = 0; - - // Params for expanded node - this.labelHeight = 0; - this.paddingTop = 0; - this.paddingLeft = 0; - this.paddingRight = 0; - this.paddingBottom = 0; - this.isInExtract = false; - this.isOutExtract = false; - this.coreBox = {width: 0, height: 0}; - - // By default, we don't fade nodes out. Default to false for safety. - this.isFadedOut = false; - } - - isInCore(): boolean { - return !this.isInExtract && !this.isOutExtract; - } -} - -/** - * Contains rendering information about a Metaedge from the underlying - * hierarchical graph. It may be from either a metagraph or a bridgegraph. - */ -export class RenderMetaedgeInfo { - /** - * Reference to the original underlying Metaedge from the hierarchical graph, - * if any. This will be null for the edges which connect OpNodes to their - * embeddings, for example. - */ - metaedge: Metaedge; - - /** - * Reference to the adjoining RenderMetaedgeInfo from the parent's - * coreGraph. This is used during layout to determine the point at which this - * edge should touch the node's bounding box. This property will be null for - * edges which terminate at a node on both ends (all non-bridge edges). - */ - adjoiningMetaedge: RenderMetaedgeInfo; - - /** - * Most of the time, a RenderMetaedgeInfo object represents a real - * edge between nodes in the underlying graph structure. But sometimes, an - * edge only exists for layout purposes. These structural edges are added - * during buildSubhierarchy() to force dagre.layout() to put bridge nodes - * at the ends of the flow. - * @see buildSubhierarchy() - */ - structural: boolean; - - /** - * Weight of the edge, used by dagre when deciding how important an edge is. - * Edges with higher weight are made shorter and straighter. The default - * dagre uses is 1. - */ - weight: number; - - /** - * X and Y coordinate pairs of the points in the path of the edge. - * @see tf.graph.node.subsceneAdjustPaths - */ - points: Point[]; - - /** - * D3 selection of the group containing the path that displays this edge. - */ - edgeGroup: d3.Selection; - - /** Id of the used as a start-marker for the edge path. */ - startMarkerId: string; - - /** Id of the used as an end-marker for the edge path. */ - endMarkerId: string; - - /** - * Whether this edge is faded out. Used for fading out unused edges when - * displaying run statistics. - */ - isFadedOut: boolean; - - constructor(metaedge: Metaedge) { - this.metaedge = metaedge; - this.adjoiningMetaedge = null; - this.structural = false; - this.weight = 1; - this.isFadedOut = false; - } -} - -function addInAnnotation(node: RenderNodeInfo, predecessor: Node, - predecessorRenderInfo: RenderNodeInfo, - edge: RenderMetaedgeInfo, type: AnnotationType): void { - let annotation = new Annotation(predecessor, predecessorRenderInfo, edge, - type, true); - node.inAnnotations.push(annotation); -} - -function addOutAnnotation(node: RenderNodeInfo, successor: Node, - successorRenderInfo: RenderNodeInfo, edge: RenderMetaedgeInfo, - type: AnnotationType): void { - let annotation = new Annotation(successor, successorRenderInfo, edge, - type, false); - node.outAnnotations.push(annotation); -} - -function setGraphDepth(graph: graphlib.Graph, - depth: number) { - _.each(graph.nodes(), nodeName => { - let child = graph.node(nodeName); - child.expanded = depth > 1; // set all child of depth 1 to collapsed - if (depth > 0) { - switch (child.node.type) { - case NodeType.META: - case NodeType.SERIES: - setGroupNodeDepth(child, depth - 1); - break; - // Do nothing for leaf - } - } - }); -}; - -export class RenderGroupNodeInfo extends RenderNodeInfo { - node: GroupNode; - /** - * The core graph is derived from the underlying node's metagraph, minus - * the extracted source-like and sink-like nodes. - */ - coreGraph: graphlib.Graph; - /** Size of the bounding box for a metanode's isolated in-extract children. */ - inExtractBox: {width: number, height: number}; - /** - * Size of the bounding box for a metanode's isolated out-extract children. - */ - outExtractBox: {width: number, height: number}; - /** Array of isolated in-extract nodes. */ - isolatedInExtract: RenderNodeInfo[]; - /** Array of isolated out-extract nodes. */ - isolatedOutExtract: RenderNodeInfo[]; - - constructor(groupNode: GroupNode) { - super(groupNode); - let metagraph = groupNode.metagraph; - let gl = metagraph.graph(); - this.coreGraph = - createGraph( - gl.name, GraphType.CORE, { compound: true }); - this.inExtractBox = {width: 0, height: 0}; - this.outExtractBox = {width: 0, height: 0}; - this.isolatedInExtract = []; - this.isolatedOutExtract = []; - } -} - -function setGroupNodeDepth(renderInfo: RenderGroupNodeInfo, - depth: number): void { - if (renderInfo.coreGraph) { - setGraphDepth(renderInfo.coreGraph, depth); - } -} - -/** - * Remove an edge from the graph and add annotations to both ends of the edge. - * - * @param The core graph. - * @param v Source name. - * @param w Sink name. - */ -function createShortcut( - graph: graphlib.Graph, - v: string, w: string) { - let src = graph.node(v); - let sink = graph.node(w); - let edge = graph.edge(v, w); - - // If either of the nodes is explicitly included in the main graph and - // both nodes are in the main graph then do not create the shortcut - // and instead keep the real edge. - if ((src.node.include === InclusionType.INCLUDE || - sink.node.include === InclusionType.INCLUDE) && - src.node.include !== InclusionType.EXCLUDE && - sink.node.include !== InclusionType.EXCLUDE) { - return; - } - - // Add each annotation. - addOutAnnotation(src, sink.node, sink, edge, AnnotationType.SHORTCUT); - addInAnnotation(sink, src.node, src, edge, AnnotationType.SHORTCUT); - - // Remove the edge from the core graph. - graph.removeEdge(v, w); -} - -/** - * Remove edges from a node, and set its isOutExtract property to true, - * and remove the node and move it to isolatedOutExtract. - * - * If detachAllEdgesForHighDegree or forceDetach is true, extract all of its - * edges. Otherwise, only extract all in-edges. - */ -function makeOutExtract(renderNode: RenderGroupNodeInfo, n: string, - forceDetach?: boolean) { - let graph = renderNode.coreGraph; - let child = graph.node(n); - child.isOutExtract = true; - - _.each(graph.predecessors(n), (p, index) => { - createShortcut(graph, p, n); - }); - - if (PARAMS.detachAllEdgesForHighDegree || forceDetach) { - _.each(graph.successors(n), (s, index) => { - createShortcut(graph, n, s); - }); - } - - // Remove the node from the core graph if it no longer has neighbors. - if (graph.neighbors(n).length === 0) { - child.node.include = InclusionType.EXCLUDE; - renderNode.isolatedOutExtract.push(child); - graph.removeNode(n); - } -} - -/** - * Remove edges from a node, set its isInExtract property to true, - * and remove the node and move it to isolatedInExtract. - * - * If detachAllEdgesForHighDegree or forceDetach is true, extract all of its - * edges. Otherwise, only remove all out-edges. - */ -export function makeInExtract(renderNode: RenderGroupNodeInfo, n: string, - forceDetach?: boolean) { - let graph = renderNode.coreGraph; - let child = graph.node(n); - child.isInExtract = true; - - _.each(graph.successors(n), (s, index) => { - createShortcut(graph, n, s); - }); - - if (PARAMS.detachAllEdgesForHighDegree || forceDetach) { - _.each(graph.predecessors(n), (p, index) => { - createShortcut(graph, p, n); - }); - } - - // Remove the node from the core graph if it no longer has neighbors. - if (graph.neighbors(n).length === 0) { - child.node.include = InclusionType.EXCLUDE; - renderNode.isolatedInExtract.push(child); - graph.removeNode(n); - } -} - -/** - * Check whether the node's type is a member of the given list of types. - * - * @param node Node. - * @param types List of type to match. - */ -function hasTypeIn(node: Node, types: string[]): boolean { - if (node.type === NodeType.OP) { - for (let i = 0; i < types.length; i++) { - if ((node).op === types[i]) { return true; } - } - } else if (node.type === NodeType.META) { - let rootOpNode = (node).getRootOp(); - if (rootOpNode) { - for (let i = 0; i < types.length; i++) { - if (rootOpNode.op === types[i]) { return true; } - } - } - } - return false; -} - -/** Move nodes that are specified to be excluded out of the core graph. */ -function extractSpecifiedNodes(renderNode: RenderGroupNodeInfo) { - let graph = renderNode.coreGraph; - _.each(graph.nodes(), n => { - let renderInfo = graph.node(n); - if (renderInfo.node.include === InclusionType.EXCLUDE) { - if (renderNode.coreGraph.outEdges(n).length > - renderNode.coreGraph.inEdges(n).length) { - makeOutExtract(renderNode, n, true); - } else { - makeInExtract(renderNode, n, true); - } - } - }); -} - -/** Remove edges from pre-defined out-extract patterns */ -function extractPredefinedSink(renderNode: RenderGroupNodeInfo) { - let graph = renderNode.coreGraph; - _.each(graph.nodes(), n => { - let renderInfo = graph.node(n); - if (renderInfo.node.include !== InclusionType.UNSPECIFIED) { - return; - } - if (hasTypeIn(renderInfo.node, PARAMS.outExtractTypes)) { - makeOutExtract(renderNode, n); - } - }); -} - -/** Remove edges from pre-defined in-extract patterns */ -function extractPredefinedSource(renderNode) { - let graph = renderNode.coreGraph; - _.each(graph.nodes(), n => { - let renderInfo = graph.node(n); - if (renderInfo.node.include !== InclusionType.UNSPECIFIED) { - return; - } - if (hasTypeIn(renderInfo.node, PARAMS.inExtractTypes)) { - makeInExtract(renderNode, n); - } - }); -} - -/** Extract nodes deemed to have either high in-degree or high out-degree. */ -function extractHighInOrOutDegree(renderNode: RenderGroupNodeInfo) { - let graph = renderNode.coreGraph; - - // Create mappings from node to in and out degrees. Count the number of valid - // nodes along the way. - let nodeToInDegree = {}; - let nodeToOutDegree = {}; - let validNodeCount = 0; - _.each(graph.nodes(), currentNode => { - if (graph.node(currentNode).node.include !== InclusionType.UNSPECIFIED) { - // This node is not included in the first place. - return; - } - - // Count the in and out degrees based on only regular edges, unless there - // are no regular edges, in which case use the number of control edges. - // This is done so that control edges don't affect if nodes are extracted - // from the core graph, unless the node is only used for control. - let inDegree = - _.reduce(graph.predecessors(currentNode), (inDegree, pred) => { - let metaedge = graph.edge(pred, currentNode).metaedge; - return inDegree + (metaedge.numRegularEdges ? 1 : 0); - }, 0); - if (inDegree === 0 && graph.predecessors(currentNode).length > 0) { - inDegree = graph.predecessors(currentNode).length; - } - - let outDegree = - _.reduce(graph.successors(currentNode), (outDegree, succ) => { - let metaedge = graph.edge(currentNode, succ).metaedge; - return outDegree + (metaedge.numRegularEdges ? 1 : 0); - }, 0); - if (outDegree === 0 && graph.successors(currentNode).length > 0) { - outDegree = graph.successors(currentNode).length; - } - - // Store the in and out degrees of this node to avoid recomputing. - nodeToInDegree[currentNode] = inDegree; - nodeToOutDegree[currentNode] = outDegree; - validNodeCount++; - }); - - if (validNodeCount < PARAMS.minNodeCountForExtraction) { - // This graph has few nodes. Do not extract any nodes. - return; - } - - // We only extract if the node has a min in or out degree greater than this. - let minUpperBound = PARAMS.minDegreeForExtraction - 1; - - // Mark for extraction nodes with in-degree > Q3 + (Q3 - Q1). - let q3Index = Math.round(validNodeCount * 0.75); - let q1Index = Math.round(validNodeCount * 0.25); - let sortedByInDegree = Object.keys(nodeToInDegree).sort((node0, node1) => { - return nodeToInDegree[node0] - nodeToInDegree[node1]; - }); - let inDegreeQ3 = nodeToInDegree[sortedByInDegree[q3Index]]; - let inDegreeQ1 = nodeToInDegree[sortedByInDegree[q1Index]]; - let inDegreeUpperBound = inDegreeQ3 + inDegreeQ3 - inDegreeQ1; - // Only extract if the upper bound is high enough. - inDegreeUpperBound = Math.max(inDegreeUpperBound, minUpperBound); - for (let i = validNodeCount - 1; - nodeToInDegree[sortedByInDegree[i]] > inDegreeUpperBound; i--) { - // Extract a high in-degree node. - makeInExtract(renderNode, sortedByInDegree[i]); - } - - // Mark for extraction nodes with out-degree > Q3 + (Q3 - Q1) * 4. - let sortedByOutDegree = Object.keys(nodeToOutDegree).sort((node0, node1) => { - return nodeToOutDegree[node0] - nodeToOutDegree[node1]; - }); - let outDegreeQ3 = nodeToOutDegree[sortedByOutDegree[q3Index]]; - let outDegreeQ1 = nodeToOutDegree[sortedByOutDegree[q1Index]]; - // The upper bound for extracting out-degree nodes is higher than that for - // extracting in-degree ones (Note the "* 4") because, in practice, some - // graphs look worse with a smaller out-degree bound. For instance, a smaller - // out-degree bound removes the convolution nodes from cifar 10 train's graph. - let outDegreeUpperBound = outDegreeQ3 + (outDegreeQ3 - outDegreeQ1) * 4; - // Only extract if the upper bound is high enough. - outDegreeUpperBound = Math.max(outDegreeUpperBound, minUpperBound); - for (let i = validNodeCount - 1; - nodeToOutDegree[sortedByOutDegree[i]] > outDegreeUpperBound; i--) { - let node = graph.node(sortedByOutDegree[i]); - if (!node || node.isInExtract) { - // This node has already been extracted due to high in-degree. It might - // have been removed from the graph in general (during in-degree - // extraction) due to a lack of neighbors. Do not extract this node twice. - continue; - } - - // Extract a high out-degree node that has not already been extracted. - makeOutExtract(renderNode, sortedByOutDegree[i]); - } -} - -/** Remove control edges from nodes that have too many control edges */ -function removeControlEdges(renderNode: RenderGroupNodeInfo) { - let graph = renderNode.coreGraph; - - // Collect control edges into a map by node name. - let map = <{[nodeName: string]: graphlib.EdgeObject[]}>{}; - _.each(graph.edges(), e => { - if (!graph.edge(e).metaedge.numRegularEdges) { - (map[e.v] = map[e.v] || []).push(e); - (map[e.w] = map[e.w] || []).push(e); - } - }); - - // For each node with too many control edges, turn them into annotations. - _.each(map, (edges, nodeName) => { - if (edges.length > PARAMS.maxControlDegree) { - _.each(edges, e => createShortcut(graph, e.v, e.w)); - } - }); -} - -/** - * Given an integer, picks a hue that is far apart from other colors. - * The formula for picking color that avoid collision is: - * hue = (color range * golden ratio * index) % color range - */ -export function mapIndexToHue(id: number): number { - let GOLDEN_RATIO = 1.61803398875; - // Hue of 0 is reserved for the gray nodes. - let MIN_HUE = 1; - let MAX_HUE = 359; - let COLOR_RANGE = MAX_HUE - MIN_HUE; - return MIN_HUE + ((COLOR_RANGE * GOLDEN_RATIO * id) % COLOR_RANGE); -}; - -/** - * Remove edges and add to annotation instead. - * - * For root node, consider predefined types for source and sink. - * We do not extract predefined type from non-root so that Variables and the - * sgd node (op type = 'NoOp') do not get extract from inside own group. - * - * The order of extraction is important here as swapping the order can totally - * screw up the graph layout. - * - * @param {Render.Node} renderNode Node to manipulate. - */ -function extractHighDegrees(renderNode: RenderGroupNodeInfo) { - - extractSpecifiedNodes(renderNode); - - if (PARAMS.outExtractTypes) { - extractPredefinedSink(renderNode); - } - - // This has to come before extract high in-degree to protect the core part - // that takes many variables. - if (PARAMS.inExtractTypes) { - extractPredefinedSource(renderNode); - } - - extractHighInOrOutDegree(renderNode); - - if (PARAMS.maxControlDegree) { - removeControlEdges(renderNode); - } - - // Extract isolated nodes, which can be - // (1) source-like and sink-like nodes that are not originally isolated but - // become isolated after further removal. - // (2) isolated nodes with annotations on one-side. These might be either - // - nodes that originally have high out-degree but because we remove - // high in-degree nodes first, they no longer have high in-degree when - // we check. (Detecting all high-degree before removing also leads to - // another problem.) - // - nodes that do not have high degree, but their neighbors are all - // extracted, so it might make sense to extract them too. - - let graph = renderNode.coreGraph; - _.each(graph.nodes(), n => { - let child = graph.node(n); - let degree = graph.neighbors(n).length; - if (child.node.include !== InclusionType.UNSPECIFIED) { - return; - } - if (degree === 0) { - let hasOutAnnotations = child.outAnnotations.list.length > 0; - let hasInAnnotations = child.inAnnotations.list.length > 0; - - if (child.isInExtract) { // Is source-like. - // This case only happens if detachAllEdgesForHighDegree is false. - // (Otherwise all source-like nodes are all isolated already.) - renderNode.isolatedInExtract.push(child); - child.node.include = InclusionType.EXCLUDE; - graph.removeNode(n); - } else if (child.isOutExtract) { // Is sink-like. - // This case only happens if detachAllEdgesForHighDegree is false. - // // (Otherwise all sink-like nodes are all isolated already.) - renderNode.isolatedOutExtract.push(child); - child.node.include = InclusionType.EXCLUDE; - graph.removeNode(n); - } else if (PARAMS.extractIsolatedNodesWithAnnotationsOnOneSide) { - if (hasOutAnnotations && !hasInAnnotations) { - child.isInExtract = true; // for ones with high out-annotations - renderNode.isolatedInExtract.push(child); - child.node.include = InclusionType.EXCLUDE; - graph.removeNode(n); - } else if (hasInAnnotations && !hasOutAnnotations) { - child.isOutExtract = true; // for ones with high in-annotations - renderNode.isolatedOutExtract.push(child); - child.node.include = InclusionType.EXCLUDE; - graph.removeNode(n); - } else { - // if a low degree node has both in- & out- annotations, do nothing - // because it is unclear which side it should go to. - } - } - } - }); -} - -/** - * Expands nodes in the graph until the desired node is visible. - * - * @param scene The scene polymer component. - * @param renderHierarchy The render hierarchy. - * @param tensorName The name of a tensor. - * @return A string that is the name of the node representing the given tensor. - * Note that the original tensor name might differ from this returned node - * name. Specifically, for instance, the tensor name usually ends with an - * output slot index (such as :0), while the node name lacks that suffix. - */ -export function expandUntilNodeIsShown( - scene, renderHierarchy, tensorName: string) { - const splitTensorName = tensorName.split('/'); - - // Graph names do not take into account the output slot. Strip it. - const lastNodeNameMatch = - splitTensorName[splitTensorName.length - 1].match(/(.*):\d+/); - if (lastNodeNameMatch.length === 2) { - splitTensorName[splitTensorName.length - 1] = lastNodeNameMatch[1]; - } - - let nodeName = splitTensorName[0]; - let renderNode = renderHierarchy.getRenderNodeByName(nodeName); - for (let i = 1; i < splitTensorName.length; i++) { - // Op nodes are not expandable. - if (renderNode.node.type === tf.graph.NodeType.OP) { - break; - } - renderHierarchy.buildSubhierarchy(nodeName); - renderNode.expanded = true; - scene.setNodeExpanded(renderNode); - nodeName += '/' + splitTensorName[i]; - renderNode = renderHierarchy.getRenderNodeByName(nodeName); - } - - return renderNode.node.name; -} - -} // close module tf.graph.render diff --git a/tensorflow/tensorboard/components/tf_graph_common/scene.ts b/tensorflow/tensorboard/components/tf_graph_common/scene.ts deleted file mode 100644 index 14d35efd9ff..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/scene.ts +++ /dev/null @@ -1,735 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -module tf.graph.scene { - const svgNamespace = 'http://www.w3.org/2000/svg'; - - /** Enums element class of objects in the scene */ - export let Class = { - Node: { - // element that contains nodes. - CONTAINER: 'nodes', - // element that contains detail about a node. - GROUP: 'node', - // element that contains visual elements (like rect, ellipse). - SHAPE: 'nodeshape', - // <*> element(s) under SHAPE that should receive color updates. - COLOR_TARGET: 'nodecolortarget', - // element showing the node's label. - LABEL: 'nodelabel', - // element that contains all visuals for the expand/collapse - // button for expandable group nodes. - BUTTON_CONTAINER: 'buttoncontainer', - // element that surrounds expand/collapse buttons. - BUTTON_CIRCLE: 'buttoncircle', - // element of the expand button. - EXPAND_BUTTON: 'expandbutton', - // element of the collapse button. - COLLAPSE_BUTTON: 'collapsebutton' - }, - Edge: { - CONTAINER: 'edges', - GROUP: 'edge', - LINE: 'edgeline', - REFERENCE_EDGE: 'referenceedge', - REF_LINE: 'refline', - STRUCTURAL: 'structural' - }, - Annotation: { - OUTBOX: 'out-annotations', - INBOX: 'in-annotations', - GROUP: 'annotation', - NODE: 'annotation-node', - EDGE: 'annotation-edge', - CONTROL_EDGE: 'annotation-control-edge', - LABEL: 'annotation-label', - ELLIPSIS: 'annotation-ellipsis' - }, - Scene: { - GROUP: 'scene', - CORE: 'core', - INEXTRACT: 'in-extract', - OUTEXTRACT: 'out-extract' - }, - Subscene: {GROUP: 'subscene'}, - OPNODE: 'op', - METANODE: 'meta', - SERIESNODE: 'series', - BRIDGENODE: 'bridge', - ELLIPSISNODE: 'ellipsis' - }; - - /** - * A health pill encapsulates an overview of tensor element values. The value - * field is a list of 12 numbers that shed light on the status of the tensor. - * Visualized in health pills are the 3rd through 8th (inclusive) numbers of - * health pill values. Those 6 numbers are counts of tensor elements that fall - * under -Inf, negative, 0, positive, +Inf, NaN (in that order). - * - * Please keep this interface consistent with HealthPillDatum within - * backend.ts. - */ - export interface HealthPill { - device_name: string; - node_name: string; - output_slot: number; - dtype: string; - shape: number[]; - value: number[]; - wall_time: number; - step: number; - } - - interface HealthPillNumericStats { - min: number; - max: number; - mean: number; - stddev: number; - } - - /** - * Encapsulates how to render a single entry in a health pill. Each entry - * corresponds to a category of tensor element values. - */ - export interface HealthPillEntry { - background_color: string; - label: string; - } - ; - export let healthPillEntries: HealthPillEntry[] = [ - { - background_color: '#CC2F2C', - label: 'NaN', - }, - { - background_color: '#FF8D00', - label: '-∞', - }, - { - background_color: '#EAEAEA', - label: '-', - }, - { - background_color: '#A5A5A5', - label: '0', - }, - { - background_color: '#262626', - label: '+', - }, - { - background_color: '#003ED4', - label: '+∞', - }, - ]; - - /** - * Helper method for fitting the graph in the svg view. - * - * @param svg The main svg. - * @param zoomG The svg group used for panning and zooming. - * @param d3zoom The zoom behavior. - * @param callback Called when the fitting is done. - */ - export function fit(svg, zoomG, d3zoom, callback) { - let svgRect = svg.getBoundingClientRect(); - let sceneSize = null; - try { - sceneSize = zoomG.getBBox(); - if (sceneSize.width === 0) { - // There is no scene anymore. We have been detached from the dom. - return; - } - } catch (e) { - // Firefox produced NS_ERROR_FAILURE if we have been - // detached from the dom. - return; - } - let scale = 0.9 * - Math.min( - svgRect.width / sceneSize.width, svgRect.height / sceneSize.height, - 2); - let params = layout.PARAMS.graph; - const transform = d3.zoomIdentity - .scale(scale) - .translate(params.padding.paddingLeft, params.padding.paddingTop); - - d3.select(svg) - .transition() - .duration(500) - .call(d3zoom.transform, transform) - .on('end.fitted', () => { - // Remove the listener for the zoomend event, - // so we don't get called at the end of regular zoom events, - // just those that fit the graph to screen. - d3zoom.on('end.fitted', null); - callback(); - }); -}; - -/** - * Helper method for panning the graph to center on the provided node, - * if the node is currently off-screen. - * - * @param nodeName The node to center the graph on - * @param svg The root SVG element for the graph - * @param zoomG The svg group used for panning and zooming. - * @param d3zoom The zoom behavior. - * @return True if the graph had to be panned to display the - * provided node. - */ -export function panToNode(nodeName: String, svg, zoomG, d3zoom): boolean { - let node = d3 - .select('[data-name="' + nodeName + '"].' + Class.Node.GROUP) - .node(); - if (!node) { - return false; - } - - // Check if the selected node is off-screen in either - // X or Y dimension in either direction. - let nodeBox = node.getBBox(); - let nodeCtm = node.getScreenCTM(); - let pointTL = svg.createSVGPoint(); - let pointBR = svg.createSVGPoint(); - pointTL.x = nodeBox.x; - pointTL.y = nodeBox.y; - pointBR.x = nodeBox.x + nodeBox.width; - pointBR.y = nodeBox.y + nodeBox.height; - pointTL = pointTL.matrixTransform(nodeCtm); - pointBR = pointBR.matrixTransform(nodeCtm); - let isOutsideOfBounds = (start, end, bound) => { - return end < 0 || start > bound; - }; - let svgRect = svg.getBoundingClientRect(); - if (isOutsideOfBounds(pointTL.x, pointBR.x, svgRect.width) || - isOutsideOfBounds(pointTL.y, pointBR.y, svgRect.height)) { - // Determine the amount to translate the graph in both X and Y dimensions in - // order to center the selected node. This takes into account the position - // of the node, the size of the svg scene, the amount the scene has been - // scaled by through zooming, and any previous transforms already performed - // by this logic. - let centerX = (pointTL.x + pointBR.x) / 2; - let centerY = (pointTL.y + pointBR.y) / 2; - let dx = ((svgRect.width / 2) - centerX); - let dy = ((svgRect.height / 2) - centerY); - - // We translate by this amount. We divide the X and Y translations by the - // scale to undo how translateBy scales the translations (in d3 v4). - const svgTransform = d3.zoomTransform(svg); - d3.select(svg).transition().duration(500).call( - d3zoom.translateBy, dx / svgTransform.k, dy / svgTransform.k); - - return true; - } - return false; -}; - -/** - * Given a container d3 selection, select a child svg element of a given tag - * and class if exists or append / insert one otherwise. If multiple children - * matches the tag and class name, returns only the first one. - * - * @param container - * @param tagName tag name. - * @param className (optional) Class name or a list of class names. - * @param before (optional) reference DOM node for insertion. - * @return selection of the element - */ -export function selectOrCreateChild( - container, tagName: string, className?: string | string[], before?): d3.Selection { - let child = selectChild(container, tagName, className); - if (!child.empty()) { - return child; - } - let newElement = - document.createElementNS('http://www.w3.org/2000/svg', tagName); - - if (className instanceof Array) { - for (let i = 0; i < className.length; i++) { - newElement.classList.add(className[i]); - } - } else { - newElement.classList.add(className); - } - - if (before) { // if before exists, insert - container.node().insertBefore(newElement, before); - } else { // otherwise, append - container.node().appendChild(newElement); - } - return d3.select(newElement) - // need to bind data to emulate d3_selection.append - .datum(container.datum()); -}; - -/** - * Given a container d3 selection, select a child element of a given tag and - * class. If multiple children matches the tag and class name, returns only - * the first one. - * - * @param container - * @param tagName tag name. - * @param className (optional) Class name or list of class names. - * @return selection of the element, or an empty selection - */ -export function selectChild( - container, tagName: string, className?: string | string[]): d3.Selection { - let children = container.node().childNodes; - for (let i = 0; i < children.length; i++) { - let child = children[i]; - if (child.tagName === tagName) { - if (className instanceof Array) { - let hasAllClasses = true; - for (let j = 0; j < className.length; j++) { - hasAllClasses = - hasAllClasses && child.classList.contains(className[j]); - } - if (hasAllClasses) { - return d3.select(child); - } - } else if ((!className || child.classList.contains(className))) { - return d3.select(child); - } - } - } - return d3.select(null); -}; - -/** - * Select or create a sceneGroup and build/update its nodes and edges. - * - * Structure Pattern: - * - * - * - * - * ... stuff from tf.graph.scene.edges.build ... - * - * - * ... stuff from tf.graph.scene.nodes.build ... - * - * - * - * - * ... stuff from tf.graph.scene.nodes.build ... - * - * - * - * - * ... stuff from tf.graph.scene.nodes.build ... - * - * - * - * - * @param container D3 selection of the parent. - * @param renderNode render node of a metanode or series node. - * @param sceneElement polymer element. - * @param sceneClass class attribute of the scene (default='scene'). - */ -export function buildGroup(container, - renderNode: render.RenderGroupNodeInfo, - sceneElement, - sceneClass: string): d3.Selection { - sceneClass = sceneClass || Class.Scene.GROUP; - let isNewSceneGroup = selectChild(container, 'g', sceneClass).empty(); - let sceneGroup = selectOrCreateChild(container, 'g', sceneClass); - - // core - let coreGroup = selectOrCreateChild(sceneGroup, 'g', Class.Scene.CORE); - let coreNodes = _.reduce(renderNode.coreGraph.nodes(), (nodes, name) => { - let node = renderNode.coreGraph.node(name); - if (!node.excluded) { - nodes.push(node); - } - return nodes; - }, []); - - if (renderNode.node.type === NodeType.SERIES) { - // For series, we want the first item on top, so reverse the array so - // the first item in the series becomes last item in the top, and thus - // is rendered on the top. - coreNodes.reverse(); - } - - // Create the layer of edges for this scene (paths). - edge.buildGroup(coreGroup, renderNode.coreGraph, sceneElement); - - // Create the layer of nodes for this scene (ellipses, rects etc). - node.buildGroup(coreGroup, coreNodes, sceneElement); - - // In-extract - if (renderNode.isolatedInExtract.length > 0) { - let inExtractGroup = - selectOrCreateChild(sceneGroup, 'g', Class.Scene.INEXTRACT); - node.buildGroup(inExtractGroup, renderNode.isolatedInExtract, - sceneElement); - } else { - selectChild(sceneGroup, 'g', Class.Scene.INEXTRACT).remove(); - } - - // Out-extract - if (renderNode.isolatedOutExtract.length > 0) { - let outExtractGroup = - selectOrCreateChild(sceneGroup, 'g', Class.Scene.OUTEXTRACT); - node.buildGroup(outExtractGroup, renderNode.isolatedOutExtract, - sceneElement); - } else { - selectChild(sceneGroup, 'g', Class.Scene.OUTEXTRACT).remove(); - } - - position(sceneGroup, renderNode); - - // Fade in the scene group if it didn't already exist. - if (isNewSceneGroup) { - sceneGroup.attr('opacity', 0).transition().attr('opacity', 1); - } - - return sceneGroup; -}; - -/** - * Given a scene's svg group, set g.in-extract, g.coreGraph, g.out-extract svg - * groups' position relative to the scene. - * - * @param sceneGroup - * @param renderNode render node of a metanode or series node. - */ -function position(sceneGroup, renderNode: render.RenderGroupNodeInfo) { - // Translate scenes down by the label height so that when showing graphs in - // expanded metanodes, the graphs are below the labels. Do not shift them - // down for series nodes as series nodes don't have labels inside of their - // bounding boxes. - let yTranslate = renderNode.node.type === NodeType.SERIES ? - 0 : layout.PARAMS.subscene.meta.labelHeight; - - // core - translate(selectChild(sceneGroup, 'g', Class.Scene.CORE), 0, yTranslate); - - // in-extract - let hasInExtract = renderNode.isolatedInExtract.length > 0; - let hasOutExtract = renderNode.isolatedOutExtract.length > 0; - - if (hasInExtract) { - let offset = layout.PARAMS.subscene.meta.extractXOffset; - let inExtractX = renderNode.coreBox.width - - renderNode.inExtractBox.width / 2 - renderNode.outExtractBox.width - - (hasOutExtract ? offset : 0); - translate( - selectChild(sceneGroup, 'g', Class.Scene.INEXTRACT), inExtractX, - yTranslate); - } - - // out-extract - if (hasOutExtract) { - let outExtractX = renderNode.coreBox.width - - renderNode.outExtractBox.width / 2; - translate( - selectChild(sceneGroup, 'g', Class.Scene.OUTEXTRACT), outExtractX, - yTranslate); - } -}; - -/** Adds a click listener to a group that fires a graph-select event */ -export function addGraphClickListener(graphGroup, sceneElement) { - d3.select(graphGroup).on('click', () => { - sceneElement.fire('graph-select'); - }); -}; - -/** Helper for adding transform: translate(x0, y0) */ -export function translate(selection, x0: number, y0: number) { - // If it is already placed on the screen, make it a transition. - if (selection.attr('transform') != null) { - selection = selection.transition('position'); - } - selection.attr('transform', 'translate(' + x0 + ',' + y0 + ')'); -}; - -/** - * Helper for setting position of a svg rect - * @param rect rect to set position of. - * @param cx Center x. - * @param cy Center x. - * @param width Width to set. - * @param height Height to set. - */ -export function positionRect(rect, cx: number, cy: number, width: number, - height: number) { - rect.transition() - .attr('x', cx - width / 2) - .attr('y', cy - height / 2) - .attr('width', width) - .attr('height', height); -}; - -/** - * Helper for setting position of a svg expand/collapse button - * @param button container group - * @param renderNode the render node of the group node to position - * the button on. - */ -export function positionButton(button, renderNode: render.RenderNodeInfo) { - let cx = layout.computeCXPositionOfNodeShape(renderNode); - // Position the button in the top-right corner of the group node, - // with space given the draw the button inside of the corner. - let width = renderNode.expanded ? - renderNode.width : renderNode.coreBox.width; - let height = renderNode.expanded ? - renderNode.height : renderNode.coreBox.height; - let x = cx + width / 2 - 6; - let y = renderNode.y - height / 2 + 6; - // For unexpanded series nodes, the button has special placement due - // to the unique visuals of this group node. - if (renderNode.node.type === NodeType.SERIES && !renderNode.expanded) { - x += 10; - y -= 2; - } - let translateStr = 'translate(' + x + ',' + y + ')'; - button.selectAll('path').transition().attr('transform', translateStr); - button.select('circle').transition().attr( - {cx: x, cy: y, r: layout.PARAMS.nodeSize.meta.expandButtonRadius}); -}; - -/** - * Helper for setting position of a svg ellipse - * @param ellipse ellipse to set position of. - * @param cx Center x. - * @param cy Center x. - * @param width Width to set. - * @param height Height to set. - */ -export function positionEllipse(ellipse, cx: number, cy: number, - width: number, height: number) { - ellipse.transition() - .attr('cx', cx) - .attr('cy', cy) - .attr('rx', width / 2) - .attr('ry', height / 2); -}; - -/** - * @param {number} stat A stat for a health pill (such as mean or variance). - * @param {boolean} shouldRoundOnesDigit Whether to round this number to the - * ones digit. Useful for say int, uint, and bool output types. - * @return {string} A human-friendly string representation of that stat. - */ -export function humanizeHealthPillStat(stat, shouldRoundOnesDigit) { - if (shouldRoundOnesDigit) { - return stat.toFixed(0); - } - - if (Math.abs(stat) >= 1) { - return stat.toFixed(1); - } - return stat.toExponential(1); -} - -/** - * Get text content describing a health pill. - */ -function _getHealthPillTextContent(healthPill: HealthPill, - totalCount: number, - elementsBreakdown: number[], - numericStats: HealthPillNumericStats) { - let text = 'Device: ' + healthPill.device_name + '\n'; - text += 'dtype: ' + healthPill.dtype + '\n'; - - let shapeStr = '(scalar)'; - if (healthPill.shape.length > 0) { - shapeStr = '(' + healthPill.shape.join(',') + ')'; - } - text += '\nshape: ' + shapeStr + '\n\n'; - - text += '#(elements): ' + totalCount + '\n'; - const breakdownItems = []; - for (let i = 0; i < elementsBreakdown.length; i++) { - if (elementsBreakdown[i] > 0) { - breakdownItems.push( - '#(' + healthPillEntries[i].label + '): ' + elementsBreakdown[i]); - } - } - text += breakdownItems.join(', ') + '\n\n'; - - // In some cases (e.g., size-0 tensors; all elements are nan or inf) the - // min/max and mean/stddev stats are meaningless. - if (numericStats.max >= numericStats.min) { - text += 'min: ' + numericStats.min + ', max: ' + numericStats.max + '\n'; - text += 'mean: ' + numericStats.mean + ', stddev: ' + numericStats.stddev; - } - - return text; -} - -/** - * Renders a health pill for an op atop a node. - */ -function _addHealthPill( - nodeGroupElement: SVGElement, healthPill: HealthPill, - nodeInfo: render.RenderNodeInfo) { - // Check if text already exists at location. - d3.select(nodeGroupElement.parentNode as any).selectAll('.health-pill').remove(); - - if (!nodeInfo || !healthPill) { - return; - } - - let lastHealthPillData = healthPill.value; - - // For now, we only visualize the 6 values that summarize counts of tensor - // elements of various categories: -Inf, negative, 0, positive, Inf, and NaN. - const lastHealthPillElementsBreakdown = lastHealthPillData.slice(2, 8); - let totalCount = lastHealthPillData[1]; - const numericStats: HealthPillNumericStats = { - min: lastHealthPillData[8], - max: lastHealthPillData[9], - mean: lastHealthPillData[10], - stddev: Math.sqrt(lastHealthPillData[11]) - }; - - let healthPillWidth = 60; - let healthPillHeight = 10; - if (nodeInfo.node.type === tf.graph.NodeType.OP) { - // Use a smaller health pill for op nodes (rendered as smaller ellipses). - healthPillWidth /= 2; - healthPillHeight /= 2; - } - - let healthPillGroup = document.createElementNS(svgNamespace, 'g'); - healthPillGroup.classList.add('health-pill'); - - // Define the gradient for the health pill. - let healthPillDefs = document.createElementNS(svgNamespace, 'defs'); - healthPillGroup.appendChild(healthPillDefs); - let healthPillGradient = - document.createElementNS(svgNamespace, 'linearGradient'); - const healthPillGradientId = 'health-pill-gradient'; - healthPillGradient.setAttribute('id', healthPillGradientId); - - let cumulativeCount = 0; - let previousOffset = '0%'; - for (let i = 0; i < lastHealthPillElementsBreakdown.length; i++) { - if (!lastHealthPillElementsBreakdown[i]) { - // Exclude empty categories. - continue; - } - cumulativeCount += lastHealthPillElementsBreakdown[i]; - - // Create a color interval using 2 stop elements. - let stopElement0 = document.createElementNS(svgNamespace, 'stop'); - stopElement0.setAttribute('offset', previousOffset); - stopElement0.setAttribute( - 'stop-color', healthPillEntries[i].background_color); - healthPillGradient.appendChild(stopElement0); - - let stopElement1 = document.createElementNS(svgNamespace, 'stop'); - let percent = (cumulativeCount * 100 / totalCount) + '%'; - stopElement1.setAttribute('offset', percent); - stopElement1.setAttribute( - 'stop-color', healthPillEntries[i].background_color); - healthPillGradient.appendChild(stopElement1); - previousOffset = percent; - } - healthPillDefs.appendChild(healthPillGradient); - - // Create the rectangle for the health pill. - let rect = document.createElementNS(svgNamespace, 'rect'); - rect.setAttribute('fill', 'url(#' + healthPillGradientId + ')'); - rect.setAttribute('width', String(healthPillWidth)); - rect.setAttribute('height', String(healthPillHeight)); - healthPillGroup.appendChild(rect); - - // Show a title with specific counts on hover. - let titleSvg = document.createElementNS(svgNamespace, 'title'); - titleSvg.textContent = _getHealthPillTextContent( - healthPill, totalCount, lastHealthPillElementsBreakdown, numericStats); - healthPillGroup.appendChild(titleSvg); - // TODO(cais): Make the tooltip content prettier. - - // Center this health pill just right above the node for the op. - let healthPillX = nodeInfo.x - healthPillWidth / 2; - let healthPillY = nodeInfo.y - healthPillHeight - nodeInfo.height / 2 - 2; - if (nodeInfo.labelOffset < 0) { - // The label is positioned above the node. Do not occlude the label. - healthPillY += nodeInfo.labelOffset; - } - - if (lastHealthPillElementsBreakdown[2] || - lastHealthPillElementsBreakdown[3] || - lastHealthPillElementsBreakdown[4]) { - // At least 1 "non-Inf and non-NaN" value exists (a -, 0, or + value). Show - // stats on tensor values. - - // Determine if we should display the output range as integers. - let shouldRoundOnesDigit = false; - let node = nodeInfo.node as OpNode; - let attributes = node.attr; - if (attributes && attributes.length) { - // Find the attribute for output type if there is one. - for (let i = 0; i < attributes.length; i++) { - if (attributes[i].key === 'T') { - // Note whether the output type is an integer. - let outputType = attributes[i].value['type']; - shouldRoundOnesDigit = - outputType && /^DT_(BOOL|INT|UINT)/.test(outputType); - break; - } - } - } - - let statsSvg = document.createElementNS(svgNamespace, 'text'); - const minString = humanizeHealthPillStat(numericStats.min, shouldRoundOnesDigit); - const maxString = humanizeHealthPillStat(numericStats.max, shouldRoundOnesDigit); - if (totalCount > 1) { - statsSvg.textContent = minString + ' ~ ' + maxString; - } else { - statsSvg.textContent = minString; - } - statsSvg.classList.add('health-pill-stats'); - statsSvg.setAttribute('x', String(healthPillWidth / 2)); - statsSvg.setAttribute('y', '-2'); - healthPillGroup.appendChild(statsSvg); - } - - healthPillGroup.setAttribute( - 'transform', 'translate(' + healthPillX + ', ' + healthPillY + ')'); - - Polymer.dom(nodeGroupElement.parentNode).appendChild(healthPillGroup); -} - -/** - * Adds health pills (which visualize tensor summaries) to a graph group. - * @param svgRoot The root SVG element of the graph to add heath pills to. - * @param nodeNamesToHealthPills An object mapping node name to health pill. - * @param colors A list of colors to use. - */ -export function addHealthPills( - svgRoot: SVGElement, nodeNamesToHealthPills: {[key: string]: HealthPill[]}, - healthPillStepIndex: number) { - if (!nodeNamesToHealthPills) { - // No health pill information available. - return; - } - - let svgRootSelection = d3.select(svgRoot); - svgRootSelection.selectAll('g.nodeshape') - .each(function(nodeInfo: render.RenderNodeInfo) { - // Only show health pill data for this node if it is available. - let healthPills = nodeNamesToHealthPills[nodeInfo.node.name]; - let healthPill = healthPills ? healthPills[healthPillStepIndex] : null; - _addHealthPill((this as SVGElement), healthPill, nodeInfo); - }); -}; - -} // close module diff --git a/tensorflow/tensorboard/components/tf_graph_common/template.ts b/tensorflow/tensorboard/components/tf_graph_common/template.ts deleted file mode 100644 index 7800d46029b..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/template.ts +++ /dev/null @@ -1,305 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -module tf.graph.template { - -/** - * Detect repeating patterns of subgraphs. - * Assign templateId to each subgraph if it belongs to a template. - * Returns clusters of similar subgraphs . - * - * @param graph - * @param verifyTemplate whether to run the template verification algorithm - * @return a dict (template id => Array of node names) - */ -export function detect(h, verifyTemplate): {[templateId: string]: string[]} { - // In any particular subgraph, there are either - // - leaf nodes (which do not have subgraph) - // - metanode nodes - some of them have only one member (singular metanode) - // and some have multiple members (non-singular metanode) - - // First, generate a nearest neighbor hash of metanode nodes. - let nnGroups = clusterSimilarSubgraphs(h); - - // For each metanode, compare its subgraph (starting from shallower groups) - // and assign template id. - let templates = groupTemplateAndAssignId(nnGroups, verifyTemplate); - - // Sort the templates by minimum level in the graph at which they appear, - // as this leads to optimal setting of the colors of each template for - // maximum differentiation. - return <{[templateId: string]: string[]}>_(templates) - .pairs() - .sortBy(function(pair: {level: number, nodes: string[]}[]) { - return pair[1].level; - }) - .map(function(pair: {level: number, nodes: string[]}[]) { - return [pair[0], pair[1].nodes]; - }) - .object() - .value(); -}; - -/** - * @return Unique string for a metanode based on depth, |V|, |E| and - * op type histogram. - */ -function getSignature(metanode) { - // depth= |V|= |E|= - let props = _.map( - { - 'depth': metanode.depth, - '|V|': metanode.metagraph.nodes().length, - '|E|': metanode.metagraph.edges().length - }, - function(v, k) { return k + '=' + v; }) - .join(' '); - - // optype1=count1,optype2=count2 - let ops = _.map(metanode.opHistogram, function(count, op) { - return op + '=' + count; - }).join(','); - - return props + ' [ops] ' + ops; -} - -/** - * Generate a nearest neighbor hash of metanodes - * based on depth, |V|, |E|, and opHistogram of their subgraph - * (excluding leaf nodes and singular metanodes). - * @param graph The graph - * @return Array of pairs of [signature, - * Object with min level of the template and an Array of tf.graph.Group] - * sort by ascending order of minimum depth at which metanode appears. - */ -function clusterSimilarSubgraphs(h: hierarchy.Hierarchy) { - /** a dict from metanode.signature() => Array of tf.graph.Groups */ - let hashDict = _(h.getNodeMap()).reduce( - (hash, node: OpNode|Metanode, name) => { - if (node.type !== NodeType.META) { - return hash; - } - let levelOfMetaNode = name.split('/').length - 1; - let signature = getSignature(node); - let templateInfo = hash[signature] || - {nodes: [], level: levelOfMetaNode}; - hash[signature] = templateInfo; - templateInfo.nodes.push(node); - if (templateInfo.level > levelOfMetaNode) { - templateInfo.level = levelOfMetaNode; - } - return hash; - }, {}); - - return _(hashDict) - .pairs() - // filter nn metanode with only one member - .filter(function(pair: {level: number, nodes: string[]}) { - return pair[1].nodes.length > 1; - }) - .sortBy(function(pair: {level: number, nodes: string[]}) { - // sort by depth - // (all members in the same nnGroup has equal depth) - return pair[1].nodes[0].depth; - }) - .value(); -} - -function groupTemplateAndAssignId(nnGroups, verifyTemplate) { - // For each metanode, compare its subgraph (starting from shallower groups) - // and assign template id. - let result: {[templateId: string]: {level: number, nodes: string[]}} = {}; - return _.reduce(nnGroups, function(templates, nnGroupPair) { - let signature = nnGroupPair[0], - nnGroup = nnGroupPair[1].nodes, - clusters = []; - - nnGroup.forEach(function(metanode) { - // check with each existing cluster - for (let i = 0; i < clusters.length; i++) { - let similar = !verifyTemplate || - isSimilarSubgraph( - clusters[i].metanode.metagraph, - metanode.metagraph - ); - // if similar, just add this metanode to the cluster - if (similar) { - // get template from the first one - metanode.templateId = clusters[i].metanode.templateId; - clusters[i].members.push(metanode.name); - return; - } - } - // otherwise create a new cluster with id 'signature [count] ' - metanode.templateId = signature + '[' + clusters.length + ']'; - clusters.push({ - metanode: metanode, - members: [metanode.name] - }); - }); - - clusters.forEach(function(c) { - templates[c.metanode.templateId] = { - level: nnGroupPair[1].level, - nodes: c.members - }; - }); - return templates; - }, result); -} - -function sortNodes(names: string[], - graph: graphlib.Graph, prefix: string) { - return _.sortByAll(names, - function(name) { - let node = graph.node(name); - return (node).op; - }, - function(name) { - let node = graph.node(name); - return (node).templateId; - }, - function(name) { - return graph.neighbors(name).length; - }, - function(name) { - return graph.predecessors(name).length; - }, - function(name) { - return graph.successors(name).length; - }, - function(name) { - return name.substr(prefix.length); - }); -} - -function isSimilarSubgraph(g1: graphlib.Graph, - g2: graphlib.Graph) { - if (!tf.graph.hasSimilarDegreeSequence(g1, g2)) { - return false; - } - - // if we want to skip, just return true here. - // return true; - - // Verify sequence by running DFS - let g1prefix = g1.graph().name; - let g2prefix = g2.graph().name; - - let visited1 = {}; - let visited2 = {}; - let stack = []; - - /** - * push sources or successors into the stack - * if the visiting pattern has been similar. - */ - function stackPushIfNotDifferent(n1, n2) { - let sub1 = n1.substr(g1prefix.length), - sub2 = n2.substr(g2prefix.length); - - /* tslint:disable */ - if (visited1[sub1] ^ visited2[sub1]) { - console.warn( - 'different visit pattern', '[' + g1prefix + ']', sub1, - '[' + g2prefix + ']', sub2); - return true; - } - /* tslint:enable */ - if (!visited1[sub1]) { // implied && !visited2[sub2] - visited1[sub1] = visited2[sub2] = true; - stack.push({n1: n1, n2: n2}); - } - - return false; - } - - // check if have same # of sources then sort and push - let sources1 = g1.sources(); - let sources2 = g2.sources(); - if (sources1.length !== sources2.length) { - /* tslint:disable */ - console.log('different source length'); - /* tslint:enable */ - return false; - } - sources1 = sortNodes(sources1, g1, g1prefix); - sources2 = sortNodes(sources2, g2, g2prefix); - - for (let i = 0; i < sources1.length; i++) { - let different = stackPushIfNotDifferent(sources1[i], sources2[i]); - if (different) { - return false; - } - } - - while (stack.length > 0) { - let cur = stack.pop(); - - // check node - let similar = isSimilarNode(g1.node(cur.n1), g2.node(cur.n2)); - if (!similar) { - return false; - } - - // check if have same # of successors then sort and push - let succ1 = g1.successors(cur.n1), succ2 = g2.successors(cur.n2); - if (succ1.length !== succ2.length) { - /* tslint:disable */ - console.log('# of successors mismatch', succ1, succ2); - /* tslint:enable */ - return false; - } - succ1 = sortNodes(succ1, g1, g1prefix); - succ2 = sortNodes(succ2, g2, g2prefix); - - for (let j = 0; j < succ1.length; j++) { - let different = stackPushIfNotDifferent(succ1[j], succ2[j]); - if (different) { - return false; - } - } - } - - return true; -} - -/** - * Returns if two nodes have identical structure. - */ -function isSimilarNode(n1: OpNode|Metanode|SeriesNode, - n2: OpNode|Metanode|SeriesNode): boolean { - if (n1.type === NodeType.META) { - // compare metanode - let metanode1 = n1; - let metanode2 = n2; - return metanode1.templateId && metanode2.templateId && - metanode1.templateId === metanode2.templateId; - } else if (n1.type === NodeType.OP && n2.type === NodeType.OP) { - // compare leaf node - return (n1).op === (n2).op; - } else if (n1.type === NodeType.SERIES && n2.type === NodeType.SERIES) { - // compare series node sizes and operations - // (only need to check one op as all op nodes are identical in series) - let sn1 = n1; - let sn2 = n2; - let seriesnode1Count = sn1.metagraph.nodeCount(); - return (seriesnode1Count === sn2.metagraph.nodeCount() && - (seriesnode1Count === 0 || - ((sn1.metagraph.node(sn1.metagraph.nodes()[0])).op === - (sn2.metagraph.node(sn2.metagraph.nodes()[0])).op))); - } - return false; -} -} diff --git a/tensorflow/tensorboard/components/tf_graph_common/test/graph-test.ts b/tensorflow/tensorboard/components/tf_graph_common/test/graph-test.ts deleted file mode 100644 index af3030197e0..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/test/graph-test.ts +++ /dev/null @@ -1,103 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -suite('graph', () => { - let assert = chai.assert; - - test('graphlib exists', () => { assert.isTrue(graphlib != null); }); - - test('simple graph contruction', done => { - let pbtxt = tf.graph.test.util.stringToArrayBuffer(` - node { - name: "Q" - op: "Input" - } - node { - name: "W" - op: "Input" - } - node { - name: "X" - op: "MatMul" - input: "Q:2" - input: "W" - }`); - let statsPbtxt = tf.graph.test.util.stringToArrayBuffer(`step_stats { - dev_stats { - device: "cpu" - node_stats { - node_name: "Q" - all_start_micros: 10 - all_end_rel_micros: 4 - } - node_stats { - node_name: "Q" - all_start_micros: 12 - all_end_rel_micros: 4 - } - } - }`); - - let buildParams: tf.graph.BuildParams = { - enableEmbedding: true, - inEmbeddingTypes: ['Const'], - outEmbeddingTypes: ['^[a-zA-Z]+Summary$'], - refEdges: {} - }; - let dummyTracker = - tf.graph.util.getTracker({set: () => { return; }, progress: 0}); - tf.graph.parser.parseGraphPbTxt(pbtxt).then(nodes => { - tf.graph.build(nodes, buildParams, dummyTracker) - .then((slimGraph: tf.graph.SlimGraph) => { - assert.isTrue(slimGraph.nodes['X'] != null); - assert.isTrue(slimGraph.nodes['W'] != null); - assert.isTrue(slimGraph.nodes['Q'] != null); - - let firstInputOfX = slimGraph.nodes['X'].inputs[0]; - assert.equal(firstInputOfX.name, 'Q'); - assert.equal(firstInputOfX.outputTensorIndex, 2); - - let secondInputOfX = slimGraph.nodes['X'].inputs[1]; - assert.equal(secondInputOfX.name, 'W'); - assert.equal(secondInputOfX.outputTensorIndex, 0); - - tf.graph.parser.parseStatsPbTxt(statsPbtxt).then(stepStats => { - tf.graph.joinStatsInfoWithGraph(slimGraph, stepStats); - assert.equal(slimGraph.nodes['Q'].stats.getTotalMicros(), 6); - done(); - }); - }); - }); - }); - - test('health pill numbers round correctly', () => { - // Integers are rounded to the ones place. - assert.equal(tf.graph.scene.humanizeHealthPillStat(42.0, true), '42'); - - // Numbers with magnitude >= 1 are rounded to the tenths place. - assert.equal(tf.graph.scene.humanizeHealthPillStat(1, false), '1.0'); - assert.equal(tf.graph.scene.humanizeHealthPillStat(42.42, false), '42.4'); - assert.equal(tf.graph.scene.humanizeHealthPillStat(-42.42, false), '-42.4'); - - // Numbers with magnitude < 1 are written in scientific notation rounded to - // the tenths place. - assert.equal(tf.graph.scene.humanizeHealthPillStat(0, false), '0.0e+0'); - assert.equal(tf.graph.scene.humanizeHealthPillStat(0.42, false), '4.2e-1'); - assert.equal( - tf.graph.scene.humanizeHealthPillStat(-0.042, false), '-4.2e-2'); - }); - - // TODO(bp): write tests. -}); diff --git a/tensorflow/tensorboard/components/tf_graph_common/test/hierarchy-test.ts b/tensorflow/tensorboard/components/tf_graph_common/test/hierarchy-test.ts deleted file mode 100644 index fa62ffe2c70..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/test/hierarchy-test.ts +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -suite('graph', () => { - let assert = chai.assert; - - test('graphlib exists', () => { assert.isTrue(graphlib != null); }); - - // TODO(bp): write tests. - -}); diff --git a/tensorflow/tensorboard/components/tf_graph_common/test/index.html b/tensorflow/tensorboard/components/tf_graph_common/test/index.html deleted file mode 100644 index 7564167129d..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/test/index.html +++ /dev/null @@ -1,34 +0,0 @@ - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_graph_common/test/layout-test.ts b/tensorflow/tensorboard/components/tf_graph_common/test/layout-test.ts deleted file mode 100644 index b4884413c9d..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/test/layout-test.ts +++ /dev/null @@ -1,23 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -suite('layout', () => { - let assert = chai.assert; - - test('dagre exists', () => { assert.isTrue(dagre != null); }); - - // TODO(bp): write tests. - -}); diff --git a/tensorflow/tensorboard/components/tf_graph_common/test/parser-test.ts b/tensorflow/tensorboard/components/tf_graph_common/test/parser-test.ts deleted file mode 100644 index 7c73178c1ce..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/test/parser-test.ts +++ /dev/null @@ -1,83 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -suite('parser', () => { - let assert = chai.assert; - - test('simple pbtxt', done => { - let pbtxt = tf.graph.test.util.stringToArrayBuffer(`node { - name: "Q" - op: "Input" - } - node { - name: "W" - op: "Input" - } - node { - name: "X" - op: "MatMul" - input: "Q" - input: "W" - }`); - tf.graph.parser.parseGraphPbTxt(pbtxt).then(nodes => { - assert.isTrue(nodes != null && nodes.length === 3); - - assert.equal('Q', nodes[0].name); - assert.equal('Input', nodes[0].op); - - assert.equal('W', nodes[1].name); - assert.equal('Input', nodes[1].op); - - assert.equal('X', nodes[2].name); - assert.equal('MatMul', nodes[2].op); - assert.equal('Q', nodes[2].input[0]); - assert.equal('W', nodes[2].input[1]); - - done(); - }); - }); - - test('stats pbtxt parsing', done => { - let statsPbtxt = tf.graph.test.util.stringToArrayBuffer(`step_stats { - dev_stats { - device: "cpu" - node_stats { - node_name: "Q" - all_start_micros: 10 - all_end_rel_micros: 4 - } - node_stats { - node_name: "Q" - all_start_micros: 12 - all_end_rel_micros: 4 - } - } - }`); - tf.graph.parser.parseStatsPbTxt(statsPbtxt).then(stepStats => { - assert.equal(stepStats.dev_stats.length, 1); - assert.equal(stepStats.dev_stats[0].device, 'cpu'); - assert.equal(stepStats.dev_stats[0].node_stats.length, 2); - assert.equal(stepStats.dev_stats[0].node_stats[0].all_start_micros, 10); - assert.equal(stepStats.dev_stats[0].node_stats[1].node_name, 'Q'); - assert.equal(stepStats.dev_stats[0].node_stats[1].all_end_rel_micros, 4); - done(); - }); - }); - - test('d3 exists', () => { assert.isTrue(d3 != null); }); - - // TODO(nsthorat): write tests. - -}); diff --git a/tensorflow/tensorboard/components/tf_graph_common/test/util-test.ts b/tensorflow/tensorboard/components/tf_graph_common/test/util-test.ts deleted file mode 100644 index 4535d24888f..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/test/util-test.ts +++ /dev/null @@ -1,56 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -suite('util', () => { - let assert = chai.assert; - - test('remove common prefix', () => { - - // Empty array. - let result = tf.graph.util.removeCommonPrefix([]); - assert.deepEqual(result, []); - - // No common prefix. - result = tf.graph.util.removeCommonPrefix(['a', 'b', 'c']); - assert.deepEqual(result, ['a', 'b', 'c']); - - // One of the elements is empty string. - result = tf.graph.util.removeCommonPrefix(['a/b', '', 'a/c']); - assert.deepEqual(result, ['a/b', '', 'a/c']); - - // Only one string. - result = tf.graph.util.removeCommonPrefix(['a/b/c']); - assert.deepEqual(result, ['a/b/c']); - - // `q/w/` is the common prefix. Expect `q/w/` to be removed. - result = tf.graph.util.removeCommonPrefix(['q/w/a', 'q/w/b', 'q/w/c/f']); - assert.deepEqual(result, ['a', 'b', 'c/f']); - - // `q/w/` is the common prefix and also an element. Expect nothing to be - // removed since the common prefix is also an element in the array. - result = tf.graph.util.removeCommonPrefix(['q/w/', 'q/w/b', 'q/w/c/f']); - assert.deepEqual(result, ['q/w/', 'q/w/b', 'q/w/c/f']); - }); - - test('query params', () => { - // Starts with question mark. - let queryParams = tf.graph.util.getQueryParams('?foo=1&bar=2'); - assert.deepEqual(queryParams, {'foo': '1', 'bar': '2'}); - - // No question mark. - queryParams = tf.graph.util.getQueryParams('foo=1&bar=2'); - assert.deepEqual(queryParams, {'foo': '1', 'bar': '2'}); - }); -}); diff --git a/tensorflow/tensorboard/components/tf_graph_common/test/util.ts b/tensorflow/tensorboard/components/tf_graph_common/test/util.ts deleted file mode 100644 index bc73b735ed2..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/test/util.ts +++ /dev/null @@ -1,31 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - - -/* tslint:disable:no-namespace */ -module tf.graph.test.util { - /** - * Converts a utf-8 string to an ArrayBuffer. - */ - export function stringToArrayBuffer(str): ArrayBuffer { - let buf = new ArrayBuffer(str.length); - let bufView = new Uint8Array(buf); - for (let i = 0, strLen = str.length; i < strLen; i++) { - bufView[i] = str.charCodeAt(i); - } - return buf; - } - -} // module diff --git a/tensorflow/tensorboard/components/tf_graph_common/tf-graph-common.html b/tensorflow/tensorboard/components/tf_graph_common/tf-graph-common.html deleted file mode 100644 index a460072a38f..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/tf-graph-common.html +++ /dev/null @@ -1,38 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_graph_common/util.ts b/tensorflow/tensorboard/components/tf_graph_common/util.ts deleted file mode 100644 index 0b2df6545cc..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_common/util.ts +++ /dev/null @@ -1,316 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/** - * @fileoverview Utility functions for the tensorflow graph visualizer. - */ - -module tf.graph.util { - /** - * Recommended delay (ms) when running an expensive task asynchronously - * that gives enough time for the progress bar to update its UI. - */ - const ASYNC_TASK_DELAY = 20; - - export function time(msg: string, task: () => T) { - let start = Date.now(); - let result = task(); - /* tslint:disable */ - console.log(msg, ':', Date.now() - start, 'ms'); - /* tslint:enable */ - return result; - } - - /** - * Creates a tracker that sets the progress property of the - * provided polymer component. The provided component must have - * a property called 'progress' that is not read-only. The progress - * property is an object with a numerical 'value' property and a - * string 'msg' property. - */ - export function getTracker(polymerComponent: any) { - return { - setMessage: function(msg) { - polymerComponent.set( - 'progress', {value: polymerComponent.progress.value, msg: msg}); - }, - updateProgress: function(value) { - polymerComponent.set('progress', { - value: polymerComponent.progress.value + value, - msg: polymerComponent.progress.msg - }); - }, - reportError: function(msg: string, err) { - // Log the stack trace in the console. - console.error(err.stack); - // And send a user-friendly message to the UI. - polymerComponent.set( - 'progress', - {value: polymerComponent.progress.value, msg: msg, error: true}); - }, - }; - } - - /** - * Creates a tracker for a subtask given the parent tracker, the total - * progress - * of the subtask and the subtask message. The parent task should pass a - * subtracker to its subtasks. The subtask reports its own progress which - * becomes relative to the main task. - */ - export function getSubtaskTracker( - parentTracker: ProgressTracker, impactOnTotalProgress: number, - subtaskMsg: string): ProgressTracker { - return { - setMessage: function(progressMsg) { - // The parent should show a concatenation of its message along with - // its subtask tracker message. - parentTracker.setMessage(subtaskMsg + ': ' + progressMsg); - }, - updateProgress: function(incrementValue) { - // Update the parent progress relative to the child progress. - // For example, if the sub-task progresses by 30%, and the impact on the - // total progress is 50%, then the task progresses by 30% * 50% = 15%. - parentTracker.updateProgress( - incrementValue * impactOnTotalProgress / 100); - }, - reportError: function(msg: string, err: Error) { - // The parent should show a concatenation of its message along with - // its subtask error message. - parentTracker.reportError(subtaskMsg + ': ' + msg, err); - } - }; - } - - /** - * Runs an expensive task and return the result. - */ - export function runTask( - msg: string, incProgressValue: number, task: () => T, - tracker: ProgressTracker): T { - // Update the progress message to say the current running task. - tracker.setMessage(msg); - // Run the expensive task with a delay that gives enough time for the - // UI to update. - try { - let result = tf.graph.util.time(msg, task); - // Update the progress value. - tracker.updateProgress(incProgressValue); - // Return the result to be used by other tasks. - return result; - } catch (e) { - // Errors that happen inside asynchronous tasks are - // reported to the tracker using a user-friendly message. - tracker.reportError('Failed ' + msg, e); - } - } - - /** - * Runs an expensive task asynchronously and returns a promise of the result. - */ - export function runAsyncTask( - msg: string, incProgressValue: number, task: () => T, - tracker: ProgressTracker): Promise { - return new Promise((resolve, reject) => { - // Update the progress message to say the current running task. - tracker.setMessage(msg); - // Run the expensive task with a delay that gives enough time for the - // UI to update. - setTimeout(function() { - try { - let result = tf.graph.util.time(msg, task); - // Update the progress value. - tracker.updateProgress(incProgressValue); - // Return the result to be used by other tasks. - resolve(result); - } catch (e) { - // Errors that happen inside asynchronous tasks are - // reported to the tracker using a user-friendly message. - tracker.reportError('Failed ' + msg, e); - } - }, ASYNC_TASK_DELAY); - }); - } - - /** - * Asynchronously runs an expensive task that returns a promise. Updates the - * tracker's progress after the promise resolves. Returns a new promise that - * resolves after the progress is updated. - */ - export function runAsyncPromiseTask( - msg: string, incProgressValue: number, task: () => Promise, - tracker: ProgressTracker): Promise { - return new Promise((resolve, reject) => { - let handleError = function(e) { - // Errors that happen inside asynchronous tasks are - // reported to the tracker using a user-friendly message. - tracker.reportError('Failed ' + msg, e); - reject(e); - }; - - // Update the progress message to say the current running task. - tracker.setMessage(msg); - // Run the expensive task with a delay that gives enough time for the - // UI to update. - setTimeout(function() { - try { - let start = Date.now(); - task() - .then(function(value) { - /* tslint:disable */ - console.log(msg, ':', Date.now() - start, 'ms'); - // Update the progress value. - tracker.updateProgress(incProgressValue); - // Return the result to be used by other tasks. - resolve(value); - }) - .catch(handleError); - } catch (e) { - handleError(e); - } - }, ASYNC_TASK_DELAY); - }); - } - - /** - * Returns a query selector with escaped special characters that are not - * allowed in a query selector. - */ - export function escapeQuerySelector(querySelector: string): string { - return querySelector.replace(/([:.\[\],/\\\(\)])/g, '\\$1'); - } - - // For unit conversion. - export const MEMORY_UNITS = [ - // Atomic unit. - {symbol: 'B'}, - // numUnits specifies how many previous units this unit contains. - {symbol: 'KB', numUnits: 1024}, {symbol: 'MB', numUnits: 1024}, - {symbol: 'GB', numUnits: 1024}, {symbol: 'TB', numUnits: 1024}, - {symbol: 'PB', numUnits: 1024} - ]; - export const TIME_UNITS = [ - // Atomic unit. Finest granularity in TensorFlow stat collection. - {symbol: 'µs'}, - // numUnits specifies how many previous units this unit contains. - {symbol: 'ms', numUnits: 1000}, {symbol: 's', numUnits: 1000}, - {symbol: 'min', numUnits: 60}, {symbol: 'hr', numUnits: 60}, - {symbol: 'days', numUnits: 24} - ]; - - /** - * Returns the human readable version of the unit. - * (e.g. 1.35 GB, 23 MB, 34 ms, 6.53 min etc). - */ - export function convertUnitsToHumanReadable(value, units, unitIndex) { - unitIndex = unitIndex == null ? 0 : unitIndex; - if (unitIndex + 1 < units.length && - value >= units[unitIndex + 1].numUnits) { - return tf.graph.util.convertUnitsToHumanReadable( - value / units[unitIndex + 1].numUnits, units, unitIndex + 1); - } - // toPrecision() has the tendency to return a number in scientific - // notation and (number - 0) brings it back to normal notation. - return (value.toPrecision(3) - 0) + ' ' + units[unitIndex].symbol; - } - - export function hasDisplayableNodeStats(stats: NodeStats) { - if (stats && - (stats.totalBytes > 0 || stats.getTotalMicros() > 0 || - stats.outputSize)) { - return true; - } - return false; - } - - /** - * Given a list of strings, it returns a new list of strings with the longest - * common prefix removed. If the common prefix is one of the strings in the - * list, it returns the original strings. - */ - export function removeCommonPrefix(strings: string[]) { - if (strings.length < 2) { - return strings; - } - - let index = 0; - let largestIndex = 0; - // Find the shortest name across all strings. - let minLength = _.min(_.map(strings, str => str.length)); - while (true) { - index++; - let prefixes = _.map(strings, str => str.substring(0, index)); - let allTheSame = prefixes.every((prefix, i) => { - return (i === 0 ? true : prefix === prefixes[i - 1]); - }); - if (allTheSame) { - if (index >= minLength) { - // There is a string whose whole name is a prefix to other string. - // In this case, we return the original list of string. - return strings; - } - largestIndex = index; - } else { - break; - } - } - return _.map(strings, str => str.substring(largestIndex)); - } - - /** - * Given a queryString, aka ?foo=1&bar=2, return the object representation. - */ - export function getQueryParams(queryString: string) { - if (queryString.charAt(0) === '?') { - queryString = queryString.slice(1); - } - - let queryParams = _.chain(queryString.split('&')) - .map((item) => { - if (item) { - return item.split('='); - } - }) - .compact() - .value(); - - return _.object(queryParams); - } - - /** - * Given a timestamp in microseconds, return a human-friendly string denoting - * how long ago the timestamp was. - */ - export function computeHumanFriendlyTime(timeInMicroseconds: number) { - var timeDifferenceInMs = - +(new Date()) - +(new Date(timeInMicroseconds / 1e3)); - if (timeDifferenceInMs < 30000) { - return 'just now'; - } else if (timeDifferenceInMs < 60000) { - return Math.floor(timeDifferenceInMs / 1000) + ' seconds ago'; - } else if (timeDifferenceInMs < 120000) { - return 'a minute ago'; - } else if (timeDifferenceInMs < 3600000) { - return Math.floor(timeDifferenceInMs / 60000) + ' minutes ago'; - } else if (Math.floor(timeDifferenceInMs / 3600000) == 1) { - return 'an hour ago'; - } else if (timeDifferenceInMs < 86400000) { - return Math.floor(timeDifferenceInMs / 3600000) + ' hours ago'; - } else if (timeDifferenceInMs < 172800000) { - return 'yesterday'; - } - return Math.floor(timeDifferenceInMs / 86400000) + ' days ago'; - } -} diff --git a/tensorflow/tensorboard/components/tf_graph_controls/BUILD b/tensorflow/tensorboard/components/tf_graph_controls/BUILD deleted file mode 100644 index ecca2ba4cb5..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_controls/BUILD +++ /dev/null @@ -1,46 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_graph_controls", - srcs = ["tf-graph-controls.html"], - path = "/tf-graph-controls", - deps = [ - "//tensorflow/tensorboard/components/tf_dashboard_common", - "//tensorflow/tensorboard/components/tf_graph_common", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "@org_polymer_paper_button", - "@org_polymer_paper_dropdown_menu", - "@org_polymer_paper_menu", - "@org_polymer_paper_radio_group", - "@org_polymer_paper_toggle_button", - "@org_polymer_paper_tooltip", - ], -) - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [":tf_graph_controls"], - destdir = "tf-graph-controls", - deps = [ - "//tensorflow/tensorboard/components/tf_dashboard_common:legacy", - "//tensorflow/tensorboard/components/tf_graph_common:legacy", - "//third_party/javascript/polymer/v1/paper-button:lib", - "//third_party/javascript/polymer/v1/paper-dropdown-menu:lib", - "//third_party/javascript/polymer/v1/paper-menu:lib", - "//third_party/javascript/polymer/v1/paper-radio-group:lib", - "//third_party/javascript/polymer/v1/paper-toggle-button:lib", - "//third_party/javascript/polymer/v1/paper-tooltip:lib", - "//third_party/javascript/polymer/v1/polymer:lib", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_graph_controls/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_controls/demo/BUILD deleted file mode 100644 index 0e120542132..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_controls/demo/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -# bazel run //third_party/tensorflow/tensorboard/components/tf_graph_controls/demo -ts_web_library( - name = "demo", - srcs = ["index.html"], - path = "/tf-graph-controls/demo", - deps = [ - "//tensorflow/tensorboard/components/tf_graph_controls", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - "@org_polymer_iron_demo_helpers", - "@org_polymer_paper_styles", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_graph_controls/demo/index.html b/tensorflow/tensorboard/components/tf_graph_controls/demo/index.html deleted file mode 100644 index 8b12641b28e..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_controls/demo/index.html +++ /dev/null @@ -1,49 +0,0 @@ - - - - - -TF Graph Controls Demo - - - - diff --git a/tensorflow/tensorboard/components/tf_graph_controls/tf-graph-controls.html b/tensorflow/tensorboard/components/tf_graph_controls/tf-graph-controls.html deleted file mode 100644 index 6d896357482..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_controls/tf-graph-controls.html +++ /dev/null @@ -1,919 +0,0 @@ - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard/BUILD b/tensorflow/tensorboard/components/tf_graph_dashboard/BUILD deleted file mode 100644 index c69a7809035..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_dashboard/BUILD +++ /dev/null @@ -1,44 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_graph_dashboard", - srcs = ["tf-graph-dashboard.html"], - path = "/tf-graph-dashboard", - deps = [ - "//tensorflow/tensorboard/components/tf_backend", - "//tensorflow/tensorboard/components/tf_dashboard_common", - "//tensorflow/tensorboard/components/tf_graph", - "//tensorflow/tensorboard/components/tf_graph_board", - "//tensorflow/tensorboard/components/tf_graph_controls", - "//tensorflow/tensorboard/components/tf_graph_loader", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "//tensorflow/tensorboard/components/vz_sorting", - ], -) - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [":tf_graph_dashboard"], - destdir = "tf-graph-dashboard", - deps = [ - "//tensorflow/tensorboard/components/tf_backend:legacy", - "//tensorflow/tensorboard/components/tf_dashboard_common:legacy", - "//tensorflow/tensorboard/components/tf_graph:legacy", - "//tensorflow/tensorboard/components/tf_graph_board:legacy", - "//tensorflow/tensorboard/components/tf_graph_controls:legacy", - "//tensorflow/tensorboard/components/tf_graph_loader:legacy", - "//tensorflow/tensorboard/components/vz_sorting:legacy", - "//third_party/javascript/polymer/v1/polymer:lib", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_dashboard/demo/BUILD deleted file mode 100644 index 66a37b89785..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_dashboard/demo/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -# bazel run //third_party/tensorflow/tensorboard/components/tf_graph_dashboard/demo -ts_web_library( - name = "demo", - srcs = ["index.html"] + glob(["data/**"]), - path = "/tf-graph-dashboard/demo", - deps = [ - "//tensorflow/tensorboard/components/tf_graph_dashboard", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - "@org_polymer_iron_demo_helpers", - "@org_polymer_paper_styles", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard/demo/data/graph_run_run1.pbtxt b/tensorflow/tensorboard/components/tf_graph_dashboard/demo/data/graph_run_run1.pbtxt deleted file mode 100644 index 30b20645346..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_dashboard/demo/data/graph_run_run1.pbtxt +++ /dev/null @@ -1,4606 +0,0 @@ -node { - name: "GradientDescent/learning_rate" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_3" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.1 - } - } - } -} -node { - name: "gradients/add_grad/Shape_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 100 - } - } - } -} -node { - name: "gradients/add_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000d\000\000\000" - } - } - } -} -node { - name: "gradients/add_grad/BroadcastGradientArgs" - op: "BroadcastGradientArgs" - input: "gradients/add_grad/Shape" - input: "gradients/add_grad/Shape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - } - } - } -} -node { - name: "gradients/add_1_grad/Shape_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 10 - } - } - } -} -node { - name: "gradients/add_1_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "gradients/add_1_grad/BroadcastGradientArgs" - op: "BroadcastGradientArgs" - input: "gradients/add_1_grad/Shape" - input: "gradients/add_1_grad/Shape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - } - } - } -} -node { - name: "gradients/Reshape_1_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: -1 - } - } - } -} -node { - name: "gradients/Reshape_3_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 200 - } - } - } -} -node { - name: "gradients/Mean_grad/Maximum/y" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "gradients/Mean_grad/Const_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "gradients/Mean_grad/Const" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "gradients/Mean_grad/Shape_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } -} -node { - name: "gradients/Mean_grad/Prod_1" - op: "Prod" - input: "gradients/Mean_grad/Shape_1" - input: "gradients/Mean_grad/Const_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/Mean_grad/Maximum" - op: "Maximum" - input: "gradients/Mean_grad/Prod_1" - input: "gradients/Mean_grad/Maximum/y" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "gradients/Mean_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 200 - } - } - } -} -node { - name: "gradients/Mean_grad/Prod" - op: "Prod" - input: "gradients/Mean_grad/Shape" - input: "gradients/Mean_grad/Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/Mean_grad/floordiv" - op: "FloorDiv" - input: "gradients/Mean_grad/Prod" - input: "gradients/Mean_grad/Maximum" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "gradients/Mean_grad/Cast" - op: "Cast" - input: "gradients/Mean_grad/floordiv" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "gradients/Mean_grad/Tile/multiples" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 200 - } - } - } -} -node { - name: "gradients/Mean_grad/Reshape/shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "gradients/Const" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1 - } - } - } -} -node { - name: "gradients/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } -} -node { - name: "gradients/Fill" - op: "Fill" - input: "gradients/Shape" - input: "gradients/Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "gradients/Mean_grad/Reshape" - op: "Reshape" - input: "gradients/Fill" - input: "gradients/Mean_grad/Reshape/shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "gradients/Mean_grad/Tile" - op: "Tile" - input: "gradients/Mean_grad/Reshape" - input: "gradients/Mean_grad/Tile/multiples" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tmultiples" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - } - } - } -} -node { - name: "gradients/Mean_grad/truediv" - op: "RealDiv" - input: "gradients/Mean_grad/Tile" - input: "gradients/Mean_grad/Cast" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - } - } - } -} -node { - name: "gradients/Reshape_3_grad/Reshape" - op: "Reshape" - input: "gradients/Mean_grad/truediv" - input: "gradients/Reshape_3_grad/Shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - } - } - } -} -node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" - op: "ExpandDims" - input: "gradients/Reshape_3_grad/Reshape" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "Const" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "Slice_2/begin" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "Sub_2/y" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "concat_1/axis" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "concat_1/values_0" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: -1 - } - } - } -} -node { - name: "Slice_1/size" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "Sub_1/y" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "Shape_2" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "Rank_2" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "Sub_1" - op: "Sub" - input: "Rank_2" - input: "Sub_1/y" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Slice_1/begin" - op: "Pack" - input: "Sub_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "Slice_1" - op: "Slice" - input: "Shape_2" - input: "Slice_1/begin" - input: "Slice_1/size" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "concat_1" - op: "ConcatV2" - input: "concat_1/values_0" - input: "Slice_1" - input: "concat_1/axis" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } -} -node { - name: "concat/axis" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "concat/values_0" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: -1 - } - } - } -} -node { - name: "Slice/size" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "Sub/y" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "Shape_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "Rank_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "Sub" - op: "Sub" - input: "Rank_1" - input: "Sub/y" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Slice/begin" - op: "Pack" - input: "Sub" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "Slice" - op: "Slice" - input: "Shape_1" - input: "Slice/begin" - input: "Slice/size" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "concat" - op: "ConcatV2" - input: "concat/values_0" - input: "Slice" - input: "concat/axis" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } -} -node { - name: "Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "Rank" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "Sub_2" - op: "Sub" - input: "Rank" - input: "Sub_2/y" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Slice_2/size" - op: "Pack" - input: "Sub_2" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "Slice_2" - op: "Slice" - input: "Shape" - input: "Slice_2/begin" - input: "Slice_2/size" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "logits_biases" - op: "VariableV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_class" - value { - list { - s: "loc:@logits_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 10 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "logits_biases/read" - op: "Identity" - input: "logits_biases" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@logits_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } -} -node { - name: "logits_weights" - op: "VariableV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_class" - value { - list { - s: "loc:@logits_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "logits_weights/read" - op: "Identity" - input: "logits_weights" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@logits_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "hidden_biases" - op: "VariableV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_class" - value { - list { - s: "loc:@hidden_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 100 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "hidden_biases/read" - op: "Identity" - input: "hidden_biases" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@hidden_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } -} -node { - name: "hidden_weights" - op: "VariableV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_class" - value { - list { - s: "loc:@hidden_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 784 - } - dim { - size: 100 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "hidden_weights/read" - op: "Identity" - input: "hidden_weights" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@hidden_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "Reshape/shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\377\377\377\377" - } - } - } -} -node { - name: "mnist_dataset_train_2/one_hot/depth" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 10 - } - } - } -} -node { - name: "mnist_dataset_train_2/one_hot/off_value" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0 - } - } - } -} -node { - name: "mnist_dataset_train_2/one_hot/on_value" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1 - } - } - } -} -node { - name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 200 - } - } - } -} -node { - name: "mnist_dataset_train_1/random_shuffle_queue" - op: "RandomShuffleQueueV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "capacity" - value { - i: 20000 - } - } - attr { - key: "component_types" - value { - list { - type: DT_FLOAT - type: DT_INT64 - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "min_after_dequeue" - value { - i: 4000 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } - attr { - key: "shapes" - value { - list { - shape { - dim { - size: 28 - } - dim { - size: 28 - } - dim { - size: 1 - } - } - shape { - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" - op: "QueueDequeueManyV2" - input: "mnist_dataset_train_1/random_shuffle_queue" - input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - } - } - } - attr { - key: "component_types" - value { - list { - type: DT_FLOAT - type: DT_INT64 - } - } - } - attr { - key: "timeout_ms" - value { - i: -1 - } - } -} -node { - name: "Reshape" - op: "Reshape" - input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" - input: "Reshape/shape" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: -1 - } - } - } - } - } -} -node { - name: "MatMul" - op: "MatMul" - input: "Reshape" - input: "hidden_weights/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: false - } - } -} -node { - name: "add" - op: "Add" - input: "MatMul" - input: "hidden_biases/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "Relu" - op: "Relu" - input: "add" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "MatMul_1" - op: "MatMul" - input: "Relu" - input: "logits_weights/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: false - } - } -} -node { - name: "add_1" - op: "Add" - input: "MatMul_1" - input: "logits_biases/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "Reshape_1" - op: "Reshape" - input: "add_1" - input: "concat" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "mnist_dataset_train_2/one_hot" - op: "OneHot" - input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany:1" - input: "mnist_dataset_train_2/one_hot/depth" - input: "mnist_dataset_train_2/one_hot/on_value" - input: "mnist_dataset_train_2/one_hot/off_value" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "TI" - value { - type: DT_INT64 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "axis" - value { - i: -1 - } - } -} -node { - name: "Reshape_2" - op: "Reshape" - input: "mnist_dataset_train_2/one_hot" - input: "concat_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "SoftmaxCrossEntropyWithLogits" - op: "SoftmaxCrossEntropyWithLogits" - input: "Reshape_1" - input: "Reshape_2" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" - op: "PreventGradient" - input: "SoftmaxCrossEntropyWithLogits:1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "message" - value { - s: "Currently there is no way to take the second derivative of softmax_cross_entropy_with_logits due to the fused implementation\'s interaction with tf.gradients()" - } - } -} -node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" - op: "Mul" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/Reshape_1_grad/Reshape" - op: "Reshape" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" - input: "gradients/Reshape_1_grad/Shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/add_1_grad/Sum_1" - op: "Sum" - input: "gradients/Reshape_1_grad/Reshape" - input: "gradients/add_1_grad/BroadcastGradientArgs:1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/add_1_grad/Reshape_1" - op: "Reshape" - input: "gradients/add_1_grad/Sum_1" - input: "gradients/add_1_grad/Shape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/add_1_grad/Sum" - op: "Sum" - input: "gradients/Reshape_1_grad/Reshape" - input: "gradients/add_1_grad/BroadcastGradientArgs" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/add_1_grad/Reshape" - op: "Reshape" - input: "gradients/add_1_grad/Sum" - input: "gradients/add_1_grad/Shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/add_1_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/add_1_grad/Reshape" - input: "^gradients/add_1_grad/Reshape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "gradients/add_1_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/add_1_grad/Reshape_1" - input: "^gradients/add_1_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_1_grad/Reshape_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } -} -node { - name: "GradientDescent/update_logits_biases/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "logits_biases" - input: "GradientDescent/learning_rate" - input: "gradients/add_1_grad/tuple/control_dependency_1" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@logits_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } -} -node { - name: "gradients/add_1_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/add_1_grad/Reshape" - input: "^gradients/add_1_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_1_grad/Reshape" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/MatMul_1_grad/MatMul_1" - op: "MatMul" - input: "Relu" - input: "gradients/add_1_grad/tuple/control_dependency" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: true - } - } - attr { - key: "transpose_b" - value { - b: false - } - } -} -node { - name: "gradients/MatMul_1_grad/MatMul" - op: "MatMul" - input: "gradients/add_1_grad/tuple/control_dependency" - input: "logits_weights/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "gradients/MatMul_1_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/MatMul_1_grad/MatMul" - input: "^gradients/MatMul_1_grad/MatMul_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "gradients/MatMul_1_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/MatMul_1_grad/MatMul_1" - input: "^gradients/MatMul_1_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/MatMul_1_grad/MatMul_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "GradientDescent/update_logits_weights/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "logits_weights" - input: "GradientDescent/learning_rate" - input: "gradients/MatMul_1_grad/tuple/control_dependency_1" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@logits_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } -} -node { - name: "gradients/MatMul_1_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/MatMul_1_grad/MatMul" - input: "^gradients/MatMul_1_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/MatMul_1_grad/MatMul" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/Relu_grad/ReluGrad" - op: "ReluGrad" - input: "gradients/MatMul_1_grad/tuple/control_dependency" - input: "Relu" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/add_grad/Sum_1" - op: "Sum" - input: "gradients/Relu_grad/ReluGrad" - input: "gradients/add_grad/BroadcastGradientArgs:1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/add_grad/Reshape_1" - op: "Reshape" - input: "gradients/add_grad/Sum_1" - input: "gradients/add_grad/Shape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/add_grad/Sum" - op: "Sum" - input: "gradients/Relu_grad/ReluGrad" - input: "gradients/add_grad/BroadcastGradientArgs" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/add_grad/Reshape" - op: "Reshape" - input: "gradients/add_grad/Sum" - input: "gradients/add_grad/Shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/add_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/add_grad/Reshape" - input: "^gradients/add_grad/Reshape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "gradients/add_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/add_grad/Reshape_1" - input: "^gradients/add_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_grad/Reshape_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } -} -node { - name: "GradientDescent/update_hidden_biases/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "hidden_biases" - input: "GradientDescent/learning_rate" - input: "gradients/add_grad/tuple/control_dependency_1" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@hidden_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } -} -node { - name: "gradients/add_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/add_grad/Reshape" - input: "^gradients/add_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_grad/Reshape" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/MatMul_grad/MatMul_1" - op: "MatMul" - input: "Reshape" - input: "gradients/add_grad/tuple/control_dependency" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: true - } - } - attr { - key: "transpose_b" - value { - b: false - } - } -} -node { - name: "gradients/MatMul_grad/MatMul" - op: "MatMul" - input: "gradients/add_grad/tuple/control_dependency" - input: "hidden_weights/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 784 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "gradients/MatMul_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/MatMul_grad/MatMul" - input: "^gradients/MatMul_grad/MatMul_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "gradients/MatMul_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/MatMul_grad/MatMul_1" - input: "^gradients/MatMul_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/MatMul_grad/MatMul_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "GradientDescent/update_hidden_weights/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "hidden_weights" - input: "GradientDescent/learning_rate" - input: "gradients/MatMul_grad/tuple/control_dependency_1" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@hidden_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } -} -node { - name: "GradientDescent" - op: "NoOp" - input: "^GradientDescent/update_hidden_weights/ApplyGradientDescent" - input: "^GradientDescent/update_hidden_biases/ApplyGradientDescent" - input: "^GradientDescent/update_logits_weights/ApplyGradientDescent" - input: "^GradientDescent/update_logits_biases/ApplyGradientDescent" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_2" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "Reshape_3" - op: "Reshape" - input: "SoftmaxCrossEntropyWithLogits" - input: "Slice_2" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - } - } - } -} -node { - name: "Mean" - op: "Mean" - input: "Reshape_3" - input: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "_send_Mean_0" - op: "_Send" - input: "Mean" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "client_terminated" - value { - b: true - } - } - attr { - key: "recv_device" - value { - s: "/job:localhost/replica:0/task:0/cpu:0" - } - } - attr { - key: "send_device" - value { - s: "/job:localhost/replica:0/task:0/cpu:0" - } - } - attr { - key: "send_device_incarnation" - value { - i: -5924635994370253548 - } - } - attr { - key: "tensor_name" - value { - s: "Mean:0" - } - } -} -library { -} -versions { - producer: 21 -} diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard/demo/data/runs.json b/tensorflow/tensorboard/components/tf_graph_dashboard/demo/data/runs.json deleted file mode 100644 index 0429aa71f82..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_dashboard/demo/data/runs.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "run1": { - "graph": true, - "scalars": ["foo/sin"] - } -} diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard/demo/index.html b/tensorflow/tensorboard/components/tf_graph_dashboard/demo/index.html deleted file mode 100644 index ae84c547b48..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_dashboard/demo/index.html +++ /dev/null @@ -1,62 +0,0 @@ - - - - - - - - -Graph Dashboard Demo - - - - diff --git a/tensorflow/tensorboard/components/tf_graph_dashboard/tf-graph-dashboard.html b/tensorflow/tensorboard/components/tf_graph_dashboard/tf-graph-dashboard.html deleted file mode 100644 index ba69882a232..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_dashboard/tf-graph-dashboard.html +++ /dev/null @@ -1,321 +0,0 @@ - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_graph_debugger_data_card/BUILD b/tensorflow/tensorboard/components/tf_graph_debugger_data_card/BUILD deleted file mode 100644 index c0d2bd5a46c..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_debugger_data_card/BUILD +++ /dev/null @@ -1,44 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_graph_debugger_data_card", - srcs = [ - "tf-graph-debugger-data-card.html", - ], - path = "/tf-graph-debugger-data-card", - deps = [ - "//tensorflow/tensorboard/components/tf_dashboard_common", - "//tensorflow/tensorboard/components/tf_graph_common", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "@org_polymer_paper_slider", - "@org_polymer_paper_spinner", - ], -) - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [":tf_graph_debugger_data_card"], - destdir = "tf-graph-debugger-data-card", - deps = [ - "//tensorflow/tensorboard/components/tf_dashboard_common:legacy", - "//tensorflow/tensorboard/components/tf_graph_common:legacy", - "//third_party/javascript/polymer/v1/iron-collapse:lib", - "//third_party/javascript/polymer/v1/iron-list:lib", - "//third_party/javascript/polymer/v1/paper-icon-button:lib", - "//third_party/javascript/polymer/v1/paper-item:lib", - "//third_party/javascript/polymer/v1/paper-slider:lib", - "//third_party/javascript/polymer/v1/paper-spinner:lib", - "//third_party/javascript/polymer/v1/polymer:lib", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_graph_debugger_data_card/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_debugger_data_card/demo/BUILD deleted file mode 100644 index 66cb1156188..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_debugger_data_card/demo/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -# bazel run //third_party/tensorflow/tensorboard/components/tf_graph_debugger_data_card/demo -ts_web_library( - name = "demo", - srcs = ["index.html"] + glob(["data/**"]), - path = "/tf-graph-debugger-data-card/demo", - deps = [ - "//tensorflow/tensorboard/components/tf_graph_debugger_data_card", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - "@org_polymer_iron_demo_helpers", - "@org_polymer_paper_styles", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_graph_debugger_data_card/demo/index.html b/tensorflow/tensorboard/components/tf_graph_debugger_data_card/demo/index.html deleted file mode 100644 index 934e4f86a83..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_debugger_data_card/demo/index.html +++ /dev/null @@ -1,36 +0,0 @@ - - - - - -TF Graph Info Demo - - - - diff --git a/tensorflow/tensorboard/components/tf_graph_debugger_data_card/tf-graph-debugger-data-card.html b/tensorflow/tensorboard/components/tf_graph_debugger_data_card/tf-graph-debugger-data-card.html deleted file mode 100644 index 6cc99a327cb..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_debugger_data_card/tf-graph-debugger-data-card.html +++ /dev/null @@ -1,560 +0,0 @@ - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_graph_info/BUILD b/tensorflow/tensorboard/components/tf_graph_info/BUILD deleted file mode 100644 index 22e886d881e..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_info/BUILD +++ /dev/null @@ -1,53 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_graph_info", - srcs = [ - "tf-graph-icon.html", - "tf-graph-info.html", - "tf-node-info.html", - "tf-node-list-item.html", - ], - path = "/tf-graph-info", - deps = [ - "//tensorflow/tensorboard/components/tf_dashboard_common", - "//tensorflow/tensorboard/components/tf_graph_common", - "//tensorflow/tensorboard/components/tf_graph_debugger_data_card", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "@org_polymer_iron_collapse", - "@org_polymer_iron_list", - "@org_polymer_paper_icon_button", - "@org_polymer_paper_item", - "@org_polymer_paper_slider", - "@org_polymer_paper_spinner", - ], -) - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [":tf_graph_info"], - destdir = "tf-graph-info", - deps = [ - "//tensorflow/tensorboard/components/tf_dashboard_common:legacy", - "//tensorflow/tensorboard/components/tf_graph_common:legacy", - "//tensorflow/tensorboard/components/tf_graph_debugger_data_card:legacy", - "//third_party/javascript/polymer/v1/iron-collapse:lib", - "//third_party/javascript/polymer/v1/iron-list:lib", - "//third_party/javascript/polymer/v1/paper-icon-button:lib", - "//third_party/javascript/polymer/v1/paper-item:lib", - "//third_party/javascript/polymer/v1/paper-slider:lib", - "//third_party/javascript/polymer/v1/paper-spinner:lib", - "//third_party/javascript/polymer/v1/polymer:lib", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_graph_info/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_info/demo/BUILD deleted file mode 100644 index 2f1f7bf2761..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_info/demo/BUILD +++ /dev/null @@ -1,26 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -# bazel run //third_party/tensorflow/tensorboard/components/tf_graph_info/demo -ts_web_library( - name = "demo", - srcs = ["index.html"] + glob(["data/**"]), - path = "/tf-graph-info/demo", - deps = [ - "//tensorflow/tensorboard/components/tf_graph_common", - "//tensorflow/tensorboard/components/tf_graph_info", - "//tensorflow/tensorboard/components/tf_graph_loader", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - "@org_polymer_iron_demo_helpers", - "@org_polymer_paper_styles", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_graph_info/demo/data/graph.pbtxt b/tensorflow/tensorboard/components/tf_graph_info/demo/data/graph.pbtxt deleted file mode 100644 index 30b20645346..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_info/demo/data/graph.pbtxt +++ /dev/null @@ -1,4606 +0,0 @@ -node { - name: "GradientDescent/learning_rate" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_3" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.1 - } - } - } -} -node { - name: "gradients/add_grad/Shape_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 100 - } - } - } -} -node { - name: "gradients/add_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000d\000\000\000" - } - } - } -} -node { - name: "gradients/add_grad/BroadcastGradientArgs" - op: "BroadcastGradientArgs" - input: "gradients/add_grad/Shape" - input: "gradients/add_grad/Shape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - } - } - } -} -node { - name: "gradients/add_1_grad/Shape_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 10 - } - } - } -} -node { - name: "gradients/add_1_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "gradients/add_1_grad/BroadcastGradientArgs" - op: "BroadcastGradientArgs" - input: "gradients/add_1_grad/Shape" - input: "gradients/add_1_grad/Shape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - } - } - } -} -node { - name: "gradients/Reshape_1_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: -1 - } - } - } -} -node { - name: "gradients/Reshape_3_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 200 - } - } - } -} -node { - name: "gradients/Mean_grad/Maximum/y" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "gradients/Mean_grad/Const_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "gradients/Mean_grad/Const" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "gradients/Mean_grad/Shape_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } -} -node { - name: "gradients/Mean_grad/Prod_1" - op: "Prod" - input: "gradients/Mean_grad/Shape_1" - input: "gradients/Mean_grad/Const_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/Mean_grad/Maximum" - op: "Maximum" - input: "gradients/Mean_grad/Prod_1" - input: "gradients/Mean_grad/Maximum/y" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "gradients/Mean_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 200 - } - } - } -} -node { - name: "gradients/Mean_grad/Prod" - op: "Prod" - input: "gradients/Mean_grad/Shape" - input: "gradients/Mean_grad/Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/Mean_grad/floordiv" - op: "FloorDiv" - input: "gradients/Mean_grad/Prod" - input: "gradients/Mean_grad/Maximum" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "gradients/Mean_grad/Cast" - op: "Cast" - input: "gradients/Mean_grad/floordiv" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "gradients/Mean_grad/Tile/multiples" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 200 - } - } - } -} -node { - name: "gradients/Mean_grad/Reshape/shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "gradients/Const" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1 - } - } - } -} -node { - name: "gradients/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } -} -node { - name: "gradients/Fill" - op: "Fill" - input: "gradients/Shape" - input: "gradients/Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "gradients/Mean_grad/Reshape" - op: "Reshape" - input: "gradients/Fill" - input: "gradients/Mean_grad/Reshape/shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "gradients/Mean_grad/Tile" - op: "Tile" - input: "gradients/Mean_grad/Reshape" - input: "gradients/Mean_grad/Tile/multiples" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tmultiples" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - } - } - } -} -node { - name: "gradients/Mean_grad/truediv" - op: "RealDiv" - input: "gradients/Mean_grad/Tile" - input: "gradients/Mean_grad/Cast" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - } - } - } -} -node { - name: "gradients/Reshape_3_grad/Reshape" - op: "Reshape" - input: "gradients/Mean_grad/truediv" - input: "gradients/Reshape_3_grad/Shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - } - } - } -} -node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" - op: "ExpandDims" - input: "gradients/Reshape_3_grad/Reshape" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "Const" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "Slice_2/begin" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "Sub_2/y" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "concat_1/axis" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "concat_1/values_0" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: -1 - } - } - } -} -node { - name: "Slice_1/size" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "Sub_1/y" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "Shape_2" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "Rank_2" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "Sub_1" - op: "Sub" - input: "Rank_2" - input: "Sub_1/y" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Slice_1/begin" - op: "Pack" - input: "Sub_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "Slice_1" - op: "Slice" - input: "Shape_2" - input: "Slice_1/begin" - input: "Slice_1/size" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "concat_1" - op: "ConcatV2" - input: "concat_1/values_0" - input: "Slice_1" - input: "concat_1/axis" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } -} -node { - name: "concat/axis" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "concat/values_0" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: -1 - } - } - } -} -node { - name: "Slice/size" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "Sub/y" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "Shape_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "Rank_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "Sub" - op: "Sub" - input: "Rank_1" - input: "Sub/y" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Slice/begin" - op: "Pack" - input: "Sub" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "Slice" - op: "Slice" - input: "Shape_1" - input: "Slice/begin" - input: "Slice/size" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "concat" - op: "ConcatV2" - input: "concat/values_0" - input: "Slice" - input: "concat/axis" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } -} -node { - name: "Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "Rank" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "Sub_2" - op: "Sub" - input: "Rank" - input: "Sub_2/y" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Slice_2/size" - op: "Pack" - input: "Sub_2" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "Slice_2" - op: "Slice" - input: "Shape" - input: "Slice_2/begin" - input: "Slice_2/size" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "logits_biases" - op: "VariableV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_class" - value { - list { - s: "loc:@logits_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 10 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "logits_biases/read" - op: "Identity" - input: "logits_biases" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@logits_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } -} -node { - name: "logits_weights" - op: "VariableV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_class" - value { - list { - s: "loc:@logits_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "logits_weights/read" - op: "Identity" - input: "logits_weights" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@logits_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "hidden_biases" - op: "VariableV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_class" - value { - list { - s: "loc:@hidden_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 100 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "hidden_biases/read" - op: "Identity" - input: "hidden_biases" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@hidden_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } -} -node { - name: "hidden_weights" - op: "VariableV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_class" - value { - list { - s: "loc:@hidden_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 784 - } - dim { - size: 100 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "hidden_weights/read" - op: "Identity" - input: "hidden_weights" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@hidden_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "Reshape/shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\377\377\377\377" - } - } - } -} -node { - name: "mnist_dataset_train_2/one_hot/depth" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 10 - } - } - } -} -node { - name: "mnist_dataset_train_2/one_hot/off_value" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0 - } - } - } -} -node { - name: "mnist_dataset_train_2/one_hot/on_value" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1 - } - } - } -} -node { - name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 200 - } - } - } -} -node { - name: "mnist_dataset_train_1/random_shuffle_queue" - op: "RandomShuffleQueueV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "capacity" - value { - i: 20000 - } - } - attr { - key: "component_types" - value { - list { - type: DT_FLOAT - type: DT_INT64 - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "min_after_dequeue" - value { - i: 4000 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } - attr { - key: "shapes" - value { - list { - shape { - dim { - size: 28 - } - dim { - size: 28 - } - dim { - size: 1 - } - } - shape { - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" - op: "QueueDequeueManyV2" - input: "mnist_dataset_train_1/random_shuffle_queue" - input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - } - } - } - attr { - key: "component_types" - value { - list { - type: DT_FLOAT - type: DT_INT64 - } - } - } - attr { - key: "timeout_ms" - value { - i: -1 - } - } -} -node { - name: "Reshape" - op: "Reshape" - input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" - input: "Reshape/shape" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: -1 - } - } - } - } - } -} -node { - name: "MatMul" - op: "MatMul" - input: "Reshape" - input: "hidden_weights/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: false - } - } -} -node { - name: "add" - op: "Add" - input: "MatMul" - input: "hidden_biases/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "Relu" - op: "Relu" - input: "add" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "MatMul_1" - op: "MatMul" - input: "Relu" - input: "logits_weights/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: false - } - } -} -node { - name: "add_1" - op: "Add" - input: "MatMul_1" - input: "logits_biases/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "Reshape_1" - op: "Reshape" - input: "add_1" - input: "concat" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "mnist_dataset_train_2/one_hot" - op: "OneHot" - input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany:1" - input: "mnist_dataset_train_2/one_hot/depth" - input: "mnist_dataset_train_2/one_hot/on_value" - input: "mnist_dataset_train_2/one_hot/off_value" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "TI" - value { - type: DT_INT64 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "axis" - value { - i: -1 - } - } -} -node { - name: "Reshape_2" - op: "Reshape" - input: "mnist_dataset_train_2/one_hot" - input: "concat_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "SoftmaxCrossEntropyWithLogits" - op: "SoftmaxCrossEntropyWithLogits" - input: "Reshape_1" - input: "Reshape_2" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" - op: "PreventGradient" - input: "SoftmaxCrossEntropyWithLogits:1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "message" - value { - s: "Currently there is no way to take the second derivative of softmax_cross_entropy_with_logits due to the fused implementation\'s interaction with tf.gradients()" - } - } -} -node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" - op: "Mul" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/Reshape_1_grad/Reshape" - op: "Reshape" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" - input: "gradients/Reshape_1_grad/Shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/add_1_grad/Sum_1" - op: "Sum" - input: "gradients/Reshape_1_grad/Reshape" - input: "gradients/add_1_grad/BroadcastGradientArgs:1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/add_1_grad/Reshape_1" - op: "Reshape" - input: "gradients/add_1_grad/Sum_1" - input: "gradients/add_1_grad/Shape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/add_1_grad/Sum" - op: "Sum" - input: "gradients/Reshape_1_grad/Reshape" - input: "gradients/add_1_grad/BroadcastGradientArgs" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/add_1_grad/Reshape" - op: "Reshape" - input: "gradients/add_1_grad/Sum" - input: "gradients/add_1_grad/Shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/add_1_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/add_1_grad/Reshape" - input: "^gradients/add_1_grad/Reshape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "gradients/add_1_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/add_1_grad/Reshape_1" - input: "^gradients/add_1_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_1_grad/Reshape_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } -} -node { - name: "GradientDescent/update_logits_biases/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "logits_biases" - input: "GradientDescent/learning_rate" - input: "gradients/add_1_grad/tuple/control_dependency_1" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@logits_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } -} -node { - name: "gradients/add_1_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/add_1_grad/Reshape" - input: "^gradients/add_1_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_1_grad/Reshape" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/MatMul_1_grad/MatMul_1" - op: "MatMul" - input: "Relu" - input: "gradients/add_1_grad/tuple/control_dependency" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: true - } - } - attr { - key: "transpose_b" - value { - b: false - } - } -} -node { - name: "gradients/MatMul_1_grad/MatMul" - op: "MatMul" - input: "gradients/add_1_grad/tuple/control_dependency" - input: "logits_weights/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "gradients/MatMul_1_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/MatMul_1_grad/MatMul" - input: "^gradients/MatMul_1_grad/MatMul_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "gradients/MatMul_1_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/MatMul_1_grad/MatMul_1" - input: "^gradients/MatMul_1_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/MatMul_1_grad/MatMul_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "GradientDescent/update_logits_weights/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "logits_weights" - input: "GradientDescent/learning_rate" - input: "gradients/MatMul_1_grad/tuple/control_dependency_1" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@logits_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } -} -node { - name: "gradients/MatMul_1_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/MatMul_1_grad/MatMul" - input: "^gradients/MatMul_1_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/MatMul_1_grad/MatMul" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/Relu_grad/ReluGrad" - op: "ReluGrad" - input: "gradients/MatMul_1_grad/tuple/control_dependency" - input: "Relu" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/add_grad/Sum_1" - op: "Sum" - input: "gradients/Relu_grad/ReluGrad" - input: "gradients/add_grad/BroadcastGradientArgs:1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/add_grad/Reshape_1" - op: "Reshape" - input: "gradients/add_grad/Sum_1" - input: "gradients/add_grad/Shape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/add_grad/Sum" - op: "Sum" - input: "gradients/Relu_grad/ReluGrad" - input: "gradients/add_grad/BroadcastGradientArgs" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/add_grad/Reshape" - op: "Reshape" - input: "gradients/add_grad/Sum" - input: "gradients/add_grad/Shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/add_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/add_grad/Reshape" - input: "^gradients/add_grad/Reshape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "gradients/add_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/add_grad/Reshape_1" - input: "^gradients/add_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_grad/Reshape_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } -} -node { - name: "GradientDescent/update_hidden_biases/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "hidden_biases" - input: "GradientDescent/learning_rate" - input: "gradients/add_grad/tuple/control_dependency_1" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@hidden_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } -} -node { - name: "gradients/add_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/add_grad/Reshape" - input: "^gradients/add_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_grad/Reshape" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/MatMul_grad/MatMul_1" - op: "MatMul" - input: "Reshape" - input: "gradients/add_grad/tuple/control_dependency" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: true - } - } - attr { - key: "transpose_b" - value { - b: false - } - } -} -node { - name: "gradients/MatMul_grad/MatMul" - op: "MatMul" - input: "gradients/add_grad/tuple/control_dependency" - input: "hidden_weights/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 784 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "gradients/MatMul_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/MatMul_grad/MatMul" - input: "^gradients/MatMul_grad/MatMul_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "gradients/MatMul_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/MatMul_grad/MatMul_1" - input: "^gradients/MatMul_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/MatMul_grad/MatMul_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "GradientDescent/update_hidden_weights/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "hidden_weights" - input: "GradientDescent/learning_rate" - input: "gradients/MatMul_grad/tuple/control_dependency_1" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@hidden_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } -} -node { - name: "GradientDescent" - op: "NoOp" - input: "^GradientDescent/update_hidden_weights/ApplyGradientDescent" - input: "^GradientDescent/update_hidden_biases/ApplyGradientDescent" - input: "^GradientDescent/update_logits_weights/ApplyGradientDescent" - input: "^GradientDescent/update_logits_biases/ApplyGradientDescent" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_2" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "Reshape_3" - op: "Reshape" - input: "SoftmaxCrossEntropyWithLogits" - input: "Slice_2" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - } - } - } -} -node { - name: "Mean" - op: "Mean" - input: "Reshape_3" - input: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "_send_Mean_0" - op: "_Send" - input: "Mean" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "client_terminated" - value { - b: true - } - } - attr { - key: "recv_device" - value { - s: "/job:localhost/replica:0/task:0/cpu:0" - } - } - attr { - key: "send_device" - value { - s: "/job:localhost/replica:0/task:0/cpu:0" - } - } - attr { - key: "send_device_incarnation" - value { - i: -5924635994370253548 - } - } - attr { - key: "tensor_name" - value { - s: "Mean:0" - } - } -} -library { -} -versions { - producer: 21 -} diff --git a/tensorflow/tensorboard/components/tf_graph_info/demo/index.html b/tensorflow/tensorboard/components/tf_graph_info/demo/index.html deleted file mode 100644 index f7d2ef7ee5e..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_info/demo/index.html +++ /dev/null @@ -1,94 +0,0 @@ - - - - - - - -TF Graph Info Demo - - - - diff --git a/tensorflow/tensorboard/components/tf_graph_info/tf-graph-icon.html b/tensorflow/tensorboard/components/tf_graph_info/tf-graph-icon.html deleted file mode 100644 index a3e9dc59c5a..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_info/tf-graph-icon.html +++ /dev/null @@ -1,296 +0,0 @@ - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_graph_info/tf-graph-info.html b/tensorflow/tensorboard/components/tf_graph_info/tf-graph-info.html deleted file mode 100644 index bac25b67f77..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_info/tf-graph-info.html +++ /dev/null @@ -1,130 +0,0 @@ - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_graph_info/tf-node-info.html b/tensorflow/tensorboard/components/tf_graph_info/tf-node-info.html deleted file mode 100644 index 66a3034b5b2..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_info/tf-node-info.html +++ /dev/null @@ -1,652 +0,0 @@ - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_graph_info/tf-node-list-item.html b/tensorflow/tensorboard/components/tf_graph_info/tf-node-list-item.html deleted file mode 100644 index c15478d126c..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_info/tf-node-list-item.html +++ /dev/null @@ -1,138 +0,0 @@ - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_graph_loader/BUILD b/tensorflow/tensorboard/components/tf_graph_loader/BUILD deleted file mode 100644 index 41fbfb8ee85..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_loader/BUILD +++ /dev/null @@ -1,32 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_graph_loader", - srcs = ["tf-graph-loader.html"], - path = "/tf-graph-loader", - deps = [ - "//tensorflow/tensorboard/components/tf_graph_common", - "//tensorflow/tensorboard/components/tf_imports:polymer", - ], -) - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [":tf_graph_loader"], - destdir = "tf-graph-loader", - deps = [ - "//tensorflow/tensorboard/components/tf_graph_common:legacy", - "//third_party/javascript/polymer/v1/polymer:lib", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_graph_loader/demo/BUILD b/tensorflow/tensorboard/components/tf_graph_loader/demo/BUILD deleted file mode 100644 index f109a19163b..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_loader/demo/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -# bazel run //third_party/tensorflow/tensorboard/components/tf_graph_loader/demo -ts_web_library( - name = "demo", - srcs = ["index.html"] + glob(["data/**"]), - path = "/tf-graph-loader/demo", - deps = [ - "//tensorflow/tensorboard/components/tf_graph_loader", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - "@org_polymer_iron_demo_helpers", - "@org_polymer_paper_styles", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_graph_loader/demo/data/graph.pbtxt b/tensorflow/tensorboard/components/tf_graph_loader/demo/data/graph.pbtxt deleted file mode 100644 index 30b20645346..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_loader/demo/data/graph.pbtxt +++ /dev/null @@ -1,4606 +0,0 @@ -node { - name: "GradientDescent/learning_rate" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_3" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0.1 - } - } - } -} -node { - name: "gradients/add_grad/Shape_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 100 - } - } - } -} -node { - name: "gradients/add_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000d\000\000\000" - } - } - } -} -node { - name: "gradients/add_grad/BroadcastGradientArgs" - op: "BroadcastGradientArgs" - input: "gradients/add_grad/Shape" - input: "gradients/add_grad/Shape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - } - } - } -} -node { - name: "gradients/add_1_grad/Shape_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 10 - } - } - } -} -node { - name: "gradients/add_1_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "gradients/add_1_grad/BroadcastGradientArgs" - op: "BroadcastGradientArgs" - input: "gradients/add_1_grad/Shape" - input: "gradients/add_1_grad/Shape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - } - shape { - dim { - size: -1 - } - } - } - } - } -} -node { - name: "gradients/Reshape_1_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: -1 - } - } - } -} -node { - name: "gradients/Reshape_3_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 200 - } - } - } -} -node { - name: "gradients/Mean_grad/Maximum/y" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "gradients/Mean_grad/Const_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "gradients/Mean_grad/Const" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "gradients/Mean_grad/Shape_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } -} -node { - name: "gradients/Mean_grad/Prod_1" - op: "Prod" - input: "gradients/Mean_grad/Shape_1" - input: "gradients/Mean_grad/Const_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/Mean_grad/Maximum" - op: "Maximum" - input: "gradients/Mean_grad/Prod_1" - input: "gradients/Mean_grad/Maximum/y" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "gradients/Mean_grad/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 200 - } - } - } -} -node { - name: "gradients/Mean_grad/Prod" - op: "Prod" - input: "gradients/Mean_grad/Shape" - input: "gradients/Mean_grad/Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/Mean_grad/floordiv" - op: "FloorDiv" - input: "gradients/Mean_grad/Prod" - input: "gradients/Mean_grad/Maximum" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "gradients/Mean_grad/Cast" - op: "Cast" - input: "gradients/Mean_grad/floordiv" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "DstT" - value { - type: DT_FLOAT - } - } - attr { - key: "SrcT" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "gradients/Mean_grad/Tile/multiples" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 200 - } - } - } -} -node { - name: "gradients/Mean_grad/Reshape/shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "gradients/Const" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1 - } - } - } -} -node { - name: "gradients/Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - } - } - } - } - } -} -node { - name: "gradients/Fill" - op: "Fill" - input: "gradients/Shape" - input: "gradients/Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "gradients/Mean_grad/Reshape" - op: "Reshape" - input: "gradients/Fill" - input: "gradients/Mean_grad/Reshape/shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "gradients/Mean_grad/Tile" - op: "Tile" - input: "gradients/Mean_grad/Reshape" - input: "gradients/Mean_grad/Tile/multiples" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tmultiples" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - } - } - } -} -node { - name: "gradients/Mean_grad/truediv" - op: "RealDiv" - input: "gradients/Mean_grad/Tile" - input: "gradients/Mean_grad/Cast" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - } - } - } -} -node { - name: "gradients/Reshape_3_grad/Reshape" - op: "Reshape" - input: "gradients/Mean_grad/truediv" - input: "gradients/Reshape_3_grad/Shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - } - } - } -} -node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" - op: "ExpandDims" - input: "gradients/Reshape_3_grad/Reshape" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims/dim" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tdim" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 1 - } - } - } - } - } -} -node { - name: "Const" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "Slice_2/begin" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 0 - } - } - } -} -node { - name: "Sub_2/y" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "concat_1/axis" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "concat_1/values_0" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: -1 - } - } - } -} -node { - name: "Slice_1/size" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "Sub_1/y" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "Shape_2" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "Rank_2" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "Sub_1" - op: "Sub" - input: "Rank_2" - input: "Sub_1/y" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Slice_1/begin" - op: "Pack" - input: "Sub_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "Slice_1" - op: "Slice" - input: "Shape_2" - input: "Slice_1/begin" - input: "Slice_1/size" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "concat_1" - op: "ConcatV2" - input: "concat_1/values_0" - input: "Slice_1" - input: "concat_1/axis" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } -} -node { - name: "concat/axis" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 0 - } - } - } -} -node { - name: "concat/values_0" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: -1 - } - } - } -} -node { - name: "Slice/size" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 1 - } - } - int_val: 1 - } - } - } -} -node { - name: "Sub/y" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 1 - } - } - } -} -node { - name: "Shape_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "Rank_1" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "Sub" - op: "Sub" - input: "Rank_1" - input: "Sub/y" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Slice/begin" - op: "Pack" - input: "Sub" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "Slice" - op: "Slice" - input: "Shape_1" - input: "Slice/begin" - input: "Slice/size" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "concat" - op: "ConcatV2" - input: "concat/values_0" - input: "Slice" - input: "concat/axis" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 2 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } -} -node { - name: "Shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\n\000\000\000" - } - } - } -} -node { - name: "Rank" - op: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 2 - } - } - } -} -node { - name: "Sub_2" - op: "Sub" - input: "Rank" - input: "Sub_2/y" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } -} -node { - name: "Slice_2/size" - op: "Pack" - input: "Sub_2" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "N" - value { - i: 1 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } - attr { - key: "axis" - value { - i: 0 - } - } -} -node { - name: "Slice_2" - op: "Slice" - input: "Shape" - input: "Slice_2/begin" - input: "Slice_2/size" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "Index" - value { - type: DT_INT32 - } - } - attr { - key: "T" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 1 - } - } - } - } - } -} -node { - name: "logits_biases" - op: "VariableV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_class" - value { - list { - s: "loc:@logits_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 10 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "logits_biases/read" - op: "Identity" - input: "logits_biases" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@logits_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } -} -node { - name: "logits_weights" - op: "VariableV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_class" - value { - list { - s: "loc:@logits_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "logits_weights/read" - op: "Identity" - input: "logits_weights" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@logits_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "hidden_biases" - op: "VariableV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_class" - value { - list { - s: "loc:@hidden_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 100 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "hidden_biases/read" - op: "Identity" - input: "hidden_biases" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@hidden_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } -} -node { - name: "hidden_weights" - op: "VariableV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_class" - value { - list { - s: "loc:@hidden_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "shape" - value { - shape { - dim { - size: 784 - } - dim { - size: 100 - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "hidden_weights/read" - op: "Identity" - input: "hidden_weights" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@hidden_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "Reshape/shape" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 2 - } - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - dim { - size: 2 - } - } - tensor_content: "\310\000\000\000\377\377\377\377" - } - } - } -} -node { - name: "mnist_dataset_train_2/one_hot/depth" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 10 - } - } - } -} -node { - name: "mnist_dataset_train_2/one_hot/off_value" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 0 - } - } - } -} -node { - name: "mnist_dataset_train_2/one_hot/on_value" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_FLOAT - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_FLOAT - tensor_shape { - } - float_val: 1 - } - } - } -} -node { - name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" - op: "Const" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "dtype" - value { - type: DT_INT32 - } - } - attr { - key: "value" - value { - tensor { - dtype: DT_INT32 - tensor_shape { - } - int_val: 200 - } - } - } -} -node { - name: "mnist_dataset_train_1/random_shuffle_queue" - op: "RandomShuffleQueueV2" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "capacity" - value { - i: 20000 - } - } - attr { - key: "component_types" - value { - list { - type: DT_FLOAT - type: DT_INT64 - } - } - } - attr { - key: "container" - value { - s: "" - } - } - attr { - key: "min_after_dequeue" - value { - i: 4000 - } - } - attr { - key: "seed" - value { - i: 0 - } - } - attr { - key: "seed2" - value { - i: 0 - } - } - attr { - key: "shapes" - value { - list { - shape { - dim { - size: 28 - } - dim { - size: 28 - } - dim { - size: 1 - } - } - shape { - } - } - } - } - attr { - key: "shared_name" - value { - s: "" - } - } -} -node { - name: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" - op: "QueueDequeueManyV2" - input: "mnist_dataset_train_1/random_shuffle_queue" - input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany/n" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - shape { - unknown_rank: true - } - } - } - } - attr { - key: "component_types" - value { - list { - type: DT_FLOAT - type: DT_INT64 - } - } - } - attr { - key: "timeout_ms" - value { - i: -1 - } - } -} -node { - name: "Reshape" - op: "Reshape" - input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany" - input: "Reshape/shape" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: -1 - } - } - } - } - } -} -node { - name: "MatMul" - op: "MatMul" - input: "Reshape" - input: "hidden_weights/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: false - } - } -} -node { - name: "add" - op: "Add" - input: "MatMul" - input: "hidden_biases/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "Relu" - op: "Relu" - input: "add" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "MatMul_1" - op: "MatMul" - input: "Relu" - input: "logits_weights/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: false - } - } -} -node { - name: "add_1" - op: "Add" - input: "MatMul_1" - input: "logits_biases/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "Reshape_1" - op: "Reshape" - input: "add_1" - input: "concat" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "mnist_dataset_train_2/one_hot" - op: "OneHot" - input: "mnist_dataset_train_2/random_shuffle_queue_DequeueMany:1" - input: "mnist_dataset_train_2/one_hot/depth" - input: "mnist_dataset_train_2/one_hot/on_value" - input: "mnist_dataset_train_2/one_hot/off_value" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "TI" - value { - type: DT_INT64 - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - unknown_rank: true - } - } - } - } - attr { - key: "axis" - value { - i: -1 - } - } -} -node { - name: "Reshape_2" - op: "Reshape" - input: "mnist_dataset_train_2/one_hot" - input: "concat_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "SoftmaxCrossEntropyWithLogits" - op: "SoftmaxCrossEntropyWithLogits" - input: "Reshape_1" - input: "Reshape_2" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" - op: "PreventGradient" - input: "SoftmaxCrossEntropyWithLogits:1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "message" - value { - s: "Currently there is no way to take the second derivative of softmax_cross_entropy_with_logits due to the fused implementation\'s interaction with tf.gradients()" - } - } -} -node { - name: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" - op: "Mul" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/ExpandDims" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/PreventGradient" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/Reshape_1_grad/Reshape" - op: "Reshape" - input: "gradients/SoftmaxCrossEntropyWithLogits_grad/mul" - input: "gradients/Reshape_1_grad/Shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/add_1_grad/Sum_1" - op: "Sum" - input: "gradients/Reshape_1_grad/Reshape" - input: "gradients/add_1_grad/BroadcastGradientArgs:1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/add_1_grad/Reshape_1" - op: "Reshape" - input: "gradients/add_1_grad/Sum_1" - input: "gradients/add_1_grad/Shape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/add_1_grad/Sum" - op: "Sum" - input: "gradients/Reshape_1_grad/Reshape" - input: "gradients/add_1_grad/BroadcastGradientArgs" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/add_1_grad/Reshape" - op: "Reshape" - input: "gradients/add_1_grad/Sum" - input: "gradients/add_1_grad/Shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/add_1_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/add_1_grad/Reshape" - input: "^gradients/add_1_grad/Reshape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "gradients/add_1_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/add_1_grad/Reshape_1" - input: "^gradients/add_1_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_1_grad/Reshape_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } -} -node { - name: "GradientDescent/update_logits_biases/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "logits_biases" - input: "GradientDescent/learning_rate" - input: "gradients/add_1_grad/tuple/control_dependency_1" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@logits_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } -} -node { - name: "gradients/add_1_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/add_1_grad/Reshape" - input: "^gradients/add_1_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_1_grad/Reshape" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "gradients/MatMul_1_grad/MatMul_1" - op: "MatMul" - input: "Relu" - input: "gradients/add_1_grad/tuple/control_dependency" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: true - } - } - attr { - key: "transpose_b" - value { - b: false - } - } -} -node { - name: "gradients/MatMul_1_grad/MatMul" - op: "MatMul" - input: "gradients/add_1_grad/tuple/control_dependency" - input: "logits_weights/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "gradients/MatMul_1_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/MatMul_1_grad/MatMul" - input: "^gradients/MatMul_1_grad/MatMul_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "gradients/MatMul_1_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/MatMul_1_grad/MatMul_1" - input: "^gradients/MatMul_1_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/MatMul_1_grad/MatMul_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } -} -node { - name: "GradientDescent/update_logits_weights/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "logits_weights" - input: "GradientDescent/learning_rate" - input: "gradients/MatMul_1_grad/tuple/control_dependency_1" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@logits_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - dim { - size: 10 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } -} -node { - name: "gradients/MatMul_1_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/MatMul_1_grad/MatMul" - input: "^gradients/MatMul_1_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/MatMul_1_grad/MatMul" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/Relu_grad/ReluGrad" - op: "ReluGrad" - input: "gradients/MatMul_1_grad/tuple/control_dependency" - input: "Relu" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/add_grad/Sum_1" - op: "Sum" - input: "gradients/Relu_grad/ReluGrad" - input: "gradients/add_grad/BroadcastGradientArgs:1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/add_grad/Reshape_1" - op: "Reshape" - input: "gradients/add_grad/Sum_1" - input: "gradients/add_grad/Shape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/add_grad/Sum" - op: "Sum" - input: "gradients/Relu_grad/ReluGrad" - input: "gradients/add_grad/BroadcastGradientArgs" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "gradients/add_grad/Reshape" - op: "Reshape" - input: "gradients/add_grad/Sum" - input: "gradients/add_grad/Shape" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/add_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/add_grad/Reshape" - input: "^gradients/add_grad/Reshape_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "gradients/add_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/add_grad/Reshape_1" - input: "^gradients/add_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_grad/Reshape_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } -} -node { - name: "GradientDescent/update_hidden_biases/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "hidden_biases" - input: "GradientDescent/learning_rate" - input: "gradients/add_grad/tuple/control_dependency_1" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@hidden_biases" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 100 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } -} -node { - name: "gradients/add_grad/tuple/control_dependency" - op: "Identity" - input: "gradients/add_grad/Reshape" - input: "^gradients/add_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/add_grad/Reshape" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "gradients/MatMul_grad/MatMul_1" - op: "MatMul" - input: "Reshape" - input: "gradients/add_grad/tuple/control_dependency" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: true - } - } - attr { - key: "transpose_b" - value { - b: false - } - } -} -node { - name: "gradients/MatMul_grad/MatMul" - op: "MatMul" - input: "gradients/add_grad/tuple/control_dependency" - input: "hidden_weights/read" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - dim { - size: 784 - } - } - } - } - } - attr { - key: "transpose_a" - value { - b: false - } - } - attr { - key: "transpose_b" - value { - b: true - } - } -} -node { - name: "gradients/MatMul_grad/tuple/group_deps" - op: "NoOp" - input: "^gradients/MatMul_grad/MatMul" - input: "^gradients/MatMul_grad/MatMul_1" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "gradients/MatMul_grad/tuple/control_dependency_1" - op: "Identity" - input: "gradients/MatMul_grad/MatMul_1" - input: "^gradients/MatMul_grad/tuple/group_deps" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_class" - value { - list { - s: "loc:@gradients/MatMul_grad/MatMul_1" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: -1 - } - dim { - size: 100 - } - } - } - } - } -} -node { - name: "GradientDescent/update_hidden_weights/ApplyGradientDescent" - op: "ApplyGradientDescent" - input: "hidden_weights" - input: "GradientDescent/learning_rate" - input: "gradients/MatMul_grad/tuple/control_dependency_1" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "_class" - value { - list { - s: "loc:@hidden_weights" - } - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 784 - } - dim { - size: 100 - } - } - } - } - } - attr { - key: "use_locking" - value { - b: false - } - } -} -node { - name: "GradientDescent" - op: "NoOp" - input: "^GradientDescent/update_hidden_weights/ApplyGradientDescent" - input: "^GradientDescent/update_hidden_biases/ApplyGradientDescent" - input: "^GradientDescent/update_logits_weights/ApplyGradientDescent" - input: "^GradientDescent/update_logits_biases/ApplyGradientDescent" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "_XlaCluster" - value { - s: "cluster_2" - } - } - attr { - key: "_output_shapes" - value { - list { - } - } - } -} -node { - name: "Reshape_3" - op: "Reshape" - input: "SoftmaxCrossEntropyWithLogits" - input: "Slice_2" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tshape" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - dim { - size: 200 - } - } - } - } - } -} -node { - name: "Mean" - op: "Mean" - input: "Reshape_3" - input: "Const" - device: "/job:localhost/replica:0/task:0/device:XLA_CPU:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "Tidx" - value { - type: DT_INT32 - } - } - attr { - key: "_XlaCluster" - value { - s: "cluster_1" - } - } - attr { - key: "_output_shapes" - value { - list { - shape { - } - } - } - } - attr { - key: "keep_dims" - value { - b: false - } - } -} -node { - name: "_send_Mean_0" - op: "_Send" - input: "Mean" - device: "/job:localhost/replica:0/task:0/cpu:0" - attr { - key: "T" - value { - type: DT_FLOAT - } - } - attr { - key: "client_terminated" - value { - b: true - } - } - attr { - key: "recv_device" - value { - s: "/job:localhost/replica:0/task:0/cpu:0" - } - } - attr { - key: "send_device" - value { - s: "/job:localhost/replica:0/task:0/cpu:0" - } - } - attr { - key: "send_device_incarnation" - value { - i: -5924635994370253548 - } - } - attr { - key: "tensor_name" - value { - s: "Mean:0" - } - } -} -library { -} -versions { - producer: 21 -} diff --git a/tensorflow/tensorboard/components/tf_graph_loader/demo/index.html b/tensorflow/tensorboard/components/tf_graph_loader/demo/index.html deleted file mode 100644 index 2ffb2a1a59c..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_loader/demo/index.html +++ /dev/null @@ -1,75 +0,0 @@ - - - - - -TF Graph Loader Demo - - - diff --git a/tensorflow/tensorboard/components/tf_graph_loader/test/index.html b/tensorflow/tensorboard/components/tf_graph_loader/test/index.html deleted file mode 100644 index c8e2027f42a..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_loader/test/index.html +++ /dev/null @@ -1,30 +0,0 @@ - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_graph_loader/tf-graph-loader.html b/tensorflow/tensorboard/components/tf_graph_loader/tf-graph-loader.html deleted file mode 100644 index 8d59cbd2aac..00000000000 --- a/tensorflow/tensorboard/components/tf_graph_loader/tf-graph-loader.html +++ /dev/null @@ -1,184 +0,0 @@ - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_histogram_dashboard/BUILD b/tensorflow/tensorboard/components/tf_histogram_dashboard/BUILD deleted file mode 100644 index e510e4b4671..00000000000 --- a/tensorflow/tensorboard/components/tf_histogram_dashboard/BUILD +++ /dev/null @@ -1,41 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_histogram_dashboard", - srcs = ["tf-histogram-dashboard.html"], - path = "/tf-histogram-dashboard", - deps = [ - "//tensorflow/tensorboard/components/tf_backend", - "//tensorflow/tensorboard/components/tf_color_scale", - "//tensorflow/tensorboard/components/tf_dashboard_common", - "//tensorflow/tensorboard/components/tf_imports:lodash", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "//tensorflow/tensorboard/components/vz_histogram_timeseries", - "@org_polymer_iron_collapse", - "@org_polymer_paper_icon_button", - "@org_polymer_paper_styles", - ], -) - -ts_web_library( - name = "demo", - srcs = ["index.html"] + glob(["data/**"]), - path = "/tf-histogram-dashboard", - deps = [ - ":tf_histogram_dashboard", - "//tensorflow/tensorboard/components/tf_imports:d3", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - "@org_polymer_iron_demo_helpers", - "@org_polymer_paper_styles", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_histogram_dashboard/data/histograms_run_run1_tag_histo1.json b/tensorflow/tensorboard/components/tf_histogram_dashboard/data/histograms_run_run1_tag_histo1.json deleted file mode 100644 index a5600a356e8..00000000000 --- a/tensorflow/tensorboard/components/tf_histogram_dashboard/data/histograms_run_run1_tag_histo1.json +++ /dev/null @@ -1 +0,0 @@ -[[400.0, 40, [-0.3584790755077172, 3.0267252195784047, 20.0, 24.012225532303315, 48.29045006426564, [-0.35363819004775493, -0.29226296698161564, -0.19961953895336082, 0.3214892636797772, 0.5177616740489182, 0.56953784145381, 0.6264916255991911, 0.7580548669750213, 0.8338603536725235, 1.220854943811942, 1.3429404381931362, 1.47723448201245, 1.624957930213695, 1.7874537232350647, 1.9661990955585713, 2.379100905625872, 2.6170109961884593, 3.1665833053880363], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 2.0]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_histogram_dashboard/data/histograms_run_run2_tag_histo1.json b/tensorflow/tensorboard/components/tf_histogram_dashboard/data/histograms_run_run2_tag_histo1.json deleted file mode 100644 index 407c375d2fc..00000000000 --- a/tensorflow/tensorboard/components/tf_histogram_dashboard/data/histograms_run_run2_tag_histo1.json +++ /dev/null @@ -1 +0,0 @@ -[[400.0, 40, [-2.599286228987632, 3.5098048900144323, 20.0, 10.792285491200078, 66.66796979177158, [-2.379100905625872, -1.9661990955585713, -1.624957930213695, -1.47723448201245, -1.109868130738129, -1.0089710279437536, -0.42790220995778355, -0.2195814928486969, 0.47069243095356195, 0.7580548669750213, 0.917246389039776, 1.3429404381931362, 1.624957930213695, 1.7874537232350647, 2.1628190051144287, 2.6170109961884593, 2.8787120958073054, 3.8315657995195243], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 1.0, 1.0]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_histogram_dashboard/data/histograms_run_run2_tag_histo2.json b/tensorflow/tensorboard/components/tf_histogram_dashboard/data/histograms_run_run2_tag_histo2.json deleted file mode 100644 index 752b621ab03..00000000000 --- a/tensorflow/tensorboard/components/tf_histogram_dashboard/data/histograms_run_run2_tag_histo2.json +++ /dev/null @@ -1 +0,0 @@ -[[400.0, 40, [-0.8286852465281818, 2.0954239138728523, 20.0, 13.546880465642861, 24.14836803774091, [-0.7580548669750213, -0.38900200905253046, -0.06996543062044111, 0.07696197368248522, 0.19961953895336082, 0.2656936063469233, 0.29226296698161564, 0.5177616740489182, 0.7580548669750213, 0.917246389039776, 1.109868130738129, 1.220854943811942, 1.624957930213695, 2.1628190051144287], [2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0, 3.0]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_histogram_dashboard/data/logdir b/tensorflow/tensorboard/components/tf_histogram_dashboard/data/logdir deleted file mode 100644 index b6362b45d77..00000000000 --- a/tensorflow/tensorboard/components/tf_histogram_dashboard/data/logdir +++ /dev/null @@ -1 +0,0 @@ -{"logdir": "/foo/some/fake/logdir"} \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_histogram_dashboard/data/runs.json b/tensorflow/tensorboard/components/tf_histogram_dashboard/data/runs.json deleted file mode 100644 index cbe657af6b6..00000000000 --- a/tensorflow/tensorboard/components/tf_histogram_dashboard/data/runs.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "run1": {"histograms": ["histo1"]}, - "run2": {"histograms": ["histo2", "histo1"]} -} \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_histogram_dashboard/index.html b/tensorflow/tensorboard/components/tf_histogram_dashboard/index.html deleted file mode 100644 index 7f1e2f9ff89..00000000000 --- a/tensorflow/tensorboard/components/tf_histogram_dashboard/index.html +++ /dev/null @@ -1,67 +0,0 @@ - - - - - - - - -Distribution Dashboard Demo - - - - diff --git a/tensorflow/tensorboard/components/tf_histogram_dashboard/tf-histogram-dashboard.html b/tensorflow/tensorboard/components/tf_histogram_dashboard/tf-histogram-dashboard.html deleted file mode 100644 index 1821ce3b6f3..00000000000 --- a/tensorflow/tensorboard/components/tf_histogram_dashboard/tf-histogram-dashboard.html +++ /dev/null @@ -1,167 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_image_dashboard/BUILD b/tensorflow/tensorboard/components/tf_image_dashboard/BUILD deleted file mode 100644 index 1e2833f74c5..00000000000 --- a/tensorflow/tensorboard/components/tf_image_dashboard/BUILD +++ /dev/null @@ -1,44 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_image_dashboard", - srcs = [ - "tf-image-dashboard.html", - "tf-image-loader.html", - ], - path = "/tf-image-dashboard", - deps = [ - "//tensorflow/tensorboard/components/tf_backend", - "//tensorflow/tensorboard/components/tf_color_scale", - "//tensorflow/tensorboard/components/tf_dashboard_common", - "//tensorflow/tensorboard/components/tf_imports:d3", - "//tensorflow/tensorboard/components/tf_imports:lodash", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "@org_polymer_paper_dialog", - "@org_polymer_paper_icon_button", - "@org_polymer_paper_slider", - "@org_polymer_paper_spinner", - ], -) - -ts_web_library( - name = "demo", - srcs = ["index.html"] + glob(["data/**"]), - path = "/tf-image-dashboard", - deps = [ - "//tensorflow/tensorboard/components/tf_image_dashboard", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - "@org_polymer_iron_demo_helpers", - "@org_polymer_paper_styles", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_image_dashboard/data/images_run_run1_tag_im1_2Fimage_2F0.json b/tensorflow/tensorboard/components/tf_image_dashboard/data/images_run_run1_tag_im1_2Fimage_2F0.json deleted file mode 100644 index 3dec4322134..00000000000 --- a/tensorflow/tensorboard/components/tf_image_dashboard/data/images_run_run1_tag_im1_2Fimage_2F0.json +++ /dev/null @@ -1,9 +0,0 @@ -[ - { - "wall_time":1459200389.088045, - "width":4, - "height":4, - "step":0, - "query":"tag=im1%2Fimage%2F0&index=0&run=run1" - } -] diff --git a/tensorflow/tensorboard/components/tf_image_dashboard/data/images_run_run1_tag_im2_2Fimage_2F0.json b/tensorflow/tensorboard/components/tf_image_dashboard/data/images_run_run1_tag_im2_2Fimage_2F0.json deleted file mode 100644 index 16152b8626a..00000000000 --- a/tensorflow/tensorboard/components/tf_image_dashboard/data/images_run_run1_tag_im2_2Fimage_2F0.json +++ /dev/null @@ -1,9 +0,0 @@ -[ - { - "wall_time":1459200389.093653, - "width":4, - "height":4, - "step":0, - "query":"tag=im2%2Fimage%2F0&index=0&run=run1" - } -] diff --git a/tensorflow/tensorboard/components/tf_image_dashboard/data/images_run_run2_tag_im1_2Fimage_2F0.json b/tensorflow/tensorboard/components/tf_image_dashboard/data/images_run_run2_tag_im1_2Fimage_2F0.json deleted file mode 100644 index a717b79c5de..00000000000 --- a/tensorflow/tensorboard/components/tf_image_dashboard/data/images_run_run2_tag_im1_2Fimage_2F0.json +++ /dev/null @@ -1,9 +0,0 @@ -[ - { - "wall_time":1459200389.117463, - "width":4, - "height":4, - "step":0, - "query":"tag=im1%2Fimage%2F0&index=0&run=run2" - } -] diff --git a/tensorflow/tensorboard/components/tf_image_dashboard/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run1.png b/tensorflow/tensorboard/components/tf_image_dashboard/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run1.png deleted file mode 100644 index 346fd0076be..00000000000 Binary files a/tensorflow/tensorboard/components/tf_image_dashboard/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run1.png and /dev/null differ diff --git a/tensorflow/tensorboard/components/tf_image_dashboard/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run2.png b/tensorflow/tensorboard/components/tf_image_dashboard/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run2.png deleted file mode 100644 index 26d2d10acaf..00000000000 Binary files a/tensorflow/tensorboard/components/tf_image_dashboard/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run2.png and /dev/null differ diff --git a/tensorflow/tensorboard/components/tf_image_dashboard/data/individualImage_tag_im2_2Fimage_2F0_index_0_run_run1.png b/tensorflow/tensorboard/components/tf_image_dashboard/data/individualImage_tag_im2_2Fimage_2F0_index_0_run_run1.png deleted file mode 100644 index 6c419062942..00000000000 Binary files a/tensorflow/tensorboard/components/tf_image_dashboard/data/individualImage_tag_im2_2Fimage_2F0_index_0_run_run1.png and /dev/null differ diff --git a/tensorflow/tensorboard/components/tf_image_dashboard/data/logdir b/tensorflow/tensorboard/components/tf_image_dashboard/data/logdir deleted file mode 100644 index c7d82022cc0..00000000000 --- a/tensorflow/tensorboard/components/tf_image_dashboard/data/logdir +++ /dev/null @@ -1 +0,0 @@ -{"logdir": "/some/fake/logdir"} \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_image_dashboard/data/runs.json b/tensorflow/tensorboard/components/tf_image_dashboard/data/runs.json deleted file mode 100644 index b75de5b6614..00000000000 --- a/tensorflow/tensorboard/components/tf_image_dashboard/data/runs.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "run1":{ - "images":[ - "im1/image/0", - "im2/image/0" - ] - }, - "run2":{ - "images":[ - "im1/image/0" - ] - } -} diff --git a/tensorflow/tensorboard/components/tf_image_dashboard/index.html b/tensorflow/tensorboard/components/tf_image_dashboard/index.html deleted file mode 100644 index 27a31d5ad50..00000000000 --- a/tensorflow/tensorboard/components/tf_image_dashboard/index.html +++ /dev/null @@ -1,72 +0,0 @@ - - - - - - - - - Image Dashboard Demo - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_image_dashboard/tf-image-dashboard.html b/tensorflow/tensorboard/components/tf_image_dashboard/tf-image-dashboard.html deleted file mode 100644 index 5d46847eb88..00000000000 --- a/tensorflow/tensorboard/components/tf_image_dashboard/tf-image-dashboard.html +++ /dev/null @@ -1,160 +0,0 @@ - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_image_dashboard/tf-image-loader.html b/tensorflow/tensorboard/components/tf_image_dashboard/tf-image-loader.html deleted file mode 100644 index 41fb12eefa7..00000000000 --- a/tensorflow/tensorboard/components/tf_image_dashboard/tf-image-loader.html +++ /dev/null @@ -1,234 +0,0 @@ - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_imports/BUILD b/tensorflow/tensorboard/components/tf_imports/BUILD deleted file mode 100644 index 84b46bf0053..00000000000 --- a/tensorflow/tensorboard/components/tf_imports/BUILD +++ /dev/null @@ -1,499 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:hacks.bzl", "tensorboard_typescript_bundle") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "webcomponentsjs", - srcs = ["@org_definitelytyped//:webcomponents.js.d.ts"], - path = "/webcomponentsjs", - visibility = ["//visibility:public"], - exports = ["@org_polymer_webcomponentsjs"], -) - -ts_web_library( - name = "polymer", - srcs = ["@org_definitelytyped//:polymer.d.ts"], - path = "/polymer", - visibility = ["//visibility:public"], - exports = ["@org_polymer"], - deps = [":webcomponentsjs"], -) - -ts_web_library( - name = "lodash", - srcs = [ - "lodash.html", - "@org_definitelytyped//:lodash.d.ts", - ], - path = "/tf-imports", - visibility = ["//visibility:public"], - deps = ["@com_lodash"], -) - -ts_web_library( - name = "threejs", - srcs = [ - "threejs.html", - "@org_definitelytyped//:three.d.ts", - "@org_threejs//:OrbitControls.js", - "@org_threejs//:three.js", - ], - path = "/tf-imports", - visibility = ["//visibility:public"], -) - -ts_web_library( - name = "numericjs", - srcs = [ - "numericjs.html", - "@com_numericjs//:numeric.js", - ], - path = "/tf-imports", - visibility = ["//visibility:public"], -) - -ts_web_library( - name = "weblas", - srcs = [ - "weblas.html", - "@io_github_waylonflinn_weblas//:weblas.js", - ], - path = "/tf-imports", - visibility = ["//visibility:public"], -) - -ts_web_library( - name = "graphlib", - srcs = [ - "graphlib.html", - "@io_github_cpettitt_graphlib//:graphlib.core.js", - ], - path = "/tf-imports", - visibility = ["//visibility:public"], - deps = [":lodash"], -) - -ts_web_library( - name = "dagre", - srcs = [ - "dagre.html", - "@io_github_cpettitt_dagre//:dagre.core.js", - ], - path = "/tf-imports", - visibility = ["//visibility:public"], - deps = [ - ":graphlib", - ":lodash", - ], -) - -ts_web_library( - name = "d3", - srcs = [ - "d3.d.ts", - "d3.html", - "@org_d3js//:d3.min.js", - ], - path = "/tf-imports", - visibility = ["//visibility:public"], -) - -ts_web_library( - name = "plottable", - srcs = [ - "plottable.d.ts", - "plottable.html", - ], - path = "/tf-imports", - visibility = ["//visibility:public"], - deps = [ - ":d3", - ":plottable_js_css", - ], -) - -ts_web_library( - name = "plottable_js_css", - srcs = [ - "@com_palantir_plottable//:package/plottable.css", - "@com_palantir_plottable//:package/plottable.js", - ], - path = "/tf-imports", - strip_prefix = "package", - visibility = ["//visibility:private"], -) - -ts_web_library( - name = "web_component_tester", - testonly = 1, - visibility = ["//visibility:public"], - exports = [ - ":chai_typings", - ":mocha_typings", - ":sinon_typings", - "@org_npmjs_registry_web_component_tester", - ], -) - -ts_web_library( - name = "chai_typings", - testonly = 1, - srcs = ["@org_definitelytyped//:chai.d.ts"], - path = "/chai", - visibility = ["//visibility:private"], -) - -ts_web_library( - name = "mocha_typings", - testonly = 1, - srcs = ["@org_definitelytyped//:mocha.d.ts"], - path = "/mocha", - visibility = ["//visibility:private"], -) - -ts_web_library( - name = "sinon_typings", - testonly = 1, - srcs = ["@org_definitelytyped//:sinon.d.ts"], - path = "/sinonjs", - visibility = ["//visibility:private"], -) - -# Generate single TypeScript typings file for d3.js with no ES6 imports. -# -# The DefinitelyTyped definition of d3 v4 was written under the assumption that -# we want to use d3 in a modularized way. We don't want to do that because its -# import statements use NodeJS namespaces, and the Web Compiler only supports -# W3C, ECMA, and IETF standards. -tensorboard_typescript_bundle( - name = "d3_typings", - out = "d3.d.ts", - namespace_srcs = {"d3": [ - "d3-transition.d.ts", - "@org_definitelytyped_types_d3_path//:index.d.ts", - "@org_definitelytyped_types_d3_time//:index.d.ts", - "@org_definitelytyped_types_d3_dsv//:index.d.ts", - "@org_definitelytyped_types_d3_color//:index.d.ts", - "@org_definitelytyped_types_d3_selection//:index.d.ts", - "@org_definitelytyped_types_d3_shape//:index.d.ts", - "@org_definitelytyped_types_d3_scale//:index.d.ts", - "@org_definitelytyped_types_d3_request//:index.d.ts", - "@org_definitelytyped_types_d3_interpolate//:index.d.ts", - "@org_definitelytyped_types_d3_drag//:index.d.ts", - "@org_definitelytyped_types_d3_brush//:index.d.ts", - "@org_definitelytyped_types_d3_axis//:index.d.ts", - "@org_definitelytyped_types_d3_zoom//:index.d.ts", - "@org_definitelytyped_types_d3_array//:index.d.ts", - "@org_definitelytyped_types_d3_chord//:index.d.ts", - "@org_definitelytyped_types_d3_collection//:index.d.ts", - "@org_definitelytyped_types_d3_dispatch//:index.d.ts", - "@org_definitelytyped_types_d3_ease//:index.d.ts", - "@org_definitelytyped_types_d3_force//:index.d.ts", - "@org_definitelytyped_types_d3_format//:index.d.ts", - "@org_definitelytyped_types_d3_hierarchy//:index.d.ts", - "@org_definitelytyped_types_d3_polygon//:index.d.ts", - "@org_definitelytyped_types_d3_quadtree//:index.d.ts", - "@org_definitelytyped_types_d3_queue//:index.d.ts", - "@org_definitelytyped_types_d3_random//:index.d.ts", - "@org_definitelytyped_types_d3_timer//:index.d.ts", - "@org_definitelytyped_types_d3_voronoi//:index.d.ts", - ]}, -) - -# It would be nice if Plottable released a .d.ts file for plottable.js like -# they did for previous versions. -tensorboard_typescript_bundle( - name = "plottable_typings", - out = "plottable.d.ts", - namespace_srcs = { - "Plottable": [ - "@com_palantir_plottable//:package/build/src/core/dataset.d.ts", - "@com_palantir_plottable//:package/build/src/core/interfaces.d.ts", - "@com_palantir_plottable//:package/build/src/core/version.d.ts", - ], - "Plottable.Animators": [ - "@com_palantir_plottable//:package/build/src/animators/animator.d.ts", - "@com_palantir_plottable//:package/build/src/animators/easingAnimator.d.ts", - "@com_palantir_plottable//:package/build/src/animators/nullAnimator.d.ts", - ], - "Plottable.Axes": [ - "@com_palantir_plottable//:package/build/src/axes/axis.d.ts", - "@com_palantir_plottable//:package/build/src/axes/categoryAxis.d.ts", - "@com_palantir_plottable//:package/build/src/axes/numericAxis.d.ts", - "@com_palantir_plottable//:package/build/src/axes/timeAxis.d.ts", - ], - "Plottable.Components": [ - "@com_palantir_plottable//:package/build/src/components/component.d.ts", - "@com_palantir_plottable//:package/build/src/components/componentContainer.d.ts", - "@com_palantir_plottable//:package/build/src/components/dragBoxLayer.d.ts", - "@com_palantir_plottable//:package/build/src/components/dragLineLayer.d.ts", - "@com_palantir_plottable//:package/build/src/components/gridlines.d.ts", - "@com_palantir_plottable//:package/build/src/components/group.d.ts", - "@com_palantir_plottable//:package/build/src/components/guideLineLayer.d.ts", - "@com_palantir_plottable//:package/build/src/components/interpolatedColorLegend.d.ts", - "@com_palantir_plottable//:package/build/src/components/label.d.ts", - "@com_palantir_plottable//:package/build/src/components/legend.d.ts", - "@com_palantir_plottable//:package/build/src/components/plotGroup.d.ts", - "@com_palantir_plottable//:package/build/src/components/selectionBoxLayer.d.ts", - "@com_palantir_plottable//:package/build/src/components/table.d.ts", - "@com_palantir_plottable//:package/build/src/components/xDragBoxLayer.d.ts", - "@com_palantir_plottable//:package/build/src/components/yDragBoxLayer.d.ts", - ], - "Plottable.Configs": [ - "@com_palantir_plottable//:package/build/src/core/config.d.ts", - ], - "Plottable.Formatters": [ - "@com_palantir_plottable//:package/build/src/core/formatters.d.ts", - ], - "Plottable.RenderController": [ - "@com_palantir_plottable//:package/build/src/core/renderController.d.ts", - ], - "Plottable.RenderPolicies": [ - "@com_palantir_plottable//:package/build/src/core/renderPolicy.d.ts", - ], - "Plottable.SymbolFactories": [ - "@com_palantir_plottable//:package/build/src/core/symbolFactories.d.ts", - ], - "Plottable.Dispatchers": [ - "@com_palantir_plottable//:package/build/src/dispatchers/dispatcher.d.ts", - "@com_palantir_plottable//:package/build/src/dispatchers/keyDispatcher.d.ts", - "@com_palantir_plottable//:package/build/src/dispatchers/mouseDispatcher.d.ts", - "@com_palantir_plottable//:package/build/src/dispatchers/touchDispatcher.d.ts", - ], - "Plottable.Drawers": [ - "@com_palantir_plottable//:package/build/src/drawers/arcDrawer.d.ts", - "@com_palantir_plottable//:package/build/src/drawers/arcOutlineDrawer.d.ts", - "@com_palantir_plottable//:package/build/src/drawers/areaDrawer.d.ts", - "@com_palantir_plottable//:package/build/src/drawers/canvasBuffer.d.ts", - "@com_palantir_plottable//:package/build/src/drawers/canvasDrawer.d.ts", - "@com_palantir_plottable//:package/build/src/drawers/drawStep.d.ts", - "@com_palantir_plottable//:package/build/src/drawers/drawer.d.ts", - "@com_palantir_plottable//:package/build/src/drawers/lineDrawer.d.ts", - "@com_palantir_plottable//:package/build/src/drawers/rectangleDrawer.d.ts", - "@com_palantir_plottable//:package/build/src/drawers/segmentDrawer.d.ts", - "@com_palantir_plottable//:package/build/src/drawers/svgDrawer.d.ts", - "@com_palantir_plottable//:package/build/src/drawers/symbolDrawer.d.ts", - ], - "Plottable.Interactions": [ - "@com_palantir_plottable//:package/build/src/interactions/clickInteraction.d.ts", - "@com_palantir_plottable//:package/build/src/interactions/dragInteraction.d.ts", - "@com_palantir_plottable//:package/build/src/interactions/interaction.d.ts", - "@com_palantir_plottable//:package/build/src/interactions/keyInteraction.d.ts", - "@com_palantir_plottable//:package/build/src/interactions/panZoomInteraction.d.ts", - "@com_palantir_plottable//:package/build/src/interactions/pointerInteraction.d.ts", - ], - "Plottable.Plots": [ - "@com_palantir_plottable//:package/build/src/plots/areaPlot.d.ts", - "@com_palantir_plottable//:package/build/src/plots/barPlot.d.ts", - "@com_palantir_plottable//:package/build/src/plots/clusteredBarPlot.d.ts", - "@com_palantir_plottable//:package/build/src/plots/commons.d.ts", - "@com_palantir_plottable//:package/build/src/plots/linePlot.d.ts", - "@com_palantir_plottable//:package/build/src/plots/piePlot.d.ts", - "@com_palantir_plottable//:package/build/src/plots/plot.d.ts", - "@com_palantir_plottable//:package/build/src/plots/rectanglePlot.d.ts", - "@com_palantir_plottable//:package/build/src/plots/scatterPlot.d.ts", - "@com_palantir_plottable//:package/build/src/plots/segmentPlot.d.ts", - "@com_palantir_plottable//:package/build/src/plots/stackedAreaPlot.d.ts", - "@com_palantir_plottable//:package/build/src/plots/stackedBarPlot.d.ts", - "@com_palantir_plottable//:package/build/src/plots/waterfallPlot.d.ts", - "@com_palantir_plottable//:package/build/src/plots/xyPlot.d.ts", - ], - "Plottable.Scales": [ - "@com_palantir_plottable//:package/build/src/scales/index.d.ts", - "@com_palantir_plottable//:package/build/src/scales/categoryScale.d.ts", - "@com_palantir_plottable//:package/build/src/scales/colorScale.d.ts", - "@com_palantir_plottable//:package/build/src/scales/interpolatedColorScale.d.ts", - "@com_palantir_plottable//:package/build/src/scales/linearScale.d.ts", - "@com_palantir_plottable//:package/build/src/scales/modifiedLogScale.d.ts", - "@com_palantir_plottable//:package/build/src/scales/quantitativeScale.d.ts", - "@com_palantir_plottable//:package/build/src/scales/scale.d.ts", - "@com_palantir_plottable//:package/build/src/scales/timeScale.d.ts", - ], - "Plottable.Scales.TickGenerators": [ - "@com_palantir_plottable//:package/build/src/scales/tickGenerators.d.ts", - ], - "Plottable.Utils": [ - "@com_palantir_plottable//:package/build/src/utils/addD3SelectionMulti.d.ts", - "@com_palantir_plottable//:package/build/src/utils/bucket.d.ts", - "@com_palantir_plottable//:package/build/src/utils/callbackSet.d.ts", - "@com_palantir_plottable//:package/build/src/utils/coerceD3.d.ts", - "@com_palantir_plottable//:package/build/src/utils/entityStore.d.ts", - "@com_palantir_plottable//:package/build/src/utils/makeEnum.d.ts", - "@com_palantir_plottable//:package/build/src/utils/map.d.ts", - "@com_palantir_plottable//:package/build/src/utils/set.d.ts", - "@com_palantir_plottable//:package/build/src/utils/transformAwareTranslator.d.ts", - ], - "Plottable.Utils.Array": [ - "@com_palantir_plottable//:package/build/src/utils/arrayUtils.d.ts", - ], - "Plottable.Utils.Color": [ - "@com_palantir_plottable//:package/build/src/utils/colorUtils.d.ts", - ], - "Plottable.Utils.DOM": [ - "@com_palantir_plottable//:package/build/src/utils/domUtils.d.ts", - ], - "Plottable.Utils.Math": [ - "@com_palantir_plottable//:package/build/src/utils/mathUtils.d.ts", - ], - "Plottable.Utils.Stacking": [ - "@com_palantir_plottable//:package/build/src/utils/stackingUtils.d.ts", - ], - "Plottable.Utils.Window": [ - "@com_palantir_plottable//:package/build/src/utils/windowUtils.d.ts", - ], - }, - namespace_symbol_aliases = { - "Plottable.Animators": { - "AttributeToAppliedProjector": "Plottable.AttributeToAppliedProjector", - "SimpleSelection": "Plottable.SimpleSelection", - }, - "Plottable.Axes": { - "Component": "Plottable.Components.Component", - "Formatter": "Plottable.Formatters.Formatter", - "Point": "Plottable.Point", - "QuantitativeScale": "Plottable.Scales.QuantitativeScale", - "Scale": "Plottable.Scales.Scale", - "Scales": "Plottable.Scales", - "SimpleSelection": "Plottable.SimpleSelection", - "SpaceRequest": "Plottable.SpaceRequest", - }, - "Plottable.Components": { - "Bounds": "Plottable.Bounds", - "Formatter": "Plottable.Formatters.Formatter", - "IEntity": "Plottable.IEntity", - "Interactions": "Plottable.Interactions", - "Plots": "Plottable.Plots", - "Point": "Plottable.Point", - "QuantitativeScale": "Plottable.Scales.QuantitativeScale", - "Scales": "Plottable.Scales", - "SimpleSelection": "Plottable.SimpleSelection", - "SpaceRequest": "Plottable.SpaceRequest", - "SymbolFactory": "Plottable.SymbolFactories.SymbolFactory", - }, - "Plottable.RenderController": { - "Component": "Plottable.Components.Component", - "RenderPolicies": "Plottable.RenderPolicies", - }, - "Plottable.SymbolFactories": { - "d3Shape": "d3", - }, - "Plottable.Dispatchers": { - "Component": "Plottable.Components.Component", - "Dispatchers": "Plottable.Dispatchers", - "Point": "Plottable.Point", - }, - "Plottable.Drawers": { - "AttributeToAppliedProjector": "Plottable.AttributeToAppliedProjector", - "AttributeToProjector": "Plottable.AttributeToProjector", - "Dataset": "Plottable.Dataset", - "IAccessor": "Plottable.IAccessor", - "IAnimator": "Plottable.Animators.IAnimator", - "SimpleSelection": "Plottable.SimpleSelection", - "SymbolFactory": "Plottable.SymbolFactories.SymbolFactory", - }, - "Plottable.Interactions": { - "Component": "Plottable.Components.Component", - "Point": "Plottable.Point", - "TransformableScale": "Plottable.Scales.TransformableScale", - }, - "Plottable.Plots": { - "AppliedDrawStep": "Plottable.Drawers.AppliedDrawStep", - "AttributeToProjector": "Plottable.AttributeToProjector", - "Bounds": "Plottable.Bounds", - "Component": "Plottable.Components.Component", - "Dataset": "Plottable.Dataset", - "DrawStep": "Plottable.Drawers.DrawStep", - "Drawers": "Plottable.Drawers", - "Formatter": "Plottable.Formatters.Formatter", - "IAccessor": "Plottable.IAccessor", - "IAnimator": "Plottable.Animators.IAnimator", - "IDrawer": "Plottable.Drawers.IDrawer", - "IEntity": "Plottable.IEntity", - "IScaleCallback": "Plottable.Scales.IScaleCallback", - "Plots": "Plottable.Plots", - "Point": "Plottable.Point", - "Projector": "Plottable.Projector", - "ProxyDrawer": "Plottable.Drawers.ProxyDrawer", - "QuantitativeScale": "Plottable.Scales.QuantitativeScale", - "Range": "Plottable.Range", - "Scale": "Plottable.Scales.Scale", - "SimpleSelection": "Plottable.SimpleSelection", - "SymbolFactory": "Plottable.SymbolFactories.SymbolFactory", - "TransformableScale": "Plottable.Scales.TransformableScale", - "Utils": "Plottable.Utils", - "d3Shape": "d3", - }, - "Plottable.Scales": { - "Dataset": "Plottable.Dataset", - "Scales": "Plottable.Scales", - }, - "Plottable.Scales.TickGenerators": { - "QuantitativeScale": "Plottable.Scales.QuantitativeScale", - }, - "Plottable.Utils": { - "Bounds": "Plottable.Bounds", - "Component": "Plottable.Components.Component", - "Dataset": "Plottable.Dataset", - "IAccessor": "Plottable.IAccessor", - "Point": "Plottable.Point", - "Range": "Plottable.Range", - "SimpleSelection": "Plottable.SimpleSelection", - "Utils": "Plottable.Utils", - }, - }, - namespace_symbol_aliases_public = { - "Plottable": { - "Axis": "Plottable.Axes.Axis", - "AxisOrientation": "Plottable.Axes.AxisOrientation", - "ClickCallback": "Plottable.Interactions.ClickCallback", - "Component": "Plottable.Components.Component", - "ComponentCallback": "Plottable.Components.ComponentCallback", - "ComponentContainer": "Plottable.Components.ComponentContainer", - "Dispatcher": "Plottable.Dispatchers.Dispatcher", - "DragBoxCallback": "Plottable.Components.DragBoxCallback", - "DragCallback": "Plottable.Interactions.DragCallback", - "EaseFn": "Plottable.Animators.EaseFn", - "EaseName": "Plottable.Animators.EaseName", - "Easing": "Plottable.Animators.Easing", - "Formatter": "Plottable.Formatters.Formatter", - "IAnimator": "Plottable.Animators.IAnimator", - "IDragLineCallback": "Plottable.Components.IDragLineCallback", - "IDrawer": "Plottable.Drawers.IDrawer", - "IResizeHandler": "Plottable.Components.IResizeHandler", - "IScaleCallback": "Plottable.Scales.IScaleCallback", - "Interaction": "Plottable.Interactions.Interaction", - "Key": "Plottable.Interactions.Key", - "KeyCallback": "Plottable.Interactions.KeyCallback", - "Null": "Plottable.Animators.Null", - "Plot": "Plottable.Plots.Plot", - "PointerCallback": "Plottable.Interactions.PointerCallback", - "ProxyDrawer": "Plottable.Drawers.ProxyDrawer", - "QuantitativeScale": "Plottable.Scales.QuantitativeScale", - "Renderer": "Plottable.Plots.Renderer", - "Scale": "Plottable.Scales.Scale", - "SymbolFactory": "Plottable.SymbolFactories.SymbolFactory", - "TimeInterval": "Plottable.Axes.TimeInterval", - "TransformableScale": "Plottable.Scales.TransformableScale", - "XAlignment": "Plottable.Components.XAlignment", - "XYPlot": "Plottable.Plots.XYPlot", - "YAlignment": "Plottable.Components.YAlignment", - }, - }, -) - -# Removes the 'declare module' block inside this file, but keeps its content. -genrule( - name = "kludge_d3_transition", - srcs = ["@org_definitelytyped_types_d3_transition//:index.d.ts"], - outs = ["d3-transition.d.ts"], - cmd = "sed '/^declare module/d' $< | awk '/^}$$/ && !p {p++;next}1' >$@", -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_imports/README.md b/tensorflow/tensorboard/components/tf_imports/README.md deleted file mode 100644 index b1cabc61b9b..00000000000 --- a/tensorflow/tensorboard/components/tf_imports/README.md +++ /dev/null @@ -1,2 +0,0 @@ -This file acts as import routers for third party javascript libraries, -e.g. Plottable and D3. diff --git a/tensorflow/tensorboard/components/tf_imports/d3.html b/tensorflow/tensorboard/components/tf_imports/d3.html deleted file mode 100644 index 76ca302709a..00000000000 --- a/tensorflow/tensorboard/components/tf_imports/d3.html +++ /dev/null @@ -1,50 +0,0 @@ - - - - - diff --git a/tensorflow/tensorboard/components/tf_imports/dagre.html b/tensorflow/tensorboard/components/tf_imports/dagre.html deleted file mode 100644 index b90dc58e390..00000000000 --- a/tensorflow/tensorboard/components/tf_imports/dagre.html +++ /dev/null @@ -1,45 +0,0 @@ - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_imports/lodash.html b/tensorflow/tensorboard/components/tf_imports/lodash.html deleted file mode 100644 index 65ff6a4b032..00000000000 --- a/tensorflow/tensorboard/components/tf_imports/lodash.html +++ /dev/null @@ -1,18 +0,0 @@ - - - diff --git a/tensorflow/tensorboard/components/tf_imports/numericjs.html b/tensorflow/tensorboard/components/tf_imports/numericjs.html deleted file mode 100644 index 81fa9491688..00000000000 --- a/tensorflow/tensorboard/components/tf_imports/numericjs.html +++ /dev/null @@ -1,43 +0,0 @@ - - - - - diff --git a/tensorflow/tensorboard/components/tf_imports/plottable.html b/tensorflow/tensorboard/components/tf_imports/plottable.html deleted file mode 100644 index 77ad544d5a0..00000000000 --- a/tensorflow/tensorboard/components/tf_imports/plottable.html +++ /dev/null @@ -1,44 +0,0 @@ - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_imports/threejs.html b/tensorflow/tensorboard/components/tf_imports/threejs.html deleted file mode 100644 index 7f4233b5713..00000000000 --- a/tensorflow/tensorboard/components/tf_imports/threejs.html +++ /dev/null @@ -1,43 +0,0 @@ - - - - - - diff --git a/tensorflow/tensorboard/components/tf_imports/weblas.html b/tensorflow/tensorboard/components/tf_imports/weblas.html deleted file mode 100644 index c07020598fc..00000000000 --- a/tensorflow/tensorboard/components/tf_imports/weblas.html +++ /dev/null @@ -1,42 +0,0 @@ - - - - - diff --git a/tensorflow/tensorboard/components/tf_option_selector/BUILD b/tensorflow/tensorboard/components/tf_option_selector/BUILD deleted file mode 100644 index 3f7eed25cb1..00000000000 --- a/tensorflow/tensorboard/components/tf_option_selector/BUILD +++ /dev/null @@ -1,21 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_option_selector", - srcs = ["tf-option-selector.html"], - path = "/tf-option-selector", - deps = [ - "//tensorflow/tensorboard/components/tf_dashboard_common", - "//tensorflow/tensorboard/components/tf_imports:polymer", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_option_selector/tf-option-selector.html b/tensorflow/tensorboard/components/tf_option_selector/tf-option-selector.html deleted file mode 100644 index d6fc9d6861f..00000000000 --- a/tensorflow/tensorboard/components/tf_option_selector/tf-option-selector.html +++ /dev/null @@ -1,94 +0,0 @@ - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_profile_dashboard/BUILD b/tensorflow/tensorboard/components/tf_profile_dashboard/BUILD deleted file mode 100644 index 5d04618a545..00000000000 --- a/tensorflow/tensorboard/components/tf_profile_dashboard/BUILD +++ /dev/null @@ -1,25 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_profile_dashboard", - srcs = [ - "tf-profile-dashboard.html", - ], - path = "/tf-profile-dashboard", - deps = [ - "//tensorflow/tensorboard/components/tf_backend", - "//tensorflow/tensorboard/components/tf_dashboard_common", - "//tensorflow/tensorboard/components/tf_graph_controls", - "@org_polymer", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/BUILD b/tensorflow/tensorboard/components/tf_profile_dashboard/demo/BUILD deleted file mode 100644 index 3cc20ba352f..00000000000 --- a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/BUILD +++ /dev/null @@ -1,24 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "demo", - srcs = ["index.html"] + glob(["data/**"]), - path = "/tf-profile-dashboard/demo", - deps = [ - "//tensorflow/tensorboard/components/tf_profile_dashboard", - "//tensorflow/tensorboard/components/tf_trace_viewer:demo", - "@org_polymer", - "@org_polymer_iron_demo_helpers", - "@org_polymer_webcomponentsjs", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/logdir b/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/logdir deleted file mode 100644 index ecaaa8ac758..00000000000 --- a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/logdir +++ /dev/null @@ -1 +0,0 @@ -{"logdir": "/some/fake/logdir"} diff --git a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/data_run_bar_tag_trace_viewer.json b/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/data_run_bar_tag_trace_viewer.json deleted file mode 100644 index bc1a08b535f..00000000000 --- a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/data_run_bar_tag_trace_viewer.json +++ /dev/null @@ -1,27 +0,0 @@ -{ - "traceEvents": [ - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 826, "ph": "C", - "name": "counter", "args": {"value": 10}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 826, "ph": "B", - "name": "A long name that doesnt fit but is exceedingly informative", - "args": {"name_false": false, "value_true": true}}, - {"cat": "PERF", "pid": 22630, "ts": 835, "ph": "I", "s": "p", - "name": "ProcessWideEvent1", "args": {}} - ], - "stackFrames": { - "1": { - "category": "m1", - "name": "main" - }, - "7": { - "category": "m2", - "name": "frame7", - "parent": "1" - }, - "8": { - "category": "m2", - "name": "frame8", - "parent": "1" - } - } -} diff --git a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/data_run_bar_tag_unsupported.json b/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/data_run_bar_tag_unsupported.json deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/data_run_foo_tag_trace_viewer.json b/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/data_run_foo_tag_trace_viewer.json deleted file mode 100644 index e1d57394e35..00000000000 --- a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/data_run_foo_tag_trace_viewer.json +++ /dev/null @@ -1,105 +0,0 @@ -{ - "traceEvents": [ - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 826, "ph": "C", - "name": "counter", "args": {"value": 10}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 826, "ph": "B", - "name": "A long name that doesnt fit but is exceedingly informative", - "args": {"name_false": false, "value_true": true}}, - {"cat": "PERF", "pid": 22630, "ts": 835, "ph": "I", "s": "p", - "name": "ProcessWideEvent1", "args": {}}, - - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 827, "ph": "B", - "name": "Asub with a name that wont fit", "args": {}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 828, "ph": "E", - "name": "Asub", "args": {}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 829, "ph": "B", - "name": "Asub", "args": {}}, - {"cat": "PREF", "pid": 22630, "tid": 22630, "dur": 15, "ts": 820, "ph": "X", - "name": "Long X type", "args": {}, "sf": 7, "esf": 8}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 832, "ph": "E", - "name": "Asub", "args": {}}, - {"cat": "PREF", "pid": 22630, "tid": 22630, "dur": 2, "ts": 818, "ph": "X", - "name": "X1", "args": {}}, - {"cat": "PREF", "pid": 22630, "tid": 22630, "dur": 2, "ts": 818, "ph": "X", - "name": "X same ts and dur as X1", "args": {}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 832, "ph": "C", - "name": "counter", "args": {"value": 1}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 833, "ph": "E", - "name": "", "args": {}}, - - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 835, "ph": "I", - "name": "ThreadLevelI1", "args": {}}, - - {"cat": "PERF", "ts": 880, "ph": "I", "s": "g", "name": "GlobalEvent1", - "args": {}}, - - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 837, "ph": "I", - "name": "ThreadLevelI2", "args": {}}, - - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 839, "ph": "C", - "name": "counter", "args": {"value": 5}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 840, "ph": "B", - "name": "A not as long a name", "args": {}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 848, "ph": "E", - "name": "A not as long a name", "args": {}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 848, "ph": "C", - "name": "counter", "args": {"value": 1}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 854, "ph": "C", - "name": "counter", "args": {"value": 10}}, - - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 850, "ph": "B", - "name": "B", "args": {}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 854, "ph": "E", - "name": "B", "args": {}}, - - {"cat": "PERF", "pid": 22630, "tid": 22631, "ts": 827, "ph": "B", - "name": "A", "args": {}}, - {"cat": "PERF", "pid": 22630, "tid": 22631, "ts": 835, "ph": "I", - "name": "ThreadLevelImmediate Three", "args": {}}, - {"cat": "PERF", "pid": 22630, "tid": 22631, "ts": 845, "ph": "I", - "name": "ThreadLevelImmediate4", "args": {}}, - {"cat": "PERF", "pid": 22630, "tid": 22631, "ts": 854, "ph": "E", - "name": "A", "args": {}}, - - {"cat": "PREF", "pid": 22630, "tid": 22630, "ts": 860, "ph": "B", - "name": "B/E over X", "args": {}}, - {"cat": "PREF", "pid": 22630, "tid": 22630, "dur": 10, "ts": 860, "ph": "X", - "name": "X", "args": {}}, - {"cat": "PREF", "pid": 22630, "tid": 22630, "ts": 860, "ph": "B", - "name": "B/E under X", "args": {}}, - {"cat": "PREF", "pid": 22630, "tid": 22630, "ts": 870, "ph": "E", - "name": "B/E under X", "args": {}}, - {"cat": "PREF", "pid": 22630, "tid": 22630, "ts": 870, "ph": "E", - "name": "B/E over X", "args": {}}, - - {"cat": "SAMPLE", "pid": 22630, "tid": 22631, "ts": 870, "ph": "P", - "name": "SampleA", "args": {}}, - {"cat": "SAMPLE", "pid": 22630, "tid": 22631, "ts": 875, "ph": "P", - "name": "SampleB", "args": {}}, - {"cat": "SAMPLE", "pid": 22630, "tid": 22631, "ts": 878, "ph": "P", - "name": "SampleC", "args": {}, "sf": 8}, - - {"cat": "__metadata", "pid": 22630, "tid": 22630, "ts": 0, "ph": "M", - "name": "thread_name", "args": {"name": "threadA"}}, - {"cat": "__metadata", "pid": 22630, "tid": 22631, "ts": 0, "ph": "M", - "name": "thread_name", "args": {"name": "threadB"}}, - {"cat": "__metadata", "pid": 22630, "tid": 22632, "ts": 0, "ph": "M", - "name": "thread_name", "args": {"name": "threadC"}} - ], - "stackFrames": { - "1": { - "category": "m1", - "name": "main" - }, - "7": { - "category": "m2", - "name": "frame7", - "parent": "1" - }, - "8": { - "category": "m2", - "name": "frame8", - "parent": "1" - } - } -} diff --git a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/tags.json b/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/tags.json deleted file mode 100644 index 12ef5bf8b2e..00000000000 --- a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/data/plugin/profile/tags.json +++ /dev/null @@ -1 +0,0 @@ -{"foo": ["trace_viewer"], "bar": ["unsupported", "trace_viewer"]} diff --git a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/index.html b/tensorflow/tensorboard/components/tf_profile_dashboard/demo/index.html deleted file mode 100644 index 15064a54f8f..00000000000 --- a/tensorflow/tensorboard/components/tf_profile_dashboard/demo/index.html +++ /dev/null @@ -1,75 +0,0 @@ - - - - - - - - - Profile Dashboard Demo - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_profile_dashboard/tf-profile-dashboard.html b/tensorflow/tensorboard/components/tf_profile_dashboard/tf-profile-dashboard.html deleted file mode 100644 index 4028f0e0f06..00000000000 --- a/tensorflow/tensorboard/components/tf_profile_dashboard/tf-profile-dashboard.html +++ /dev/null @@ -1,222 +0,0 @@ - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_runs_selector/BUILD b/tensorflow/tensorboard/components/tf_runs_selector/BUILD deleted file mode 100644 index 30265c8d294..00000000000 --- a/tensorflow/tensorboard/components/tf_runs_selector/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -package(default_visibility = ["//tensorflow:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_runs_selector", - srcs = [ - "tf-runs-selector.html", - ], - path = "/tf-runs-selector", - deps = [ - "//tensorflow/tensorboard/components/tf_backend", - "//tensorflow/tensorboard/components/tf_dashboard_common", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "@org_polymer_paper_button", - "@org_polymer_paper_dialog", - "@org_polymer_paper_styles", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_runs_selector/tf-runs-selector.html b/tensorflow/tensorboard/components/tf_runs_selector/tf-runs-selector.html deleted file mode 100644 index 6964bb076de..00000000000 --- a/tensorflow/tensorboard/components/tf_runs_selector/tf-runs-selector.html +++ /dev/null @@ -1,195 +0,0 @@ - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/BUILD b/tensorflow/tensorboard/components/tf_scalar_dashboard/BUILD deleted file mode 100644 index 7cc192b4640..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/BUILD +++ /dev/null @@ -1,38 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_scalar_dashboard", - srcs = [ - "tf-scalar-dashboard.html", - "tf-smoothing-input.html", - ], - path = "/tf-scalar-dashboard", - deps = [ - "//tensorflow/tensorboard/components/tf_backend", - "//tensorflow/tensorboard/components/tf_color_scale", - "//tensorflow/tensorboard/components/tf_dashboard_common", - "//tensorflow/tensorboard/components/tf_imports:lodash", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "//tensorflow/tensorboard/components/tf_runs_selector", - "//tensorflow/tensorboard/components/vz_line_chart", - "@org_polymer_iron_collapse", - "@org_polymer_paper_checkbox", - "@org_polymer_paper_dropdown_menu", - "@org_polymer_paper_icon_button", - "@org_polymer_paper_input", - "@org_polymer_paper_item", - "@org_polymer_paper_menu", - "@org_polymer_paper_slider", - "@org_polymer_paper_styles", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/BUILD b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/BUILD deleted file mode 100644 index 0e892b1aa30..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/BUILD +++ /dev/null @@ -1,27 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "demo", - srcs = ["index.html"], - path = "/tf-scalar-dashboard/demo", - deps = [ - "//tensorflow/tensorboard/components/tf_backend", - "//tensorflow/tensorboard/components/tf_imports:d3", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - "//tensorflow/tensorboard/components/tf_scalar_dashboard", - "//tensorflow/tensorboard/demo:demo_data", - "@org_polymer_iron_demo_helpers", - "@org_polymer_paper_styles", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/logdir b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/logdir deleted file mode 100644 index b6362b45d77..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/logdir +++ /dev/null @@ -1 +0,0 @@ -{"logdir": "/foo/some/fake/logdir"} \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/runs.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/runs.json deleted file mode 100644 index d45f530763c..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/runs.json +++ /dev/null @@ -1,4 +0,0 @@ -{ - "run1": {"scalars": ["foo/sin", "foo/cos", "foo/square", "bar/square"]}, - "run2": {"scalars": ["foo/cos", "foo/square", "bar/square"]} -} \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars.json deleted file mode 100644 index bc269395b68..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars.json +++ /dev/null @@ -1 +0,0 @@ -{"run2": {"foo/cos": [[0.0, 0, 2.0], [10.0, 1, 1.0806045532226562], [20.0, 2, -0.832293689250946], [30.0, 3, -1.979984998703003], [40.0, 4, -1.3072872161865234]], "bar/square": [[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]], "foo/square": [[0.0, 0, 0.0], [10.0, 1, 2.0], [20.0, 2, 8.0], [30.0, 3, 18.0], [40.0, 4, 32.0]]}, "run1": {"foo/sin": [[0.0, 0, 0.0], [10.0, 1, 0.8414709568023682], [20.0, 2, 0.9092974066734314], [30.0, 3, 0.14112000167369843], [40.0, 4, -0.756802499294281]], "foo/cos": [[0.0, 0, 1.0], [10.0, 1, 0.5403022766113281], [20.0, 2, -0.416146844625473], [30.0, 3, -0.9899924993515015], [40.0, 4, -0.6536436080932617]], "bar/square": [[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]], "foo/square": [[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]]}} \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_bar_2Fsquare.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_bar_2Fsquare.json deleted file mode 100644 index 6d584fb4a9e..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_bar_2Fsquare.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_foo_2Fcos.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_foo_2Fcos.json deleted file mode 100644 index 025eaa16e93..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_foo_2Fcos.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 1.0], [10.0, 1, 0.5403022766113281], [20.0, 2, -0.416146844625473], [30.0, 3, -0.9899924993515015], [40.0, 4, -0.6536436080932617]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_foo_2Fsin.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_foo_2Fsin.json deleted file mode 100644 index eae69dd78f3..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_foo_2Fsin.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 0.0], [10.0, 1, 0.8414709568023682], [20.0, 2, 0.9092974066734314], [30.0, 3, 0.14112000167369843], [40.0, 4, -0.756802499294281]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_foo_2Fsquare.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_foo_2Fsquare.json deleted file mode 100644 index 6d584fb4a9e..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run1_tag_foo_2Fsquare.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run2_tag_bar_2Fsquare.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run2_tag_bar_2Fsquare.json deleted file mode 100644 index 6d584fb4a9e..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run2_tag_bar_2Fsquare.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run2_tag_foo_2Fcos.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run2_tag_foo_2Fcos.json deleted file mode 100644 index dd3593f9d10..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run2_tag_foo_2Fcos.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 2.0], [10.0, 1, 1.0806045532226562], [20.0, 2, -0.832293689250946], [30.0, 3, -1.979984998703003], [40.0, 4, -1.3072872161865234]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run2_tag_foo_2Fsquare.json b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run2_tag_foo_2Fsquare.json deleted file mode 100644 index 0ff9ef0551d..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/data/scalars_run_run2_tag_foo_2Fsquare.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 0.0], [10.0, 1, 2.0], [20.0, 2, 8.0], [30.0, 3, 18.0], [40.0, 4, 32.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/index.html b/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/index.html deleted file mode 100644 index 78f657b4104..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/demo/index.html +++ /dev/null @@ -1,70 +0,0 @@ - - - - - - - - - - -Scalar Dashboard Demo - - - - diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/tf-scalar-dashboard.html b/tensorflow/tensorboard/components/tf_scalar_dashboard/tf-scalar-dashboard.html deleted file mode 100644 index 848ed5292de..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/tf-scalar-dashboard.html +++ /dev/null @@ -1,293 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_scalar_dashboard/tf-smoothing-input.html b/tensorflow/tensorboard/components/tf_scalar_dashboard/tf-smoothing-input.html deleted file mode 100644 index a0760330001..00000000000 --- a/tensorflow/tensorboard/components/tf_scalar_dashboard/tf-smoothing-input.html +++ /dev/null @@ -1,138 +0,0 @@ - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_storage/BUILD b/tensorflow/tensorboard/components/tf_storage/BUILD deleted file mode 100644 index 197e0ae73d6..00000000000 --- a/tensorflow/tensorboard/components/tf_storage/BUILD +++ /dev/null @@ -1,36 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_storage", - srcs = [ - "storage.ts", - "tf-storage.html", - ], - path = "/tf-storage", - deps = [ - "//tensorflow/tensorboard/components/tf_globals", - "//tensorflow/tensorboard/components/tf_imports:lodash", - ], -) - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [":tf_storage"], - destdir = "tf-storage", - deps = [ - "//tensorflow/tensorboard/components/tf_globals:legacy", - "//tensorflow/tensorboard/components/tf_imports_google:lib", - "//third_party/javascript/polymer/v1/polymer:lib", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_storage/storage.ts b/tensorflow/tensorboard/components/tf_storage/storage.ts deleted file mode 100644 index 873bc483a07..00000000000 --- a/tensorflow/tensorboard/components/tf_storage/storage.ts +++ /dev/null @@ -1,400 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {getFakeHash, setFakeHash, TABS, useHash} from '../tf-globals/globals'; - - -/* tslint:disable:no-namespace variable-name */ -/** - * The Storage Module provides storage for URL parameters, and an API for - * getting and setting TensorBoard's stateful URI. - * - * It generates URI components like: events&runPrefix=train* - * which TensorBoard uses after like localhost:8000/#events&runPrefix=train* - * to store state in the URI. - * - * It also allows saving the values to localStorage for long-term persistence. - */ -type StringDict = {[key: string]: string}; - -/** - * A key that users cannot use, since TensorBoard uses this to store info - * about the active tab. - */ -export let TAB = '__tab__'; - -/** - * The name of the property for users to set on a Polymer component - * in order for its stored properties to be stored in the URI unambiguously. - * (No need to set this if you want multiple instances of the component to - * share URI state) - * - * Example: - * - * - * The disambiguator should be set to any unique value so that multiple - * instances of the component can store properties in URI storage. - * - * Because it's hard to dereference this variable in HTML property bindings, - * it is NOT safe to change the disambiguator string without find+replace - * across the codebase. - */ -export let DISAMBIGUATOR = 'disambiguator'; - -/** - * Return a string stored in URI or localStorage. - * Undefined if not found. - */ -export function getString(key: string, useLocalStorage: boolean): string { - if (useLocalStorage) { - return window.localStorage.getItem(key); - } else { - return _componentToDict(_readComponent())[key]; - } -} - -/** - * Set a string in URI or localStorage. - */ -export function setString( - key: string, value: string, useLocalStorage: boolean) { - if (useLocalStorage) { - window.localStorage.setItem(key, value); - } else { - const items = _componentToDict(_readComponent()); - items[key] = value; - _writeComponent(_dictToComponent(items)); - } -} - -/** - * Return a boolean stored in stored in URI or localStorage. - * Undefined if not found. - */ -export function getBoolean(key: string, useLocalStorage: boolean): boolean { - const item = getString(key, useLocalStorage); - return item === 'true' ? true : item === 'false' ? false : undefined; -} - -/** - * Store a boolean in URI or localStorage. - */ -export function setBoolean( - key: string, value: boolean, useLocalStorage = false) { - setString(key, value.toString(), useLocalStorage); -} - -/** - * Return a number stored in stored in URI or localStorage. - * Undefined if not found. - */ -export function getNumber(key: string, useLocalStorage: boolean): number { - const item = getString(key, useLocalStorage); - return item === undefined ? undefined : +item; -} - -/** - * Store a number in URI or localStorage. - */ -export function setNumber( - key: string, value: number, useLocalStorage: boolean) { - setString(key, '' + value, useLocalStorage); -} - -/** - * Return an object stored in stored in URI or localStorage. - * Undefined if not found. - */ -export function getObject(key: string, useLocalStorage: boolean): {} { - const item = getString(key, useLocalStorage); - return item === undefined ? undefined : JSON.parse(atob(item)); -} - -/** - * Store an object in URI or localStorage. - */ -export function setObject(key: string, value: {}, useLocalStorage: boolean) { - setString(key, btoa(JSON.stringify(value)), useLocalStorage); -} - -/** - * Get a unique storage name for a (Polymer component, propertyName) tuple. - * - * DISAMBIGUATOR must be set on the component, if other components use the - * same propertyName. - */ -export function getURIStorageName( - component: {}, propertyName: string): string { - const d = component[DISAMBIGUATOR]; - const components = d == null ? [propertyName] : [d, propertyName]; - return components.join('.'); -} - -/** - * Return a function that: - * (1) Initializes a Polymer boolean property with a default value, if its - * value is not already set - * (2) Sets up listener that updates Polymer property on hash change. - */ -export function getBooleanInitializer( - propertyName: string, defaultVal: boolean, - useLocalStorage = false): Function { - return _getInitializer( - getBoolean, propertyName, defaultVal, useLocalStorage); -} - -/** - * Return a function that: - * (1) Initializes a Polymer string property with a default value, if its - * value is not already set - * (2) Sets up listener that updates Polymer property on hash change. - */ -export function getStringInitializer( - propertyName: string, defaultVal: string, - useLocalStorage = false): Function { - return _getInitializer( - getString, propertyName, defaultVal, useLocalStorage); -} - -/** - * Return a function that: - * (1) Initializes a Polymer number property with a default value, if its - * value is not already set - * (2) Sets up listener that updates Polymer property on hash change. - */ -export function getNumberInitializer( - propertyName: string, defaultVal: number, - useLocalStorage = false): Function { - return _getInitializer( - getNumber, propertyName, defaultVal, useLocalStorage); -} - -/** - * Return a function that: - * (1) Initializes a Polymer Object property with a default value, if its - * value is not already set - * (2) Sets up listener that updates Polymer property on hash change. - * - * Generates a deep clone of the defaultVal to avoid mutation issues. - */ -export function getObjectInitializer( - propertyName: string, defaultVal: {}, useLocalStorage = false): Function { - return _getInitializer( - getObject, propertyName, defaultVal, useLocalStorage); -} - -/** - * Return a function that updates URIStorage when a string property changes. - */ -export function getBooleanObserver( - propertyName: string, defaultVal: boolean, - useLocalStorage = false): Function { - return _getObserver( - getBoolean, setBoolean, propertyName, defaultVal, useLocalStorage); -} - -/** - * Return a function that updates URIStorage when a string property changes. - */ -export function getStringObserver( - propertyName: string, defaultVal: string, - useLocalStorage = false): Function { - return _getObserver( - getString, setString, propertyName, defaultVal, useLocalStorage); -} - -/** - * Return a function that updates URIStorage when a number property changes. - */ -export function getNumberObserver( - propertyName: string, defaultVal: number, - useLocalStorage = false): Function { - return _getObserver( - getNumber, setNumber, propertyName, defaultVal, useLocalStorage); -} - -/** - * Return a function that updates URIStorage when an object property changes. - * Generates a deep clone of the defaultVal to avoid mutation issues. - */ -export function getObjectObserver( - propertyName: string, defaultVal: {}, useLocalStorage = false): Function { - const clone = _.cloneDeep(defaultVal); - return _getObserver( - getObject, setObject, propertyName, clone, useLocalStorage); -} - -/** - * Read component from URI (e.g. returns "events&runPrefix=train*"). - */ -function _readComponent(): string { - return useHash() ? window.location.hash.slice(1) : getFakeHash(); -} - -/** - * Write component to URI. - */ -function _writeComponent(component: string) { - if (useHash()) { - window.location.hash = component; - } else { - setFakeHash(component); - } -} - -/** - * Convert dictionary of strings into a URI Component. - * All key value entries get added as key value pairs in the component, - * with the exception of a key with the TAB value, which if present - * gets prepended to the URI Component string for backwards compatibility - * reasons. - */ -function _dictToComponent(items: StringDict): string { - let component = ''; - - // Add the tab name e.g. 'events', 'images', 'histograms' as a prefix - // for backwards compatbility. - if (items[TAB] !== undefined) { - component += items[TAB]; - } - - // Join other strings with &key=value notation - const nonTab = _.pairs(items) - .filter((pair) => pair[0] !== TAB) - .map((pair) => { - return encodeURIComponent(pair[0]) + '=' + - encodeURIComponent(pair[1]); - }) - .join('&'); - - return nonTab.length > 0 ? (component + '&' + nonTab) : component; -} - -/** - * Convert a URI Component into a dictionary of strings. - * Component should consist of key-value pairs joined by a delimiter - * with the exception of the tabName. - * Returns dict consisting of all key-value pairs and - * dict[TAB] = tabName - */ -function _componentToDict(component: string): StringDict { - const items = {} as StringDict; - - const tokens = component.split('&'); - tokens.forEach((token) => { - const kv = token.split('='); - // Special backwards compatibility for URI components like #events - if (kv.length === 1 && _.contains(TABS, kv[0])) { - items[TAB] = kv[0]; - } else if (kv.length === 2) { - items[decodeURIComponent(kv[0])] = decodeURIComponent(kv[1]); - } - }); - return items; -} - -/** - * Return a function that: - * (1) Initializes a Polymer property with a default value, if its - * value is not already set - * (2) Sets up listener that updates Polymer property on hash change. - */ -function _getInitializer( - get: (name: string, useLocalStorage: boolean) => T, propertyName: string, - defaultVal: T, useLocalStorage): Function { - return function() { - const URIStorageName = getURIStorageName(this, propertyName); - // setComponentValue will be called every time the hash changes, and is - // responsible for ensuring that new state in the hash will be propagated - // to the component with that property. - // It is important that this function does not re-assign needlessly, - // to avoid Polymer observer churn. - const setComponentValue = () => { - const uriValue = get(URIStorageName, false); - const currentValue = this[propertyName]; - // if uriValue is undefined, we will ensure that the property has the - // default value - if (uriValue === undefined) { - let valueToSet: T; - // if we are using localStorage, we will set the value to the value - // from localStorage. Then, the corresponding observer will proxy - // the localStorage value into URI storage. - // in this way, localStorage takes precedence over the default val - // but not over the URI value. - if (useLocalStorage) { - const useLocalStorageValue = get(URIStorageName, true); - valueToSet = useLocalStorageValue === undefined ? - defaultVal : - useLocalStorageValue; - } else { - valueToSet = defaultVal; - } - if (!_.isEqual(currentValue, valueToSet)) { - // If we don't have an explicit URI value, then we need to ensure - // the property value is equal to the default value. - // We will assign a clone rather than the canonical default, because - // the component receiving this property may mutate it, and we need - // to keep a pristine copy of the default. - this[propertyName] = _.clone(valueToSet); - } - // In this case, we have an explicit URI value, so we will ensure that - // the component has an equivalent value. - } else { - if (!_.isEqual(uriValue, currentValue)) { - this[propertyName] = uriValue; - } - } - }; - // Set the value on the property. - setComponentValue(); - // Update it when the hashchanges. - window.addEventListener('hashchange', setComponentValue); - }; -} - -/** - * Return a function that updates URIStorage when a property changes. - */ -function _getObserver( - get: (name: string, useLocalStorage: boolean) => T, - set: (name: string, newVal: T, useLocalStorage: boolean) => void, - propertyName: string, defaultVal: T, useLocalStorage: boolean): Function { - return function() { - const URIStorageName = getURIStorageName(this, propertyName); - const newVal = this[propertyName]; - // if this is a localStorage property, we always synchronize the value - // in localStorage to match the one currently in the URI. - if (useLocalStorage) { - set(URIStorageName, newVal, true); - } - if (!_.isEqual(newVal, get(URIStorageName, false))) { - if (_.isEqual(newVal, defaultVal)) { - _unsetFromURI(URIStorageName); - } else { - set(URIStorageName, newVal, false); - } - } - }; -} - -/** - * Delete a key from the URI. - */ -function _unsetFromURI(key) { - const items = _componentToDict(_readComponent()); - delete items[key]; - _writeComponent(_dictToComponent(items)); -} - diff --git a/tensorflow/tensorboard/components/tf_storage/test/BUILD b/tensorflow/tensorboard/components/tf_storage/test/BUILD deleted file mode 100644 index 32399ba7cbe..00000000000 --- a/tensorflow/tensorboard/components/tf_storage/test/BUILD +++ /dev/null @@ -1,30 +0,0 @@ -package( - default_testonly = True, - default_visibility = ["//tensorflow/tensorboard:internal"], -) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "test", - srcs = [ - "storageTests.ts", - "tests.html", - ], - path = "/tf-storage/test", - deps = [ - "//tensorflow/tensorboard/components/tf_imports:polymer", - "//tensorflow/tensorboard/components/tf_imports:web_component_tester", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - "//tensorflow/tensorboard/components/tf_storage", - ], -) - -filegroup( - name = "all_files", - testonly = 0, - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_storage/test/storageTests.ts b/tensorflow/tensorboard/components/tf_storage/test/storageTests.ts deleted file mode 100644 index 82dc51f05da..00000000000 --- a/tensorflow/tensorboard/components/tf_storage/test/storageTests.ts +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -import {TAB, getString, getNumber, getObject, setString, setNumber, setObject} from '../storage'; -import {TABS} from '../../tf-globals/globals'; - -/* tslint:disable:no-namespace */ -describe('URIStorage', () => { - it('get/setString', () => { - setString('key_a', 'hello', false); - setString('key_b', 'there', false); - chai.assert.equal('hello', getString('key_a', false)); - chai.assert.equal('there', getString('key_b', false)); - chai.assert.equal(null, getString('key_c', false)); - }); - - it('get/setNumber', () => { - setNumber('key_a', 12, false); - setNumber('key_b', 3.4, false); - chai.assert.equal(12, getNumber('key_a', false)); - chai.assert.equal(3.4, getNumber('key_b', false)); - chai.assert.equal(null, getNumber('key_c', false)); - }); - - it('get/setObject', () => { - const obj = {'foo': 2.3, 'bar': 'barstr'}; - setObject('key_a', obj, false); - chai.assert.deepEqual(obj, getObject('key_a', false)); - }); - - it('get/setWeirdValues', () => { - setNumber('key_a', NaN, false); - chai.assert.deepEqual(NaN, getNumber('key_a', false)); - - setNumber('key_a', +Infinity, false); - chai.assert.equal(+Infinity, getNumber('key_a', false)); - - setNumber('key_a', -Infinity, false); - chai.assert.equal(-Infinity, getNumber('key_a', false)); - - setNumber('key_a', 1 / 3, false); - chai.assert.equal(1 / 3, getNumber('key_a', false)); - - setNumber('key_a', -0, false); - chai.assert.equal(-0, getNumber('key_a', false)); - }); - - it('set/getTab', () => { - setString(TAB, TABS[0], false); - chai.assert.equal(TABS[0], getString(TAB, false)); - }); -}); - diff --git a/tensorflow/tensorboard/components/tf_storage/test/tests.html b/tensorflow/tensorboard/components/tf_storage/test/tests.html deleted file mode 100644 index 4668b119d24..00000000000 --- a/tensorflow/tensorboard/components/tf_storage/test/tests.html +++ /dev/null @@ -1,25 +0,0 @@ - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_tensorboard/BUILD b/tensorflow/tensorboard/components/tf_tensorboard/BUILD deleted file mode 100644 index 95fb8b7a882..00000000000 --- a/tensorflow/tensorboard/components/tf_tensorboard/BUILD +++ /dev/null @@ -1,65 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") -load("//tensorflow/tensorboard/defs:vulcanize.bzl", "tensorboard_html_binary") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_tensorboard", - srcs = [ - "autoReloadBehavior.ts", - "style.html", - "tf-tensorboard.html", - ], - path = "/tf-tensorboard", - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/tensorboard/components/tf_audio_dashboard", - "//tensorflow/tensorboard/components/tf_backend", - "//tensorflow/tensorboard/components/tf_dashboard_common", - "//tensorflow/tensorboard/components/tf_distribution_dashboard", - "//tensorflow/tensorboard/components/tf_globals", - "//tensorflow/tensorboard/components/tf_graph_dashboard", - "//tensorflow/tensorboard/components/tf_histogram_dashboard", - "//tensorflow/tensorboard/components/tf_image_dashboard", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "//tensorflow/tensorboard/components/tf_scalar_dashboard", - "//tensorflow/tensorboard/components/tf_storage", - "//tensorflow/tensorboard/components/tf_text_dashboard", - "//tensorflow/tensorboard/components/vz_projector", - "@org_polymer_font_roboto", - "@org_polymer_iron_icons", - "@org_polymer_paper_button", - "@org_polymer_paper_checkbox", - "@org_polymer_paper_dialog", - "@org_polymer_paper_header_panel", - "@org_polymer_paper_icon_button", - "@org_polymer_paper_tabs", - "@org_polymer_paper_toolbar", - ], -) - -ts_web_library( - name = "demo", - srcs = ["demo.html"], - path = "/tf-tensorboard", - deps = [ - ":tf_tensorboard", - "//tensorflow/tensorboard/demo:demo_data", - ], -) - -tensorboard_html_binary( - name = "devserver", - testonly = 1, - input_path = "/tf-tensorboard/demo.html", - output_path = "/index.html", - deps = [":demo"], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_tensorboard/autoReloadBehavior.ts b/tensorflow/tensorboard/components/tf_tensorboard/autoReloadBehavior.ts deleted file mode 100644 index 54df16f5b5d..00000000000 --- a/tensorflow/tensorboard/components/tf_tensorboard/autoReloadBehavior.ts +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -export var AUTORELOAD_LOCALSTORAGE_KEY = 'TF.TensorBoard.autoReloadEnabled'; - -var getAutoReloadFromLocalStorage: () => boolean = () => { - var val = window.localStorage.getItem(AUTORELOAD_LOCALSTORAGE_KEY); - return val === 'true' || val == null; // defaults to true -}; - -/** - * @polymerBehavior - */ -export var AutoReloadBehavior = { - properties: { - autoReloadEnabled: { - type: Boolean, - observer: '_autoReloadObserver', - value: getAutoReloadFromLocalStorage, - }, - _autoReloadId: { - type: Number, - }, - autoReloadIntervalSecs: { - type: Number, - value: 30, - }, - }, - detached: function() { - window.clearTimeout(this._autoReloadId); - }, - _autoReloadObserver: function(autoReload) { - window.localStorage.setItem(AUTORELOAD_LOCALSTORAGE_KEY, autoReload); - if (autoReload) { - var _this = this; - this._autoReloadId = window.setTimeout( - this._doAutoReload.bind(this), this.autoReloadIntervalSecs * 1000); - } else { - window.clearTimeout(this._autoReloadId); - } - }, - _doAutoReload: function() { - if (this.reload == null) { - throw new Error('AutoReloadBehavior requires a reload method'); - } - this.reload(); - this._autoReloadId = window.setTimeout( - this._doAutoReload.bind(this), this.autoReloadIntervalSecs * 1000); - } -}; diff --git a/tensorflow/tensorboard/components/tf_tensorboard/demo.html b/tensorflow/tensorboard/components/tf_tensorboard/demo.html deleted file mode 100644 index f691f6211bc..00000000000 --- a/tensorflow/tensorboard/components/tf_tensorboard/demo.html +++ /dev/null @@ -1,24 +0,0 @@ - - - - -TensorBoard Demo - - - - diff --git a/tensorflow/tensorboard/components/tf_tensorboard/style.html b/tensorflow/tensorboard/components/tf_tensorboard/style.html deleted file mode 100644 index 575e89e3982..00000000000 --- a/tensorflow/tensorboard/components/tf_tensorboard/style.html +++ /dev/null @@ -1,28 +0,0 @@ - - - - - diff --git a/tensorflow/tensorboard/components/tf_tensorboard/test/autoReloadTests.ts b/tensorflow/tensorboard/components/tf_tensorboard/test/autoReloadTests.ts deleted file mode 100644 index b68fd8c9438..00000000000 --- a/tensorflow/tensorboard/components/tf_tensorboard/test/autoReloadTests.ts +++ /dev/null @@ -1,89 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {AUTORELOAD_LOCALSTORAGE_KEY, AutoReloadBehavior} from '../autoReloadBehavior'; - -declare function fixture(id: string): void; - -window.HTMLImports.whenReady(() => { - Polymer({ - is: 'autoreload-test-element', - behaviors: [AutoReloadBehavior], - }); - - describe('autoReload-behavior', function() { - let testElement; - const ls = window.localStorage; - const key = AUTORELOAD_LOCALSTORAGE_KEY; - let clock; - let callCount: number; - - beforeEach(function() { - ls.setItem(key, 'false'); // start it turned off so we can mutate fns - testElement = fixture('autoReloadFixture'); - callCount = 0; - testElement.reload = function() { callCount++; }; - }); - - before(function() { clock = sinon.useFakeTimers(); }); - - after(function() { clock.restore(); }); - - it('reads and writes autoReload state from localStorage', function() { - ls.removeItem(key); - testElement = fixture('autoReloadFixture'); - chai.assert.isTrue( - testElement.autoReloadEnabled, 'autoReload defaults to true'); - chai.assert.equal(ls.getItem(key), 'true', 'autoReload setting saved'); - testElement = fixture('autoReloadFixture'); - chai.assert.isTrue( - testElement.autoReloadEnabled, 'read true from localStorage'); - testElement.autoReloadEnabled = false; - chai.assert.equal(ls.getItem(key), 'false', 'autoReload setting saved'); - testElement = fixture('autoReloadFixture'); - chai.assert.isFalse( - testElement.autoReloadEnabled, 'read false setting properly'); - testElement.autoReloadEnabled = true; - chai.assert.equal(ls.getItem(key), 'true', 'saved true setting'); - }); - - it('reloads every interval secs when autoReloading', function() { - testElement.autoReloadIntervalSecs = 1; - testElement.autoReloadEnabled = true; - clock.tick(1000); - chai.assert.equal(callCount, 1, 'ticking clock triggered call'); - clock.tick(20 * 1000); - chai.assert.equal(callCount, 21, 'ticking clock 20s triggered 20 calls'); - }); - - it('can cancel pending autoReload', function() { - testElement.autoReloadIntervalSecs = 10; - testElement.autoReloadEnabled = true; - clock.tick(5 * 1000); - testElement.autoReloadEnabled = false; - clock.tick(20 * 1000); - chai.assert.equal(callCount, 0, 'callCount is 0'); - }); - - it('throws an error in absence of reload method', function() { - testElement.reload = undefined; - testElement.autoReloadIntervalSecs = 1; - testElement.autoReloadEnabled = true; - chai.assert.throws(function() { - clock.tick(5000); - }); - }); - }); -}); diff --git a/tensorflow/tensorboard/components/tf_tensorboard/test/e2eTests.html b/tensorflow/tensorboard/components/tf_tensorboard/test/e2eTests.html deleted file mode 100644 index 5efc02ef98a..00000000000 --- a/tensorflow/tensorboard/components/tf_tensorboard/test/e2eTests.html +++ /dev/null @@ -1,31 +0,0 @@ - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_tensorboard/test/e2eTests.ts b/tensorflow/tensorboard/components/tf_tensorboard/test/e2eTests.ts deleted file mode 100644 index a00027963be..00000000000 --- a/tensorflow/tensorboard/components/tf_tensorboard/test/e2eTests.ts +++ /dev/null @@ -1,44 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {TABS} from '../../tf-globals/globals'; - -describe('end-to-end test', () => { - window.HTMLImports.whenReady(() => { - let tb = d3.select('tf-tensorboard'); - var tabs = (tb.node()).$.tabs; - - function testTab(tabIndex: number) { - it(`selecting ${TABS[tabIndex]} tab`, done => { - // Every dashboard emits a rendered event when it is done rendering. - tb.on('rendered', () => done()); - tabs.set('selected', tabIndex); - }); - } - // Listen for when the default tab has rendered and test other tabs after. - tb.on('rendered', () => { - // The default tab already rendered. Test everything else. - // If a bug happened while rendering the default tab, the test would - // have failed. Re-selecting the default tab and listening for - // "rendered" event won't work since the content is not re-stamped. - let selected = +tabs.get('selected'); - for (let i = 0; i < TABS.length; i++) { - if (i !== selected) { - testTab(i); - } - } - }); - }); -}); diff --git a/tensorflow/tensorboard/components/tf_tensorboard/test/fastTabSwitch.html b/tensorflow/tensorboard/components/tf_tensorboard/test/fastTabSwitch.html deleted file mode 100644 index 88bb6edc482..00000000000 --- a/tensorflow/tensorboard/components/tf_tensorboard/test/fastTabSwitch.html +++ /dev/null @@ -1,31 +0,0 @@ - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_tensorboard/test/fastTabSwitch.ts b/tensorflow/tensorboard/components/tf_tensorboard/test/fastTabSwitch.ts deleted file mode 100644 index 905ed4ee4aa..00000000000 --- a/tensorflow/tensorboard/components/tf_tensorboard/test/fastTabSwitch.ts +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {TABS} from '../../tf-globals/globals'; - -describe('fast tab switch', () => { - window.HTMLImports.whenReady(() => { - let tb = d3.select('tf-tensorboard'); - // tslint:disable-next-line:no-any be quiet tsc - var tabs = (tb.node()).$.tabs; - - // This test will select the events tab. Once the events tab - // renders, will select the graph tab, and immediately select - // the images tab wihout waiting for the graph tab to finish - // rendering. Finally, it finishes when the images tab - // has rendered and no errors were thrown. - const eventsTabIndex = TABS.indexOf('events'); - const imagesTabIndex = TABS.indexOf('images'); - const graphTabIndex = TABS.indexOf('graphs'); - - // Listen for when the events tab rendered. - tb.on('rendered', () => { - it('switching to graph tab and immediately to images', done => { - // Select the graph tab. - tabs.set('selected', graphTabIndex); - // Interrupt graph rendering by immediately selecting the images tab - // and finish when the images tab has rendered. - tb.on('rendered', () => done()); - tabs.set('selected', imagesTabIndex); - }); - }); - // Select the events tab. - tabs.set('selected', eventsTabIndex); - }); -}); diff --git a/tensorflow/tensorboard/components/tf_tensorboard/test/index.html b/tensorflow/tensorboard/components/tf_tensorboard/test/index.html deleted file mode 100644 index 8806f36fad9..00000000000 --- a/tensorflow/tensorboard/components/tf_tensorboard/test/index.html +++ /dev/null @@ -1,35 +0,0 @@ - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_tensorboard/test/tensorboardTests.html b/tensorflow/tensorboard/components/tf_tensorboard/test/tensorboardTests.html deleted file mode 100644 index 2122cb79b16..00000000000 --- a/tensorflow/tensorboard/components/tf_tensorboard/test/tensorboardTests.html +++ /dev/null @@ -1,44 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_tensorboard/test/tensorboardTests.ts b/tensorflow/tensorboard/components/tf_tensorboard/test/tensorboardTests.ts deleted file mode 100644 index 06ff446f186..00000000000 --- a/tensorflow/tensorboard/components/tf_tensorboard/test/tensorboardTests.ts +++ /dev/null @@ -1,132 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {TABS} from '../../tf-globals/globals'; - -describe('tf-tensorboard tests', () => { - window.HTMLImports.whenReady(() => { - let tensorboard: any; - beforeEach(function() { - tensorboard = fixture('tensorboardFixture'); - tensorboard.demoDir = 'data'; - tensorboard.autoReloadEnabled = false; - }); - - it('specified tabs are correct', function(done) { - setTimeout(function() { - let tabs = tensorboard.$.tabs.getElementsByTagName('paper-tab'); - let tabMode = Array.prototype.map.call(tabs, (x) => x.dataMode); - chai.assert.deepEqual(tabMode, TABS, 'mode is correct'); - let tabText = - Array.prototype.map.call(tabs, (x) => x.innerText.toLowerCase()); - chai.assert.deepEqual(tabText, TABS, 'text is correct'); - done(); - }); - }); - - it('renders injected content', function() { - let injected = tensorboard.querySelector('#inject-me'); - chai.assert.isNotNull(injected); - }); - - describe('reloading the selected dashboard', function() { - TABS.forEach((name, tabIndex) => { - // These tabs do not support reload mode. - if (name === 'graphs' || name === 'projections') { - return; - } - it(`${name}: calling reload reloads dashboard`, function(done) { - tensorboard.$.tabs.set('selected', tabIndex); - setTimeout(function() { - let called = false; - tensorboard.selectedDashboard().reload = function() { - called = true; - }; - tensorboard.reload(); - chai.assert.isFalse( - tensorboard.$$('#reload-button').disabled, - 'reload button not disabled'); - chai.assert.isTrue(called, `reload was called`); - done(); - }); - }); - }); - }); - - it('reload is disabled for graph dashboard', function(done) { - const idx = TABS.indexOf('graphs'); - chai.assert.notEqual(idx, -1, 'graphs was found'); - tensorboard.$.tabs.set('selected', idx); - setTimeout( - function() { // async so that the queued tab change will happen - let called = false; - tensorboard.selectedDashboard().reload = function() { - called = true; - }; - tensorboard.reload(); - chai.assert.isTrue( - tensorboard.$$('#reload-button').disabled, - 'reload button disabled'); - chai.assert.isFalse(called, `reload was not called`); - done(); - }); - }); - - describe('top right global icons', function() { - it('Clicking the reload button will call reload', function() { - let called = false; - tensorboard.reload = function() { called = true; }; - tensorboard.$$('#reload-button').click(); - chai.assert.isTrue(called); - }); - - it('settings pane is hidden', function() { - chai.assert.equal(tensorboard.$.settings.style['display'], 'none'); - }); - - it('settings icon button opens the settings pane', function(done) { - tensorboard.$$('#settings-button').click(); - // This test is a little hacky since we depend on polymer's - // async behavior, which is difficult to predict. - - // keep checking until the panel is visible. error with a timeout if it - // is broken. - function verify() { - if (tensorboard.$.settings.style['display'] !== 'none') { - done(); - } else { - setTimeout(verify, 3); // wait and see if it becomes true - } - } - verify(); - }); - - it('Autoreload checkbox toggle works', function() { - let checkbox = tensorboard.$$('#auto-reload-checkbox'); - chai.assert.equal(checkbox.checked, tensorboard.autoReloadEnabled); - let oldValue = checkbox.checked; - checkbox.click(); - chai.assert.notEqual(oldValue, checkbox.checked); - chai.assert.equal(checkbox.checked, tensorboard.autoReloadEnabled); - }); - - it('Autoreload checkbox contains correct interval info', function() { - let checkbox = tensorboard.$$('#auto-reload-checkbox'); - let timeInSeconds = tensorboard.autoReloadIntervalSecs + 's'; - chai.assert.include(checkbox.innerText, timeInSeconds); - }); - }); - }); -}); diff --git a/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html b/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html deleted file mode 100644 index 26b742996aa..00000000000 --- a/tensorflow/tensorboard/components/tf_tensorboard/tf-tensorboard.html +++ /dev/null @@ -1,361 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/BUILD b/tensorflow/tensorboard/components/tf_text_dashboard/BUILD deleted file mode 100644 index bed551aedfc..00000000000 --- a/tensorflow/tensorboard/components/tf_text_dashboard/BUILD +++ /dev/null @@ -1,45 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_text_dashboard", - srcs = [ - "tf-text-dashboard.html", - "tf-text-loader.html", - ], - path = "/tf-text-dashboard", - deps = [ - "//tensorflow/tensorboard/components/tf_backend", - "//tensorflow/tensorboard/components/tf_color_scale", - "//tensorflow/tensorboard/components/tf_dashboard_common", - "//tensorflow/tensorboard/components/tf_imports:d3", - "//tensorflow/tensorboard/components/tf_imports:lodash", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "@org_polymer_paper_dialog", - "@org_polymer_paper_icon_button", - "@org_polymer_paper_material", - "@org_polymer_paper_slider", - "@org_polymer_paper_spinner", - ], -) - -ts_web_library( - name = "demo", - srcs = ["index.html"] + glob(["data/**"]), - path = "/tf-text-dashboard", - deps = [ - ":tf_text_dashboard", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - "@org_polymer_iron_demo_helpers", - "@org_polymer_paper_styles", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/data/logdir b/tensorflow/tensorboard/components/tf_text_dashboard/data/logdir deleted file mode 100644 index c7d82022cc0..00000000000 --- a/tensorflow/tensorboard/components/tf_text_dashboard/data/logdir +++ /dev/null @@ -1 +0,0 @@ -{"logdir": "/some/fake/logdir"} \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/data/runs.json b/tensorflow/tensorboard/components/tf_text_dashboard/data/runs.json deleted file mode 100644 index aea7de5f917..00000000000 --- a/tensorflow/tensorboard/components/tf_text_dashboard/data/runs.json +++ /dev/null @@ -1 +0,0 @@ -{"fry": ["message", "markdown"], "leela": ["message"]} \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/data/text_run_fry_tag_markdown.json b/tensorflow/tensorboard/components/tf_text_dashboard/data/text_run_fry_tag_markdown.json deleted file mode 100644 index 94183ae13d1..00000000000 --- a/tensorflow/tensorboard/components/tf_text_dashboard/data/text_run_fry_tag_markdown.json +++ /dev/null @@ -1,32 +0,0 @@ -[ - { - "wall_time": 1489715207.593146, - "step": 0, - "text": "

Italics1 Italics2 bold1 bold2

" - }, - { - "wall_time": 1489715207.593801, - "step": 1, - "text": "
    \n
  1. List item one.
  2. \n
  3. List item two.
  4. \n
  5. Sublist
  6. \n
  7. Sublist2
  8. \n
  9. List continues.
  10. \n
" - }, - { - "wall_time": 1489715207.594842, - "step": 2, - "text": "\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n
AnExampleTable
ABC
123
" - }, - { - "wall_time": 1489715207.595761, - "step": 3, - "text": "

hello you

" - }, - { - "wall_time": 1489715207.595761, - "step": 4, - "text": "

TensorFlow

" - }, - { - "wall_time": 1489715207.595761, - "step": 530234352, - "text": "<script>alert('xss')</script>" - } -] diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/data/text_run_fry_tag_message.json b/tensorflow/tensorboard/components/tf_text_dashboard/data/text_run_fry_tag_message.json deleted file mode 100644 index e8cc006c0d0..00000000000 --- a/tensorflow/tensorboard/components/tf_text_dashboard/data/text_run_fry_tag_message.json +++ /dev/null @@ -1,22 +0,0 @@ -[ - { - "wall_time": 1489715207.593146, - "step": 0, - "text": "fry loves garnet" - }, - { - "wall_time": 1489715207.593801, - "step": 1, - "text": "fry loves amethyst" - }, - { - "wall_time": 1489715207.594842, - "step": 2, - "text": "fry loves pearl" - }, - { - "wall_time": 1489715207.595761, - "step": 3, - "text": "fry loves steven" - } -] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/data/text_run_leela_tag_message.json b/tensorflow/tensorboard/components/tf_text_dashboard/data/text_run_leela_tag_message.json deleted file mode 100644 index 5a6d2598937..00000000000 --- a/tensorflow/tensorboard/components/tf_text_dashboard/data/text_run_leela_tag_message.json +++ /dev/null @@ -1,22 +0,0 @@ -[ - { - "step": 0, - "wall_time": 1489715207.607792, - "text": "leela loves garnet and feels strongly about various issues of the day including the two-cent titanium tax and whether nixon's head contributes to greenhouse gas emissions" - }, - { - "step": 1, - "wall_time": 1489715207.609011, - "text": "leela loves amethyst" - }, - { - "step": 2, - "wall_time": 1489715207.610028, - "text": "leela loves pearl" - }, - { - "step": 3, - "wall_time": 1489715207.611142, - "text": "leela loves someverylongwordwithoutanybreaksorspacessowecanseehowthatishandledbythefrontend" - } -] \ No newline at end of file diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/index.html b/tensorflow/tensorboard/components/tf_text_dashboard/index.html deleted file mode 100644 index 55ec4d79cf9..00000000000 --- a/tensorflow/tensorboard/components/tf_text_dashboard/index.html +++ /dev/null @@ -1,74 +0,0 @@ - - - - - - - - - text Dashboard Demo - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/tf-text-dashboard.html b/tensorflow/tensorboard/components/tf_text_dashboard/tf-text-dashboard.html deleted file mode 100644 index 9b4fd3239c9..00000000000 --- a/tensorflow/tensorboard/components/tf_text_dashboard/tf-text-dashboard.html +++ /dev/null @@ -1,113 +0,0 @@ - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_text_dashboard/tf-text-loader.html b/tensorflow/tensorboard/components/tf_text_dashboard/tf-text-loader.html deleted file mode 100644 index 374e0478dd1..00000000000 --- a/tensorflow/tensorboard/components/tf_text_dashboard/tf-text-loader.html +++ /dev/null @@ -1,143 +0,0 @@ - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/tf_trace_viewer/BUILD b/tensorflow/tensorboard/components/tf_trace_viewer/BUILD deleted file mode 100644 index 9f582329f1d..00000000000 --- a/tensorflow/tensorboard/components/tf_trace_viewer/BUILD +++ /dev/null @@ -1,30 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "tf_trace_viewer", - srcs = [ - "tf-trace-viewer.html", - "@org_chromium_catapult_vulcanized_trace_viewer//:trace_viewer_full.html", - ], - path = "/tf-trace-viewer", -) - -ts_web_library( - name = "demo", - srcs = ["demo.html"], - path = "/tf-trace-viewer", - deps = [ - ":tf_trace_viewer", - "//tensorflow/tensorboard/components/tf_trace_viewer/data", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_trace_viewer/data/BUILD b/tensorflow/tensorboard/components/tf_trace_viewer/data/BUILD deleted file mode 100644 index c295d38258f..00000000000 --- a/tensorflow/tensorboard/components/tf_trace_viewer/data/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") - -licenses(["notice"]) # Apache 2.0 - -web_library( - name = "data", - srcs = glob(["*.json"]), - path = "/tf-trace-viewer/data/plugin/profile", -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/tf_trace_viewer/data/trace.json b/tensorflow/tensorboard/components/tf_trace_viewer/data/trace.json deleted file mode 100644 index e1d57394e35..00000000000 --- a/tensorflow/tensorboard/components/tf_trace_viewer/data/trace.json +++ /dev/null @@ -1,105 +0,0 @@ -{ - "traceEvents": [ - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 826, "ph": "C", - "name": "counter", "args": {"value": 10}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 826, "ph": "B", - "name": "A long name that doesnt fit but is exceedingly informative", - "args": {"name_false": false, "value_true": true}}, - {"cat": "PERF", "pid": 22630, "ts": 835, "ph": "I", "s": "p", - "name": "ProcessWideEvent1", "args": {}}, - - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 827, "ph": "B", - "name": "Asub with a name that wont fit", "args": {}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 828, "ph": "E", - "name": "Asub", "args": {}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 829, "ph": "B", - "name": "Asub", "args": {}}, - {"cat": "PREF", "pid": 22630, "tid": 22630, "dur": 15, "ts": 820, "ph": "X", - "name": "Long X type", "args": {}, "sf": 7, "esf": 8}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 832, "ph": "E", - "name": "Asub", "args": {}}, - {"cat": "PREF", "pid": 22630, "tid": 22630, "dur": 2, "ts": 818, "ph": "X", - "name": "X1", "args": {}}, - {"cat": "PREF", "pid": 22630, "tid": 22630, "dur": 2, "ts": 818, "ph": "X", - "name": "X same ts and dur as X1", "args": {}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 832, "ph": "C", - "name": "counter", "args": {"value": 1}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 833, "ph": "E", - "name": "", "args": {}}, - - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 835, "ph": "I", - "name": "ThreadLevelI1", "args": {}}, - - {"cat": "PERF", "ts": 880, "ph": "I", "s": "g", "name": "GlobalEvent1", - "args": {}}, - - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 837, "ph": "I", - "name": "ThreadLevelI2", "args": {}}, - - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 839, "ph": "C", - "name": "counter", "args": {"value": 5}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 840, "ph": "B", - "name": "A not as long a name", "args": {}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 848, "ph": "E", - "name": "A not as long a name", "args": {}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 848, "ph": "C", - "name": "counter", "args": {"value": 1}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 854, "ph": "C", - "name": "counter", "args": {"value": 10}}, - - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 850, "ph": "B", - "name": "B", "args": {}}, - {"cat": "PERF", "pid": 22630, "tid": 22630, "ts": 854, "ph": "E", - "name": "B", "args": {}}, - - {"cat": "PERF", "pid": 22630, "tid": 22631, "ts": 827, "ph": "B", - "name": "A", "args": {}}, - {"cat": "PERF", "pid": 22630, "tid": 22631, "ts": 835, "ph": "I", - "name": "ThreadLevelImmediate Three", "args": {}}, - {"cat": "PERF", "pid": 22630, "tid": 22631, "ts": 845, "ph": "I", - "name": "ThreadLevelImmediate4", "args": {}}, - {"cat": "PERF", "pid": 22630, "tid": 22631, "ts": 854, "ph": "E", - "name": "A", "args": {}}, - - {"cat": "PREF", "pid": 22630, "tid": 22630, "ts": 860, "ph": "B", - "name": "B/E over X", "args": {}}, - {"cat": "PREF", "pid": 22630, "tid": 22630, "dur": 10, "ts": 860, "ph": "X", - "name": "X", "args": {}}, - {"cat": "PREF", "pid": 22630, "tid": 22630, "ts": 860, "ph": "B", - "name": "B/E under X", "args": {}}, - {"cat": "PREF", "pid": 22630, "tid": 22630, "ts": 870, "ph": "E", - "name": "B/E under X", "args": {}}, - {"cat": "PREF", "pid": 22630, "tid": 22630, "ts": 870, "ph": "E", - "name": "B/E over X", "args": {}}, - - {"cat": "SAMPLE", "pid": 22630, "tid": 22631, "ts": 870, "ph": "P", - "name": "SampleA", "args": {}}, - {"cat": "SAMPLE", "pid": 22630, "tid": 22631, "ts": 875, "ph": "P", - "name": "SampleB", "args": {}}, - {"cat": "SAMPLE", "pid": 22630, "tid": 22631, "ts": 878, "ph": "P", - "name": "SampleC", "args": {}, "sf": 8}, - - {"cat": "__metadata", "pid": 22630, "tid": 22630, "ts": 0, "ph": "M", - "name": "thread_name", "args": {"name": "threadA"}}, - {"cat": "__metadata", "pid": 22630, "tid": 22631, "ts": 0, "ph": "M", - "name": "thread_name", "args": {"name": "threadB"}}, - {"cat": "__metadata", "pid": 22630, "tid": 22632, "ts": 0, "ph": "M", - "name": "thread_name", "args": {"name": "threadC"}} - ], - "stackFrames": { - "1": { - "category": "m1", - "name": "main" - }, - "7": { - "category": "m2", - "name": "frame7", - "parent": "1" - }, - "8": { - "category": "m2", - "name": "frame8", - "parent": "1" - } - } -} diff --git a/tensorflow/tensorboard/components/tf_trace_viewer/demo.html b/tensorflow/tensorboard/components/tf_trace_viewer/demo.html deleted file mode 100644 index dd0029e9679..00000000000 --- a/tensorflow/tensorboard/components/tf_trace_viewer/demo.html +++ /dev/null @@ -1,30 +0,0 @@ - - - - -Trace Viewer Demo - -
- - -
diff --git a/tensorflow/tensorboard/components/tf_trace_viewer/tf-trace-viewer.html b/tensorflow/tensorboard/components/tf_trace_viewer/tf-trace-viewer.html deleted file mode 100644 index a7b0b2cd730..00000000000 --- a/tensorflow/tensorboard/components/tf_trace_viewer/tf-trace-viewer.html +++ /dev/null @@ -1,127 +0,0 @@ - - - - - - diff --git a/tensorflow/tensorboard/components/trace_viewer.html b/tensorflow/tensorboard/components/trace_viewer.html deleted file mode 100644 index c9bcdc9e207..00000000000 --- a/tensorflow/tensorboard/components/trace_viewer.html +++ /dev/null @@ -1,28 +0,0 @@ - - - - -Trace Viewer - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_distribution_chart/BUILD b/tensorflow/tensorboard/components/vz_distribution_chart/BUILD deleted file mode 100644 index 6645805d0c0..00000000000 --- a/tensorflow/tensorboard/components/vz_distribution_chart/BUILD +++ /dev/null @@ -1,39 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "vz_distribution_chart", - srcs = [ - "vz-distribution-chart.html", - "vz-distribution-chart.ts", - ], - path = "/vz-distribution-chart", - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/tensorboard/components/tf_imports:lodash", - "//tensorflow/tensorboard/components/tf_imports:plottable", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "//tensorflow/tensorboard/components/vz_line_chart", - ], -) - -ts_web_library( - name = "demo", - srcs = ["index.html"], - path = "/vz-distribution-chart", - deps = [ - ":vz_distribution_chart", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - "@org_polymer_iron_demo_helpers", - "@org_polymer_paper_styles", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/vz_distribution_chart/index.html b/tensorflow/tensorboard/components/vz_distribution_chart/index.html deleted file mode 100644 index 39db09354bd..00000000000 --- a/tensorflow/tensorboard/components/vz_distribution_chart/index.html +++ /dev/null @@ -1,61 +0,0 @@ - - - - - - - - vz-distribution chart demo - - - - - - - -

Simple distribution chart

- - - - - - diff --git a/tensorflow/tensorboard/components/vz_distribution_chart/vz-distribution-chart.html b/tensorflow/tensorboard/components/vz_distribution_chart/vz-distribution-chart.html deleted file mode 100644 index 1f1fdda9196..00000000000 --- a/tensorflow/tensorboard/components/vz_distribution_chart/vz-distribution-chart.html +++ /dev/null @@ -1,45 +0,0 @@ - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_distribution_chart/vz-distribution-chart.ts b/tensorflow/tensorboard/components/vz_distribution_chart/vz-distribution-chart.ts deleted file mode 100644 index f3911d301d9..00000000000 --- a/tensorflow/tensorboard/components/vz_distribution_chart/vz-distribution-chart.ts +++ /dev/null @@ -1,237 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import * as ChartHelpers from '../vz-line-chart/vz-chart-helpers'; - -export class DistributionChart { - private run2datasets: {[run: string]: Plottable.Dataset}; - protected runs: string[]; - - protected xAccessor: Plottable.IAccessor; - protected xScale: Plottable.QuantitativeScale; - protected yScale: Plottable.QuantitativeScale; - protected gridlines: Plottable.Components.Gridlines; - protected center: Plottable.Components.Group; - protected xAxis: Plottable.Axes.Numeric|Plottable.Axes.Time; - protected yAxis: Plottable.Axes.Numeric; - protected xLabel: Plottable.Components.AxisLabel; - protected yLabel: Plottable.Components.AxisLabel; - protected outer: Plottable.Components.Table; - protected colorScale: Plottable.Scales.Color; - private plots: Plottable.XYPlot[]; - - private targetSVG: d3.Selection; - - constructor(xType: string, colorScale: Plottable.Scales.Color) { - this.run2datasets = {}; - this.colorScale = colorScale; - this.buildChart(xType); - } - - protected getDataset(run: string) { - if (this.run2datasets[run] === undefined) { - this.run2datasets[run] = new Plottable.Dataset([], {run: run}); - } - return this.run2datasets[run]; - } - - protected buildChart(xType: string) { - if (this.outer) { - this.outer.destroy(); - } - let xComponents = ChartHelpers.getXComponents(xType); - this.xAccessor = xComponents.accessor; - this.xScale = xComponents.scale; - this.xAxis = xComponents.axis; - this.xAxis.margin(0).tickLabelPadding(3); - this.yScale = new Plottable.Scales.Linear(); - this.yAxis = new Plottable.Axes.Numeric(this.yScale, 'left'); - let yFormatter = ChartHelpers.multiscaleFormatter( - ChartHelpers.Y_AXIS_FORMATTER_PRECISION); - this.yAxis.margin(0).tickLabelPadding(5).formatter(yFormatter); - this.yAxis.usesTextWidthApproximation(true); - - let center = this.buildPlot(this.xAccessor, this.xScale, this.yScale); - - this.gridlines = - new Plottable.Components.Gridlines(this.xScale, this.yScale); - - this.center = new Plottable.Components.Group([this.gridlines, center]); - this.outer = new Plottable.Components.Table( - [[this.yAxis, this.center], [null, this.xAxis]]); - } - - protected buildPlot(xAccessor, xScale, yScale): Plottable.Component { - let percents = [0, 228, 1587, 3085, 5000, 6915, 8413, 9772, 10000]; - let opacities = _.range(percents.length - 1) - .map((i) => (percents[i + 1] - percents[i]) / 2500); - let accessors = percents.map((p, i) => (datum) => datum[i][1]); - let median = 4; - let medianAccessor = accessors[median]; - - let plots = _.range(accessors.length - 1).map((i) => { - let p = new Plottable.Plots.Area(); - p.x(xAccessor, xScale); - - let y0 = i > median ? accessors[i] : accessors[i + 1]; - let y = i > median ? accessors[i + 1] : accessors[i]; - p.y(y, yScale); - p.y0(y0); - p.attr( - 'fill', - (d: any, i: number, dataset: Plottable.Dataset) => - this.colorScale.scale(dataset.metadata().run)); - p.attr( - 'stroke', - (d: any, i: number, dataset: Plottable.Dataset) => - this.colorScale.scale(dataset.metadata().run)); - p.attr('stroke-weight', (d: any, i: number, m: any) => '0.5px'); - p.attr('stroke-opacity', () => opacities[i]); - p.attr('fill-opacity', () => opacities[i]); - return p; - }); - - let medianPlot = new Plottable.Plots.Line(); - medianPlot.x(xAccessor, xScale); - medianPlot.y(medianAccessor, yScale); - medianPlot.attr( - 'stroke', (d: any, i: number, m: any) => this.colorScale.scale(m.run)); - - this.plots = plots; - return new Plottable.Components.Group(plots); - } - - public setVisibleSeries(runs: string[]) { - this.runs = runs; - let datasets = runs.map((r) => this.getDataset(r)); - this.plots.forEach((p) => p.datasets(datasets)); - } - - /** - * Set the data of a series on the chart. - */ - public setSeriesData(name: string, data: any) { - this.getDataset(name).data(data); - } - - public renderTo(targetSVG: d3.Selection) { - this.targetSVG = targetSVG; - this.outer.renderTo(targetSVG); - } - - public redraw() { - this.outer.redraw(); - } - - protected destroy() { - this.outer.destroy(); - } -} - - -Polymer({ - is: 'vz-distribution-chart', - properties: { - /** - * Scale that maps series names to colors. The default colors are from - * d3.d3.schemeCategory10. Use this property to replace the default - * line colors with colors of your own choice. - * @type {Plottable.Scales.Color} - * @required - */ - colorScale: { - type: Object, - value: function() { - return new Plottable.Scales.Color().range(d3.schemeCategory10); - } - }, - /** - * The way to display the X values. Allows: - * - "step" - Linear scale using the "step" property of the datum. - * - "wall_time" - Temporal scale using the "wall_time" property of the - * datum. - * - "relative" - Temporal scale using the "relative" property of the - * datum if it is present or calculating from "wall_time" if it isn't. - */ - xType: {type: String, value: 'step'}, - _attached: Boolean, - _chart: Object, - _visibleSeriesCache: { - type: Array, - value: function() { - return [] - } - }, - _seriesDataCache: { - type: Object, - value: function() { - return {} - } - }, - _makeChartAsyncCallbackId: {type: Number, value: null} - }, - observers: [ - '_makeChart(xType, colorScale, _attached)', - '_reloadFromCache(_chart)', - ], - setVisibleSeries: function(names) { - this._visibleSeriesCache = names; - if (this._chart) { - this._chart.setVisibleSeries(names); - this.redraw(); - } - }, - setSeriesData: function(name, data) { - this._seriesDataCache[name] = data; - if (this._chart) { - this._chart.setSeriesData(name, data); - } - }, - redraw: function() { - this._chart.redraw(); - }, - ready: function() { - this.scopeSubtree(this.$.chartdiv, true); - }, - _makeChart: function(xType, colorScale, _attached) { - if (this._makeChartAsyncCallbackId === null) { - this.cancelAsync(this._makeChartAsyncCallbackId); - } - - this._makeChartAsyncCallbackId = this.async(function() { - this._makeChartAsyncCallbackId = null; - if (!_attached) return; - if (this._chart) this._chart.destroy(); - var chart = new DistributionChart(xType, colorScale); - var svg = d3.select(this.$.chartdiv); - chart.renderTo(svg); - this._chart = chart; - }, 350); - }, - _reloadFromCache: function() { - if (this._chart) { - this._chart.setVisibleSeries(this._visibleSeriesCache); - this._visibleSeriesCache.forEach(function(name) { - this._chart.setSeriesData(name, this._seriesDataCache[name] || []); - }.bind(this)); - } - }, - attached: function() { - this._attached = true; - }, - detached: function() { - this._attached = false; - } -}); diff --git a/tensorflow/tensorboard/components/vz_histogram_timeseries/BUILD b/tensorflow/tensorboard/components/vz_histogram_timeseries/BUILD deleted file mode 100644 index 6f6c8d94c37..00000000000 --- a/tensorflow/tensorboard/components/vz_histogram_timeseries/BUILD +++ /dev/null @@ -1,46 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "vz_histogram_timeseries", - srcs = ["vz-histogram-timeseries.html"], - path = "/vz-histogram-timeseries", - deps = [ - "//tensorflow/tensorboard/components/tf_imports:d3", - "//tensorflow/tensorboard/components/tf_imports:polymer", - ], -) - -ts_web_library( - name = "demo", - srcs = ["index.html"], - path = "/vz-histogram-timeseries", - deps = [ - ":vz_histogram_timeseries", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - "@org_polymer_iron_demo_helpers", - "@org_polymer_paper_button", - "@org_polymer_paper_styles", - ], -) - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [":vz_histogram_timeseries"], - visibility = ["//learning/vis/vz_elements/catalog:__pkg__"], - destdir = "vz-histogram-timeseries", - deps = [ - "//tensorflow/tensorboard/components/tf_imports_google:lib", - "//third_party/javascript/polymer/v1/polymer:lib", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/vz_histogram_timeseries/index.html b/tensorflow/tensorboard/components/vz_histogram_timeseries/index.html deleted file mode 100644 index 42efa83eb07..00000000000 --- a/tensorflow/tensorboard/components/vz_histogram_timeseries/index.html +++ /dev/null @@ -1,84 +0,0 @@ - - - - - - - - vz-histogram-timeseries demo - - - - - - - - -

vz-histogram-timeseries mode

- - - - -

vz-histogram-timeseries axis

- - - - - - - diff --git a/tensorflow/tensorboard/components/vz_histogram_timeseries/vz-histogram-timeseries.html b/tensorflow/tensorboard/components/vz_histogram_timeseries/vz-histogram-timeseries.html deleted file mode 100644 index bdba230077d..00000000000 --- a/tensorflow/tensorboard/components/vz_histogram_timeseries/vz-histogram-timeseries.html +++ /dev/null @@ -1,707 +0,0 @@ - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_line_chart/BUILD b/tensorflow/tensorboard/components/vz_line_chart/BUILD deleted file mode 100644 index 8bbf8a24d34..00000000000 --- a/tensorflow/tensorboard/components/vz_line_chart/BUILD +++ /dev/null @@ -1,53 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "vz_line_chart", - srcs = [ - "dragZoomInteraction.ts", - "vz-chart-helpers.ts", - "vz-line-chart.html", - "vz-line-chart.ts", - ], - path = "/vz-line-chart", - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/tensorboard/components/tf_imports:d3", - "//tensorflow/tensorboard/components/tf_imports:lodash", - "//tensorflow/tensorboard/components/tf_imports:plottable", - "//tensorflow/tensorboard/components/tf_imports:polymer", - ], -) - -ts_web_library( - name = "demo", - srcs = ["index.html"], - path = "/vz-line-chart", - deps = [ - ":vz_line_chart", - "@org_polymer_iron_demo_helpers", - "@org_polymer_paper_styles", - ], -) - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [":vz_line_chart"], - visibility = ["//learning/vis/vz_elements/catalog:__pkg__"], - destdir = "vz-line-chart", - deps = [ - "//tensorflow/tensorboard/components/tf_imports_google:lib", - "//tensorflow/tensorboard/components/vz_sorting:legacy", - "//third_party/javascript/polymer/v1/polymer:lib", - ], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/vz_line_chart/dragZoomInteraction.ts b/tensorflow/tensorboard/components/vz_line_chart/dragZoomInteraction.ts deleted file mode 100644 index c7f1f30e76b..00000000000 --- a/tensorflow/tensorboard/components/vz_line_chart/dragZoomInteraction.ts +++ /dev/null @@ -1,200 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the 'License'); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -export class DragZoomLayer extends Plottable.Components.SelectionBoxLayer { - private _dragInteraction: Plottable.Interactions.Drag; - private _doubleClickInteraction: Plottable.Interactions.Click; - private isZoomed = false; - private easeFn: (t: number) => number = d3.easeCubicInOut; - private _animationTime = 750; - private onStart: Function; - private onEnd: Function; - private unzoomMethod: Function; - - /** - * Constructs a SelectionBoxLayer with an attached DragInteraction and - * ClickInteraction. On drag, it triggers an animated zoom into the box - * that was dragged. On double click, it zooms back out to the original - * view, before any zooming. - * The zoom animation uses an easing function (default - * d3.ease('cubic-in-out')) and is customizable. - * Usage: Construct the selection box layer and attach x and y scales, - * and then add the layer over the plot you are zooming on using a - * Component Group. - * TODO(danmane) - merge this into Plottable - */ - constructor( - xScale: Plottable.QuantitativeScale, - yScale: Plottable.QuantitativeScale, - unzoomMethod: Function) { - super(); - this.xScale(xScale); - this.yScale(yScale); - this._dragInteraction = new Plottable.Interactions.Drag(); - this._dragInteraction.attachTo(this); - this._doubleClickInteraction = new Plottable.Interactions.Click(); - this._doubleClickInteraction.attachTo(this); - this.setupCallbacks(); - this.unzoomMethod = unzoomMethod; - } - - /** - * Register a method that calls when the DragZoom interaction starts. - */ - public interactionStart(cb: Function) { - this.onStart = cb; - } - - /** - * Register a method that calls when the DragZoom interaction ends. - */ - public interactionEnd(cb: Function) { - this.onEnd = cb; - } - - private setupCallbacks() { - let dragging = false; - this._dragInteraction.onDragStart((startPoint: Plottable.Point) => { - this.bounds({ - topLeft: startPoint, - bottomRight: startPoint, - }); - this.onStart(); - }); - this._dragInteraction.onDrag((startPoint, endPoint) => { - this.bounds({topLeft: startPoint, bottomRight: endPoint}); - this.boxVisible(true); - dragging = true; - }); - this._dragInteraction.onDragEnd((startPoint, endPoint) => { - this.boxVisible(false); - this.bounds({topLeft: startPoint, bottomRight: endPoint}); - if (dragging) { - this.zoom(); - } else { - this.onEnd(); - } - dragging = false; - }); - - this._doubleClickInteraction.onDoubleClick(this.unzoom.bind(this)); - } - - /* Set the time (in ms) over which the zoom will interpolate. - * 0 implies no interpolation. (ie zoom is instant) - */ - public animationTime(): number; - public animationTime(animationTime: number): DragZoomLayer; - public animationTime(animationTime?: number): any { - if (animationTime == null) { - return this._animationTime; - } - if (animationTime < 0) { - throw new Error('animationTime cannot be negative'); - } - this._animationTime = animationTime; - return this; - } - - /** - * Set the easing function, which determines how the zoom interpolates - * over time. - */ - public ease(fn: (t: number) => number): DragZoomLayer { - if (typeof(fn) !== 'function') { - throw new Error('ease function must be a function'); - } - if (fn(0) !== 0 || fn(1) !== 1) { - Plottable.Utils.Window.warn( - 'Easing function does not maintain invariant ' + - 'f(0)==0 && f(1)==1. Bad behavior may result.'); - } - this.easeFn = fn; - return this; - } - - // Zoom into extent of the selection box bounds - private zoom() { - let x0: number = this.xExtent()[0].valueOf(); - let x1: number = this.xExtent()[1].valueOf(); - let y0: number = this.yExtent()[1].valueOf(); - let y1: number = this.yExtent()[0].valueOf(); - - if (x0 === x1 || y0 === y1) { - return; - } - - if (!this.isZoomed) { - this.isZoomed = true; - } - this.interpolateZoom(x0, x1, y0, y1); - } - - // Restore the scales to their state before any zoom - private unzoom() { - if (!this.isZoomed) { - return; - } - this.isZoomed = false; - let xScale = this.xScale() as any; - xScale._domainMin = null; - xScale._domainMax = null; - let xDomain = xScale._getExtent(); - this.xScale().domain(xDomain); - this.unzoomMethod(); - } - - // If we are zooming, disable interactions, to avoid contention - private isZooming(isZooming: boolean) { - this._dragInteraction.enabled(!isZooming); - this._doubleClickInteraction.enabled(!isZooming); - } - - private interpolateZoom(x0f: number, x1f: number, y0f: number, y1f: number) { - let x0s: number = this.xScale().domain()[0].valueOf(); - let x1s: number = this.xScale().domain()[1].valueOf(); - let y0s: number = this.yScale().domain()[0].valueOf(); - let y1s: number = this.yScale().domain()[1].valueOf(); - - // Copy a ref to the ease fn, so that changing ease wont affect zooms in - // progress. - let ease = this.easeFn; - let interpolator = (a: number, b: number, p: number) => - d3.interpolateNumber(a, b)(ease(p)); - - this.isZooming(true); - let start = Date.now(); - let draw = () => { - let now = Date.now(); - let passed = now - start; - let p = this._animationTime === 0 ? - 1 : - Math.min(1, passed / this._animationTime); - let x0 = interpolator(x0s, x0f, p); - let x1 = interpolator(x1s, x1f, p); - let y0 = interpolator(y0s, y0f, p); - let y1 = interpolator(y1s, y1f, p); - this.xScale().domain([x0, x1]); - this.yScale().domain([y0, y1]); - if (p < 1) { - Plottable.Utils.DOM.requestAnimationFramePolyfill(draw); - } else { - this.onEnd(); - this.isZooming(false); - } - }; - draw(); - } -} diff --git a/tensorflow/tensorboard/components/vz_line_chart/index.html b/tensorflow/tensorboard/components/vz_line_chart/index.html deleted file mode 100644 index 856ab7d1efe..00000000000 --- a/tensorflow/tensorboard/components/vz_line_chart/index.html +++ /dev/null @@ -1,72 +0,0 @@ - - - - - - - - vz-line-chart demo - - - - - - -

Simple line chart

- - - - -

Exponential Smoothing enabled

- - - - - - - diff --git a/tensorflow/tensorboard/components/vz_line_chart/vz-chart-helpers.ts b/tensorflow/tensorboard/components/vz_line_chart/vz-chart-helpers.ts deleted file mode 100644 index fa89e06ada1..00000000000 --- a/tensorflow/tensorboard/components/vz_line_chart/vz-chart-helpers.ts +++ /dev/null @@ -1,219 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -export interface Datum { - wall_time: Date; - step: number; -} - -export interface Scalar { - scalar: number; - smoothed: number; -} - -export type ScalarDatum = Datum & Scalar; - -export type DataFn = (run: string, tag: string) => Promise>; - -export let Y_TOOLTIP_FORMATTER_PRECISION = 4; -export let STEP_FORMATTER_PRECISION = 4; -export let Y_AXIS_FORMATTER_PRECISION = 3; -export let TOOLTIP_Y_PIXEL_OFFSET = 20; -export let TOOLTIP_CIRCLE_SIZE = 4; -export let NAN_SYMBOL_SIZE = 6; - -export interface Point { - x: number; // pixel space - y: number; // pixel space - datum: ScalarDatum; - dataset: Plottable.Dataset; -} - -/* Create a formatter function that will switch between exponential and - * regular display depending on the scale of the number being formatted, - * and show `digits` significant digits. - */ -export function multiscaleFormatter(digits: number): ((v: number) => string) { - return (v: number) => { - let absv = Math.abs(v); - if (absv < 1E-15) { - // Sometimes zero-like values get an annoying representation - absv = 0; - } - let f: (x: number) => string; - if (absv >= 1E4) { - f = d3.format('.' + digits + 'e'); - } else if (absv > 0 && absv < 0.01) { - f = d3.format('.' + digits + 'e'); - } else { - f = d3.format('.' + digits + 'g'); - } - return f(v); - }; -} - -/* Compute an appropriate domain given an array of all the values that are - * going to be displayed. If ignoreOutliers is true, it will ignore the - * lowest 10% and highest 10% of the data when computing a domain. - * It has n log n performance when ignoreOutliers is true, as it needs to - * sort the data. - */ -export function computeDomain(values: number[], ignoreOutliers: boolean) { - // Don't include infinities and NaNs in the domain computation. - values = values.filter(z => isFinite(z)); - - if (values.length === 0) { - return [-0.1, 1.1]; - } - let a: number; - let b: number; - if (ignoreOutliers) { - let sorted = _.sortBy(values); - a = d3.quantile(sorted, 0.05); - b = d3.quantile(sorted, 0.95); - } else { - a = d3.min(values); - b = d3.max(values); - } - - let padding: number; - let span = b - a; - if (span === 0) { - // If b===a, we would create an empty range. We instead select the range - // [0, 2*a] if a > 0, or [-2*a, 0] if a < 0, plus a little bit of - // extra padding on the top and bottom of the plot. - padding = Math.abs(a) * 1.1 + 1.1; - } else { - padding = span * 0.2; - } - - let lower: number; - if (a >= 0 && a < span) { - // We include the intercept (y = 0) if doing so less than doubles the span - // of the y-axis. (We actually select a lower bound that's slightly less - // than 0 so that 0.00 will clearly be written on the lower edge of the - // chart. The label on the lowest tick is often filtered out.) - lower = -0.1 * b; - } else { - lower = a - padding; - } - - - let domain = [lower, b + padding]; - domain = d3.scaleLinear().domain(domain).nice().domain(); - return domain; -} - -export function accessorize(key: string): Plottable.IAccessor { - // tslint:disable-next-line:no-any be quiet tsc - return (d: any, index: number, dataset: Plottable.Dataset) => d[key]; -} - -export interface XComponents { - /* tslint:disable */ - scale: Plottable.Scales.Linear|Plottable.Scales.Time, - axis: Plottable.Axes.Numeric|Plottable.Axes.Time, - accessor: Plottable.IAccessor, - /* tslint:enable */ -} - -export let stepFormatter = - Plottable.Formatters.siSuffix(STEP_FORMATTER_PRECISION); -export function stepX(): XComponents { - let scale = new Plottable.Scales.Linear(); - let axis = new Plottable.Axes.Numeric(scale, 'bottom'); - axis.formatter(stepFormatter); - return { - scale: scale, - axis: axis, - accessor: (d: Datum) => d.step, - }; -} - -export let timeFormatter = Plottable.Formatters.time('%a %b %e, %H:%M:%S'); - -export function wallX(): XComponents { - let scale = new Plottable.Scales.Time(); - return { - scale: scale, - axis: new Plottable.Axes.Time(scale, 'bottom'), - accessor: (d: Datum) => d.wall_time, - }; -} -export let relativeAccessor = - // tslint:disable-next-line:no-any be quiet tsc - (d: any, index: number, dataset: Plottable.Dataset) => { - // We may be rendering the final-point datum for scatterplot. - // If so, we will have already provided the 'relative' property - if (d.relative != null) { - return d.relative; - } - let data = dataset.data(); - // I can't imagine how this function would be called when the data is - // empty (after all, it iterates over the data), but lets guard just - // to be safe. - let first = data.length > 0 ? +data[0].wall_time : 0; - return (+d.wall_time - first) / (60 * 60 * 1000); // ms to hours - }; - -export let relativeFormatter = (n: number) => { - // we will always show 2 units of precision, e.g days and hours, or - // minutes and seconds, but not hours and minutes and seconds - let ret = ''; - let days = Math.floor(n / 24); - n -= (days * 24); - if (days) { - ret += days + 'd '; - } - let hours = Math.floor(n); - n -= hours; - n *= 60; - if (hours || days) { - ret += hours + 'h '; - } - let minutes = Math.floor(n); - n -= minutes; - n *= 60; - if (minutes || hours || days) { - ret += minutes + 'm '; - } - let seconds = Math.floor(n); - return ret + seconds + 's'; -}; -export function relativeX(): XComponents { - let scale = new Plottable.Scales.Linear(); - return { - scale: scale, - axis: new Plottable.Axes.Numeric(scale, 'bottom'), - accessor: relativeAccessor, - }; -} - -// a very literal definition of NaN: true for NaN for a non-number type -// or null, etc. False for Infinity or -Infinity -export let isNaN = (x) => +x !== x; - -export function getXComponents(xType: string): XComponents { - switch (xType) { - case 'step': - return stepX(); - case 'wall_time': - return wallX(); - case 'relative': - return relativeX(); - default: - throw new Error('invalid xType: ' + xType); - } -} diff --git a/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.html b/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.html deleted file mode 100644 index 38e0d7cb8d8..00000000000 --- a/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.html +++ /dev/null @@ -1,131 +0,0 @@ - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.ts b/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.ts deleted file mode 100644 index 5da6190ea24..00000000000 --- a/tensorflow/tensorboard/components/vz_line_chart/vz-line-chart.ts +++ /dev/null @@ -1,773 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -/* tslint:disable:no-namespace variable-name */ - -import {DragZoomLayer} from './dragZoomInteraction' -import * as ChartHelpers from './vz-chart-helpers' - -Polymer({ - is: 'vz-line-chart', - properties: { - /** - * Scale that maps series names to colors. The default colors are from - * d3.schemeCategory10. Use this property to replace the default line - * colors with colors of your own choice. - * @type {Plottable.Scales.Color} - * @required - */ - colorScale: { - type: Object, - value: function() { - return new Plottable.Scales.Color().range(d3.schemeCategory10); - } - }, - - /** - * Whether smoothing is enabled or not. If true, smoothed lines will be - * plotted in the chart while the unsmoothed lines will be ghosted in - * the background. - * - * The smoothing algorithm is a simple moving average, which, given a - * point p and a window w, replaces p with a simple average of the - * points in the [p - floor(w/2), p + floor(w/2)] range. If there - * aren't enough points to cover the entire window to the left, the - * window is reduced to fit exactly the amount of elements available. - * This means that the smoothed line will be less in and gradually - * become more smooth until the desired window is reached. However when - * there aren't enough points on the right, the line stops being - * rendered at all. - */ - smoothingEnabled: {type: Boolean, value: false}, - - /** - * Weight (between 0.0 and 1.0) of the smoothing. This weight controls - * the window size, and a weight of 1.0 means using 50% of the entire - * dataset as the window, while a weight of 0.0 means using a window of - * 0 (and thus replacing each point with themselves). - * - * The growth between 0.0 and 1.0 is not linear though. Because - * changing the window from 0% to 30% of the dataset smooths the line a - * lot more than changing the window from 70% to 100%, an exponential - * function is used instead: http://i.imgur.com/bDrhEZU.png. This - * function increases the size of the window slowly at the beginning - * and gradually speeds up the growth, but 0.0 still means a window of - * 0 and 1.0 still means a window of the dataset's length. - */ - smoothingWeight: {type: Number, value: 0.6}, - - /** - * The way to display the X values. Allows: - * - "step" - Linear scale using the "step" property of the datum. - * - "wall_time" - Temporal scale using the "wall_time" property of the - * datum. - * - "relative" - Temporal scale using the "relative" property of the - * datum if it is present or calculating from "wall_time" if it isn't. - */ - xType: {type: String, value: 'step'}, - - /** - * The scale for the y-axis. Allows: - * - "linear" - linear scale (Plottable.Scales.Linear) - * - "log" - modified-log scale (Plottable.Scales.ModifiedLog) - */ - yScaleType: {type: String, value: 'linear'}, - - /** - * Whether to ignore outlier data when computing the yScale domain. - */ - - ignoreYOutliers: { - type: Boolean, - value: false, - }, - - /** - * Change how the tooltip is sorted. Allows: - * - "default" - Sort the tooltip by input order. - * - "ascending" - Sort the tooltip by ascending value. - * - "descending" - Sort the tooltip by descending value. - * - "nearest" - Sort the tooltip by closest to cursor. - */ - tooltipSortingMethod: {type: String, value: 'default'}, - - /** - * Change how the tooltip is positioned. Allows: - * - "bottom" - Position the tooltip on the bottom of the chart. - * - "right" - Position the tooltip to the right of the chart. - */ - tooltipPosition: {type: String, value: 'bottom'}, - - _attached: Boolean, - _chart: Object, - _visibleSeriesCache: { - type: Array, - value: function() { - return [] - } - }, - _seriesDataCache: { - type: Object, - value: function() { - return {} - } - }, - _makeChartAsyncCallbackId: {type: Number, value: null} - }, - observers: [ - '_makeChart(xType, yScaleType, colorScale, _attached)', - '_reloadFromCache(_chart)', - '_smoothingChanged(smoothingEnabled, smoothingWeight, _chart)', - '_tooltipSortingMethodChanged(tooltipSortingMethod, _chart)', - '_tooltipPositionChanged(tooltipPosition, _chart)', - '_outliersChanged(ignoreYOutliers, _chart)' - ], - - /** - * Sets the series that the chart displays. Series with other names will - * not be displayed. - * - * @param {Array} names Array with the names of the series to - * display. - */ - setVisibleSeries: function(names) { - this._visibleSeriesCache = names; - if (this._chart) { - this._chart.setVisibleSeries(names); - this.redraw(); - } - }, - - /** - * Sets the data of one of the series. Note that to display this series - * its name must be in the setVisibleSeries() array. - * - * @param {string} name Name of the series. - * @param {Array} data Data of the series. This is - * an array of objects with at least the following properties: - * - step: (Number) - index of the datum. - * - wall_time: (Date) - Date object with the datum's time. - * - scalar: (Number) - Value of the datum. - */ - setSeriesData: function(name, data) { - this._seriesDataCache[name] = data; - if (this._chart) { - this._chart.setSeriesData(name, data); - } - }, - - /** - * Re-renders the chart. Useful if e.g. the container size changed. - */ - redraw: function() { - this._chart.redraw(); - }, - attached: function() { - this._attached = true; - }, - detached: function() { - this._attached = false; - }, - ready: function() { - this.scopeSubtree(this.$.tooltip, true); - this.scopeSubtree(this.$.chartdiv, true); - }, - _makeChart: function(xType, yScaleType, colorScale, _attached) { - if (this._makeChartAsyncCallbackId !== null) { - this.cancelAsync(this._makeChartAsyncCallbackId); - this._makeChartAsyncCallbackId = null; - } - - this._makeChartAsyncCallbackId = this.async(function() { - this._makeChartAsyncCallbackId = null; - if (!this._attached) return; - if (this._chart) this._chart.destroy(); - var tooltip = d3.select(this.$.tooltip); - var chart = new LineChart(xType, yScaleType, colorScale, tooltip); - var div = d3.select(this.$.chartdiv); - chart.renderTo(div); - this._chart = chart; - }, 350); - }, - _reloadFromCache: function() { - if (this._chart) { - this._chart.setVisibleSeries(this._visibleSeriesCache); - this._visibleSeriesCache.forEach(function(name) { - this._chart.setSeriesData(name, this._seriesDataCache[name] || []); - }.bind(this)); - } - }, - _smoothingChanged: function() { - if (!this._chart) { - return; - } - if (this.smoothingEnabled) { - this._chart.smoothingUpdate(this.smoothingWeight); - } else { - this._chart.smoothingDisable(); - } - }, - _outliersChanged: function() { - if (!this._chart) { - return; - } - this._chart.ignoreYOutliers(this.ignoreYOutliers); - }, - _tooltipSortingMethodChanged: function() { - if (this._chart) { - this._chart.setTooltipSortingMethod(this.tooltipSortingMethod); - } - }, - _tooltipPositionChanged: function() { - if (this._chart) { - this._chart.setTooltipPosition(this.tooltipPosition); - } - } -}); - -class LineChart { - private name2datasets: {[name: string]: Plottable.Dataset}; - private seriesNames: string[]; - - private xAccessor: Plottable.IAccessor; - private xScale: Plottable.QuantitativeScale; - private yScale: Plottable.QuantitativeScale; - private gridlines: Plottable.Components.Gridlines; - private center: Plottable.Components.Group; - private xAxis: Plottable.Axes.Numeric|Plottable.Axes.Time; - private yAxis: Plottable.Axes.Numeric; - private outer: Plottable.Components.Table; - private colorScale: Plottable.Scales.Color; - private tooltip: d3.Selection; - private dzl: DragZoomLayer; - - private linePlot: Plottable.Plots.Line; - private smoothLinePlot: Plottable.Plots.Line; - private scatterPlot: Plottable.Plots.Scatter; - private nanDisplay: Plottable.Plots.Scatter; - private scalarAccessor: Plottable.IAccessor; - private smoothedAccessor: Plottable.IAccessor; - private lastPointsDataset: Plottable.Dataset; - private datasets: Plottable.Dataset[]; - private onDatasetChanged: (dataset: Plottable.Dataset) => void; - private nanDataset: Plottable.Dataset; - private smoothingWeight: number; - private smoothingEnabled: Boolean; - private tooltipSortingMethod: string; - private tooltipPosition: string; - private _ignoreYOutliers: boolean; - - private targetSVG: d3.Selection; - - constructor( - xType: string, yScaleType: string, colorScale: Plottable.Scales.Color, - tooltip: d3.Selection) { - this.seriesNames = []; - this.name2datasets = {}; - this.colorScale = colorScale; - this.tooltip = tooltip; - this.datasets = []; - this._ignoreYOutliers = false; - // lastPointDataset is a dataset that contains just the last point of - // every dataset we're currently drawing. - this.lastPointsDataset = new Plottable.Dataset(); - this.nanDataset = new Plottable.Dataset(); - // need to do a single bind, so we can deregister the callback from - // old Plottable.Datasets. (Deregistration is done by identity checks.) - this.onDatasetChanged = this._onDatasetChanged.bind(this); - this.buildChart(xType, yScaleType); - } - - private buildChart(xType: string, yScaleType: string) { - if (this.outer) { - this.outer.destroy(); - } - let xComponents = ChartHelpers.getXComponents(xType); - this.xAccessor = xComponents.accessor; - this.xScale = xComponents.scale; - this.xAxis = xComponents.axis; - this.xAxis.margin(0).tickLabelPadding(3); - this.yScale = LineChart.getYScaleFromType(yScaleType); - this.yAxis = new Plottable.Axes.Numeric(this.yScale, 'left'); - let yFormatter = ChartHelpers.multiscaleFormatter( - ChartHelpers.Y_AXIS_FORMATTER_PRECISION); - this.yAxis.margin(0).tickLabelPadding(5).formatter(yFormatter); - this.yAxis.usesTextWidthApproximation(true); - - this.dzl = new DragZoomLayer( - this.xScale, this.yScale, this.updateSpecialDatasets.bind(this)); - - let center = this.buildPlot(this.xAccessor, this.xScale, this.yScale); - - this.gridlines = - new Plottable.Components.Gridlines(this.xScale, this.yScale); - - let xZeroLine = new Plottable.Components.GuideLineLayer('horizontal'); - xZeroLine.scale(this.yScale).value(0); - let yZeroLine = new Plottable.Components.GuideLineLayer('vertical'); - yZeroLine.scale(this.xScale).value(0); - - this.center = new Plottable.Components.Group( - [this.gridlines, xZeroLine, yZeroLine, center, this.dzl]); - this.outer = new Plottable.Components.Table( - [[this.yAxis, this.center], [null, this.xAxis]]); - } - - private buildPlot(xAccessor, xScale, yScale): Plottable.Component { - this.scalarAccessor = (d: ChartHelpers.ScalarDatum) => d.scalar; - this.smoothedAccessor = (d: ChartHelpers.ScalarDatum) => d.smoothed; - let linePlot = new Plottable.Plots.Line(); - linePlot.x(xAccessor, xScale); - linePlot.y(this.scalarAccessor, yScale); - linePlot.attr( - 'stroke', - (d: ChartHelpers.Datum, i: number, dataset: Plottable.Dataset) => - this.colorScale.scale(dataset.metadata().name)); - this.linePlot = linePlot; - let group = this.setupTooltips(linePlot); - - let smoothLinePlot = new Plottable.Plots.Line(); - smoothLinePlot.x(xAccessor, xScale); - smoothLinePlot.y(this.smoothedAccessor, yScale); - smoothLinePlot.attr( - 'stroke', - (d: ChartHelpers.Datum, i: number, dataset: Plottable.Dataset) => - this.colorScale.scale(dataset.metadata().name)); - this.smoothLinePlot = smoothLinePlot; - - // The scatterPlot will display the last point for each dataset. - // This way, if there is only one datum for the series, it is still - // visible. We hide it when tooltips are active to keep things clean. - let scatterPlot = new Plottable.Plots.Scatter(); - scatterPlot.x(xAccessor, xScale); - scatterPlot.y(this.scalarAccessor, yScale); - scatterPlot.attr('fill', (d: any) => this.colorScale.scale(d.name)); - scatterPlot.attr('opacity', 1); - scatterPlot.size(ChartHelpers.TOOLTIP_CIRCLE_SIZE * 2); - scatterPlot.datasets([this.lastPointsDataset]); - this.scatterPlot = scatterPlot; - - let nanDisplay = new Plottable.Plots.Scatter(); - nanDisplay.x(xAccessor, xScale); - nanDisplay.y((x) => x.displayY, yScale); - nanDisplay.attr('fill', (d: any) => this.colorScale.scale(d.name)); - nanDisplay.attr('opacity', 1); - nanDisplay.size(ChartHelpers.NAN_SYMBOL_SIZE * 2); - nanDisplay.datasets([this.nanDataset]); - nanDisplay.symbol(Plottable.SymbolFactories.triangle); - this.nanDisplay = nanDisplay; - - return new Plottable.Components.Group( - [nanDisplay, scatterPlot, smoothLinePlot, group]); - } - - /** Updates the chart when a dataset changes. Called every time the data of - * a dataset changes to update the charts. - */ - private _onDatasetChanged(dataset: Plottable.Dataset) { - if (this.smoothingEnabled) { - this.resmoothDataset(dataset); - } - this.updateSpecialDatasets(); - } - - public ignoreYOutliers(ignoreYOutliers: boolean) { - if (ignoreYOutliers !== this._ignoreYOutliers) { - this._ignoreYOutliers = ignoreYOutliers; - this.updateSpecialDatasets(); - } - } - - private updateSpecialDatasets() { - if (this.smoothingEnabled) { - this.updateSpecialDatasetsWithAccessor(this.smoothedAccessor); - } else { - this.updateSpecialDatasetsWithAccessor(this.scalarAccessor); - } - } - - /** Constructs special datasets. Each special dataset contains exceptional - * values from all of the regular datasets, e.g. last points in series, or - * NaN values. Those points will have a `name` and `relative` property added - * (since usually those are context in the surrounding dataset). - * The accessor will point to the correct data to access. - */ - private updateSpecialDatasetsWithAccessor(accessor: - Plottable.IAccessor) { - let lastPointsData = - this.datasets - .map((d) => { - let datum = null; - // filter out NaNs to ensure last point is a clean one - let nonNanData = - d.data().filter((x) => !isNaN(accessor(x, -1, d))); - if (nonNanData.length > 0) { - let idx = nonNanData.length - 1; - datum = nonNanData[idx]; - datum.name = d.metadata().name; - datum.relative = ChartHelpers.relativeAccessor(datum, -1, d); - } - return datum; - }) - .filter((x) => x != null); - this.lastPointsDataset.data(lastPointsData); - - // Take a dataset, return an array of NaN data points - // the NaN points will have a "displayY" property which is the - // y-value of a nearby point that was not NaN (0 if all points are NaN) - let datasetToNaNData = (d: Plottable.Dataset) => { - let displayY = null; - let data = d.data(); - let i = 0; - while (i < data.length && displayY == null) { - if (!isNaN(accessor(data[i], -1, d))) { - displayY = accessor(data[i], -1, d); - } - i++; - } - if (displayY == null) { - displayY = 0; - } - let nanData = []; - for (i = 0; i < data.length; i++) { - if (!isNaN(accessor(data[i], -1, d))) { - displayY = accessor(data[i], -1, d); - } else { - data[i].name = d.metadata().name; - data[i].displayY = displayY; - data[i].relative = ChartHelpers.relativeAccessor(data[i], -1, d); - nanData.push(data[i]); - } - } - return nanData; - }; - let nanData = _.flatten(this.datasets.map(datasetToNaNData)); - this.nanDataset.data(nanData); - - let datasetToValues: (d: Plottable.Dataset) => number[] = (d) => { - return d.data().map((x) => accessor(x, -1, d)); - }; - let vals = _.flatten(this.datasets.map(datasetToValues)); - vals = vals.filter((x) => x === x && x !== Infinity && x !== -Infinity); - let domain = ChartHelpers.computeDomain(vals, this._ignoreYOutliers); - this.yScale.domain(domain); - } - - private setupTooltips(plot: Plottable.XYPlot): - Plottable.Components.Group { - let pi = new Plottable.Interactions.Pointer(); - pi.attachTo(plot); - // PointsComponent is a Plottable Component that will hold the little - // circles we draw over the closest data points - let pointsComponent = new Plottable.Component(); - let group = new Plottable.Components.Group([plot, pointsComponent]); - - let hideTooltips = () => { - this.tooltip.style('opacity', 0); - this.scatterPlot.attr('opacity', 1); - pointsComponent.content().selectAll('.point').remove(); - }; - - let enabled = true; - let disableTooltips = () => { - enabled = false; - hideTooltips(); - }; - let enableTooltips = () => { - enabled = true; - }; - - this.dzl.interactionStart(disableTooltips); - this.dzl.interactionEnd(enableTooltips); - - pi.onPointerMove((p: Plottable.Point) => { - if (!enabled) { - return; - } - let target: ChartHelpers.Point = { - x: p.x, - y: p.y, - datum: null, - dataset: null, - }; - - - let bbox: SVGRect = (this.gridlines.content().node()).getBBox(); - - // pts is the closets point to the tooltip for each dataset - let pts = plot.datasets() - .map((dataset) => this.findClosestPoint(target, dataset)) - .filter(x => x != null); - let intersectsBBox = Plottable.Utils.DOM.intersectsBBox; - // We draw tooltips for points that are NaN, or are currently visible - let ptsForTooltips = pts.filter( - (p) => intersectsBBox(p.x, p.y, bbox) || isNaN(p.datum.scalar)); - // Only draw little indicator circles for the non-NaN points - let ptsToCircle = ptsForTooltips.filter((p) => !isNaN(p.datum.scalar)); - - let ptsSelection: any = - pointsComponent.content().selectAll('.point').data( - ptsToCircle, - (p: ChartHelpers.Point) => p.dataset.metadata().name); - if (pts.length !== 0) { - ptsSelection.enter().append('circle').classed('point', true); - ptsSelection.attr('r', ChartHelpers.TOOLTIP_CIRCLE_SIZE) - .attr('cx', (p) => p.x) - .attr('cy', (p) => p.y) - .style('stroke', 'none') - .attr( - 'fill', - (p) => this.colorScale.scale(p.dataset.metadata().name)); - ptsSelection.exit().remove(); - this.drawTooltips(ptsForTooltips, target); - } else { - hideTooltips(); - } - }); - - pi.onPointerExit(hideTooltips); - - return group; - } - - private drawTooltips( - points: ChartHelpers.Point[], target: ChartHelpers.Point) { - // Formatters for value, step, and wall_time - this.scatterPlot.attr('opacity', 0); - let valueFormatter = ChartHelpers.multiscaleFormatter( - ChartHelpers.Y_TOOLTIP_FORMATTER_PRECISION); - - let dist = (p: ChartHelpers.Point) => - Math.pow(p.x - target.x, 2) + Math.pow(p.y - target.y, 2); - let closestDist = _.min(points.map(dist)); - - let valueSortMethod = this.scalarAccessor; - if (this.smoothingEnabled) { - valueSortMethod = this.smoothedAccessor; - } - - if (this.tooltipSortingMethod === 'ascending') { - points = _.sortBy(points, (d) => valueSortMethod(d.datum, -1, d.dataset)); - } else if (this.tooltipSortingMethod === 'descending') { - points = _.sortBy(points, (d) => valueSortMethod(d.datum, -1, d.dataset)) - .reverse(); - } else if (this.tooltipSortingMethod === 'nearest') { - points = _.sortBy(points, dist); - } else { - // The 'default' sorting method maintains the order of names passed to - // setVisibleSeries(). However we reverse that order when defining the - // datasets. So we must call reverse again to restore the order. - points = points.slice(0).reverse(); - } - - let rows = this.tooltip.select('tbody') - .html('') - .selectAll('tr') - .data(points) - .enter() - .append('tr'); - // Grey out the point if any of the following are true: - // - The cursor is outside of the x-extent of the dataset - // - The point's y value is NaN - rows.classed('distant', (d) => { - let firstPoint = d.dataset.data()[0]; - let lastPoint = _.last(d.dataset.data()); - let firstX = this.xScale.scale(this.xAccessor(firstPoint, 0, d.dataset)); - let lastX = this.xScale.scale(this.xAccessor(lastPoint, 0, d.dataset)); - let s = this.smoothingEnabled ? d.datum.smoothed : d.datum.scalar; - return target.x < firstX || target.x > lastX || isNaN(s); - }); - rows.classed('closest', (p) => dist(p) === closestDist); - // It is a bit hacky that we are manually applying the width to the swatch - // and the nowrap property to the text here. The reason is as follows: - // the style gets updated asynchronously by Polymer scopeSubtree observer. - // Which means we would get incorrect sizing information since the text - // would wrap by default. However, we need correct measurements so that - // we can stop the text from falling off the edge of the screen. - // therefore, we apply the size-critical styles directly. - rows.style('white-space', 'nowrap'); - rows.append('td') - .append('span') - .classed('swatch', true) - .style( - 'background-color', - (d) => this.colorScale.scale(d.dataset.metadata().name)); - rows.append('td').text((d) => d.dataset.metadata().name); - if (this.smoothingEnabled) { - rows.append('td').text( - (d) => isNaN(d.datum.smoothed) ? 'NaN' : - valueFormatter(d.datum.smoothed)); - } - rows.append('td').text( - (d) => isNaN(d.datum.scalar) ? 'NaN' : valueFormatter(d.datum.scalar)); - rows.append('td').text((d) => ChartHelpers.stepFormatter(d.datum.step)); - rows.append('td').text( - (d) => ChartHelpers.timeFormatter(d.datum.wall_time)); - rows.append('td').text( - (d) => ChartHelpers.relativeFormatter( - ChartHelpers.relativeAccessor(d.datum, -1, d.dataset))); - - // compute left position - let documentWidth = document.body.clientWidth; - let node: any = this.tooltip.node(); - let parentRect = node.parentElement.getBoundingClientRect(); - let nodeRect = node.getBoundingClientRect(); - // prevent it from falling off the right side of the screen - let left = documentWidth - parentRect.left - nodeRect.width - 60, top = 0; - - if (this.tooltipPosition === 'right') { - left = Math.min(parentRect.width, left); - } else { // 'bottom' - left = Math.min(0, left); - top = parentRect.height + ChartHelpers.TOOLTIP_Y_PIXEL_OFFSET; - } - - this.tooltip.style('transform', 'translate(' + left + 'px,' + top + 'px)'); - this.tooltip.style('opacity', 1); - } - - private findClosestPoint( - target: ChartHelpers.Point, - dataset: Plottable.Dataset): ChartHelpers.Point { - let points: ChartHelpers.Point[] = dataset.data().map((d, i) => { - let x = this.xAccessor(d, i, dataset); - let y = this.smoothingEnabled ? this.smoothedAccessor(d, i, dataset) : - this.scalarAccessor(d, i, dataset); - return { - x: this.xScale.scale(x), - y: this.yScale.scale(y), - datum: d, - dataset: dataset, - }; - }); - let idx: number = - _.sortedIndex(points, target, (p: ChartHelpers.Point) => p.x); - if (idx === points.length) { - return points[points.length - 1]; - } else if (idx === 0) { - return points[0]; - } else { - let prev = points[idx - 1]; - let next = points[idx]; - let prevDist = Math.abs(prev.x - target.x); - let nextDist = Math.abs(next.x - target.x); - return prevDist < nextDist ? prev : next; - } - } - - private resmoothDataset(dataset: Plottable.Dataset) { - let data = dataset.data(); - const smoothingWeight = this.smoothingWeight; - let last = data.length > 0 ? data[0].scalar : NaN; - data.forEach((d) => { - if (!_.isFinite(last)) { - d.smoothed = d.scalar; - } else { - // 1st-order IIR low-pass filter to attenuate the higher- - // frequency components of the time-series. - d.smoothed = last * smoothingWeight + (1 - smoothingWeight) * d.scalar; - } - last = d.smoothed; - }); - } - - private getDataset(name: string) { - if (this.name2datasets[name] === undefined) { - this.name2datasets[name] = new Plottable.Dataset([], {name: name}); - } - return this.name2datasets[name]; - } - - static getYScaleFromType(yScaleType: string): - Plottable.QuantitativeScale { - if (yScaleType === 'log') { - return new Plottable.Scales.ModifiedLog(); - } else if (yScaleType === 'linear') { - return new Plottable.Scales.Linear(); - } else { - throw new Error('Unrecognized yScale type ' + yScaleType); - } - } - - /** - * Update the selected series on the chart. - */ - public setVisibleSeries(names: string[]) { - names = names.sort(); - this.seriesNames = names; - - names.reverse(); // draw first series on top - this.datasets.forEach((d) => d.offUpdate(this.onDatasetChanged)); - this.datasets = names.map((r) => this.getDataset(r)); - this.datasets.forEach((d) => d.onUpdate(this.onDatasetChanged)); - this.linePlot.datasets(this.datasets); - - if (this.smoothingEnabled) { - this.smoothLinePlot.datasets(this.datasets); - } - this.updateSpecialDatasets(); - } - - /** - * Set the data of a series on the chart. - */ - public setSeriesData(name: string, data: ChartHelpers.ScalarDatum[]) { - this.getDataset(name).data(data); - } - - public smoothingUpdate(weight: number) { - this.smoothingWeight = weight; - this.datasets.forEach((d) => this.resmoothDataset(d)); - - if (!this.smoothingEnabled) { - this.linePlot.addClass('ghost'); - this.scatterPlot.y(this.smoothedAccessor, this.yScale); - this.smoothingEnabled = true; - this.smoothLinePlot.datasets(this.datasets); - } - - this.updateSpecialDatasetsWithAccessor(this.smoothedAccessor); - } - - public smoothingDisable() { - if (this.smoothingEnabled) { - this.linePlot.removeClass('ghost'); - this.scatterPlot.y(this.scalarAccessor, this.yScale); - this.smoothLinePlot.datasets([]); - this.smoothingEnabled = false; - this.updateSpecialDatasetsWithAccessor(this.scalarAccessor); - } - } - - public setTooltipSortingMethod(method: string) { - this.tooltipSortingMethod = method; - } - - public setTooltipPosition(position: string) { - this.tooltipPosition = position; - } - - public renderTo(targetSVG: d3.Selection) { - this.targetSVG = targetSVG; - this.outer.renderTo(targetSVG); - } - - public redraw() { - this.outer.redraw(); - } - - public destroy() { - this.outer.destroy(); - } -} diff --git a/tensorflow/tensorboard/components/vz_projector/BUILD b/tensorflow/tensorboard/components/vz_projector/BUILD deleted file mode 100644 index acc1312a944..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/BUILD +++ /dev/null @@ -1,110 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "vz_projector", - srcs = [ - "analyticsLogger.ts", - "bundle.html", - "data.ts", - "data-provider.ts", - "data-provider-demo.ts", - "data-provider-proto.ts", - "data-provider-server.ts", - "external.d.ts", - "knn.ts", - "label.ts", - "logging.ts", - "projectorEventContext.ts", - "projectorScatterPlotAdapter.ts", - "renderContext.ts", - "scatterPlot.ts", - "scatterPlotRectangleSelector.ts", - "scatterPlotVisualizer.ts", - "scatterPlotVisualizer3DLabels.ts", - "scatterPlotVisualizerCanvasLabels.ts", - "scatterPlotVisualizerPolylines.ts", - "scatterPlotVisualizerSprites.ts", - "styles.html", - "util.ts", - "vector.ts", - "vz-projector.html", - "vz-projector.ts", - "vz-projector-app.html", - "vz-projector-bookmark-panel.html", - "vz-projector-bookmark-panel.ts", - "vz-projector-colab.html", - "vz-projector-dashboard.html", - "vz-projector-data-panel.html", - "vz-projector-data-panel.ts", - "vz-projector-input.html", - "vz-projector-input.ts", - "vz-projector-inspector-panel.html", - "vz-projector-inspector-panel.ts", - "vz-projector-legend.html", - "vz-projector-legend.ts", - "vz-projector-metadata-card.html", - "vz-projector-metadata-card.ts", - "vz-projector-projections-panel.html", - "vz-projector-projections-panel.ts", - "vz-projector-util.ts", - ], - path = "/vz-projector", - visibility = ["//visibility:public"], - deps = [ - ":bh_tsne", - ":heap", - ":sptree", - "//tensorflow/tensorboard/components/tf_dashboard_common", - "//tensorflow/tensorboard/components/tf_imports:d3", - "//tensorflow/tensorboard/components/tf_imports:numericjs", - "//tensorflow/tensorboard/components/tf_imports:polymer", - "//tensorflow/tensorboard/components/tf_imports:threejs", - "//tensorflow/tensorboard/components/tf_imports:weblas", - "@org_polymer_iron_collapse", - "@org_polymer_iron_icons", - "@org_polymer_paper_button", - "@org_polymer_paper_checkbox", - "@org_polymer_paper_dialog", - "@org_polymer_paper_dialog_scrollable", - "@org_polymer_paper_dropdown_menu", - "@org_polymer_paper_icon_button", - "@org_polymer_paper_input", - "@org_polymer_paper_item", - "@org_polymer_paper_listbox", - "@org_polymer_paper_slider", - "@org_polymer_paper_spinner", - "@org_polymer_paper_styles", - "@org_polymer_paper_toast", - "@org_polymer_paper_toggle_button", - "@org_polymer_paper_tooltip", - ], -) - -ts_web_library( - name = "heap", - srcs = ["heap.ts"], - path = "/vz-projector", -) - -ts_web_library( - name = "sptree", - srcs = ["sptree.ts"], - path = "/vz-projector", -) - -ts_web_library( - name = "bh_tsne", - srcs = ["bh_tsne.ts"], - path = "/vz-projector", - deps = [":sptree"], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/vz_projector/analyticsLogger.ts b/tensorflow/tensorboard/components/vz_projector/analyticsLogger.ts deleted file mode 100644 index aa1f86927da..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/analyticsLogger.ts +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -import {ProjectionType} from './data'; - -export class AnalyticsLogger { - private eventLogging: boolean; - private pageViewLogging: boolean; - - /** - * Constructs an event logger using Google Analytics. It assumes there is a - * Google Analytics script added to the page elsewhere. If there is no such - * script, the logger acts as a no-op. - * - * @param pageViewLogging Whether to log page views. - * @param eventLogging Whether to log user interaction. - */ - constructor(pageViewLogging: boolean, eventLogging: boolean) { - if (typeof ga === 'undefined' || ga == null) { - this.eventLogging = false; - this.pageViewLogging = false; - return; - } - this.eventLogging = eventLogging; - this.pageViewLogging = pageViewLogging; - } - - logPageView(pageTitle: string) { - if (this.pageViewLogging) { - // Always send a page view. - ga('send', {hitType: 'pageview', page: `/v/${pageTitle}`}); - } - } - - logProjectionChanged(projection: ProjectionType) { - if (this.eventLogging) { - ga('send', { - hitType: 'event', - eventCategory: 'Projection', - eventAction: 'click', - eventLabel: projection - }); - } - } - - logWebGLDisabled() { - if (this.eventLogging) { - ga('send', { - hitType: 'event', - eventCategory: 'Error', - eventAction: 'PageLoad', - eventLabel: 'WebGL_disabled' - }); - } - } -} diff --git a/tensorflow/tensorboard/components/vz_projector/bh_tsne.ts b/tensorflow/tensorboard/components/vz_projector/bh_tsne.ts deleted file mode 100644 index 063d57ec401..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/bh_tsne.ts +++ /dev/null @@ -1,473 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/** - * This is a fork of the Karpathy's TSNE.js (original license below). - * This fork implements Barnes-Hut approximation and runs in O(NlogN) - * time, as opposed to the Karpathy's O(N^2) version. - * - * @author smilkov@google.com (Daniel Smilkov) - */ - -/** - * @license - * The MIT License (MIT) - * Copyright (c) 2015 Andrej Karpathy - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - */ - -import {SPNode, SPTree} from './sptree'; - -type AugmSPNode = SPNode&{numCells: number, yCell: number[], rCell: number}; - -/** - * Barnes-hut approximation level. Higher means more approximation and faster - * results. Recommended value mentioned in the paper is 0.8. - */ -const THETA = 0.8; - -const MIN_POSSIBLE_PROB = 1E-9; - -// Variables used for memorizing the second random number since running -// gaussRandom() generates two random numbers at the cost of 1 atomic -// computation. This optimization results in 2X speed-up of the generator. -let return_v = false; -let v_val = 0.0; - -/** Returns the square euclidean distance between two vectors. */ -export function dist2(a: number[], b: number[]): number { - if (a.length !== b.length) { - throw new Error('Vectors a and b must be of same length'); - } - - let result = 0; - for (let i = 0; i < a.length; ++i) { - let diff = a[i] - b[i]; - result += diff * diff; - } - return result; -} - -/** Returns the square euclidean distance between two 2D points. */ -export function dist2_2D(a: number[], b: number[]): number { - let dX = a[0] - b[0]; - let dY = a[1] - b[1]; - return dX * dX + dY * dY; -} - -/** Returns the square euclidean distance between two 3D points. */ -export function dist2_3D(a: number[], b: number[]): number { - let dX = a[0] - b[0]; - let dY = a[1] - b[1]; - let dZ = a[2] - b[2]; - return dX * dX + dY * dY + dZ * dZ; -} - -function gaussRandom(rng: () => number): number { - if (return_v) { - return_v = false; - return v_val; - } - let u = 2 * rng() - 1; - let v = 2 * rng() - 1; - let r = u * u + v * v; - if (r === 0 || r > 1) { - return gaussRandom(rng); - } - let c = Math.sqrt(-2 * Math.log(r) / r); - v_val = v * c; // cache this for next function call for efficiency - return_v = true; - return u * c; -}; - -// return random normal number -function randn(rng: () => number, mu: number, std: number) { - return mu + gaussRandom(rng) * std; -}; - -// utilitity that creates contiguous vector of zeros of size n -function zeros(n: number): Float64Array { - return new Float64Array(n); -}; - -// utility that returns a matrix filled with random numbers -// generated by the provided generator. -function randnMatrix(n: number, d: number, rng: () => number) { - let nd = n * d; - let x = zeros(nd); - for (let i = 0; i < nd; ++i) { - x[i] = randn(rng, 0.0, 1E-4); - } - return x; -}; - -// utility that returns a matrix filled with the provided value. -function arrayofs(n: number, d: number, val: number) { - let x: number[][] = []; - for (let i = 0; i < n; ++i) { - x.push(d === 3 ? [val, val, val] : [val, val]); - } - return x; -}; - -// compute (p_{i|j} + p_{j|i})/(2n) -function nearest2P( - nearest: {index: number, dist: number}[][], perplexity: number, - tol: number) { - let N = nearest.length; - let Htarget = Math.log(perplexity); // target entropy of distribution - let P = zeros(N * N); // temporary probability matrix - let K = nearest[0].length; - let pRow: number[] = new Array(K); // pij[]. - - for (let i = 0; i < N; ++i) { - let neighbors = nearest[i]; - let betaMin = -Infinity; - let betaMax = Infinity; - let beta = 1; // initial value of precision - let maxTries = 50; - - // perform binary search to find a suitable precision beta - // so that the entropy of the distribution is appropriate - let numTries = 0; - while (true) { - // compute entropy and kernel row with beta precision - let psum = 0.0; - for (let k = 0; k < neighbors.length; ++k) { - let neighbor = neighbors[k]; - let pij = (i === neighbor.index) ? 0 : Math.exp(-neighbor.dist * beta); - pij = Math.max(pij, MIN_POSSIBLE_PROB); - pRow[k] = pij; - psum += pij; - } - // normalize p and compute entropy - let Hhere = 0.0; - for (let k = 0; k < pRow.length; ++k) { - pRow[k] /= psum; - let pij = pRow[k]; - if (pij > 1E-7) { - Hhere -= pij * Math.log(pij); - }; - } - - // adjust beta based on result - if (Hhere > Htarget) { - // entropy was too high (distribution too diffuse) - // so we need to increase the precision for more peaky distribution - betaMin = beta; // move up the bounds - if (betaMax === Infinity) { - beta = beta * 2; - } else { - beta = (beta + betaMax) / 2; - } - - } else { - // converse case. make distrubtion less peaky - betaMax = beta; - if (betaMin === -Infinity) { - beta = beta / 2; - } else { - beta = (beta + betaMin) / 2; - } - } - numTries++; - // stopping conditions: too many tries or got a good precision - if (numTries >= maxTries || Math.abs(Hhere - Htarget) < tol) { - break; - } - } - - // copy over the final prow to P at row i - for (let k = 0; k < pRow.length; ++k) { - let pij = pRow[k]; - let j = neighbors[k].index; - P[i * N + j] = pij; - } - } // end loop over examples i - - // symmetrize P and normalize it to sum to 1 over all ij - let N2 = N * 2; - for (let i = 0; i < N; ++i) { - for (let j = i + 1; j < N; ++j) { - let i_j = i * N + j; - let j_i = j * N + i; - let value = (P[i_j] + P[j_i]) / N2; - P[i_j] = value; - P[j_i] = value; - } - } - return P; -}; - -// helper function -function sign(x: number) { - return x > 0 ? 1 : x < 0 ? -1 : 0; -} - -function computeForce_2d( - force: number[], mult: number, pointA: number[], pointB: number[]) { - force[0] += mult * (pointA[0] - pointB[0]); - force[1] += mult * (pointA[1] - pointB[1]); -} - -function computeForce_3d( - force: number[], mult: number, pointA: number[], pointB: number[]) { - force[0] += mult * (pointA[0] - pointB[0]); - force[1] += mult * (pointA[1] - pointB[1]); - force[2] += mult * (pointA[2] - pointB[2]); -} - -export interface TSNEOptions { - /** How many dimensions. */ - dim: number; - /** Roughly how many neighbors each point influences. */ - perplexity?: number; - /** Learning rate. */ - epsilon?: number; - /** A random number generator. */ - rng?: () => number; -} - -export class TSNE { - private perplexity: number; - private epsilon: number; - /** Random generator */ - private rng: () => number; - private iter = 0; - private Y: Float64Array; - private N: number; - private P: Float64Array; - private gains: number[][]; - private ystep: number[][]; - private nearest: {index: number, dist: number}[][]; - private dim: number; - private dist2: (a: number[], b: number[]) => number; - private computeForce: - (force: number[], mult: number, pointA: number[], - pointB: number[]) => void; - - constructor(opt: TSNEOptions) { - opt = opt || {dim: 2}; - this.perplexity = opt.perplexity || 30; - this.epsilon = opt.epsilon || 10; - this.rng = opt.rng || Math.random; - this.dim = opt.dim; - if (opt.dim === 2) { - this.dist2 = dist2_2D; - this.computeForce = computeForce_2d; - } else if (opt.dim === 3) { - this.dist2 = dist2_3D; - this.computeForce = computeForce_3d; - } else { - throw new Error('Only 2D and 3D is supported'); - } - } - - // this function takes a fattened distance matrix and creates - // matrix P from them. - // D is assumed to be provided as an array of size N^2. - initDataDist(nearest: {index: number, dist: number}[][]) { - let N = nearest.length; - this.nearest = nearest; - this.P = nearest2P(nearest, this.perplexity, 1E-4); - this.N = N; - this.initSolution(); // refresh this - } - - // (re)initializes the solution to random - initSolution() { - // generate random solution to t-SNE - this.Y = randnMatrix(this.N, this.dim, this.rng); // the solution - this.gains = arrayofs(this.N, this.dim, 1.0); // step gains - // to accelerate progress in unchanging directions - this.ystep = arrayofs(this.N, this.dim, 0.0); // momentum accumulator - this.iter = 0; - } - - // return pointer to current solution - getSolution() { return this.Y; } - - // perform a single step of optimization to improve the embedding - step() { - this.iter += 1; - let N = this.N; - - let grad = this.costGrad(this.Y); // evaluate gradient - - // perform gradient step - let ymean = this.dim === 3 ? [0, 0, 0] : [0, 0]; - for (let i = 0; i < N; ++i) { - for (let d = 0; d < this.dim; ++d) { - let gid = grad[i][d]; - let sid = this.ystep[i][d]; - let gainid = this.gains[i][d]; - - // compute gain update - let newgain = sign(gid) === sign(sid) ? gainid * 0.8 : gainid + 0.2; - if (newgain < 0.01) { - newgain = 0.01; // clamp - } - this.gains[i][d] = newgain; // store for next turn - - // compute momentum step direction - let momval = this.iter < 250 ? 0.5 : 0.8; - let newsid = momval * sid - this.epsilon * newgain * grad[i][d]; - this.ystep[i][d] = newsid; // remember the step we took - - // step! - let i_d = i * this.dim + d; - this.Y[i_d] += newsid; - ymean[d] += this.Y[i_d]; // accumulate mean so that we - // can center later - } - } - - // reproject Y to be zero mean - for (let i = 0; i < N; ++i) { - for (let d = 0; d < this.dim; ++d) { - this.Y[i * this.dim + d] -= ymean[d] / N; - } - } - } - - // return cost and gradient, given an arrangement - costGrad(Y: Float64Array): number[][] { - let N = this.N; - let P = this.P; - - // Trick that helps with local optima. - let alpha = this.iter < 100 ? 4 : 1; - - // Make data for the SP tree. - let points: number[][] = new Array(N); // (x, y)[] - for (let i = 0; i < N; ++i) { - let iTimesD = i * this.dim; - let row = new Array(this.dim); - for (let d = 0; d < this.dim; ++d) { - row[d] = Y[iTimesD + d]; - } - points[i] = row; - } - - // Make a tree. - let tree = new SPTree(points); - let root = tree.root as AugmSPNode; - // Annotate the tree. - - let annotateTree = - (node: AugmSPNode): {numCells: number, yCell: number[]} => { - let numCells = 1; - if (node.children == null) { - // Update the current node and tell the parent. - node.numCells = numCells; - node.yCell = node.point; - return {numCells, yCell: node.yCell}; - } - // node.point is a 2 or 3-dim number[], so slice() makes a copy. - let yCell = node.point.slice(); - for (let i = 0; i < node.children.length; ++i) { - let child = node.children[i]; - if (child == null) { - continue; - } - let result = annotateTree(child as AugmSPNode); - numCells += result.numCells; - for (let d = 0; d < this.dim; ++d) { - yCell[d] += result.yCell[d]; - } - } - // Update the node and tell the parent. - node.numCells = numCells; - node.yCell = yCell.map(v => v / numCells); - return {numCells, yCell}; - }; - - // Augment the tree with more info. - annotateTree(root); - tree.visit((node: AugmSPNode, low: number[], high: number[]) => { - node.rCell = high[0] - low[0]; - return false; - }); - // compute current Q distribution, unnormalized first - let grad: number[][] = []; - let Z = 0; - let forces: [number[], number[]][] = new Array(N); - for (let i = 0; i < N; ++i) { - let pointI = points[i]; - // Compute the positive forces for the i-th node. - let Fpos = this.dim === 3 ? [0, 0, 0] : [0, 0]; - let neighbors = this.nearest[i]; - for (let k = 0; k < neighbors.length; ++k) { - let j = neighbors[k].index; - let pij = P[i * N + j]; - let pointJ = points[j]; - let squaredDistItoJ = this.dist2(pointI, pointJ); - let premult = pij / (1 + squaredDistItoJ); - this.computeForce(Fpos, premult, pointI, pointJ); - } - // Compute the negative forces for the i-th node. - let FnegZ = this.dim === 3 ? [0, 0, 0] : [0, 0]; - tree.visit((node: AugmSPNode) => { - let squaredDistToCell = this.dist2(pointI, node.yCell); - // Squared distance from point i to cell. - if (node.children == null || - (squaredDistToCell > 0 && - node.rCell / Math.sqrt(squaredDistToCell) < THETA)) { - let qijZ = 1 / (1 + squaredDistToCell); - let dZ = node.numCells * qijZ; - Z += dZ; - dZ *= qijZ; - this.computeForce(FnegZ, dZ, pointI, node.yCell); - return true; - } - // Cell is too close to approximate. - let squaredDistToPoint = this.dist2(pointI, node.point); - let qijZ = 1 / (1 + squaredDistToPoint); - Z += qijZ; - qijZ *= qijZ; - this.computeForce(FnegZ, qijZ, pointI, node.point); - return false; - }, true); - forces[i] = [Fpos, FnegZ]; - } - // Normalize the negative forces and compute the gradient. - const A = 4 * alpha; - const B = 4 / Z; - for (let i = 0; i < N; ++i) { - let [FPos, FNegZ] = forces[i]; - let gsum = new Array(this.dim); - for (let d = 0; d < this.dim; ++d) { - gsum[d] = A * FPos[d] - B * FNegZ[d]; - } - grad.push(gsum); - } - return grad; - } -} diff --git a/tensorflow/tensorboard/components/vz_projector/bundle.html b/tensorflow/tensorboard/components/vz_projector/bundle.html deleted file mode 100644 index f5a25230a0b..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/bundle.html +++ /dev/null @@ -1,48 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_projector/data-provider-demo.ts b/tensorflow/tensorboard/components/vz_projector/data-provider-demo.ts deleted file mode 100644 index 1410a84a8e4..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/data-provider-demo.ts +++ /dev/null @@ -1,127 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {DataSet, SpriteAndMetadataInfo, State} from './data'; -import {ProjectorConfig, DataProvider, EmbeddingInfo, TENSORS_MSG_ID} from './data-provider'; -import * as dataProvider from './data-provider'; -import * as logging from './logging'; - -const BYTES_EXTENSION = '.bytes'; - -/** Data provider that loads data from a demo folder. */ -export class DemoDataProvider implements DataProvider { - private projectorConfigPath: string; - private projectorConfig: ProjectorConfig; - - constructor(projectorConfigPath: string) { - this.projectorConfigPath = projectorConfigPath; - } - - private getEmbeddingInfo(tensorName: string): EmbeddingInfo { - let embeddings = this.projectorConfig.embeddings; - for (let i = 0; i < embeddings.length; i++) { - let embedding = embeddings[i]; - if (embedding.tensorName === tensorName) { - return embedding; - } - } - return null; - } - - retrieveRuns(callback: (runs: string[]) => void): void { - callback(['Demo']); - } - - retrieveProjectorConfig(run: string, callback: (d: ProjectorConfig) => void) - : void { - const msgId = logging.setModalMessage('Fetching projector config...'); - - const xhr = new XMLHttpRequest(); - xhr.open('GET', this.projectorConfigPath); - xhr.onerror = (err) => { - let errorMessage = err.message; - // If the error is a valid XMLHttpResponse, it's possible this is a - // cross-origin error. - if (xhr.responseText != null) { - errorMessage = 'Cannot fetch projector config, possibly a ' + - 'Cross-Origin request error.'; - } - logging.setErrorMessage(errorMessage, 'fetching projector config'); - }; - xhr.onload = () => { - const projectorConfig = JSON.parse(xhr.responseText) as ProjectorConfig; - logging.setModalMessage(null, msgId); - this.projectorConfig = projectorConfig; - callback(projectorConfig); - }; - xhr.send(); - } - - retrieveTensor(run: string, tensorName: string, - callback: (ds: DataSet) => void) { - let embedding = this.getEmbeddingInfo(tensorName); - let url = `${embedding.tensorPath}`; - if (embedding.tensorPath.substr(-1 * BYTES_EXTENSION.length) === - BYTES_EXTENSION) { - dataProvider.retrieveTensorAsBytes( - this, this.getEmbeddingInfo(tensorName), run, tensorName, url, - callback); - } else { - logging.setModalMessage('Fetching tensors...', TENSORS_MSG_ID); - const request = new XMLHttpRequest(); - request.open('GET', url); - request.responseType = 'arraybuffer'; - - request.onerror = () => { - logging.setErrorMessage(request.responseText, 'fetching tensors'); - }; - request.onload = () => { - dataProvider.parseTensors(request.response).then(points => { - callback(new DataSet(points)); - }); - }; - request.send(); - } - } - - retrieveSpriteAndMetadata(run: string, tensorName: string, - callback: (r: SpriteAndMetadataInfo) => void) { - let embedding = this.getEmbeddingInfo(tensorName); - let spriteImagePath = null; - if (embedding.sprite && embedding.sprite.imagePath) { - spriteImagePath = embedding.sprite.imagePath; - } - dataProvider.retrieveSpriteAndMetadataInfo( - embedding.metadataPath, spriteImagePath, embedding.sprite, callback); - } - - getBookmarks( - run: string, tensorName: string, callback: (r: State[]) => void) { - let embedding = this.getEmbeddingInfo(tensorName); - let msgId = logging.setModalMessage('Fetching bookmarks...'); - - const xhr = new XMLHttpRequest(); - xhr.open('GET', embedding.bookmarksPath); - xhr.onerror = (err) => { - logging.setErrorMessage(xhr.responseText); - }; - xhr.onload = () => { - const bookmarks = JSON.parse(xhr.responseText) as State[]; - logging.setModalMessage(null, msgId); - callback(bookmarks); - }; - xhr.send(); - } -} diff --git a/tensorflow/tensorboard/components/vz_projector/data-provider-proto.ts b/tensorflow/tensorboard/components/vz_projector/data-provider-proto.ts deleted file mode 100644 index 67124a92323..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/data-provider-proto.ts +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {DataPoint, DataProto, DataSet, SpriteAndMetadataInfo, PointMetadata, State} from './data'; -import {analyzeMetadata, ProjectorConfig, DataProvider} from './data-provider'; - - -export class ProtoDataProvider implements DataProvider { - private dataProto: DataProto; - - constructor(dataProto: DataProto) { - this.dataProto = dataProto; - } - - retrieveRuns(callback: (runs: string[]) => void): void { - callback(['proto']); - } - - retrieveProjectorConfig(run: string, callback: (d: ProjectorConfig) => void) { - callback({ - modelCheckpointPath: 'proto', - embeddings: [{ - tensorName: 'proto', - tensorShape: this.dataProto.shape, - metadataPath: 'proto' - }] - }); - } - - retrieveTensor(run: string, tensorName: string, - callback: (ds: DataSet) => void) { - callback(this.flatArrayToDataset(this.dataProto.tensor)); - } - - retrieveSpriteAndMetadata(run: string, tensorName: string, - callback: (r: SpriteAndMetadataInfo) => void): void { - let columnNames = this.dataProto.metadata.columns.map(c => c.name); - let n = this.dataProto.shape[0]; - let pointsMetadata: PointMetadata[] = new Array(n); - this.dataProto.metadata.columns.forEach(c => { - let values = c.numericValues || c.stringValues; - for (let i = 0; i < n; i++) { - pointsMetadata[i] = pointsMetadata[i] || {}; - pointsMetadata[i][c.name] = values[i]; - } - }); - callback({ - stats: analyzeMetadata(columnNames, pointsMetadata), - pointsInfo: pointsMetadata - }); - } - - getBookmarks(run: string, tensorName: string, - callback: (r: State[]) => void): void { - return callback([]); - } - - private flatArrayToDataset(tensor: number[]): DataSet { - let points: DataPoint[] = []; - let n = this.dataProto.shape[0]; - let d = this.dataProto.shape[1]; - if (n * d !== tensor.length) { - throw 'The shape doesn\'t match the length of the flattened array'; - } - for (let i = 0; i < n; i++) { - let offset = i * d; - points.push({ - vector: new Float32Array(tensor.slice(offset, offset + d)), - metadata: {}, - projections: null, - index: i - }); - } - return new DataSet(points); - } -} diff --git a/tensorflow/tensorboard/components/vz_projector/data-provider-server.ts b/tensorflow/tensorboard/components/vz_projector/data-provider-server.ts deleted file mode 100644 index 02720ebf6a7..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/data-provider-server.ts +++ /dev/null @@ -1,137 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {DataSet, SpriteAndMetadataInfo, State} from './data'; -import * as dataProvider from './data-provider'; -import {DataProvider, EmbeddingInfo, ProjectorConfig} from './data-provider'; -import * as logging from './logging'; - -// Limit for the number of data points we receive from the server. -export const LIMIT_NUM_POINTS = 100000; - -/** - * Data provider that loads data provided by a python server (usually backed - * by a checkpoint file). - */ -export class ServerDataProvider implements DataProvider { - private routePrefix: string; - private runProjectorConfigCache: {[run: string]: ProjectorConfig} = {}; - - constructor(routePrefix: string) { - this.routePrefix = routePrefix; - } - - private getEmbeddingInfo(run: string, tensorName: string, - callback: (e: EmbeddingInfo) => void): void { - this.retrieveProjectorConfig(run, config => { - const embeddings = config.embeddings; - for (let i = 0; i < embeddings.length; i++) { - const embedding = embeddings[i]; - if (embedding.tensorName === tensorName) { - callback(embedding); - return; - } - } - callback(null); - }); - } - - retrieveRuns(callback: (runs: string[]) => void): void { - const msgId = logging.setModalMessage('Fetching runs...'); - - const xhr = new XMLHttpRequest(); - xhr.open('GET', `${this.routePrefix}/runs`); - xhr.onerror = (err) => { - logging.setErrorMessage(xhr.responseText, 'fetching runs'); - }; - xhr.onload = () => { - const runs = JSON.parse(xhr.responseText); - logging.setModalMessage(null, msgId); - callback(runs); - }; - xhr.send(); - } - - retrieveProjectorConfig(run: string, callback: (d: ProjectorConfig) => void) - : void { - if (run in this.runProjectorConfigCache) { - callback(this.runProjectorConfigCache[run]); - return; - } - - const msgId = logging.setModalMessage('Fetching projector config...'); - - const xhr = new XMLHttpRequest(); - xhr.open('GET', `${this.routePrefix}/info?run=${run}`); - xhr.onerror = (err) => { - logging.setErrorMessage(xhr.responseText, 'fetching projector config'); - }; - xhr.onload = () => { - const config = JSON.parse(xhr.responseText) as ProjectorConfig; - logging.setModalMessage(null, msgId); - this.runProjectorConfigCache[run] = config; - callback(config); - }; - xhr.send(); - } - - retrieveTensor(run: string, tensorName: string, - callback: (ds: DataSet) => void) { - this.getEmbeddingInfo(run, tensorName, embedding => { - dataProvider.retrieveTensorAsBytes( - this, embedding, run, tensorName, - `${this.routePrefix}/tensor?run=${run}&name=${tensorName}` + - `&num_rows=${LIMIT_NUM_POINTS}`, - callback); - }); - } - - retrieveSpriteAndMetadata(run: string, tensorName: string, - callback: (r: SpriteAndMetadataInfo) => void) { - this.getEmbeddingInfo(run, tensorName, embedding => { - let metadataPath = null; - if (embedding.metadataPath) { - metadataPath = - `${this.routePrefix}/metadata?` + - `run=${run}&name=${tensorName}&num_rows=${LIMIT_NUM_POINTS}`; - } - let spriteImagePath = null; - if (embedding.sprite && embedding.sprite.imagePath) { - spriteImagePath = - `${this.routePrefix}/sprite_image?run=${run}&name=${tensorName}`; - } - dataProvider.retrieveSpriteAndMetadataInfo(metadataPath, spriteImagePath, - embedding.sprite, callback); - }); - } - - getBookmarks( - run: string, tensorName: string, callback: (r: State[]) => void) { - const msgId = logging.setModalMessage('Fetching bookmarks...'); - - const xhr = new XMLHttpRequest(); - xhr.open( - 'GET', `${this.routePrefix}/bookmarks?run=${run}&name=${tensorName}`); - xhr.onerror = (err) => { - logging.setErrorMessage(xhr.responseText, 'fetching bookmarks'); - }; - xhr.onload = () => { - logging.setModalMessage(null, msgId); - const bookmarks = JSON.parse(xhr.responseText); - callback(bookmarks); - }; - xhr.send(); - } -} diff --git a/tensorflow/tensorboard/components/vz_projector/data-provider.ts b/tensorflow/tensorboard/components/vz_projector/data-provider.ts deleted file mode 100644 index c8eede798c6..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/data-provider.ts +++ /dev/null @@ -1,429 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {ColumnStats, DataPoint, DataSet, SpriteAndMetadataInfo, PointMetadata, State} from './data'; -import * as logging from './logging'; -import {runAsyncTask} from './util'; - -/** Maximum number of colors supported in the color map. */ -const NUM_COLORS_COLOR_MAP = 50; -const MAX_SPRITE_IMAGE_SIZE_PX = 8192; - -export const METADATA_MSG_ID = 'metadata'; -export const TENSORS_MSG_ID = 'tensors'; - -/** Matches the json format of `projector_config.proto` */ -export interface SpriteMetadata { - imagePath: string; - singleImageDim: [number, number]; -} - -/** Matches the json format of `projector_config.proto` */ -export interface EmbeddingInfo { - /** Name of the tensor. */ - tensorName: string; - /** The shape of the tensor. */ - tensorShape: [number, number]; - /** - * The path to the tensors TSV file. If empty, it is assumed that the tensor - * is stored in the checkpoint file. - */ - tensorPath?: string; - /** The path to the metadata file associated with the tensor. */ - metadataPath?: string; - /** The path to the bookmarks file associated with the tensor. */ - bookmarksPath?: string; - sprite?: SpriteMetadata; -} - -/** - * Matches the json format of `projector_config.proto` - * This should be kept in sync with the code in vz-projector-data-panel which - * holds a template for users to build a projector config JSON object from the - * projector UI. - */ -export interface ProjectorConfig { - embeddings: EmbeddingInfo[]; - modelCheckpointPath?: string; -} - -export type ServingMode = 'demo' | 'server' | 'proto'; - -/** Interface between the data storage and the UI. */ -export interface DataProvider { - /** Returns a list of run names that have embedding config files. */ - retrieveRuns(callback: (runs: string[]) => void): void; - - /** - * Returns the projector configuration: number of tensors, their shapes, - * and their associated metadata files. - */ - retrieveProjectorConfig(run: string, - callback: (d: ProjectorConfig) => void): void; - - /** Fetches and returns the tensor with the specified name. */ - retrieveTensor(run: string, tensorName: string, - callback: (ds: DataSet) => void); - - /** - * Fetches the metadata for the specified tensor. - */ - retrieveSpriteAndMetadata(run: string, tensorName: string, - callback: (r: SpriteAndMetadataInfo) => void): void; - - getBookmarks(run: string, tensorName: string, callback: (r: State[]) => void): - void; -} - -export function retrieveTensorAsBytes( - dp: DataProvider, embedding: EmbeddingInfo, run: string, tensorName: string, - tensorsPath: string, callback: (ds: DataSet) => void) { - // Get the tensor. - logging.setModalMessage('Fetching tensor values...', TENSORS_MSG_ID); - let xhr = new XMLHttpRequest(); - xhr.open('GET', tensorsPath); - xhr.responseType = 'arraybuffer'; - xhr.onprogress = (ev) => { - if (ev.lengthComputable) { - let percent = (ev.loaded * 100 / ev.total).toFixed(1); - logging.setModalMessage( - 'Fetching tensor values: ' + percent + '%', TENSORS_MSG_ID); - } - }; - xhr.onload = () => { - if (xhr.status !== 200) { - let msg = String.fromCharCode.apply(null, new Uint8Array(xhr.response)); - logging.setErrorMessage(msg, 'fetching tensors'); - return; - } - let data: Float32Array; - try { - data = new Float32Array(xhr.response); - } catch (e) { - logging.setErrorMessage(e, 'parsing tensor bytes'); - return; - } - - let dim = embedding.tensorShape[1]; - let N = data.length / dim; - if (embedding.tensorShape[0] > N) { - logging.setWarningMessage( - `Showing the first ${N.toLocaleString()}` + - ` of ${embedding.tensorShape[0].toLocaleString()} data points`); - } - parseTensorsFromFloat32Array(data, dim).then(dataPoints => { - callback(new DataSet(dataPoints)); - }); - }; - xhr.send(); -} - -export function parseRawTensors( - content: ArrayBuffer, callback: (ds: DataSet) => void) { - parseTensors(content).then(data => { - callback(new DataSet(data)); - }); -} - -export function parseRawMetadata( - contents: ArrayBuffer, callback: (r: SpriteAndMetadataInfo) => void) { - parseMetadata(contents).then(result => callback(result)); -} - -/** - * Parse an ArrayBuffer in a streaming fashion line by line (or custom delim). - * Can handle very large files. - * - * @param content The array buffer. - * @param callback The callback called on each line. - * @param chunkSize The size of each read chunk, defaults to ~1MB. (optional) - * @param delim The delimiter used to split a line, defaults to '\n'. (optional) - * @returns A promise for when it is finished. - */ -function streamParse( - content: ArrayBuffer, callback: (line: string) => void, chunkSize = 1000000, - delim = '\n'): Promise { - return new Promise((resolve, reject) => { - let offset = 0; - let bufferSize = content.byteLength - 1; - let data = ''; - - function readHandler(str) { - offset += chunkSize; - let parts = str.split(delim); - let first = data + parts[0]; - if (parts.length === 1) { - data = first; - readChunk(offset, chunkSize); - return; - } - data = parts[parts.length - 1]; - callback(first); - for (let i = 1; i < parts.length - 1; i++) { - callback(parts[i]); - } - if (offset >= bufferSize) { - if (data) { - callback(data); - } - resolve(); - return; - } - readChunk(offset, chunkSize); - } - - function readChunk(offset: number, size: number) { - const contentChunk = content.slice(offset, offset + size); - - const blob = new Blob([contentChunk]); - const file = new FileReader(); - file.onload = (e: any) => readHandler(e.target.result); - file.readAsText(blob); - } - - readChunk(offset, chunkSize); - }); -} - -/** Parses a tsv text file. */ -export function parseTensors( - content: ArrayBuffer, valueDelim = '\t'): Promise { - logging.setModalMessage('Parsing tensors...', TENSORS_MSG_ID); - - return new Promise((resolve, reject) => { - const data: DataPoint[] = []; - let numDim: number; - - streamParse(content, (line: string) => { - line = line.trim(); - if (line === '') { - return; - } - const row = line.split(valueDelim); - const dataPoint: DataPoint = { - metadata: {}, - vector: null, - index: data.length, - projections: null, - }; - // If the first label is not a number, take it as the label. - if (isNaN(row[0] as any) || numDim === row.length - 1) { - dataPoint.metadata['label'] = row[0]; - dataPoint.vector = new Float32Array(row.slice(1).map(Number)); - } else { - dataPoint.vector = new Float32Array(row.map(Number)); - } - data.push(dataPoint); - if (numDim == null) { - numDim = dataPoint.vector.length; - } - if (numDim !== dataPoint.vector.length) { - logging.setModalMessage( - 'Parsing failed. Vector dimensions do not match'); - throw Error('Parsing failed'); - } - if (numDim <= 1) { - logging.setModalMessage( - 'Parsing failed. Found a vector with only one dimension?'); - throw Error('Parsing failed'); - } - }).then(() => { - logging.setModalMessage(null, TENSORS_MSG_ID); - resolve(data); - }); - }); -} - -/** Parses a tsv text file. */ -export function parseTensorsFromFloat32Array(data: Float32Array, - dim: number): Promise { - return runAsyncTask('Parsing tensors...', () => { - const N = data.length / dim; - const dataPoints: DataPoint[] = []; - let offset = 0; - for (let i = 0; i < N; ++i) { - dataPoints.push({ - metadata: {}, - vector: data.subarray(offset, offset + dim), - index: i, - projections: null, - }); - offset += dim; - } - return dataPoints; - }, TENSORS_MSG_ID).then(dataPoints => { - logging.setModalMessage(null, TENSORS_MSG_ID); - return dataPoints; - }); -} - -export function analyzeMetadata( - columnNames, pointsMetadata: PointMetadata[]): ColumnStats[] { - const columnStats: ColumnStats[] = columnNames.map(name => { - return { - name: name, - isNumeric: true, - tooManyUniqueValues: false, - min: Number.POSITIVE_INFINITY, - max: Number.NEGATIVE_INFINITY - }; - }); - - const mapOfValues: [{[value: string]: number}] = - columnNames.map(() => new Object()); - - pointsMetadata.forEach(metadata => { - columnNames.forEach((name: string, colIndex: number) => { - const stats = columnStats[colIndex]; - const map = mapOfValues[colIndex]; - const value = metadata[name]; - - // Skip missing values. - if (value == null) { - return; - } - - if (!stats.tooManyUniqueValues) { - if (value in map) { - map[value]++; - } else { - map[value] = 1; - } - if (Object.keys(map).length > NUM_COLORS_COLOR_MAP) { - stats.tooManyUniqueValues = true; - } - } - if (isNaN(value as any)) { - stats.isNumeric = false; - } else { - metadata[name] = +value; - stats.min = Math.min(stats.min, +value); - stats.max = Math.max(stats.max, +value); - } - }); - }); - columnStats.forEach((stats, colIndex) => { - stats.uniqueEntries = Object.keys(mapOfValues[colIndex]).map(label => { - return {label, count: mapOfValues[colIndex][label]}; - }); - }); - return columnStats; -} - -export function parseMetadata(content: ArrayBuffer): - Promise { - logging.setModalMessage('Parsing metadata...', METADATA_MSG_ID); - - return new Promise((resolve, reject) => { - let pointsMetadata: PointMetadata[] = []; - let hasHeader = false; - let lineNumber = 0; - let columnNames = ['label']; - streamParse(content, (line: string) => { - if (line.trim().length === 0) { - return; - } - if (lineNumber === 0) { - hasHeader = line.indexOf('\t') >= 0; - - // If the first row doesn't contain metadata keys, we assume that the - // values are labels. - if (hasHeader) { - columnNames = line.split('\t'); - lineNumber++; - return; - } - } - - lineNumber++; - - let rowValues = line.split('\t'); - let metadata: PointMetadata = {}; - pointsMetadata.push(metadata); - columnNames.forEach((name: string, colIndex: number) => { - let value = rowValues[colIndex]; - // Normalize missing values. - value = (value === '' ? null : value); - metadata[name] = value; - }); - }).then(() => { - logging.setModalMessage(null, METADATA_MSG_ID); - resolve({ - stats: analyzeMetadata(columnNames, pointsMetadata), - pointsInfo: pointsMetadata - }); - }); - }); -} - -export function fetchImage(url: string): Promise { - return new Promise((resolve, reject) => { - let image = new Image(); - image.onload = () => resolve(image); - image.onerror = (err) => reject(err); - image.crossOrigin = ''; - image.src = url; - }); -} - -export function retrieveSpriteAndMetadataInfo(metadataPath: string, - spriteImagePath: string, spriteMetadata: SpriteMetadata, - callback: (r: SpriteAndMetadataInfo) => void) { - let metadataPromise: Promise = Promise.resolve({}); - if (metadataPath) { - metadataPromise = new Promise((resolve, reject) => { - logging.setModalMessage('Fetching metadata...', METADATA_MSG_ID); - - const request = new XMLHttpRequest(); - request.open('GET', metadataPath); - request.responseType = 'arraybuffer'; - - request.onerror = () => { - logging.setErrorMessage(request.responseText, 'fetching metadata'); - reject(); - }; - request.onload = () => { - resolve(parseMetadata(request.response)); - }; - request.send(null); - }); - } - let spriteMsgId = null; - let spritesPromise: Promise = null; - if (spriteImagePath) { - spriteMsgId = logging.setModalMessage('Fetching sprite image...'); - spritesPromise = fetchImage(spriteImagePath); - } - - // Fetch the metadata and the image in parallel. - Promise.all([metadataPromise, spritesPromise]).then(values => { - if (spriteMsgId) { - logging.setModalMessage(null, spriteMsgId); - } - const [metadata, spriteImage] = values; - - if (spriteImage && (spriteImage.height > MAX_SPRITE_IMAGE_SIZE_PX || - spriteImage.width > MAX_SPRITE_IMAGE_SIZE_PX)) { - logging.setModalMessage( - `Error: Sprite image of dimensions ${spriteImage.width}px x ` + - `${spriteImage.height}px exceeds maximum dimensions ` + - `${MAX_SPRITE_IMAGE_SIZE_PX}px x ${MAX_SPRITE_IMAGE_SIZE_PX}px`); - } else { - metadata.spriteImage = spriteImage; - metadata.spriteMetadata = spriteMetadata; - callback(metadata); - } - }); -} diff --git a/tensorflow/tensorboard/components/vz_projector/data.ts b/tensorflow/tensorboard/components/vz_projector/data.ts deleted file mode 100644 index c4e81985fc8..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/data.ts +++ /dev/null @@ -1,547 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {TSNE} from './bh_tsne'; -import {SpriteMetadata} from './data-provider'; -import * as knn from './knn'; -import * as logging from './logging'; -import * as scatterPlot from './scatterPlot'; -import * as util from './util'; -import * as vector from './vector'; - -export type DistanceFunction = (a: number[], b: number[]) => number; -export type ProjectionComponents3D = [string, string, string]; - -export interface PointMetadata { [key: string]: number|string; } - -export interface DataProto { - shape: [number, number]; - tensor: number[]; - metadata: { - columns: Array< - {name: string; stringValues: string[]; numericValues: number[];}>; - }; -} - -/** Statistics for a metadata column. */ -export interface ColumnStats { - name: string; - isNumeric: boolean; - tooManyUniqueValues: boolean; - uniqueEntries?: Array<{label: string, count: number}>; - min: number; - max: number; -} - -export interface SpriteAndMetadataInfo { - stats?: ColumnStats[]; - pointsInfo?: PointMetadata[]; - spriteImage?: HTMLImageElement; - spriteMetadata?: SpriteMetadata; -} - -/** A single collection of points which make up a sequence through space. */ -export interface Sequence { - /** Indices into the DataPoints array in the Data object. */ - pointIndices: number[]; -} - -export interface DataPoint { - /** The point in the original space. */ - vector: Float32Array; - - /* - * Metadata for each point. Each metadata is a set of key/value pairs - * where the value can be a string or a number. - */ - metadata: PointMetadata; - - /** index of the sequence, used for highlighting on click */ - sequenceIndex?: number; - - /** index in the original data source */ - index: number; - - /** This is where the calculated projections space are cached */ - projections: {[key: string]: number}; -} - -const IS_FIREFOX = navigator.userAgent.toLowerCase().indexOf('firefox') >= 0; -/** Controls whether nearest neighbors computation is done on the GPU or CPU. */ -const KNN_GPU_ENABLED = util.hasWebGLSupport() && !IS_FIREFOX; - -export const TSNE_SAMPLE_SIZE = 10000; -export const PCA_SAMPLE_SIZE = 50000; -/** Number of dimensions to sample when doing approximate PCA. */ -export const PCA_SAMPLE_DIM = 200; -/** Number of pca components to compute. */ -const NUM_PCA_COMPONENTS = 10; -/** - * Reserved metadata attributes used for sequence information - * NOTE: Use "__seq_next__" as "__next__" is deprecated. - */ -const SEQUENCE_METADATA_ATTRS = ['__next__', '__seq_next__']; - -function getSequenceNextPointIndex(pointMetadata: PointMetadata): number|null { - let sequenceAttr = null; - for (let metadataAttr of SEQUENCE_METADATA_ATTRS) { - if (metadataAttr in pointMetadata && pointMetadata[metadataAttr] !== '') { - sequenceAttr = pointMetadata[metadataAttr]; - break; - } - } - if (sequenceAttr == null) { - return null; - } - return +sequenceAttr; -} - -/** - * Dataset contains a DataPoints array that should be treated as immutable. This - * acts as a working subset of the original data, with cached properties - * from computationally expensive operations. Because creating a subset - * requires normalizing and shifting the vector space, we make a copy of the - * data so we can still always create new subsets based on the original data. - */ -export class DataSet { - points: DataPoint[]; - sequences: Sequence[]; - - shuffledDataIndices: number[] = []; - - /** - * This keeps a list of all current projections so you can easily test to see - * if it's been calculated already. - */ - projections: {[projection: string]: boolean} = {}; - nearest: knn.NearestEntry[][]; - nearestK: number; - tSNEIteration: number = 0; - tSNEShouldStop = true; - dim: [number, number] = [0, 0]; - hasTSNERun: boolean = false; - spriteAndMetadataInfo: SpriteAndMetadataInfo; - fracVariancesExplained: number[]; - - private tsne: TSNE; - - /** Creates a new Dataset */ - constructor( - points: DataPoint[], spriteAndMetadataInfo?: SpriteAndMetadataInfo) { - this.points = points; - this.shuffledDataIndices = util.shuffle(util.range(this.points.length)); - this.sequences = this.computeSequences(points); - this.dim = [this.points.length, this.points[0].vector.length]; - this.spriteAndMetadataInfo = spriteAndMetadataInfo; - } - - private computeSequences(points: DataPoint[]) { - // Keep a list of indices seen so we don't compute sequences for a given - // point twice. - let indicesSeen = new Int8Array(points.length); - // Compute sequences. - let indexToSequence: {[index: number]: Sequence} = {}; - let sequences: Sequence[] = []; - for (let i = 0; i < points.length; i++) { - if (indicesSeen[i]) { - continue; - } - indicesSeen[i] = 1; - - // Ignore points without a sequence attribute. - let next = getSequenceNextPointIndex(points[i].metadata); - if (next == null) { - continue; - } - if (next in indexToSequence) { - let existingSequence = indexToSequence[next]; - // Pushing at the beginning of the array. - existingSequence.pointIndices.unshift(i); - indexToSequence[i] = existingSequence; - continue; - } - // The current point is pointing to a new/unseen sequence. - let newSequence: Sequence = {pointIndices: []}; - indexToSequence[i] = newSequence; - sequences.push(newSequence); - let currentIndex = i; - while (points[currentIndex]) { - newSequence.pointIndices.push(currentIndex); - let next = getSequenceNextPointIndex(points[currentIndex].metadata); - if (next != null) { - indicesSeen[next] = 1; - currentIndex = next; - } else { - currentIndex = -1; - } - } - } - return sequences; - } - - projectionCanBeRendered(projection: ProjectionType): boolean { - if (projection !== 'tsne') { - return true; - } - return this.tSNEIteration > 0; - } - - /** - * Returns a new subset dataset by copying out data. We make a copy because - * we have to modify the vectors by normalizing them. - * - * @param subset Array of indices of points that we want in the subset. - * - * @return A subset of the original dataset. - */ - getSubset(subset?: number[]): DataSet { - const pointsSubset = ((subset != null) && (subset.length > 0)) ? - subset.map(i => this.points[i]) : - this.points; - let points = pointsSubset.map(dp => { - return { - metadata: dp.metadata, - index: dp.index, - vector: dp.vector.slice(), - projections: {} as {[key: string]: number} - }; - }); - return new DataSet(points, this.spriteAndMetadataInfo); - } - - /** - * Computes the centroid, shifts all points to that centroid, - * then makes them all unit norm. - */ - normalize() { - // Compute the centroid of all data points. - let centroid = vector.centroid(this.points, a => a.vector); - if (centroid == null) { - throw Error('centroid should not be null'); - } - // Shift all points by the centroid and make them unit norm. - for (let id = 0; id < this.points.length; ++id) { - let dataPoint = this.points[id]; - dataPoint.vector = vector.sub(dataPoint.vector, centroid); - vector.unit(dataPoint.vector); - } - } - - /** Projects the dataset onto a given vector and caches the result. */ - projectLinear(dir: vector.Vector, label: string) { - this.projections[label] = true; - this.points.forEach(dataPoint => { - dataPoint.projections[label] = vector.dot(dataPoint.vector, dir); - }); - } - - /** Projects the dataset along the top 10 principal components. */ - projectPCA(): Promise { - if (this.projections['pca-0'] != null) { - return Promise.resolve(null); - } - return util.runAsyncTask('Computing PCA...', () => { - // Approximate pca vectors by sampling the dimensions. - let dim = this.points[0].vector.length; - let vectors = this.shuffledDataIndices.map(i => this.points[i].vector); - if (dim > PCA_SAMPLE_DIM) { - vectors = vector.projectRandom(vectors, PCA_SAMPLE_DIM); - } - let sampledVectors = vectors.slice(0, PCA_SAMPLE_SIZE); - - let sigma = numeric.div( - numeric.dot(numeric.transpose(sampledVectors), sampledVectors), - sampledVectors.length); - let svd = numeric.svd(sigma); - - let variances: number[] = svd.S; - let totalVariance = 0; - for (let i = 0; i < variances.length; ++i) { - totalVariance += variances[i]; - } - for (let i = 0; i < variances.length; ++i) { - variances[i] /= totalVariance; - } - this.fracVariancesExplained = variances; - - let U: number[][] = svd.U; - let pcaVectors = vectors.map(vector => { - let newV = new Float32Array(NUM_PCA_COMPONENTS); - for (let newDim = 0; newDim < NUM_PCA_COMPONENTS; newDim++) { - let dot = 0; - for (let oldDim = 0; oldDim < vector.length; oldDim++) { - dot += vector[oldDim] * U[oldDim][newDim]; - } - newV[newDim] = dot; - } - return newV; - }); - for (let d = 0; d < NUM_PCA_COMPONENTS; d++) { - let label = 'pca-' + d; - this.projections[label] = true; - for (let i = 0; i < pcaVectors.length; i++) { - let pointIndex = this.shuffledDataIndices[i]; - this.points[pointIndex].projections[label] = pcaVectors[i][d]; - } - } - }); - } - - /** Runs tsne on the data. */ - projectTSNE( - perplexity: number, learningRate: number, tsneDim: number, - stepCallback: (iter: number) => void) { - this.hasTSNERun = true; - let k = Math.floor(3 * perplexity); - let opt = {epsilon: learningRate, perplexity: perplexity, dim: tsneDim}; - this.tsne = new TSNE(opt); - this.tSNEShouldStop = false; - this.tSNEIteration = 0; - - let sampledIndices = this.shuffledDataIndices.slice(0, TSNE_SAMPLE_SIZE); - let step = () => { - if (this.tSNEShouldStop) { - stepCallback(null); - this.tsne = null; - return; - } - this.tsne.step(); - let result = this.tsne.getSolution(); - sampledIndices.forEach((index, i) => { - let dataPoint = this.points[index]; - - dataPoint.projections['tsne-0'] = result[i * tsneDim + 0]; - dataPoint.projections['tsne-1'] = result[i * tsneDim + 1]; - if (tsneDim === 3) { - dataPoint.projections['tsne-2'] = result[i * tsneDim + 2]; - } - }); - this.tSNEIteration++; - stepCallback(this.tSNEIteration); - requestAnimationFrame(step); - }; - - // Nearest neighbors calculations. - let knnComputation: Promise; - - if (this.nearest != null && k === this.nearestK) { - // We found the nearest neighbors before and will reuse them. - knnComputation = Promise.resolve(this.nearest); - } else { - let sampledData = sampledIndices.map(i => this.points[i]); - this.nearestK = k; - knnComputation = KNN_GPU_ENABLED ? - knn.findKNNGPUCosine(sampledData, k, (d => d.vector)) : - knn.findKNN( - sampledData, k, (d => d.vector), - (a, b, limit) => vector.cosDistNorm(a, b)); - } - knnComputation.then(nearest => { - this.nearest = nearest; - util.runAsyncTask('Initializing T-SNE...', () => { - this.tsne.initDataDist(this.nearest); - }).then(step); - }); - } - - /** - * Merges metadata to the dataset and returns whether it succeeded. - */ - mergeMetadata(metadata: SpriteAndMetadataInfo): boolean { - if (metadata.pointsInfo.length !== this.points.length) { - let errorMessage = `Number of tensors (${this.points.length}) do not` + - ` match the number of lines in metadata` + - ` (${metadata.pointsInfo.length}).`; - - if (metadata.stats.length === 1 && - this.points.length + 1 === metadata.pointsInfo.length) { - // If there is only one column of metadata and the number of points is - // exactly one less than the number of metadata lines, this is due to an - // unnecessary header line in the metadata and we can show a meaningful - // error. - logging.setErrorMessage( - errorMessage + ' Single column metadata should not have a header ' + - 'row.', - 'merging metadata'); - return false; - } else if ( - metadata.stats.length > 1 && - this.points.length - 1 === metadata.pointsInfo.length) { - // If there are multiple columns of metadata and the number of points is - // exactly one greater than the number of lines in the metadata, this - // means there is a missing metadata header. - logging.setErrorMessage( - errorMessage + ' Multi-column metadata should have a header ' + - 'row with column labels.', - 'merging metadata'); - return false; - } - - logging.setWarningMessage(errorMessage); - } - this.spriteAndMetadataInfo = metadata; - metadata.pointsInfo.slice(0, this.points.length) - .forEach((m, i) => this.points[i].metadata = m); - return true; - } - - stopTSNE() { - this.tSNEShouldStop = true; - } - - /** - * Finds the nearest neighbors of the query point using a - * user-specified distance metric. - */ - findNeighbors(pointIndex: number, distFunc: DistanceFunction, numNN: number): - knn.NearestEntry[] { - // Find the nearest neighbors of a particular point. - let neighbors = knn.findKNNofPoint( - this.points, pointIndex, numNN, (d => d.vector), distFunc); - // TODO(smilkov): Figure out why we slice. - let result = neighbors.slice(0, numNN); - return result; - } - - /** - * Search the dataset based on a metadata field. - */ - query(query: string, inRegexMode: boolean, fieldName: string): number[] { - let predicate = util.getSearchPredicate(query, inRegexMode, fieldName); - let matches: number[] = []; - this.points.forEach((point, id) => { - if (predicate(point)) { - matches.push(id); - } - }); - return matches; - } -} - -export type ProjectionType = 'tsne' | 'pca' | 'custom'; - -export class Projection { - constructor( - public projectionType: ProjectionType, - public projectionComponents: ProjectionComponents3D, - public dimensionality: number, public dataSet: DataSet) {} -} - -export interface ColorOption { - name: string; - desc?: string; - map?: (value: string|number) => string; - /** List of items for the color map. Defined only for categorical map. */ - items?: {label: string, count: number}[]; - /** Threshold values and their colors. Defined for gradient color map. */ - thresholds?: {value: number, color: string}[]; - isSeparator?: boolean; - tooManyUniqueValues?: boolean; -} - -/** - * An interface that holds all the data for serializing the current state of - * the world. - */ -export class State { - /** A label identifying this state. */ - label: string = ''; - - /** Whether this State is selected in the bookmarks pane. */ - isSelected: boolean = false; - - /** The selected projection tab. */ - selectedProjection: ProjectionType; - - /** Dimensions of the DataSet. */ - dataSetDimensions: [number, number]; - - /** t-SNE parameters */ - tSNEIteration: number = 0; - tSNEPerplexity: number = 0; - tSNELearningRate: number = 0; - tSNEis3d: boolean = true; - - /** PCA projection component dimensions */ - pcaComponentDimensions: number[] = []; - - /** Custom projection parameters */ - customSelectedSearchByMetadataOption: string; - customXLeftText: string; - customXLeftRegex: boolean; - customXRightText: string; - customXRightRegex: boolean; - customYUpText: string; - customYUpRegex: boolean; - customYDownText: string; - customYDownRegex: boolean; - - /** The computed projections of the tensors. */ - projections: Array<{[key: string]: number}> = []; - - /** Filtered dataset indices. */ - filteredPoints: number[]; - - /** The indices of selected points. */ - selectedPoints: number[] = []; - - /** Camera state (2d/3d, position, target, zoom, etc). */ - cameraDef: scatterPlot.CameraDef; - - /** Color by option. */ - selectedColorOptionName: string; - forceCategoricalColoring: boolean; - - /** Label by option. */ - selectedLabelOption: string; -} - -export function getProjectionComponents( - projection: ProjectionType, - components: (number|string)[]): ProjectionComponents3D { - if (components.length > 3) { - throw new RangeError('components length must be <= 3'); - } - const projectionComponents: [string, string, string] = [null, null, null]; - const prefix = (projection === 'custom') ? 'linear' : projection; - for (let i = 0; i < components.length; ++i) { - if (components[i] == null) { - continue; - } - projectionComponents[i] = prefix + '-' + components[i]; - } - return projectionComponents; -} - -export function stateGetAccessorDimensions(state: State): Array { - let dimensions: Array; - switch (state.selectedProjection) { - case 'pca': - dimensions = state.pcaComponentDimensions.slice(); - break; - case 'tsne': - dimensions = [0, 1]; - if (state.tSNEis3d) { - dimensions.push(2); - } - break; - case 'custom': - dimensions = ['x', 'y']; - break; - default: - throw new Error('Unexpected fallthrough'); - } - return dimensions; -} diff --git a/tensorflow/tensorboard/components/vz_projector/external.d.ts b/tensorflow/tensorboard/components/vz_projector/external.d.ts deleted file mode 100644 index cbc1512c215..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/external.d.ts +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// TODO(smilkov): Split into weblas.d.ts and numeric.d.ts and write -// typings for numeric. -interface Tensor { - new(size: [number, number], data: Float32Array); - transfer(): Float32Array; - delete(): void; -} - -interface Weblas { - sgemm(M: number, N: number, K: number, alpha: number, - A: Float32Array, B: Float32Array, beta: number, C: Float32Array): - Float32Array; - pipeline: { - Tensor: Tensor; - sgemm(alpha: number, A: Tensor, B: Tensor, beta: number, - C: Tensor): Tensor; - }; - util: { - transpose(M: number, N: number, data: Float32Array): Tensor; - }; - -} - -declare let numeric: any; -declare let weblas: Weblas; - -interface AnalyticsEventType { - hitType: string; - page?: string; - eventCategory?: string; - eventAction?: string; - eventLabel?: string; - eventValue?: number; -} - -declare let ga: (command: string, eventObj: AnalyticsEventType) => void; \ No newline at end of file diff --git a/tensorflow/tensorboard/components/vz_projector/heap.ts b/tensorflow/tensorboard/components/vz_projector/heap.ts deleted file mode 100644 index ac3144e6493..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/heap.ts +++ /dev/null @@ -1,146 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/** Min key heap. */ -export type HeapItem = { - key: number, - value: T -}; - -/** - * Min-heap data structure. Provides O(1) for peek, returning the smallest key. - */ -// TODO(jart): Rename to Heap and use Comparator. -export class MinHeap { - private arr: HeapItem[] = []; - - /** Push an element with the provided key. */ - push(key: number, value: T): void { - this.arr.push({key, value}); - this.bubbleUp(this.arr.length - 1); - } - - /** Pop the element with the smallest key. */ - pop(): HeapItem { - if (this.arr.length === 0) { - throw new Error('pop() called on empty binary heap'); - } - let item = this.arr[0]; - let last = this.arr.length - 1; - this.arr[0] = this.arr[last]; - this.arr.pop(); - if (last > 0) { - this.bubbleDown(0); - } - return item; - }; - - /** Returns, but doesn't remove the element with the smallest key */ - peek(): HeapItem { return this.arr[0]; } - - /** - * Pops the element with the smallest key and at the same time - * adds the newly provided element. This is faster than calling - * pop() and push() separately. - */ - popPush(key: number, value: T): HeapItem { - if (this.arr.length === 0) { - throw new Error('pop() called on empty binary heap'); - } - let item = this.arr[0]; - this.arr[0] = {key, value}; - if (this.arr.length > 0) { - this.bubbleDown(0); - } - return item; - } - - /** Returns the number of elements in the heap. */ - size(): number { return this.arr.length; } - - /** Returns all the items in the heap. */ - items(): HeapItem[] { return this.arr; } - - private swap(a: number, b: number) { - let temp = this.arr[a]; - this.arr[a] = this.arr[b]; - this.arr[b] = temp; - } - - private bubbleDown(pos: number) { - let left = (pos << 1) + 1; - let right = left + 1; - let largest = pos; - if (left < this.arr.length && this.arr[left].key < this.arr[largest].key) { - largest = left; - } - if (right < this.arr.length && - this.arr[right].key < this.arr[largest].key) { - largest = right; - } - if (largest !== pos) { - this.swap(largest, pos); - this.bubbleDown(largest); - } - } - - private bubbleUp(pos: number) { - if (pos <= 0) { - return; - } - let parent = ((pos - 1) >> 1); - if (this.arr[pos].key < this.arr[parent].key) { - this.swap(pos, parent); - this.bubbleUp(parent); - } - } -} - -/** List that keeps the K elements with the smallest keys. */ -export class KMin { - private k: number; - private maxHeap = new MinHeap(); - - /** Constructs a new k-min data structure with the provided k. */ - constructor(k: number) { this.k = k; } - - /** Adds an element to the list. */ - add(key: number, value: T) { - if (this.maxHeap.size() < this.k) { - this.maxHeap.push(-key, value); - return; - } - let largest = this.maxHeap.peek(); - // If the new element is smaller, replace the largest with the new element. - if (key < -largest.key) { - this.maxHeap.popPush(-key, value); - } - } - - /** Returns the k items with the smallest keys. */ - getMinKItems(): T[] { - let items = this.maxHeap.items(); - items.sort((a, b) => b.key - a.key); - return items.map(a => a.value); - } - - /** Returns the size of the list. */ - getSize(): number { return this.maxHeap.size(); } - - /** Returns the largest key in the list. */ - getLargestKey(): number { - return this.maxHeap.size() === 0 ? null : -this.maxHeap.peek().key; - } -} diff --git a/tensorflow/tensorboard/components/vz_projector/knn.ts b/tensorflow/tensorboard/components/vz_projector/knn.ts deleted file mode 100644 index 906e077b5d7..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/knn.ts +++ /dev/null @@ -1,235 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {runAsyncTask} from './util'; -import * as logging from './logging'; -import {KMin} from './heap'; -import {Vector} from './vector'; -import * as vector from './vector'; - -export type NearestEntry = { - index: number, - dist: number -}; - -/** - * Optimal size for the height of the matrix when doing computation on the GPU - * using WebGL. This was found experimentally. - * - * This also guarantees that for computing pair-wise distance for up to 10K - * vectors, no more than 40MB will be allocated in the GPU. Without the - * allocation limit, we can freeze the graphics of the whole OS. - */ -const OPTIMAL_GPU_BLOCK_SIZE = 256; -/** Id of message box used for knn gpu progress bar. */ -const KNN_GPU_MSG_ID = 'knn-gpu'; - -/** - * Returns the K nearest neighbors for each vector where the distance - * computation is done on the GPU (WebGL) using cosine distance. - * - * @param dataPoints List of data points, where each data point holds an - * n-dimensional vector. - * @param k Number of nearest neighbors to find. - * @param accessor A method that returns the vector, given the data point. - */ -export function findKNNGPUCosine( - dataPoints: T[], k: number, - accessor: (dataPoint: T) => Float32Array): Promise { - let N = dataPoints.length; - let dim = accessor(dataPoints[0]).length; - - // The goal is to compute a large matrix multiplication A*A.T where A is of - // size NxD and A.T is its transpose. This results in a NxN matrix which - // could be too big to store on the GPU memory. To avoid memory overflow, we - // compute multiple A*partial_A.T where partial_A is of size BxD (B is much - // smaller than N). This results in storing only NxB size matrices on the GPU - // at a given time. - - // A*A.T will give us NxN matrix holding the cosine distance between every - // pair of points, which we sort using KMin data structure to obtain the - // K nearest neighbors for each point. - let typedArray = vector.toTypedArray(dataPoints, accessor); - let bigMatrix = new weblas.pipeline.Tensor([N, dim], typedArray); - let nearest: NearestEntry[][] = new Array(N); - let numPieces = Math.ceil(N / OPTIMAL_GPU_BLOCK_SIZE); - let M = Math.floor(N / numPieces); - let modulo = N % numPieces; - let offset = 0; - let progress = 0; - let progressDiff = 1 / (2 * numPieces); - let piece = 0; - - function step(resolve: (result: NearestEntry[][]) => void) { - let progressMsg = - 'Finding nearest neighbors: ' + (progress * 100).toFixed() + '%'; - runAsyncTask(progressMsg, () => { - let B = piece < modulo ? M + 1 : M; - let typedB = new Float32Array(B * dim); - for (let i = 0; i < B; ++i) { - let vector = accessor(dataPoints[offset + i]); - for (let d = 0; d < dim; ++d) { - typedB[i * dim + d] = vector[d]; - } - } - let partialMatrix = new weblas.pipeline.Tensor([B, dim], typedB); - // Result is N x B matrix. - let result = - weblas.pipeline.sgemm(1, bigMatrix, partialMatrix, null, null); - let partial = result.transfer(); - partialMatrix.delete(); - result.delete(); - progress += progressDiff; - for (let i = 0; i < B; i++) { - let kMin = new KMin(k); - let iReal = offset + i; - for (let j = 0; j < N; j++) { - if (j === iReal) { - continue; - } - let cosDist = 1 - partial[j * B + i]; // [j, i]; - kMin.add(cosDist, {index: j, dist: cosDist}); - } - nearest[iReal] = kMin.getMinKItems(); - } - progress += progressDiff; - offset += B; - piece++; - }, KNN_GPU_MSG_ID).then(() => { - if (piece < numPieces) { - step(resolve); - } else { - logging.setModalMessage(null, KNN_GPU_MSG_ID); - bigMatrix.delete(); - resolve(nearest); - } - }, error => { - // GPU failed. Reverting back to CPU. - logging.setModalMessage(null, KNN_GPU_MSG_ID); - let distFunc = (a, b, limit) => vector.cosDistNorm(a, b); - findKNN(dataPoints, k, accessor, distFunc).then(nearest => { - resolve(nearest); - }); - }); - } - return new Promise(resolve => step(resolve)); -} - -/** - * Returns the K nearest neighbors for each vector where the distance - * computation is done on the CPU using a user-specified distance method. - * - * @param dataPoints List of data points, where each data point holds an - * n-dimensional vector. - * @param k Number of nearest neighbors to find. - * @param accessor A method that returns the vector, given the data point. - * @param dist Method that takes two vectors and a limit, and computes the - * distance between two vectors, with the ability to stop early if the - * distance is above the limit. - */ -export function findKNN( - dataPoints: T[], k: number, accessor: (dataPoint: T) => Float32Array, - dist: (a: Vector, b: Vector, limit: number) => - number): Promise { - return runAsyncTask('Finding nearest neighbors...', () => { - let N = dataPoints.length; - let nearest: NearestEntry[][] = new Array(N); - // Find the distances from node i. - let kMin: KMin[] = new Array(N); - for (let i = 0; i < N; i++) { - kMin[i] = new KMin(k); - } - for (let i = 0; i < N; i++) { - let a = accessor(dataPoints[i]); - let kMinA = kMin[i]; - for (let j = i + 1; j < N; j++) { - let kMinB = kMin[j]; - let limitI = kMinA.getSize() === k ? - kMinA.getLargestKey() || Number.MAX_VALUE : - Number.MAX_VALUE; - let limitJ = kMinB.getSize() === k ? - kMinB.getLargestKey() || Number.MAX_VALUE : - Number.MAX_VALUE; - let limit = Math.max(limitI, limitJ); - let dist2ItoJ = dist(a, accessor(dataPoints[j]), limit); - if (dist2ItoJ >= 0) { - kMinA.add(dist2ItoJ, {index: j, dist: dist2ItoJ}); - kMinB.add(dist2ItoJ, {index: i, dist: dist2ItoJ}); - } - } - } - for (let i = 0; i < N; i++) { - nearest[i] = kMin[i].getMinKItems(); - } - return nearest; - }); -} - -/** Calculates the minimum distance between a search point and a rectangle. */ -function minDist( - point: [number, number], x1: number, y1: number, x2: number, y2: number) { - let x = point[0]; - let y = point[1]; - let dx1 = x - x1; - let dx2 = x - x2; - let dy1 = y - y1; - let dy2 = y - y2; - - if (dx1 * dx2 <= 0) { // x is between x1 and x2 - if (dy1 * dy2 <= 0) { // (x,y) is inside the rectangle - return 0; // return 0 as point is in rect - } - return Math.min(Math.abs(dy1), Math.abs(dy2)); - } - if (dy1 * dy2 <= 0) { // y is between y1 and y2 - // We know it is already inside the rectangle - return Math.min(Math.abs(dx1), Math.abs(dx2)); - } - let corner: [number, number]; - if (x > x2) { - // Upper-right vs lower-right. - corner = y > y2 ? [x2, y2] : [x2, y1]; - } else { - // Upper-left vs lower-left. - corner = y > y2 ? [x1, y2] : [x1, y1]; - } - return Math.sqrt(vector.dist22D([x, y], corner)); -} - -/** - * Returns the nearest neighbors of a particular point. - * - * @param dataPoints List of data points. - * @param pointIndex The index of the point we need the nearest neighbors of. - * @param k Number of nearest neighbors to search for. - * @param accessor Method that maps a data point => vector (array of numbers). - * @param distance Method that takes two vectors and returns their distance. - */ -export function findKNNofPoint( - dataPoints: T[], pointIndex: number, k: number, - accessor: (dataPoint: T) => Float32Array, - distance: (a: Vector, b: Vector) => number) { - let kMin = new KMin(k); - let a = accessor(dataPoints[pointIndex]); - for (let i = 0; i < dataPoints.length; ++i) { - if (i === pointIndex) { - continue; - } - let b = accessor(dataPoints[i]); - let dist = distance(a, b); - kMin.add(dist, {index: i, dist: dist}); - } - return kMin.getMinKItems(); -} diff --git a/tensorflow/tensorboard/components/vz_projector/label.ts b/tensorflow/tensorboard/components/vz_projector/label.ts deleted file mode 100644 index 67987f06ea3..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/label.ts +++ /dev/null @@ -1,151 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -export interface BoundingBox { - loX: number; - loY: number; - hiX: number; - hiY: number; -} - -/** - * Accelerates label placement by dividing the view into a uniform grid. - * Labels only need to be tested for collision with other labels that overlap - * the same grid cells. This is a fork of {@code amoeba.CollisionGrid}. - */ -export class CollisionGrid { - private numHorizCells: number; - private numVertCells: number; - private grid: BoundingBox[][]; - private bound: BoundingBox; - private cellWidth: number; - private cellHeight: number; - - /** - * Constructs a new Collision grid. - * - * @param bound The bound of the grid. Labels out of bounds will be rejected. - * @param cellWidth Width of a cell in the grid. - * @param cellHeight Height of a cell in the grid. - */ - constructor(bound: BoundingBox, cellWidth: number, cellHeight: number) { - /** The bound of the grid. Labels out of bounds will be rejected. */ - this.bound = bound; - - /** Width of a cell in the grid. */ - this.cellWidth = cellWidth; - - /** Height of a cell in the grid. */ - this.cellHeight = cellHeight; - - /** Number of grid cells along the x axis. */ - this.numHorizCells = Math.ceil(this.boundWidth(bound) / cellWidth); - - /** Number of grid cells along the y axis. */ - this.numVertCells = Math.ceil(this.boundHeight(bound) / cellHeight); - - /** - * The 2d grid (stored as a 1d array.) Each cell consists of an array of - * BoundingBoxes for objects that are in the cell. - */ - this.grid = new Array(this.numHorizCells * this.numVertCells); - } - - private boundWidth(bound: BoundingBox) { return bound.hiX - bound.loX; } - - private boundHeight(bound: BoundingBox) { return bound.hiY - bound.loY; } - - private boundsIntersect(a: BoundingBox, b: BoundingBox) { - return !(a.loX > b.hiX || a.loY > b.hiY || a.hiX < b.loX || a.hiY < b.loY); - } - - /** - * Checks if a given bounding box has any conflicts in the grid and inserts it - * if none are found. - * - * @param bound The bound to insert. - * @param justTest If true, just test if it conflicts, without inserting. - * @return True if the bound was successfully inserted; false if it - * could not be inserted due to a conflict. - */ - insert(bound: BoundingBox, justTest = false): boolean { - // Reject if the label is out of bounds. - if ((bound.hiX < this.bound.loX) || (bound.loX > this.bound.hiX) || - (bound.hiY < this.bound.loY) || (bound.loY > this.bound.hiY)) { - return false; - } - - let minCellX = this.getCellX(bound.loX); - let maxCellX = this.getCellX(bound.hiX); - let minCellY = this.getCellY(bound.loY); - let maxCellY = this.getCellY(bound.hiY); - - // Check all overlapped cells to verify that we can insert. - let baseIdx = minCellY * this.numHorizCells + minCellX; - let idx = baseIdx; - for (let j = minCellY; j <= maxCellY; j++) { - for (let i = minCellX; i <= maxCellX; i++) { - let cell = this.grid[idx++]; - if (cell) { - for (let k = 0; k < cell.length; k++) { - if (this.boundsIntersect(bound, cell[k])) { - return false; - } - } - } - } - idx += this.numHorizCells - (maxCellX - minCellX + 1); - } - - if (justTest) { - return true; - } - - // Insert into the overlapped cells. - idx = baseIdx; - for (let j = minCellY; j <= maxCellY; j++) { - for (let i = minCellX; i <= maxCellX; i++) { - if (!this.grid[idx]) { - this.grid[idx] = [bound]; - } else { - this.grid[idx].push(bound); - } - idx++; - } - idx += this.numHorizCells - (maxCellX - minCellX + 1); - } - return true; - } - - /** - * Returns the x index of the grid cell where the given x coordinate falls. - * - * @param x the coordinate, in world space. - * @return the x index of the cell. - */ - private getCellX(x: number) { - return Math.floor((x - this.bound.loX) / this.cellWidth); - }; - - /** - * Returns the y index of the grid cell where the given y coordinate falls. - * - * @param y the coordinate, in world space. - * @return the y index of the cell. - */ - private getCellY(y: number) { - return Math.floor((y - this.bound.loY) / this.cellHeight); - }; -} diff --git a/tensorflow/tensorboard/components/vz_projector/logging.ts b/tensorflow/tensorboard/components/vz_projector/logging.ts deleted file mode 100644 index 59f37206012..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/logging.ts +++ /dev/null @@ -1,103 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/** Duration in ms for showing warning messages to the user */ -const WARNING_DURATION_MS = 10000; - -let dom: HTMLElement = null; -let msgId = 0; -let numActiveMessages = 0; - -export function setDomContainer(domElement: HTMLElement) { - dom = domElement; -} - -/** - * Updates the user message with the provided id. - * - * @param msg The message shown to the user. If null, the message is removed. - * @param id The id of an existing message. If no id is provided, a unique id - * is assigned. - * @param title The title of the notification. - * @param isErrorMsg If true, the message is error and the dialog will have a - * close button. - * @return The id of the message. - */ -export function setModalMessage( - msg: string, id: string = null, title = null, isErrorMsg = false): string { - if (dom == null) { - console.warn('Can\'t show modal message before the dom is initialized'); - return; - } - if (id == null) { - id = (msgId++).toString(); - } - let dialog = dom.querySelector('#notification-dialog') as any; - dialog.querySelector('.close-button').style.display = - isErrorMsg ? null : 'none'; - let spinner = dialog.querySelector('.progress'); - spinner.style.display = isErrorMsg ? 'none' : null; - spinner.active = isErrorMsg ? null : true; - dialog.querySelector('#notification-title').innerHTML = title; - let msgsContainer = dialog.querySelector('#notify-msgs') as HTMLElement; - if (isErrorMsg) { - msgsContainer.innerHTML = ''; - } else { - const errors = msgsContainer.querySelectorAll('.error'); - for (let i = 0; i < errors.length; i++) { - msgsContainer.removeChild(errors[i]); - } - } - let divId = `notify-msg-${id}`; - let msgDiv = dialog.querySelector('#' + divId) as HTMLDivElement; - if (msgDiv == null) { - msgDiv = document.createElement('div'); - msgDiv.className = 'notify-msg ' + (isErrorMsg ? 'error' : ''); - msgDiv.id = divId; - - msgsContainer.insertBefore(msgDiv, msgsContainer.firstChild); - - if (!isErrorMsg) { - numActiveMessages++; - } else { - numActiveMessages = 0; - } - } - if (msg == null) { - numActiveMessages--; - if (numActiveMessages === 0) { - dialog.close(); - } - msgDiv.remove(); - } else { - msgDiv.innerText = msg; - dialog.open(); - } - return id; -} - -export function setErrorMessage(errMsg: string, task?: string) { - setModalMessage(errMsg, null, 'Error ' + (task != null ? task : ''), true); -} - -/** - * Shows a warning message to the user for a certain amount of time. - */ -export function setWarningMessage(msg: string): void { - let toast = dom.querySelector('#toast') as any; - toast.text = msg; - toast.duration = WARNING_DURATION_MS; - toast.open(); -} diff --git a/tensorflow/tensorboard/components/vz_projector/projectorEventContext.ts b/tensorflow/tensorboard/components/vz_projector/projectorEventContext.ts deleted file mode 100644 index 36f5c4c5841..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/projectorEventContext.ts +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {DistanceFunction, Projection} from './data'; -import {NearestEntry} from './knn'; - -export type HoverListener = (index: number) => void; -export type SelectionChangedListener = - (selectedPointIndices: number[], neighborsOfFirstPoint: NearestEntry[]) => - void; -export type ProjectionChangedListener = (projection: Projection) => void; -export type DistanceMetricChangedListener = - (distanceMetric: DistanceFunction) => void; -export interface ProjectorEventContext { - /** Register a callback to be invoked when the mouse hovers over a point. */ - registerHoverListener(listener: HoverListener); - /** Notify the hover system that a point is under the mouse. */ - notifyHoverOverPoint(pointIndex: number); - /** Registers a callback to be invoked when the selection changes. */ - registerSelectionChangedListener(listener: SelectionChangedListener); - /** - * Notify the selection system that a client has changed the selected point - * set. - */ - notifySelectionChanged(newSelectedPointIndices: number[]); - /** Registers a callback to be invoked when the projection changes. */ - registerProjectionChangedListener(listener: ProjectionChangedListener); - /** Notify listeners that a reprojection occurred. */ - notifyProjectionChanged(projection: Projection); - registerDistanceMetricChangedListener(listener: - DistanceMetricChangedListener); - notifyDistanceMetricChanged(distMetric: DistanceFunction); -} diff --git a/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts b/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts deleted file mode 100644 index c0da9526598..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/projectorScatterPlotAdapter.ts +++ /dev/null @@ -1,711 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {DataSet, DistanceFunction, Projection, ProjectionComponents3D, State} from './data'; -import {NearestEntry} from './knn'; -import {ProjectorEventContext} from './projectorEventContext'; -import {LabelRenderParams} from './renderContext'; -import {ScatterPlot} from './scatterPlot'; -import {ScatterPlotVisualizer3DLabels} from './scatterPlotVisualizer3DLabels'; -import {ScatterPlotVisualizerCanvasLabels} from './scatterPlotVisualizerCanvasLabels'; -import {ScatterPlotVisualizerPolylines} from './scatterPlotVisualizerPolylines'; -import {ScatterPlotVisualizerSprites} from './scatterPlotVisualizerSprites'; -import * as vector from './vector'; - -const LABEL_FONT_SIZE = 10; -const LABEL_SCALE_DEFAULT = 1.0; -const LABEL_SCALE_LARGE = 2; -const LABEL_FILL_COLOR_SELECTED = 0x000000; -const LABEL_FILL_COLOR_HOVER = 0x000000; -const LABEL_FILL_COLOR_NEIGHBOR = 0x000000; -const LABEL_STROKE_COLOR_SELECTED = 0xFFFFFF; -const LABEL_STROKE_COLOR_HOVER = 0xFFFFFF; -const LABEL_STROKE_COLOR_NEIGHBOR = 0xFFFFFF; - -const POINT_COLOR_UNSELECTED = 0xE3E3E3; -const POINT_COLOR_NO_SELECTION = 0x7575D9; -const POINT_COLOR_SELECTED = 0xFA6666; -const POINT_COLOR_HOVER = 0x760B4F; - -const POINT_SCALE_DEFAULT = 1.0; -const POINT_SCALE_SELECTED = 1.2; -const POINT_SCALE_NEIGHBOR = 1.2; -const POINT_SCALE_HOVER = 1.2; - -const LABELS_3D_COLOR_UNSELECTED = 0xFFFFFF; -const LABELS_3D_COLOR_NO_SELECTION = 0xFFFFFF; - -const SPRITE_IMAGE_COLOR_UNSELECTED = 0xFFFFFF; -const SPRITE_IMAGE_COLOR_NO_SELECTION = 0xFFFFFF; - -const POLYLINE_START_HUE = 60; -const POLYLINE_END_HUE = 360; -const POLYLINE_SATURATION = 1; -const POLYLINE_LIGHTNESS = .3; - -const POLYLINE_DEFAULT_OPACITY = .2; -const POLYLINE_DEFAULT_LINEWIDTH = 2; -const POLYLINE_SELECTED_OPACITY = .9; -const POLYLINE_SELECTED_LINEWIDTH = 3; -const POLYLINE_DESELECTED_OPACITY = .05; - -const SCATTER_PLOT_CUBE_LENGTH = 2; - -/** Color scale for nearest neighbors. */ -const NN_COLOR_SCALE = - d3.scaleLinear() - .domain([1, 0.7, 0.4]) - .range(['hsl(285, 80%, 40%)', 'hsl(0, 80%, 65%)', 'hsl(40, 70%, 60%)']) - .clamp(true); - -/** - * Interprets projector events and assembes the arrays and commands necessary - * to use the ScatterPlot to render the current projected data set. - */ -export class ProjectorScatterPlotAdapter { - public scatterPlot: ScatterPlot; - private projection: Projection; - private hoverPointIndex: number; - private selectedPointIndices: number[]; - private neighborsOfFirstSelectedPoint: NearestEntry[]; - private renderLabelsIn3D: boolean = false; - private labelPointAccessor: string; - private legendPointColorer: (ds: DataSet, index: number) => string; - private distanceMetric: DistanceFunction; - - private spriteVisualizer: ScatterPlotVisualizerSprites; - private labels3DVisualizer: ScatterPlotVisualizer3DLabels; - private canvasLabelsVisualizer: ScatterPlotVisualizerCanvasLabels; - private polylineVisualizer: ScatterPlotVisualizerPolylines; - - constructor( - private scatterPlotContainer: HTMLElement, - projectorEventContext: ProjectorEventContext) { - this.scatterPlot = - new ScatterPlot(scatterPlotContainer, projectorEventContext); - projectorEventContext.registerProjectionChangedListener(projection => { - this.projection = projection; - this.updateScatterPlotWithNewProjection(projection); - }); - projectorEventContext.registerSelectionChangedListener( - (selectedPointIndices, neighbors) => { - this.selectedPointIndices = selectedPointIndices; - this.neighborsOfFirstSelectedPoint = neighbors; - this.updateScatterPlotPositions(); - this.updateScatterPlotAttributes(); - this.scatterPlot.render(); - }); - projectorEventContext.registerHoverListener(hoverPointIndex => { - this.hoverPointIndex = hoverPointIndex; - this.updateScatterPlotAttributes(); - this.scatterPlot.render(); - }); - projectorEventContext.registerDistanceMetricChangedListener( - distanceMetric => { - this.distanceMetric = distanceMetric; - this.updateScatterPlotAttributes(); - this.scatterPlot.render(); - }); - this.createVisualizers(false); - } - - notifyProjectionPositionsUpdated() { - this.updateScatterPlotPositions(); - this.scatterPlot.render(); - } - - setDataSet(dataSet: DataSet) { - if (this.projection != null) { - // TODO(nicholsonc): setDataSet needs to go away, the projection is the - // atomic unit of update. - this.projection.dataSet = dataSet; - } - if (this.polylineVisualizer != null) { - this.polylineVisualizer.setDataSet(dataSet); - } - if (this.labels3DVisualizer != null) { - this.labels3DVisualizer.setLabelStrings( - this.generate3DLabelsArray(dataSet, this.labelPointAccessor)); - } - if (this.spriteVisualizer == null) { - return; - } - this.spriteVisualizer.clearSpriteAtlas(); - if ((dataSet == null) || (dataSet.spriteAndMetadataInfo == null)) { - return; - } - const metadata = dataSet.spriteAndMetadataInfo; - if ((metadata.spriteImage == null) || (metadata.spriteMetadata == null)) { - return; - } - const n = dataSet.points.length; - const spriteIndices = new Float32Array(n); - for (let i = 0; i < n; ++i) { - spriteIndices[i] = dataSet.points[i].index; - } - this.spriteVisualizer.setSpriteAtlas( - metadata.spriteImage, metadata.spriteMetadata.singleImageDim, - spriteIndices); - } - - set3DLabelMode(renderLabelsIn3D: boolean) { - this.renderLabelsIn3D = renderLabelsIn3D; - this.createVisualizers(renderLabelsIn3D); - this.updateScatterPlotAttributes(); - this.scatterPlot.render(); - } - - setLegendPointColorer( - legendPointColorer: (ds: DataSet, index: number) => string) { - this.legendPointColorer = legendPointColorer; - } - - setLabelPointAccessor(labelPointAccessor: string) { - this.labelPointAccessor = labelPointAccessor; - if (this.labels3DVisualizer != null) { - const ds = (this.projection == null) ? null : this.projection.dataSet; - this.labels3DVisualizer.setLabelStrings( - this.generate3DLabelsArray(ds, labelPointAccessor)); - } - } - - resize() { - this.scatterPlot.resize(); - } - - populateBookmarkFromUI(state: State) { - state.cameraDef = this.scatterPlot.getCameraDef(); - } - - restoreUIFromBookmark(state: State) { - this.scatterPlot.setCameraParametersForNextCameraCreation( - state.cameraDef, false); - } - - updateScatterPlotPositions() { - const ds = (this.projection == null) ? null : this.projection.dataSet; - const projectionComponents = - (this.projection == null) ? null : this.projection.projectionComponents; - const newPositions = - this.generatePointPositionArray(ds, projectionComponents); - this.scatterPlot.setPointPositions(newPositions); - } - - updateScatterPlotAttributes() { - if (this.projection == null) { - return; - } - const dataSet = this.projection.dataSet; - const selectedSet = this.selectedPointIndices; - const hoverIndex = this.hoverPointIndex; - const neighbors = this.neighborsOfFirstSelectedPoint; - const pointColorer = this.legendPointColorer; - - const pointColors = this.generatePointColorArray( - dataSet, pointColorer, this.distanceMetric, selectedSet, neighbors, - hoverIndex, this.renderLabelsIn3D, this.getSpriteImageMode()); - const pointScaleFactors = this.generatePointScaleFactorArray( - dataSet, selectedSet, neighbors, hoverIndex); - const labels = this.generateVisibleLabelRenderParams( - dataSet, selectedSet, neighbors, hoverIndex); - const polylineColors = - this.generateLineSegmentColorMap(dataSet, pointColorer); - const polylineOpacities = - this.generateLineSegmentOpacityArray(dataSet, selectedSet); - const polylineWidths = - this.generateLineSegmentWidthArray(dataSet, selectedSet); - - this.scatterPlot.setPointColors(pointColors); - this.scatterPlot.setPointScaleFactors(pointScaleFactors); - this.scatterPlot.setLabels(labels); - this.scatterPlot.setPolylineColors(polylineColors); - this.scatterPlot.setPolylineOpacities(polylineOpacities); - this.scatterPlot.setPolylineWidths(polylineWidths); - } - - render() { - this.scatterPlot.render(); - } - - generatePointPositionArray( - ds: DataSet, projectionComponents: ProjectionComponents3D): Float32Array { - if (ds == null) { - return null; - } - - const xScaler = d3.scaleLinear(); - const yScaler = d3.scaleLinear(); - let zScaler = null; - { - // Determine max and min of each axis of our data. - const xExtent = d3.extent( - ds.points, - (p, i) => ds.points[i].projections[projectionComponents[0]]); - const yExtent = d3.extent( - ds.points, - (p, i) => ds.points[i].projections[projectionComponents[1]]); - - const range = - [-SCATTER_PLOT_CUBE_LENGTH / 2, SCATTER_PLOT_CUBE_LENGTH / 2]; - - xScaler.domain(xExtent).range(range); - yScaler.domain(yExtent).range(range); - - if (projectionComponents[2] != null) { - const zExtent = d3.extent( - ds.points, - (p, i) => ds.points[i].projections[projectionComponents[2]]); - zScaler = d3.scaleLinear(); - zScaler.domain(zExtent).range(range); - } - } - - const positions = new Float32Array(ds.points.length * 3); - let dst = 0; - - ds.points.forEach((d, i) => { - positions[dst++] = - xScaler(ds.points[i].projections[projectionComponents[0]]); - positions[dst++] = - yScaler(ds.points[i].projections[projectionComponents[1]]); - positions[dst++] = 0.0; - }); - - if (zScaler) { - dst = 2; - ds.points.forEach((d, i) => { - positions[dst] = - zScaler(ds.points[i].projections[projectionComponents[2]]); - dst += 3; - }); - } - - return positions; - } - - generateVisibleLabelRenderParams( - ds: DataSet, selectedPointIndices: number[], - neighborsOfFirstPoint: NearestEntry[], - hoverPointIndex: number): LabelRenderParams { - if (ds == null) { - return null; - } - - const selectedPointCount = - (selectedPointIndices == null) ? 0 : selectedPointIndices.length; - const neighborCount = - (neighborsOfFirstPoint == null) ? 0 : neighborsOfFirstPoint.length; - const n = selectedPointCount + neighborCount + - ((hoverPointIndex != null) ? 1 : 0); - - const visibleLabels = new Uint32Array(n); - const scale = new Float32Array(n); - const opacityFlags = new Int8Array(n); - const fillColors = new Uint8Array(n * 3); - const strokeColors = new Uint8Array(n * 3); - const labelStrings: string[] = []; - - scale.fill(LABEL_SCALE_DEFAULT); - opacityFlags.fill(1); - - let dst = 0; - - if (hoverPointIndex != null) { - labelStrings.push( - this.getLabelText(ds, hoverPointIndex, this.labelPointAccessor)); - visibleLabels[dst] = hoverPointIndex; - scale[dst] = LABEL_SCALE_LARGE; - opacityFlags[dst] = 0; - const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_HOVER); - packRgbIntoUint8Array( - fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]); - const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_HOVER); - packRgbIntoUint8Array( - strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[1]); - ++dst; - } - - // Selected points - { - const n = selectedPointCount; - const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_SELECTED); - const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_SELECTED); - for (let i = 0; i < n; ++i) { - const labelIndex = selectedPointIndices[i]; - labelStrings.push( - this.getLabelText(ds, labelIndex, this.labelPointAccessor)); - visibleLabels[dst] = labelIndex; - scale[dst] = LABEL_SCALE_LARGE; - opacityFlags[dst] = (n === 1) ? 0 : 1; - packRgbIntoUint8Array( - fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]); - packRgbIntoUint8Array( - strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[2]); - ++dst; - } - } - - // Neighbors - { - const n = neighborCount; - const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_NEIGHBOR); - const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_NEIGHBOR); - for (let i = 0; i < n; ++i) { - const labelIndex = neighborsOfFirstPoint[i].index; - labelStrings.push( - this.getLabelText(ds, labelIndex, this.labelPointAccessor)); - visibleLabels[dst] = labelIndex; - packRgbIntoUint8Array( - fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]); - packRgbIntoUint8Array( - strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[2]); - ++dst; - } - } - - return new LabelRenderParams( - new Float32Array(visibleLabels), labelStrings, scale, opacityFlags, - LABEL_FONT_SIZE, fillColors, strokeColors); - } - - generatePointScaleFactorArray( - ds: DataSet, selectedPointIndices: number[], - neighborsOfFirstPoint: NearestEntry[], - hoverPointIndex: number): Float32Array { - if (ds == null) { - return new Float32Array(0); - } - - const scale = new Float32Array(ds.points.length); - scale.fill(POINT_SCALE_DEFAULT); - - const selectedPointCount = - (selectedPointIndices == null) ? 0 : selectedPointIndices.length; - const neighborCount = - (neighborsOfFirstPoint == null) ? 0 : neighborsOfFirstPoint.length; - - // Scale up all selected points. - { - const n = selectedPointCount; - for (let i = 0; i < n; ++i) { - const p = selectedPointIndices[i]; - scale[p] = POINT_SCALE_SELECTED; - } - } - - // Scale up the neighbor points. - { - const n = neighborCount; - for (let i = 0; i < n; ++i) { - const p = neighborsOfFirstPoint[i].index; - scale[p] = POINT_SCALE_NEIGHBOR; - } - } - - // Scale up the hover point. - if (hoverPointIndex != null) { - scale[hoverPointIndex] = POINT_SCALE_HOVER; - } - - return scale; - } - - generateLineSegmentColorMap( - ds: DataSet, legendPointColorer: (ds: DataSet, index: number) => string): - {[polylineIndex: number]: Float32Array} { - let polylineColorArrayMap: {[polylineIndex: number]: Float32Array} = {}; - if (ds == null) { - return polylineColorArrayMap; - } - - for (let i = 0; i < ds.sequences.length; i++) { - let sequence = ds.sequences[i]; - let colors = new Float32Array(2 * (sequence.pointIndices.length - 1) * 3); - let colorIndex = 0; - - if (legendPointColorer) { - for (let j = 0; j < sequence.pointIndices.length - 1; j++) { - const c1 = - new THREE.Color(legendPointColorer(ds, sequence.pointIndices[j])); - const c2 = new THREE.Color( - legendPointColorer(ds, sequence.pointIndices[j + 1])); - colors[colorIndex++] = c1.r; - colors[colorIndex++] = c1.g; - colors[colorIndex++] = c1.b; - colors[colorIndex++] = c2.r; - colors[colorIndex++] = c2.g; - colors[colorIndex++] = c2.b; - } - } else { - for (let j = 0; j < sequence.pointIndices.length - 1; j++) { - const c1 = - getDefaultPointInPolylineColor(j, sequence.pointIndices.length); - const c2 = getDefaultPointInPolylineColor( - j + 1, sequence.pointIndices.length); - colors[colorIndex++] = c1.r; - colors[colorIndex++] = c1.g; - colors[colorIndex++] = c1.b; - colors[colorIndex++] = c2.r; - colors[colorIndex++] = c2.g; - colors[colorIndex++] = c2.b; - } - } - - polylineColorArrayMap[i] = colors; - } - - return polylineColorArrayMap; - } - - generateLineSegmentOpacityArray(ds: DataSet, selectedPoints: number[]): - Float32Array { - if (ds == null) { - return new Float32Array(0); - } - const opacities = new Float32Array(ds.sequences.length); - const selectedPointCount = - (selectedPoints == null) ? 0 : selectedPoints.length; - if (selectedPointCount > 0) { - opacities.fill(POLYLINE_DESELECTED_OPACITY); - const i = ds.points[selectedPoints[0]].sequenceIndex; - opacities[i] = POLYLINE_SELECTED_OPACITY; - } else { - opacities.fill(POLYLINE_DEFAULT_OPACITY); - } - return opacities; - } - - generateLineSegmentWidthArray(ds: DataSet, selectedPoints: number[]): - Float32Array { - if (ds == null) { - return new Float32Array(0); - } - const widths = new Float32Array(ds.sequences.length); - widths.fill(POLYLINE_DEFAULT_LINEWIDTH); - const selectedPointCount = - (selectedPoints == null) ? 0 : selectedPoints.length; - if (selectedPointCount > 0) { - const i = ds.points[selectedPoints[0]].sequenceIndex; - widths[i] = POLYLINE_SELECTED_LINEWIDTH; - } - return widths; - } - - generatePointColorArray( - ds: DataSet, legendPointColorer: (ds: DataSet, index: number) => string, - distFunc: DistanceFunction, selectedPointIndices: number[], - neighborsOfFirstPoint: NearestEntry[], hoverPointIndex: number, - label3dMode: boolean, spriteImageMode: boolean): Float32Array { - if (ds == null) { - return new Float32Array(0); - } - - const selectedPointCount = - (selectedPointIndices == null) ? 0 : selectedPointIndices.length; - const neighborCount = - (neighborsOfFirstPoint == null) ? 0 : neighborsOfFirstPoint.length; - const colors = new Float32Array(ds.points.length * 3); - - let unselectedColor = POINT_COLOR_UNSELECTED; - let noSelectionColor = POINT_COLOR_NO_SELECTION; - - if (label3dMode) { - unselectedColor = LABELS_3D_COLOR_UNSELECTED; - noSelectionColor = LABELS_3D_COLOR_NO_SELECTION; - } - - if (spriteImageMode) { - unselectedColor = SPRITE_IMAGE_COLOR_UNSELECTED; - noSelectionColor = SPRITE_IMAGE_COLOR_NO_SELECTION; - } - - // Give all points the unselected color. - { - const n = ds.points.length; - let dst = 0; - if (selectedPointCount > 0) { - const c = new THREE.Color(unselectedColor); - for (let i = 0; i < n; ++i) { - colors[dst++] = c.r; - colors[dst++] = c.g; - colors[dst++] = c.b; - } - } else { - if (legendPointColorer != null) { - for (let i = 0; i < n; ++i) { - const c = new THREE.Color(legendPointColorer(ds, i)); - colors[dst++] = c.r; - colors[dst++] = c.g; - colors[dst++] = c.b; - } - } else { - const c = new THREE.Color(noSelectionColor); - for (let i = 0; i < n; ++i) { - colors[dst++] = c.r; - colors[dst++] = c.g; - colors[dst++] = c.b; - } - } - } - } - - // Color the selected points. - { - const n = selectedPointCount; - const c = new THREE.Color(POINT_COLOR_SELECTED); - for (let i = 0; i < n; ++i) { - let dst = selectedPointIndices[i] * 3; - colors[dst++] = c.r; - colors[dst++] = c.g; - colors[dst++] = c.b; - } - } - - // Color the neighbors. - { - const n = neighborCount; - let minDist = n > 0 ? neighborsOfFirstPoint[0].dist : 0; - for (let i = 0; i < n; ++i) { - const c = new THREE.Color( - dist2color(distFunc, neighborsOfFirstPoint[i].dist, minDist)); - let dst = neighborsOfFirstPoint[i].index * 3; - colors[dst++] = c.r; - colors[dst++] = c.g; - colors[dst++] = c.b; - } - } - - // Color the hover point. - if (hoverPointIndex != null) { - const c = new THREE.Color(POINT_COLOR_HOVER); - let dst = hoverPointIndex * 3; - colors[dst++] = c.r; - colors[dst++] = c.g; - colors[dst++] = c.b; - } - - return colors; - } - - generate3DLabelsArray(ds: DataSet, accessor: string) { - if ((ds == null) || (accessor == null)) { - return null; - } - let labels: string[] = []; - const n = ds.points.length; - for (let i = 0; i < n; ++i) { - labels.push(this.getLabelText(ds, i, accessor)); - } - return labels; - } - - private getLabelText(ds: DataSet, i: number, accessor: string) { - return ds.points[i].metadata[accessor].toString(); - } - - private updateScatterPlotWithNewProjection(projection: Projection) { - if (projection == null) { - this.createVisualizers(this.renderLabelsIn3D); - this.scatterPlot.render(); - return; - } - this.setDataSet(projection.dataSet); - this.scatterPlot.setDimensions(projection.dimensionality); - if (projection.dataSet.projectionCanBeRendered(projection.projectionType)) { - this.updateScatterPlotAttributes(); - this.notifyProjectionPositionsUpdated(); - } - this.scatterPlot.setCameraParametersForNextCameraCreation(null, false); - } - - private createVisualizers(inLabels3DMode: boolean) { - const ds = (this.projection == null) ? null : this.projection.dataSet; - const scatterPlot = this.scatterPlot; - scatterPlot.removeAllVisualizers(); - this.labels3DVisualizer = null; - this.canvasLabelsVisualizer = null; - this.spriteVisualizer = null; - this.polylineVisualizer = null; - if (inLabels3DMode) { - this.labels3DVisualizer = new ScatterPlotVisualizer3DLabels(); - this.labels3DVisualizer.setLabelStrings( - this.generate3DLabelsArray(ds, this.labelPointAccessor)); - } else { - this.spriteVisualizer = new ScatterPlotVisualizerSprites(); - scatterPlot.addVisualizer(this.spriteVisualizer); - this.canvasLabelsVisualizer = - new ScatterPlotVisualizerCanvasLabels(this.scatterPlotContainer); - } - this.polylineVisualizer = new ScatterPlotVisualizerPolylines(); - this.setDataSet(ds); - if (this.spriteVisualizer) { - scatterPlot.addVisualizer(this.spriteVisualizer); - } - if (this.labels3DVisualizer) { - scatterPlot.addVisualizer(this.labels3DVisualizer); - } - if (this.canvasLabelsVisualizer) { - scatterPlot.addVisualizer(this.canvasLabelsVisualizer); - } - scatterPlot.addVisualizer(this.polylineVisualizer); - } - - private getSpriteImageMode(): boolean { - if (this.projection == null) { - return false; - } - const ds = this.projection.dataSet; - if ((ds == null) || (ds.spriteAndMetadataInfo == null)) { - return false; - } - return ds.spriteAndMetadataInfo.spriteImage != null; - } -} - -function packRgbIntoUint8Array( - rgbArray: Uint8Array, labelIndex: number, r: number, g: number, b: number) { - rgbArray[labelIndex * 3] = r; - rgbArray[labelIndex * 3 + 1] = g; - rgbArray[labelIndex * 3 + 2] = b; -} - -function styleRgbFromHexColor(hex: number): [number, number, number] { - const c = new THREE.Color(hex); - return [(c.r * 255) | 0, (c.g * 255) | 0, (c.b * 255) | 0]; -} - -function getDefaultPointInPolylineColor( - index: number, totalPoints: number): THREE.Color { - let hue = POLYLINE_START_HUE + - (POLYLINE_END_HUE - POLYLINE_START_HUE) * index / totalPoints; - - let rgb = d3.hsl(hue, POLYLINE_SATURATION, POLYLINE_LIGHTNESS).rgb(); - return new THREE.Color(rgb.r / 255, rgb.g / 255, rgb.b / 255); -} - -/** - * Normalizes the distance so it can be visually encoded with color. - * The normalization depends on the distance metric (cosine vs euclidean). - */ -export function normalizeDist( - distFunc: DistanceFunction, d: number, minDist: number): number { - return (distFunc === vector.dist) ? (minDist / d) : (1 - d); -} - -/** Normalizes and encodes the provided distance with color. */ -export function dist2color( - distFunc: DistanceFunction, d: number, minDist: number): string { - return NN_COLOR_SCALE(normalizeDist(distFunc, d, minDist)); -} diff --git a/tensorflow/tensorboard/components/vz_projector/renderContext.ts b/tensorflow/tensorboard/components/vz_projector/renderContext.ts deleted file mode 100644 index 8d5232a8048..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/renderContext.ts +++ /dev/null @@ -1,53 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http:www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/** - * LabelRenderParams describes the set of points that should have labels - * rendered next to them. - */ -export class LabelRenderParams { - constructor( - public pointIndices: Float32Array, public labelStrings: string[], - public scaleFactors: Float32Array, public useSceneOpacityFlags: Int8Array, - public defaultFontSize: number, public fillColors: Uint8Array, - public strokeColors: Uint8Array) {} -} - -/** Details about the camera projection being used to render the scene. */ -export enum CameraType { - Perspective, - Orthographic -} - -/** - * RenderContext contains all of the state required to color and render the data - * set. ScatterPlot passes this to every attached visualizer as part of the - * render callback. - * TODO(nicholsonc): This should only contain the data that's changed between - * each frame. Data like colors / scale factors / labels should be reapplied - * only when they change. - */ -export class RenderContext { - constructor( - public camera: THREE.Camera, public cameraType: CameraType, - public cameraTarget: THREE.Vector3, public screenWidth: number, - public screenHeight: number, public nearestCameraSpacePointZ: number, - public farthestCameraSpacePointZ: number, public backgroundColor: number, - public pointColors: Float32Array, public pointScaleFactors: Float32Array, - public labels: LabelRenderParams, - public polylineColors: {[polylineIndex: number]: Float32Array}, - public polylineOpacities: Float32Array, - public polylineWidths: Float32Array) {} -} diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlot.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlot.ts deleted file mode 100644 index 283b608e836..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/scatterPlot.ts +++ /dev/null @@ -1,723 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {ProjectorEventContext} from './projectorEventContext'; -import {CameraType, LabelRenderParams, RenderContext} from './renderContext'; -import {BoundingBox, ScatterPlotRectangleSelector} from './scatterPlotRectangleSelector'; -import {ScatterPlotVisualizer} from './scatterPlotVisualizer'; -import * as util from './util'; -import {Point2D, Point3D} from './vector'; - -const BACKGROUND_COLOR = 0xffffff; - -/** - * The length of the cube (diameter of the circumscribing sphere) where all the - * points live. - */ -const CUBE_LENGTH = 2; -const MAX_ZOOM = 5 * CUBE_LENGTH; -const MIN_ZOOM = 0.025 * CUBE_LENGTH; - -// Constants relating to the camera parameters. -const PERSP_CAMERA_FOV_VERTICAL = 70; -const PERSP_CAMERA_NEAR_CLIP_PLANE = 0.01; -const PERSP_CAMERA_FAR_CLIP_PLANE = 100; -const ORTHO_CAMERA_FRUSTUM_HALF_EXTENT = 1.2; - -// Key presses. -const SHIFT_KEY = 16; -const CTRL_KEY = 17; - -const START_CAMERA_POS_3D = new THREE.Vector3(0.45, 0.9, 1.6); -const START_CAMERA_TARGET_3D = new THREE.Vector3(0, 0, 0); -const START_CAMERA_POS_2D = new THREE.Vector3(0, 0, 4); -const START_CAMERA_TARGET_2D = new THREE.Vector3(0, 0, 0); - -const ORBIT_MOUSE_ROTATION_SPEED = 1; -const ORBIT_ANIMATION_ROTATION_CYCLE_IN_SECONDS = 7; - -export type OnCameraMoveListener = - (cameraPosition: THREE.Vector3, cameraTarget: THREE.Vector3) => void; - -/** Supported modes of interaction. */ -export enum MouseMode { - AREA_SELECT, - CAMERA_AND_CLICK_SELECT -} - -/** Defines a camera, suitable for serialization. */ -export class CameraDef { - orthographic: boolean = false; - position: Point3D; - target: Point3D; - zoom: number; -} - -/** - * Maintains a three.js instantiation and context, - * animation state, and all other logic that's - * independent of how a 3D scatter plot is actually rendered. Also holds an - * array of visualizers and dispatches application events to them. - */ -export class ScatterPlot { - private visualizers: ScatterPlotVisualizer[] = []; - - private onCameraMoveListeners: OnCameraMoveListener[] = []; - - private height: number; - private width: number; - - private mouseMode: MouseMode; - private backgroundColor: number = BACKGROUND_COLOR; - - private dimensionality: number = 3; - private renderer: THREE.WebGLRenderer; - - private scene: THREE.Scene; - private pickingTexture: THREE.WebGLRenderTarget; - private light: THREE.PointLight; - - private cameraDef: CameraDef = null; - private camera: THREE.Camera; - private orbitAnimationOnNextCameraCreation: boolean = false; - private orbitCameraControls: any; - private orbitAnimationId: number; - - private worldSpacePointPositions: Float32Array; - private pointColors: Float32Array; - private pointScaleFactors: Float32Array; - private labels: LabelRenderParams; - private polylineColors: {[polylineIndex: number]: Float32Array}; - private polylineOpacities: Float32Array; - private polylineWidths: Float32Array; - - private selecting = false; - private nearestPoint: number; - private mouseIsDown = false; - private isDragSequence = false; - private rectangleSelector: ScatterPlotRectangleSelector; - - constructor( - private container: HTMLElement, - private projectorEventContext: ProjectorEventContext) { - this.getLayoutValues(); - - this.scene = new THREE.Scene(); - this.renderer = new THREE.WebGLRenderer( - {alpha: true, premultipliedAlpha: false, antialias: false}); - this.renderer.setClearColor(BACKGROUND_COLOR, 1); - this.container.appendChild(this.renderer.domElement); - this.light = new THREE.PointLight(0xFFECBF, 1, 0); - this.scene.add(this.light); - - this.setDimensions(3); - this.recreateCamera(this.makeDefaultCameraDef(this.dimensionality)); - this.renderer.render(this.scene, this.camera); - - this.rectangleSelector = new ScatterPlotRectangleSelector( - this.container, - (boundingBox: BoundingBox) => this.selectBoundingBox(boundingBox)); - this.addInteractionListeners(); - } - - private addInteractionListeners() { - this.container.addEventListener('mousemove', this.onMouseMove.bind(this)); - this.container.addEventListener('mousedown', this.onMouseDown.bind(this)); - this.container.addEventListener('mouseup', this.onMouseUp.bind(this)); - this.container.addEventListener('click', this.onClick.bind(this)); - window.addEventListener('keydown', this.onKeyDown.bind(this), false); - window.addEventListener('keyup', this.onKeyUp.bind(this), false); - } - - private addCameraControlsEventListeners(cameraControls: any) { - // Start is called when the user stars interacting with - // controls. - cameraControls.addEventListener('start', () => { - this.stopOrbitAnimation(); - this.onCameraMoveListeners.forEach( - l => l(this.camera.position, cameraControls.target)); - }); - - // Change is called everytime the user interacts with the controls. - cameraControls.addEventListener('change', () => { - this.render(); - }); - - // End is called when the user stops interacting with the - // controls (e.g. on mouse up, after dragging). - cameraControls.addEventListener('end', () => {}); - } - - private makeOrbitControls( - camera: THREE.Camera, cameraDef: CameraDef, cameraIs3D: boolean) { - if (this.orbitCameraControls != null) { - this.orbitCameraControls.dispose(); - } - const occ = - new (THREE as any).OrbitControls(camera, this.renderer.domElement); - occ.target0 = new THREE.Vector3( - cameraDef.target[0], cameraDef.target[1], cameraDef.target[2]); - occ.position0 = new THREE.Vector3().copy(camera.position); - occ.zoom0 = cameraDef.zoom; - occ.enableRotate = cameraIs3D; - occ.autoRotate = false; - occ.rotateSpeed = ORBIT_MOUSE_ROTATION_SPEED; - if (cameraIs3D) { - occ.mouseButtons.ORBIT = THREE.MOUSE.LEFT; - occ.mouseButtons.PAN = THREE.MOUSE.RIGHT; - } else { - occ.mouseButtons.ORBIT = null; - occ.mouseButtons.PAN = THREE.MOUSE.LEFT; - } - occ.reset(); - - this.camera = camera; - this.orbitCameraControls = occ; - this.addCameraControlsEventListeners(this.orbitCameraControls); - } - - private makeCamera3D(cameraDef: CameraDef, w: number, h: number) { - let camera: THREE.PerspectiveCamera; - { - const aspectRatio = w / h; - camera = new THREE.PerspectiveCamera( - PERSP_CAMERA_FOV_VERTICAL, aspectRatio, PERSP_CAMERA_NEAR_CLIP_PLANE, - PERSP_CAMERA_FAR_CLIP_PLANE); - camera.position.set( - cameraDef.position[0], cameraDef.position[1], cameraDef.position[2]); - const at = new THREE.Vector3( - cameraDef.target[0], cameraDef.target[1], cameraDef.target[2]); - camera.lookAt(at); - camera.zoom = cameraDef.zoom; - camera.updateProjectionMatrix(); - } - this.camera = camera; - this.makeOrbitControls(camera, cameraDef, true); - } - - private makeCamera2D(cameraDef: CameraDef, w: number, h: number) { - let camera: THREE.OrthographicCamera; - const target = new THREE.Vector3( - cameraDef.target[0], cameraDef.target[1], cameraDef.target[2]); - { - const aspectRatio = w / h; - let left = -ORTHO_CAMERA_FRUSTUM_HALF_EXTENT; - let right = ORTHO_CAMERA_FRUSTUM_HALF_EXTENT; - let bottom = -ORTHO_CAMERA_FRUSTUM_HALF_EXTENT; - let top = ORTHO_CAMERA_FRUSTUM_HALF_EXTENT; - // Scale up the larger of (w, h) to match the aspect ratio. - if (aspectRatio > 1) { - left *= aspectRatio; - right *= aspectRatio; - } else { - top /= aspectRatio; - bottom /= aspectRatio; - } - camera = - new THREE.OrthographicCamera(left, right, top, bottom, -1000, 1000); - camera.position.set( - cameraDef.position[0], cameraDef.position[1], cameraDef.position[2]); - camera.up = new THREE.Vector3(0, 1, 0); - camera.lookAt(target); - camera.zoom = cameraDef.zoom; - camera.updateProjectionMatrix(); - } - this.camera = camera; - this.makeOrbitControls(camera, cameraDef, false); - } - - private makeDefaultCameraDef(dimensionality: number): CameraDef { - const def = new CameraDef(); - def.orthographic = (dimensionality === 2); - def.zoom = 1.0; - if (def.orthographic) { - def.position = - [START_CAMERA_POS_2D.x, START_CAMERA_POS_2D.y, START_CAMERA_POS_2D.z]; - def.target = [ - START_CAMERA_TARGET_2D.x, START_CAMERA_TARGET_2D.y, - START_CAMERA_TARGET_2D.z - ]; - } else { - def.position = - [START_CAMERA_POS_3D.x, START_CAMERA_POS_3D.y, START_CAMERA_POS_3D.z]; - def.target = [ - START_CAMERA_TARGET_3D.x, START_CAMERA_TARGET_3D.y, - START_CAMERA_TARGET_3D.z - ]; - } - return def; - } - - /** Recreate the scatter plot camera from a definition structure. */ - recreateCamera(cameraDef: CameraDef) { - if (cameraDef.orthographic) { - this.makeCamera2D(cameraDef, this.width, this.height); - } else { - this.makeCamera3D(cameraDef, this.width, this.height); - } - this.orbitCameraControls.minDistance = MIN_ZOOM; - this.orbitCameraControls.maxDistance = MAX_ZOOM; - this.orbitCameraControls.update(); - if (this.orbitAnimationOnNextCameraCreation) { - this.startOrbitAnimation(); - } - } - - private onClick(e?: MouseEvent, notify = true) { - if (e && this.selecting) { - return; - } - // Only call event handlers if the click originated from the scatter plot. - if (!this.isDragSequence && notify) { - const selection = (this.nearestPoint != null) ? [this.nearestPoint] : []; - this.projectorEventContext.notifySelectionChanged(selection); - } - this.isDragSequence = false; - this.render(); - } - - private onMouseDown(e: MouseEvent) { - this.isDragSequence = false; - this.mouseIsDown = true; - if (this.selecting) { - this.orbitCameraControls.enabled = false; - this.rectangleSelector.onMouseDown(e.offsetX, e.offsetY); - this.setNearestPointToMouse(e); - } else if ( - !e.ctrlKey && this.sceneIs3D() && - this.orbitCameraControls.mouseButtons.ORBIT === THREE.MOUSE.RIGHT) { - // The user happened to press the ctrl key when the tab was active, - // unpressed the ctrl when the tab was inactive, and now he/she - // is back to the projector tab. - this.orbitCameraControls.mouseButtons.ORBIT = THREE.MOUSE.LEFT; - this.orbitCameraControls.mouseButtons.PAN = THREE.MOUSE.RIGHT; - } else if ( - e.ctrlKey && this.sceneIs3D() && - this.orbitCameraControls.mouseButtons.ORBIT === THREE.MOUSE.LEFT) { - // Similarly to the situation above. - this.orbitCameraControls.mouseButtons.ORBIT = THREE.MOUSE.RIGHT; - this.orbitCameraControls.mouseButtons.PAN = THREE.MOUSE.LEFT; - } - } - - /** When we stop dragging/zooming, return to normal behavior. */ - private onMouseUp(e: any) { - if (this.selecting) { - this.orbitCameraControls.enabled = true; - this.rectangleSelector.onMouseUp(); - this.render(); - } - this.mouseIsDown = false; - } - - /** - * When the mouse moves, find the nearest point (if any) and send it to the - * hoverlisteners (usually called from embedding.ts) - */ - private onMouseMove(e: MouseEvent) { - this.isDragSequence = this.mouseIsDown; - // Depending if we're selecting or just navigating, handle accordingly. - if (this.selecting && this.mouseIsDown) { - this.rectangleSelector.onMouseMove(e.offsetX, e.offsetY); - this.render(); - } else if (!this.mouseIsDown) { - this.setNearestPointToMouse(e); - this.projectorEventContext.notifyHoverOverPoint(this.nearestPoint); - } - } - - /** For using ctrl + left click as right click, and for circle select */ - private onKeyDown(e: any) { - // If ctrl is pressed, use left click to orbit - if (e.keyCode === CTRL_KEY && this.sceneIs3D()) { - this.orbitCameraControls.mouseButtons.ORBIT = THREE.MOUSE.RIGHT; - this.orbitCameraControls.mouseButtons.PAN = THREE.MOUSE.LEFT; - } - - // If shift is pressed, start selecting - if (e.keyCode === SHIFT_KEY) { - this.selecting = true; - this.container.style.cursor = 'crosshair'; - } - } - - /** For using ctrl + left click as right click, and for circle select */ - private onKeyUp(e: any) { - if (e.keyCode === CTRL_KEY && this.sceneIs3D()) { - this.orbitCameraControls.mouseButtons.ORBIT = THREE.MOUSE.LEFT; - this.orbitCameraControls.mouseButtons.PAN = THREE.MOUSE.RIGHT; - } - - // If shift is released, stop selecting - if (e.keyCode === SHIFT_KEY) { - this.selecting = (this.getMouseMode() === MouseMode.AREA_SELECT); - if (!this.selecting) { - this.container.style.cursor = 'default'; - } - this.render(); - } - } - - /** - * Returns a list of indices of points in a bounding box from the picking - * texture. - * @param boundingBox The bounding box to select from. - */ - private getPointIndicesFromPickingTexture(boundingBox: BoundingBox): - number[] { - if (this.worldSpacePointPositions == null) { - return null; - } - const pointCount = this.worldSpacePointPositions.length / 3; - const dpr = window.devicePixelRatio || 1; - const x = Math.floor(boundingBox.x * dpr); - const y = Math.floor(boundingBox.y * dpr); - const width = Math.floor(boundingBox.width * dpr); - const height = Math.floor(boundingBox.height * dpr); - - // Create buffer for reading all of the pixels from the texture. - let pixelBuffer = new Uint8Array(width * height * 4); - - // Read the pixels from the bounding box. - this.renderer.readRenderTargetPixels( - this.pickingTexture, x, this.pickingTexture.height - y, width, height, - pixelBuffer); - - // Keep a flat list of each point and whether they are selected or not. This - // approach is more efficient than using an object keyed by the index. - let pointIndicesSelection = - new Uint8Array(this.worldSpacePointPositions.length); - for (let i = 0; i < width * height; i++) { - const id = (pixelBuffer[i * 4] << 16) | (pixelBuffer[i * 4 + 1] << 8) | - pixelBuffer[i * 4 + 2]; - if (id !== 0xffffff && (id < pointCount)) { - pointIndicesSelection[id] = 1; - } - } - let pointIndices: number[] = []; - for (let i = 0; i < pointIndicesSelection.length; i++) { - if (pointIndicesSelection[i] === 1) { - pointIndices.push(i); - } - } - - return pointIndices; - } - - - private selectBoundingBox(boundingBox: BoundingBox) { - let pointIndices = this.getPointIndicesFromPickingTexture(boundingBox); - this.projectorEventContext.notifySelectionChanged(pointIndices); - } - - private setNearestPointToMouse(e: MouseEvent) { - if (this.pickingTexture == null) { - this.nearestPoint = null; - return; - } - const boundingBox: - BoundingBox = {x: e.offsetX, y: e.offsetY, width: 1, height: 1}; - const pointIndices = this.getPointIndicesFromPickingTexture(boundingBox); - this.nearestPoint = (pointIndices != null) ? pointIndices[0] : null; - } - - private getLayoutValues(): Point2D { - this.width = this.container.offsetWidth; - this.height = Math.max(1, this.container.offsetHeight); - return [this.width, this.height]; - } - - private sceneIs3D(): boolean { - return this.dimensionality === 3; - } - - private remove3dAxisFromScene(): THREE.Object3D { - const axes = this.scene.getObjectByName('axes'); - if (axes != null) { - this.scene.remove(axes); - } - return axes; - } - - private add3dAxis() { - const axes = new THREE.AxisHelper(); - axes.name = 'axes'; - this.scene.add(axes); - } - - /** Set 2d vs 3d mode. */ - setDimensions(dimensionality: number) { - if ((dimensionality !== 2) && (dimensionality !== 3)) { - throw new RangeError('dimensionality must be 2 or 3'); - } - this.dimensionality = dimensionality; - - const def = this.cameraDef || this.makeDefaultCameraDef(dimensionality); - this.recreateCamera(def); - - this.remove3dAxisFromScene(); - if (dimensionality === 3) { - this.add3dAxis(); - } - } - - /** Gets the current camera information, suitable for serialization. */ - getCameraDef(): CameraDef { - const def = new CameraDef(); - const pos = this.camera.position; - const tgt = this.orbitCameraControls.target; - def.orthographic = !this.sceneIs3D(); - def.position = [pos.x, pos.y, pos.z]; - def.target = [tgt.x, tgt.y, tgt.z]; - def.zoom = (this.camera as any).zoom; - return def; - } - - /** Sets parameters for the next camera recreation. */ - setCameraParametersForNextCameraCreation( - def: CameraDef, orbitAnimation: boolean) { - this.cameraDef = def; - this.orbitAnimationOnNextCameraCreation = orbitAnimation; - } - - /** Gets the current camera position. */ - getCameraPosition(): Point3D { - const currPos = this.camera.position; - return [currPos.x, currPos.y, currPos.z]; - } - - /** Gets the current camera target. */ - getCameraTarget(): Point3D { - let currTarget = this.orbitCameraControls.target; - return [currTarget.x, currTarget.y, currTarget.z]; - } - - /** Sets up the camera from given position and target coordinates. */ - setCameraPositionAndTarget(position: Point3D, target: Point3D) { - this.stopOrbitAnimation(); - this.camera.position.set(position[0], position[1], position[2]); - this.orbitCameraControls.target.set(target[0], target[1], target[2]); - this.orbitCameraControls.update(); - this.render(); - } - - /** Starts orbiting the camera around its current lookat target. */ - startOrbitAnimation() { - if (!this.sceneIs3D()) { - return; - } - if (this.orbitAnimationId != null) { - this.stopOrbitAnimation(); - } - this.orbitCameraControls.autoRotate = true; - this.orbitCameraControls.rotateSpeed = - ORBIT_ANIMATION_ROTATION_CYCLE_IN_SECONDS; - this.updateOrbitAnimation(); - } - - private updateOrbitAnimation() { - this.orbitCameraControls.update(); - this.orbitAnimationId = - requestAnimationFrame(() => this.updateOrbitAnimation()); - } - - /** Stops the orbiting animation on the camera. */ - stopOrbitAnimation() { - this.orbitCameraControls.autoRotate = false; - this.orbitCameraControls.rotateSpeed = ORBIT_MOUSE_ROTATION_SPEED; - if (this.orbitAnimationId != null) { - cancelAnimationFrame(this.orbitAnimationId); - this.orbitAnimationId = null; - } - } - - /** Adds a visualizer to the set, will start dispatching events to it */ - addVisualizer(visualizer: ScatterPlotVisualizer) { - if (this.scene) { - visualizer.setScene(this.scene); - } - visualizer.onResize(this.width, this.height); - visualizer.onPointPositionsChanged(this.worldSpacePointPositions); - this.visualizers.push(visualizer); - } - - /** Removes all visualizers attached to this scatter plot. */ - removeAllVisualizers() { - this.visualizers.forEach(v => v.dispose()); - this.visualizers = []; - } - - /** Update scatter plot with a new array of packed xyz point positions. */ - setPointPositions(worldSpacePointPositions: Float32Array) { - this.worldSpacePointPositions = worldSpacePointPositions; - this.visualizers.forEach( - v => v.onPointPositionsChanged(worldSpacePointPositions)); - } - - render() { - { - const lightPos = this.camera.position.clone(); - lightPos.x += 1; - lightPos.y += 1; - this.light.position.set(lightPos.x, lightPos.y, lightPos.z); - } - - const cameraType = (this.camera instanceof THREE.PerspectiveCamera) ? - CameraType.Perspective : - CameraType.Orthographic; - - let cameraSpacePointExtents: [number, number] = [0, 0]; - if (this.worldSpacePointPositions != null) { - cameraSpacePointExtents = util.getNearFarPoints( - this.worldSpacePointPositions, this.camera.position, - this.orbitCameraControls.target); - } - - const rc = new RenderContext( - this.camera, cameraType, this.orbitCameraControls.target, this.width, - this.height, cameraSpacePointExtents[0], cameraSpacePointExtents[1], - this.backgroundColor, this.pointColors, this.pointScaleFactors, - this.labels, this.polylineColors, this.polylineOpacities, - this.polylineWidths); - - // Render first pass to picking target. This render fills pickingTexture - // with colors that are actually point ids, so that sampling the texture at - // the mouse's current x,y coordinates will reveal the data point that the - // mouse is over. - this.visualizers.forEach(v => v.onPickingRender(rc)); - - { - const axes = this.remove3dAxisFromScene(); - this.renderer.render(this.scene, this.camera, this.pickingTexture); - if (axes != null) { - this.scene.add(axes); - } - } - - // Render second pass to color buffer, to be displayed on the canvas. - this.visualizers.forEach(v => v.onRender(rc)); - - this.renderer.render(this.scene, this.camera); - } - - setMouseMode(mouseMode: MouseMode) { - this.mouseMode = mouseMode; - if (mouseMode === MouseMode.AREA_SELECT) { - this.selecting = true; - this.container.style.cursor = 'crosshair'; - } else { - this.selecting = false; - this.container.style.cursor = 'default'; - } - } - - /** Set the colors for every data point. (RGB triplets) */ - setPointColors(colors: Float32Array) { - this.pointColors = colors; - } - - /** Set the scale factors for every data point. (scalars) */ - setPointScaleFactors(scaleFactors: Float32Array) { - this.pointScaleFactors = scaleFactors; - } - - /** Set the labels to rendered */ - setLabels(labels: LabelRenderParams) { - this.labels = labels; - } - - /** Set the colors for every data polyline. (RGB triplets) */ - setPolylineColors(colors: {[polylineIndex: number]: Float32Array}) { - this.polylineColors = colors; - } - - setPolylineOpacities(opacities: Float32Array) { - this.polylineOpacities = opacities; - } - - setPolylineWidths(widths: Float32Array) { - this.polylineWidths = widths; - } - - getMouseMode(): MouseMode { - return this.mouseMode; - } - - resetZoom() { - this.recreateCamera(this.makeDefaultCameraDef(this.dimensionality)); - this.render(); - } - - setDayNightMode(isNight: boolean) { - const canvases = this.container.querySelectorAll('canvas'); - const filterValue = isNight ? 'invert(100%)' : null; - for (let i = 0; i < canvases.length; i++) { - canvases[i].style.filter = filterValue; - } - } - - resize(render = true) { - const [oldW, oldH] = [this.width, this.height]; - const [newW, newH] = this.getLayoutValues(); - - if (this.dimensionality === 3) { - const camera = (this.camera as THREE.PerspectiveCamera); - camera.aspect = newW / newH; - camera.updateProjectionMatrix(); - } else { - const camera = (this.camera as THREE.OrthographicCamera); - // Scale the ortho frustum by however much the window changed. - const scaleW = newW / oldW; - const scaleH = newH / oldH; - const newCamHalfWidth = ((camera.right - camera.left) * scaleW) / 2; - const newCamHalfHeight = ((camera.top - camera.bottom) * scaleH) / 2; - camera.top = newCamHalfHeight; - camera.bottom = -newCamHalfHeight; - camera.left = -newCamHalfWidth; - camera.right = newCamHalfWidth; - camera.updateProjectionMatrix(); - } - - // Accouting for retina displays. - const dpr = window.devicePixelRatio || 1; - this.renderer.setPixelRatio(dpr); - this.renderer.setSize(newW, newH); - - // the picking texture needs to be exactly the same as the render texture. - { - const renderCanvasSize = this.renderer.getSize(); - const pixelRatio = this.renderer.getPixelRatio(); - this.pickingTexture = new THREE.WebGLRenderTarget( - renderCanvasSize.width * pixelRatio, - renderCanvasSize.height * pixelRatio); - this.pickingTexture.texture.minFilter = THREE.LinearFilter; - } - - this.visualizers.forEach(v => v.onResize(newW, newH)); - - if (render) { - this.render(); - }; - } - - onCameraMove(listener: OnCameraMoveListener) { - this.onCameraMoveListeners.push(listener); - } - - clickOnPoint(pointIndex: number) { - this.nearestPoint = pointIndex; - this.onClick(null, false); - } -} diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlotRectangleSelector.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlotRectangleSelector.ts deleted file mode 100644 index a781877014e..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/scatterPlotRectangleSelector.ts +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -const FILL = '#dddddd'; -const FILL_OPACITY = .2; -const STROKE = '#aaaaaa'; -const STROKE_WIDTH = 2; -const STROKE_DASHARRAY = '10 5'; - -export interface BoundingBox { - // The bounding box (x, y) position refers to the bottom left corner of the - // rect. - x: number; - y: number; - width: number; - height: number; -} - -/** - * A class that manages and renders a data selection rectangle. - */ -export class ScatterPlotRectangleSelector { - private svgElement: SVGElement; - private rectElement: SVGRectElement; - - private isMouseDown: boolean; - private startCoordinates: [number, number]; - private lastBoundingBox: BoundingBox; - - private selectionCallback: (boundingBox: BoundingBox) => void; - - /** - * @param container The container HTML element that the selection SVG rect - * will be a child of. - * @param selectionCallback The callback that accepts a bounding box to be - * called when selection changes. Currently, we only call the callback on - * mouseUp. - */ - constructor( - container: HTMLElement, - selectionCallback: (boundingBox: BoundingBox) => void) { - this.svgElement = container.querySelector('#selector') as SVGElement; - this.rectElement = - document.createElementNS('http://www.w3.org/2000/svg', 'rect'); - this.rectElement.style.stroke = STROKE; - this.rectElement.style.strokeDasharray = STROKE_DASHARRAY; - this.rectElement.style.strokeWidth = '' + STROKE_WIDTH; - this.rectElement.style.fill = FILL; - this.rectElement.style.fillOpacity = '' + FILL_OPACITY; - this.svgElement.appendChild(this.rectElement); - - this.selectionCallback = selectionCallback; - this.isMouseDown = false; - } - - onMouseDown(offsetX: number, offsetY: number) { - this.isMouseDown = true; - this.rectElement.style.display = 'block'; - - this.startCoordinates = [offsetX, offsetY]; - this.lastBoundingBox = { - x: this.startCoordinates[0], - y: this.startCoordinates[1], - width: 1, - height: 1 - }; - } - - onMouseMove(offsetX: number, offsetY: number) { - if (!this.isMouseDown) { - return; - } - - this.lastBoundingBox.x = Math.min(offsetX, this.startCoordinates[0]); - this.lastBoundingBox.y = Math.max(offsetY, this.startCoordinates[1]); - this.lastBoundingBox.width = - Math.max(offsetX, this.startCoordinates[0]) - this.lastBoundingBox.x; - this.lastBoundingBox.height = - this.lastBoundingBox.y - Math.min(offsetY, this.startCoordinates[1]); - - this.rectElement.setAttribute('x', '' + this.lastBoundingBox.x); - this.rectElement.setAttribute( - 'y', '' + (this.lastBoundingBox.y - this.lastBoundingBox.height)); - this.rectElement.setAttribute('width', '' + this.lastBoundingBox.width); - this.rectElement.setAttribute('height', '' + this.lastBoundingBox.height); - } - - onMouseUp() { - this.isMouseDown = false; - this.rectElement.style.display = 'none'; - this.rectElement.setAttribute('width', '0'); - this.rectElement.setAttribute('height', '0'); - this.selectionCallback(this.lastBoundingBox); - } -} diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer.ts deleted file mode 100644 index b0974a20538..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer.ts +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {RenderContext} from './renderContext'; - -/** - * ScatterPlotVisualizer is an interface used by ScatterPlotContainer - * to manage and aggregate any number of concurrent visualization behaviors. - * To add a new visualization to the 3D scatter plot, create a new class that - * implements this interface and attach it to the ScatterPlotContainer. - */ -export interface ScatterPlotVisualizer { - /** Called to initialize the visualizer with the primary scene. */ - setScene(scene: THREE.Scene); - /** - * Called when the main scatter plot tears down the visualizer. Remove all - * objects from the scene, and dispose any heavy resources. - */ - dispose(); - /** - * Called when the positions of the scatter plot points have changed. - */ - onPointPositionsChanged(newWorldSpacePointPositions: Float32Array); - /** - * Called immediately before the main scatter plot performs a picking - * (selection) render. Set up render state for any geometry to use picking IDs - * instead of visual colors. - */ - onPickingRender(renderContext: RenderContext); - /** - * Called immediately before the main scatter plot performs a color (visual) - * render. Set up render state, lights, etc here. - */ - onRender(renderContext: RenderContext); - /** - * Called when the canvas size changes. - */ - onResize(newWidth: number, newHeight: number); -} diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer3DLabels.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer3DLabels.ts deleted file mode 100644 index 7820af0d48d..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizer3DLabels.ts +++ /dev/null @@ -1,367 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {RenderContext} from './renderContext'; -import {ScatterPlotVisualizer} from './scatterPlotVisualizer'; -import * as util from './util'; - -const FONT_SIZE = 80; -const ONE_OVER_FONT_SIZE = 1 / FONT_SIZE; -const LABEL_SCALE = 2.2; // at 1:1 texel/pixel ratio -const LABEL_COLOR = 'black'; -const LABEL_BACKGROUND = 'white'; -const MAX_CANVAS_DIMENSION = 8192; -const NUM_GLYPHS = 256; -const RGB_ELEMENTS_PER_ENTRY = 3; -const XYZ_ELEMENTS_PER_ENTRY = 3; -const UV_ELEMENTS_PER_ENTRY = 2; -const VERTICES_PER_GLYPH = 2 * 3; // 2 triangles, 3 verts per triangle - -/** - * Each label is made up of triangles (two per letter.) Each vertex, then, is - * the corner of one of these triangles (and thus the corner of a letter - * rectangle.) - * Each has the following attributes: - * posObj: The (x, y) position of the vertex within the label, where the - * bottom center of the word is positioned at (0, 0); - * position: The position of the label in worldspace. - * vUv: The (u, v) coordinates that index into the glyphs sheet (range 0, 1.) - * color: The color of the label (matches the corresponding point's color.) - * wordShown: Boolean. Whether or not the label is visible. - */ - -const VERTEX_SHADER = ` - attribute vec2 posObj; - attribute vec3 color; - varying vec2 vUv; - varying vec3 vColor; - - void main() { - vUv = uv; - vColor = color; - - // Rotate label to face camera. - - vec4 vRight = vec4( - modelViewMatrix[0][0], modelViewMatrix[1][0], modelViewMatrix[2][0], 0); - - vec4 vUp = vec4( - modelViewMatrix[0][1], modelViewMatrix[1][1], modelViewMatrix[2][1], 0); - - vec4 vAt = -vec4( - modelViewMatrix[0][2], modelViewMatrix[1][2], modelViewMatrix[2][2], 0); - - mat4 pointToCamera = mat4(vRight, vUp, vAt, vec4(0, 0, 0, 1)); - - vec2 scaledPos = posObj * ${ONE_OVER_FONT_SIZE} * ${LABEL_SCALE}; - - vec4 posRotated = pointToCamera * vec4(scaledPos, 0, 1); - vec4 mvPosition = modelViewMatrix * (vec4(position, 0) + posRotated); - gl_Position = projectionMatrix * mvPosition; - }`; - -const FRAGMENT_SHADER = ` - uniform sampler2D texture; - uniform bool picking; - varying vec2 vUv; - varying vec3 vColor; - - void main() { - if (picking) { - gl_FragColor = vec4(vColor, 1.0); - } else { - vec4 fromTexture = texture2D(texture, vUv); - gl_FragColor = vec4(vColor, 1.0) * fromTexture; - } - }`; - -type GlyphTexture = { - texture: THREE.Texture; lengths: Float32Array; offsets: Float32Array; -}; - -/** - * Renders the text labels as 3d geometry in the world. - */ -export class ScatterPlotVisualizer3DLabels implements ScatterPlotVisualizer { - private scene: THREE.Scene; - private labelStrings: string[]; - private geometry: THREE.BufferGeometry; - private worldSpacePointPositions: Float32Array; - private pickingColors: Float32Array; - private renderColors: Float32Array; - private material: THREE.ShaderMaterial; - private uniforms: Object; - private labelsMesh: THREE.Mesh; - private positions: THREE.BufferAttribute; - private totalVertexCount: number; - private labelVertexMap: number[][]; - private glyphTexture: GlyphTexture; - - private createGlyphTexture(): GlyphTexture { - let canvas = document.createElement('canvas'); - canvas.width = MAX_CANVAS_DIMENSION; - canvas.height = FONT_SIZE; - let ctx = canvas.getContext('2d'); - ctx.font = 'bold ' + FONT_SIZE * 0.75 + 'px roboto'; - ctx.textBaseline = 'top'; - ctx.fillStyle = LABEL_BACKGROUND; - ctx.rect(0, 0, canvas.width, canvas.height); - ctx.fill(); - ctx.fillStyle = LABEL_COLOR; - let spaceOffset = ctx.measureText(' ').width; - // For each letter, store length, position at the encoded index. - let glyphLengths = new Float32Array(NUM_GLYPHS); - let glyphOffset = new Float32Array(NUM_GLYPHS); - let leftCoord = 0; - for (let i = 0; i < NUM_GLYPHS; i++) { - let text = ' ' + String.fromCharCode(i); - let textLength = ctx.measureText(text).width; - glyphLengths[i] = textLength - spaceOffset; - glyphOffset[i] = leftCoord; - ctx.fillText(text, leftCoord - spaceOffset, 0); - leftCoord += textLength; - } - const tex = util.createTexture(canvas); - return {texture: tex, lengths: glyphLengths, offsets: glyphOffset}; - } - - private processLabelVerts(pointCount: number) { - let numTotalLetters = 0; - this.labelVertexMap = []; - for (let i = 0; i < pointCount; i++) { - const label = this.labelStrings[i]; - let vertsArray: number[] = []; - for (let j = 0; j < label.length; j++) { - for (let k = 0; k < VERTICES_PER_GLYPH; k++) { - vertsArray.push(numTotalLetters * VERTICES_PER_GLYPH + k); - } - numTotalLetters++; - } - this.labelVertexMap.push(vertsArray); - } - this.totalVertexCount = numTotalLetters * VERTICES_PER_GLYPH; - } - - private createColorBuffers(pointCount: number) { - this.pickingColors = - new Float32Array(this.totalVertexCount * RGB_ELEMENTS_PER_ENTRY); - this.renderColors = - new Float32Array(this.totalVertexCount * RGB_ELEMENTS_PER_ENTRY); - for (let i = 0; i < pointCount; i++) { - let color = new THREE.Color(i); - this.labelVertexMap[i].forEach((j) => { - this.pickingColors[RGB_ELEMENTS_PER_ENTRY * j] = color.r; - this.pickingColors[RGB_ELEMENTS_PER_ENTRY * j + 1] = color.g; - this.pickingColors[RGB_ELEMENTS_PER_ENTRY * j + 2] = color.b; - this.renderColors[RGB_ELEMENTS_PER_ENTRY * j] = 1.0; - this.renderColors[RGB_ELEMENTS_PER_ENTRY * j + 1] = 1.0; - this.renderColors[RGB_ELEMENTS_PER_ENTRY * j + 2] = 1.0; - }); - } - } - - private createLabels() { - if ((this.labelStrings == null) || - (this.worldSpacePointPositions == null)) { - return; - } - const pointCount = - this.worldSpacePointPositions.length / XYZ_ELEMENTS_PER_ENTRY; - if (pointCount !== this.labelStrings.length) { - return; - } - this.glyphTexture = this.createGlyphTexture(); - - this.uniforms = { - texture: {type: 't'}, - picking: {type: 'bool'}, - }; - - this.material = new THREE.ShaderMaterial({ - uniforms: this.uniforms, - transparent: true, - vertexShader: VERTEX_SHADER, - fragmentShader: FRAGMENT_SHADER, - }); - - this.processLabelVerts(pointCount); - this.createColorBuffers(pointCount); - - let positionArray = - new Float32Array(this.totalVertexCount * XYZ_ELEMENTS_PER_ENTRY); - this.positions = - new THREE.BufferAttribute(positionArray, XYZ_ELEMENTS_PER_ENTRY); - - let posArray = - new Float32Array(this.totalVertexCount * XYZ_ELEMENTS_PER_ENTRY); - let uvArray = - new Float32Array(this.totalVertexCount * UV_ELEMENTS_PER_ENTRY); - let colorsArray = - new Float32Array(this.totalVertexCount * RGB_ELEMENTS_PER_ENTRY); - let positionObject = new THREE.BufferAttribute(posArray, 2); - let uv = new THREE.BufferAttribute(uvArray, UV_ELEMENTS_PER_ENTRY); - let colors = new THREE.BufferAttribute(colorsArray, RGB_ELEMENTS_PER_ENTRY); - - this.geometry = new THREE.BufferGeometry(); - this.geometry.addAttribute('posObj', positionObject); - this.geometry.addAttribute('position', this.positions); - this.geometry.addAttribute('uv', uv); - this.geometry.addAttribute('color', colors); - - let lettersSoFar = 0; - for (let i = 0; i < pointCount; i++) { - const label = this.labelStrings[i]; - let leftOffset = 0; - // Determine length of word in pixels. - for (let j = 0; j < label.length; j++) { - let letterCode = label.charCodeAt(j); - leftOffset += this.glyphTexture.lengths[letterCode]; - } - leftOffset /= -2; // centers text horizontally around the origin - for (let j = 0; j < label.length; j++) { - let letterCode = label.charCodeAt(j); - let letterWidth = this.glyphTexture.lengths[letterCode]; - let scale = FONT_SIZE; - let right = (leftOffset + letterWidth) / scale; - let left = (leftOffset) / scale; - let top = FONT_SIZE / scale; - - // First triangle - positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 0, left, 0); - positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 1, right, 0); - positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 2, left, top); - - // Second triangle - positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 3, left, top); - positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 4, right, 0); - positionObject.setXY(lettersSoFar * VERTICES_PER_GLYPH + 5, right, top); - - // Set UVs based on letter. - let uLeft = (this.glyphTexture.offsets[letterCode]); - let uRight = (this.glyphTexture.offsets[letterCode] + letterWidth); - // Scale so that uvs lie between 0 and 1 on the texture. - uLeft /= MAX_CANVAS_DIMENSION; - uRight /= MAX_CANVAS_DIMENSION; - let vTop = 1; - let vBottom = 0; - uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 0, uLeft, vTop); - uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 1, uRight, vTop); - uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 2, uLeft, vBottom); - uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 3, uLeft, vBottom); - uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 4, uRight, vTop); - uv.setXY(lettersSoFar * VERTICES_PER_GLYPH + 5, uRight, vBottom); - - lettersSoFar++; - leftOffset += letterWidth; - } - } - - for (let i = 0; i < pointCount; i++) { - const p = util.vector3FromPackedArray(this.worldSpacePointPositions, i); - this.labelVertexMap[i].forEach((j) => { - this.positions.setXYZ(j, p.x, p.y, p.z); - }); - }; - - this.labelsMesh = new THREE.Mesh(this.geometry, this.material); - this.labelsMesh.frustumCulled = false; - this.scene.add(this.labelsMesh); - } - - private colorLabels(pointColors: Float32Array) { - if (this.labelStrings == null || this.geometry == null || - pointColors == null) { - return; - } - - const colors = this.geometry.getAttribute('color') as THREE.BufferAttribute; - colors.array = this.renderColors; - - const n = pointColors.length / XYZ_ELEMENTS_PER_ENTRY; - let src = 0; - for (let i = 0; i < n; ++i) { - const c = new THREE.Color( - pointColors[src], pointColors[src + 1], pointColors[src + 2]); - const m = this.labelVertexMap[i].length; - for (let j = 0; j < m; ++j) { - colors.setXYZ(this.labelVertexMap[i][j], c.r, c.g, c.b); - } - src += RGB_ELEMENTS_PER_ENTRY; - } - colors.needsUpdate = true; - } - - setScene(scene: THREE.Scene) { - this.scene = scene; - } - - dispose() { - if (this.labelsMesh) { - if (this.scene) { - this.scene.remove(this.labelsMesh); - } - this.labelsMesh = null; - } - if (this.geometry) { - this.geometry.dispose(); - this.geometry = null; - } - if ((this.glyphTexture != null) && (this.glyphTexture.texture != null)) { - this.glyphTexture.texture.dispose(); - this.glyphTexture.texture = null; - } - } - - onPickingRender(rc: RenderContext) { - if (this.geometry == null) { - this.createLabels(); - } - if (this.geometry == null) { - return; - } - this.material.uniforms.texture.value = this.glyphTexture.texture; - this.material.uniforms.picking.value = true; - const colors = this.geometry.getAttribute('color') as THREE.BufferAttribute; - colors.array = this.pickingColors; - colors.needsUpdate = true; - } - - onRender(rc: RenderContext) { - if (this.geometry == null) { - this.createLabels(); - } - if (this.geometry == null) { - return; - } - this.colorLabels(rc.pointColors); - this.material.uniforms.texture.value = this.glyphTexture.texture; - this.material.uniforms.picking.value = false; - const colors = this.geometry.getAttribute('color') as THREE.BufferAttribute; - colors.array = this.renderColors; - colors.needsUpdate = true; - } - - onPointPositionsChanged(newPositions: Float32Array) { - this.worldSpacePointPositions = newPositions; - this.dispose(); - } - - setLabelStrings(labelStrings: string[]) { - this.labelStrings = labelStrings; - this.dispose(); - } - - onResize(newWidth: number, newHeight: number) {} -} diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerCanvasLabels.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerCanvasLabels.ts deleted file mode 100644 index 2f3146d213c..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerCanvasLabels.ts +++ /dev/null @@ -1,186 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {BoundingBox, CollisionGrid} from './label'; -import {CameraType, RenderContext} from './renderContext'; -import {ScatterPlotVisualizer} from './scatterPlotVisualizer'; -import * as util from './util'; - -const MAX_LABELS_ON_SCREEN = 10000; -const LABEL_STROKE_WIDTH = 3; -const LABEL_FILL_WIDTH = 6; - -/** - * Creates and maintains a 2d canvas on top of the GL canvas. All labels, when - * active, are rendered to the 2d canvas as part of the visible render pass. - */ -export class ScatterPlotVisualizerCanvasLabels implements - ScatterPlotVisualizer { - private worldSpacePointPositions: Float32Array; - private gc: CanvasRenderingContext2D; - private canvas: HTMLCanvasElement; - private labelsActive: boolean = true; - - constructor(container: HTMLElement) { - this.canvas = document.createElement('canvas'); - container.appendChild(this.canvas); - - this.gc = this.canvas.getContext('2d'); - this.canvas.style.position = 'absolute'; - this.canvas.style.left = '0'; - this.canvas.style.top = '0'; - this.canvas.style.pointerEvents = 'none'; - } - - private removeAllLabels() { - const pixelWidth = this.canvas.width * window.devicePixelRatio; - const pixelHeight = this.canvas.height * window.devicePixelRatio; - this.gc.clearRect(0, 0, pixelWidth, pixelHeight); - } - - /** Render all of the non-overlapping visible labels to the canvas. */ - private makeLabels(rc: RenderContext) { - if ((rc.labels == null) || (rc.labels.pointIndices.length === 0)) { - return; - } - if (this.worldSpacePointPositions == null) { - return; - } - - const lrc = rc.labels; - const sceneIs3D: boolean = (rc.cameraType === CameraType.Perspective); - const labelHeight = parseInt(this.gc.font, 10); - const dpr = window.devicePixelRatio; - - let grid: CollisionGrid; - { - const pixw = this.canvas.width * dpr; - const pixh = this.canvas.height * dpr; - const bb: BoundingBox = {loX: 0, hiX: pixw, loY: 0, hiY: pixh}; - grid = new CollisionGrid(bb, pixw / 25, pixh / 50); - } - - let opacityMap = - d3.scalePow() - .exponent(Math.E) - .domain([rc.farthestCameraSpacePointZ, rc.nearestCameraSpacePointZ]) - .range([0.1, 1]); - - const camPos = rc.camera.position; - const camToTarget = camPos.clone().sub(rc.cameraTarget); - let camToPoint = new THREE.Vector3(); - - this.gc.textBaseline = 'middle'; - this.gc.miterLimit = 2; - - // Have extra space between neighboring labels. Don't pack too tightly. - const labelMargin = 2; - // Shift the label to the right of the point circle. - const xShift = 4; - - const n = Math.min(MAX_LABELS_ON_SCREEN, lrc.pointIndices.length); - for (let i = 0; i < n; ++i) { - let point: THREE.Vector3; - { - const pi = lrc.pointIndices[i]; - point = util.vector3FromPackedArray(this.worldSpacePointPositions, pi); - } - - // discard points that are behind the camera - camToPoint.copy(camPos).sub(point); - if (camToTarget.dot(camToPoint) < 0) { - continue; - } - - let [x, y] = util.vector3DToScreenCoords( - rc.camera, rc.screenWidth, rc.screenHeight, point); - x += xShift; - - // Computing the width of the font is expensive, - // so we assume width of 1 at first. Then, if the label doesn't - // conflict with other labels, we measure the actual width. - const textBoundingBox: BoundingBox = { - loX: x - labelMargin, - hiX: x + 1 + labelMargin, - loY: y - labelHeight / 2 - labelMargin, - hiY: y + labelHeight / 2 + labelMargin - }; - - if (grid.insert(textBoundingBox, true)) { - const text = lrc.labelStrings[i]; - const fontSize = lrc.defaultFontSize * lrc.scaleFactors[i] * dpr; - this.gc.font = fontSize + 'px roboto'; - - // Now, check with properly computed width. - textBoundingBox.hiX += this.gc.measureText(text).width - 1; - if (grid.insert(textBoundingBox)) { - let opacity = 1; - if (sceneIs3D && (lrc.useSceneOpacityFlags[i] === 1)) { - opacity = opacityMap(camToPoint.length()); - } - this.gc.fillStyle = - this.styleStringFromPackedRgba(lrc.fillColors, i, opacity); - this.gc.strokeStyle = - this.styleStringFromPackedRgba(lrc.strokeColors, i, opacity); - this.gc.lineWidth = LABEL_STROKE_WIDTH; - this.gc.strokeText(text, x, y); - this.gc.lineWidth = LABEL_FILL_WIDTH; - this.gc.fillText(text, x, y); - } - } - } - } - - private styleStringFromPackedRgba( - packedRgbaArray: Uint8Array, colorIndex: number, - opacity: number): string { - const offset = colorIndex * 3; - const r = packedRgbaArray[offset]; - const g = packedRgbaArray[offset + 1]; - const b = packedRgbaArray[offset + 2]; - return 'rgba(' + r + ',' + g + ',' + b + ',' + opacity + ')'; - } - - onResize(newWidth: number, newHeight: number) { - let dpr = window.devicePixelRatio; - this.canvas.width = newWidth * dpr; - this.canvas.height = newHeight * dpr; - this.canvas.style.width = newWidth + 'px'; - this.canvas.style.height = newHeight + 'px'; - } - - dispose() { - this.removeAllLabels(); - this.canvas = null; - this.gc = null; - } - - onPointPositionsChanged(newPositions: Float32Array) { - this.worldSpacePointPositions = newPositions; - this.removeAllLabels(); - } - - onRender(rc: RenderContext) { - if (!this.labelsActive) { - return; - } - - this.removeAllLabels(); - this.makeLabels(rc); - } - - setScene(scene: THREE.Scene) {} - onPickingRender(renderContext: RenderContext) {} -} diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerPolylines.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerPolylines.ts deleted file mode 100644 index e6d4aeda28b..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerPolylines.ts +++ /dev/null @@ -1,149 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {DataSet} from './data'; -import {RenderContext} from './renderContext'; -import {ScatterPlotVisualizer} from './scatterPlotVisualizer'; -import * as util from './util'; - -const RGB_NUM_ELEMENTS = 3; -const XYZ_NUM_ELEMENTS = 3; - -/** - * Renders polylines that connect multiple points in the dataset. - */ -export class ScatterPlotVisualizerPolylines implements ScatterPlotVisualizer { - private dataSet: DataSet; - private scene: THREE.Scene; - private polylines: THREE.Line[]; - private polylinePositionBuffer: - {[polylineIndex: number]: THREE.BufferAttribute} = {}; - private polylineColorBuffer: - {[polylineIndex: number]: THREE.BufferAttribute} = {}; - - private updateSequenceIndicesInDataSet(ds: DataSet) { - for (let i = 0; i < ds.sequences.length; i++) { - const sequence = ds.sequences[i]; - for (let j = 0; j < sequence.pointIndices.length - 1; j++) { - ds.points[sequence.pointIndices[j]].sequenceIndex = i; - ds.points[sequence.pointIndices[j + 1]].sequenceIndex = i; - } - } - } - - private createPolylines(scene: THREE.Scene) { - if (!this.dataSet || !this.dataSet.sequences) { - return; - } - - this.updateSequenceIndicesInDataSet(this.dataSet); - this.polylines = []; - - for (let i = 0; i < this.dataSet.sequences.length; i++) { - const geometry = new THREE.BufferGeometry(); - geometry.addAttribute('position', this.polylinePositionBuffer[i]); - geometry.addAttribute('color', this.polylineColorBuffer[i]); - - const material = new THREE.LineBasicMaterial({ - linewidth: 1, // unused default, overwritten by width array. - opacity: 1.0, // unused default, overwritten by opacity array. - transparent: true, - vertexColors: THREE.VertexColors - }); - - const polyline = new THREE.LineSegments(geometry, material); - polyline.frustumCulled = false; - this.polylines.push(polyline); - scene.add(polyline); - } - } - - dispose() { - if (this.polylines == null) { - return; - } - for (let i = 0; i < this.polylines.length; i++) { - this.scene.remove(this.polylines[i]); - this.polylines[i].geometry.dispose(); - } - this.polylines = null; - this.polylinePositionBuffer = {}; - this.polylineColorBuffer = {}; - } - - setScene(scene: THREE.Scene) { - this.scene = scene; - } - - setDataSet(dataSet: DataSet) { - this.dataSet = dataSet; - } - - onPointPositionsChanged(newPositions: Float32Array) { - if ((newPositions == null) || (this.polylines != null)) { - this.dispose(); - } - if ((newPositions == null) || (this.dataSet == null)) { - return; - } - // Set up the position buffer arrays for each polyline. - for (let i = 0; i < this.dataSet.sequences.length; i++) { - let sequence = this.dataSet.sequences[i]; - const vertexCount = 2 * (sequence.pointIndices.length - 1); - - let polylines = new Float32Array(vertexCount * XYZ_NUM_ELEMENTS); - this.polylinePositionBuffer[i] = - new THREE.BufferAttribute(polylines, XYZ_NUM_ELEMENTS); - - let colors = new Float32Array(vertexCount * RGB_NUM_ELEMENTS); - this.polylineColorBuffer[i] = - new THREE.BufferAttribute(colors, RGB_NUM_ELEMENTS); - } - for (let i = 0; i < this.dataSet.sequences.length; i++) { - const sequence = this.dataSet.sequences[i]; - let src = 0; - for (let j = 0; j < sequence.pointIndices.length - 1; j++) { - const p1Index = sequence.pointIndices[j]; - const p2Index = sequence.pointIndices[j + 1]; - const p1 = util.vector3FromPackedArray(newPositions, p1Index); - const p2 = util.vector3FromPackedArray(newPositions, p2Index); - this.polylinePositionBuffer[i].setXYZ(src, p1.x, p1.y, p1.z); - this.polylinePositionBuffer[i].setXYZ(src + 1, p2.x, p2.y, p2.z); - src += 2; - } - this.polylinePositionBuffer[i].needsUpdate = true; - } - - if (this.polylines == null) { - this.createPolylines(this.scene); - } - } - - onRender(renderContext: RenderContext) { - if (this.polylines == null) { - return; - } - for (let i = 0; i < this.polylines.length; i++) { - this.polylines[i].material.opacity = renderContext.polylineOpacities[i]; - (this.polylines[i].material as THREE.LineBasicMaterial).linewidth = - renderContext.polylineWidths[i]; - this.polylineColorBuffer[i].array = renderContext.polylineColors[i]; - this.polylineColorBuffer[i].needsUpdate = true; - } - } - - onPickingRender(renderContext: RenderContext) {} - onResize(newWidth: number, newHeight: number) {} -} diff --git a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts b/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts deleted file mode 100644 index be9c1703c72..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/scatterPlotVisualizerSprites.ts +++ /dev/null @@ -1,435 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {CameraType, RenderContext} from './renderContext'; -import {ScatterPlotVisualizer} from './scatterPlotVisualizer'; -import * as util from './util'; - -const NUM_POINTS_FOG_THRESHOLD = 5000; -const MIN_POINT_SIZE = 5.0; -const IMAGE_SIZE = 30; - -// Constants relating to the indices of buffer arrays. -const RGB_NUM_ELEMENTS = 3; -const INDEX_NUM_ELEMENTS = 1; -const XYZ_NUM_ELEMENTS = 3; - -const VERTEX_SHADER = ` - // Index of the specific vertex (passed in as bufferAttribute), and the - // variable that will be used to pass it to the fragment shader. - attribute float spriteIndex; - attribute vec3 color; - attribute float scaleFactor; - - varying vec2 xyIndex; - varying vec3 vColor; - - uniform bool sizeAttenuation; - uniform float pointSize; - uniform float spritesPerRow; - uniform float spritesPerColumn; - - void main() { - // Pass index and color values to fragment shader. - vColor = color; - xyIndex = vec2(mod(spriteIndex, spritesPerRow), - floor(spriteIndex / spritesPerColumn)); - - // Transform current vertex by modelViewMatrix (model world position and - // camera world position matrix). - vec4 cameraSpacePos = modelViewMatrix * vec4(position, 1.0); - - // Project vertex in camera-space to screen coordinates using the camera's - // projection matrix. - gl_Position = projectionMatrix * cameraSpacePos; - - // Create size attenuation (if we're in 3D mode) by making the size of - // each point inversly proportional to its distance to the camera. - float outputPointSize = pointSize; - if (sizeAttenuation) { - outputPointSize = -pointSize / cameraSpacePos.z; - } - - gl_PointSize = - max(outputPointSize * scaleFactor, ${MIN_POINT_SIZE.toFixed(1)}); - }`; - -const FRAGMENT_SHADER_POINT_TEST_CHUNK = ` - bool point_in_unit_circle(vec2 spriteCoord) { - vec2 centerToP = spriteCoord - vec2(0.5, 0.5); - return dot(centerToP, centerToP) < (0.5 * 0.5); - } - - bool point_in_unit_equilateral_triangle(vec2 spriteCoord) { - vec3 v0 = vec3(0, 1, 0); - vec3 v1 = vec3(0.5, 0, 0); - vec3 v2 = vec3(1, 1, 0); - vec3 p = vec3(spriteCoord, 0); - float p_in_v0_v1 = cross(v1 - v0, p - v0).z; - float p_in_v1_v2 = cross(v2 - v1, p - v1).z; - return (p_in_v0_v1 > 0.0) && (p_in_v1_v2 > 0.0); - } - - bool point_in_unit_square(vec2 spriteCoord) { - return true; - } -`; - -const FRAGMENT_SHADER = ` - varying vec2 xyIndex; - varying vec3 vColor; - - uniform sampler2D texture; - uniform float spritesPerRow; - uniform float spritesPerColumn; - uniform bool isImage; - - ${THREE.ShaderChunk['common']} - ${THREE.ShaderChunk['fog_pars_fragment']} - ${FRAGMENT_SHADER_POINT_TEST_CHUNK} - - void main() { - if (isImage) { - // Coordinates of the vertex within the entire sprite image. - vec2 coords = - (gl_PointCoord + xyIndex) / vec2(spritesPerRow, spritesPerColumn); - gl_FragColor = vec4(vColor, 1.0) * texture2D(texture, coords); - } else { - bool inside = point_in_unit_circle(gl_PointCoord); - if (!inside) { - discard; - } - gl_FragColor = vec4(vColor, 1); - } - ${THREE.ShaderChunk['fog_fragment']} - }`; - -const FRAGMENT_SHADER_PICKING = ` - varying vec2 xyIndex; - varying vec3 vColor; - uniform bool isImage; - - ${FRAGMENT_SHADER_POINT_TEST_CHUNK} - - void main() { - xyIndex; // Silence 'unused variable' warning. - if (isImage) { - gl_FragColor = vec4(vColor, 1); - } else { - bool inside = point_in_unit_circle(gl_PointCoord); - if (!inside) { - discard; - } - gl_FragColor = vec4(vColor, 1); - } - }`; - -/** - * Uses GL point sprites to render the dataset. - */ -export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer { - private scene: THREE.Scene; - private fog: THREE.Fog; - private texture: THREE.Texture = null; - private standinTextureForPoints: THREE.Texture; - private spritesPerRow: number; - private spritesPerColumn: number; - private spriteDimensions: [number, number]; - private spriteIndexBufferAttribute: THREE.BufferAttribute; - private renderMaterial: THREE.ShaderMaterial; - private pickingMaterial: THREE.ShaderMaterial; - - private points: THREE.Points; - private worldSpacePointPositions: Float32Array; - private pickingColors: Float32Array; - private renderColors: Float32Array; - - constructor() { - this.standinTextureForPoints = - util.createTexture(document.createElement('canvas')); - this.renderMaterial = this.createRenderMaterial(false); - this.pickingMaterial = this.createPickingMaterial(false); - } - - private createTextureFromSpriteAtlas( - spriteAtlas: HTMLImageElement, spriteDimensions: [number, number], - spriteIndices: Float32Array) { - this.texture = util.createTexture(spriteAtlas); - this.spritesPerRow = spriteAtlas.width / spriteDimensions[0]; - this.spritesPerColumn = spriteAtlas.height / spriteDimensions[1]; - this.spriteDimensions = spriteDimensions; - this.spriteIndexBufferAttribute = - new THREE.BufferAttribute(spriteIndices, INDEX_NUM_ELEMENTS); - - if (this.points != null) { - (this.points.geometry as THREE.BufferGeometry) - .addAttribute('spriteIndex', this.spriteIndexBufferAttribute); - } - } - - private createUniforms(): any { - return { - texture: {type: 't'}, - spritesPerRow: {type: 'f'}, - spritesPerColumn: {type: 'f'}, - fogColor: {type: 'c'}, - fogNear: {type: 'f'}, - fogFar: {type: 'f'}, - isImage: {type: 'bool'}, - sizeAttenuation: {type: 'bool'}, - pointSize: {type: 'f'} - }; - } - - private createRenderMaterial(haveImage: boolean): THREE.ShaderMaterial { - const uniforms = this.createUniforms(); - return new THREE.ShaderMaterial({ - uniforms: uniforms, - vertexShader: VERTEX_SHADER, - fragmentShader: FRAGMENT_SHADER, - transparent: !haveImage, - depthTest: haveImage, - depthWrite: haveImage, - fog: true, - blending: THREE.MultiplyBlending, - }); - } - - private createPickingMaterial(haveImage: boolean): THREE.ShaderMaterial { - const uniforms = this.createUniforms(); - return new THREE.ShaderMaterial({ - uniforms: uniforms, - vertexShader: VERTEX_SHADER, - fragmentShader: FRAGMENT_SHADER_PICKING, - transparent: true, - depthTest: true, - depthWrite: true, - fog: false, - blending: THREE.NormalBlending, - }); - } - - /** - * Create points, set their locations and actually instantiate the - * geometry. - */ - private createPointSprites(scene: THREE.Scene, positions: Float32Array) { - const pointCount = - (positions != null) ? (positions.length / XYZ_NUM_ELEMENTS) : 0; - const geometry = this.createGeometry(pointCount); - - this.fog = new THREE.Fog(0xFFFFFF); // unused value, gets overwritten. - - this.points = new THREE.Points(geometry, this.renderMaterial); - this.points.frustumCulled = false; - if (this.spriteIndexBufferAttribute != null) { - (this.points.geometry as THREE.BufferGeometry) - .addAttribute('spriteIndex', this.spriteIndexBufferAttribute); - } - scene.add(this.points); - } - - private calculatePointSize(sceneIs3D: boolean): number { - if (this.texture != null) { - return sceneIs3D ? IMAGE_SIZE : this.spriteDimensions[0]; - } - const n = (this.worldSpacePointPositions != null) ? - (this.worldSpacePointPositions.length / XYZ_NUM_ELEMENTS) : - 1; - const SCALE = 200; - const LOG_BASE = 8; - const DIVISOR = 1.5; - // Scale point size inverse-logarithmically to the number of points. - const pointSize = SCALE / Math.log(n) / Math.log(LOG_BASE); - return sceneIs3D ? pointSize : (pointSize / DIVISOR); - } - - /** - * Set up buffer attributes to be used for the points/images. - */ - private createGeometry(pointCount: number): THREE.BufferGeometry { - const n = pointCount; - - // Fill pickingColors with each point's unique id as its color. - this.pickingColors = new Float32Array(n * RGB_NUM_ELEMENTS); - { - let dst = 0; - for (let i = 0; i < n; i++) { - const c = new THREE.Color(i); - this.pickingColors[dst++] = c.r; - this.pickingColors[dst++] = c.g; - this.pickingColors[dst++] = c.b; - } - } - - const geometry = new THREE.BufferGeometry(); - geometry.addAttribute( - 'position', new THREE.BufferAttribute(null, XYZ_NUM_ELEMENTS)); - geometry.addAttribute( - 'color', new THREE.BufferAttribute(null, RGB_NUM_ELEMENTS)); - geometry.addAttribute( - 'scaleFactor', new THREE.BufferAttribute(null, INDEX_NUM_ELEMENTS)); - return geometry; - } - - private setFogDistances( - sceneIs3D: boolean, nearestPointZ: number, farthestPointZ: number) { - if (sceneIs3D) { - const n = this.worldSpacePointPositions.length / XYZ_NUM_ELEMENTS; - this.fog.near = nearestPointZ; - // If there are fewer points we want less fog. We do this - // by making the "far" value (that is, the distance from the camera to the - // far edge of the fog) proportional to the number of points. - let multiplier = - 2 - Math.min(n, NUM_POINTS_FOG_THRESHOLD) / NUM_POINTS_FOG_THRESHOLD; - this.fog.far = farthestPointZ * multiplier; - } else { - this.fog.near = Infinity; - this.fog.far = Infinity; - } - } - - dispose() { - this.disposeGeometry(); - this.disposeTextureAtlas(); - } - - private disposeGeometry() { - if (this.points != null) { - this.scene.remove(this.points); - this.points.geometry.dispose(); - this.points = null; - this.worldSpacePointPositions = null; - } - } - - private disposeTextureAtlas() { - if (this.texture != null) { - this.texture.dispose(); - } - this.texture = null; - this.renderMaterial = null; - this.pickingMaterial = null; - } - - setScene(scene: THREE.Scene) { - this.scene = scene; - } - - setSpriteAtlas( - spriteImage: HTMLImageElement, spriteDimensions: [number, number], - spriteIndices: Float32Array) { - this.disposeTextureAtlas(); - this.createTextureFromSpriteAtlas( - spriteImage, spriteDimensions, spriteIndices); - this.renderMaterial = this.createRenderMaterial(true); - this.pickingMaterial = this.createPickingMaterial(true); - } - - clearSpriteAtlas() { - this.disposeTextureAtlas(); - this.renderMaterial = this.createRenderMaterial(false); - this.pickingMaterial = this.createPickingMaterial(false); - } - - onPointPositionsChanged(newPositions: Float32Array) { - if ((newPositions == null) || (newPositions.length === 0)) { - this.dispose(); - return; - } - if (this.points != null) { - if (this.worldSpacePointPositions.length !== newPositions.length) { - this.disposeGeometry(); - } - } - - this.worldSpacePointPositions = newPositions; - - if (this.points == null) { - this.createPointSprites(this.scene, newPositions); - } - - const positions = (this.points.geometry as THREE.BufferGeometry) - .getAttribute('position') as THREE.BufferAttribute; - positions.array = newPositions; - positions.needsUpdate = true; - } - - onPickingRender(rc: RenderContext) { - if (this.points == null) { - return; - } - - const sceneIs3D: boolean = (rc.cameraType === CameraType.Perspective); - - this.pickingMaterial.uniforms.spritesPerRow.value = this.spritesPerRow; - this.pickingMaterial.uniforms.spritesPerRow.value = this.spritesPerColumn; - this.pickingMaterial.uniforms.sizeAttenuation.value = sceneIs3D; - this.pickingMaterial.uniforms.pointSize.value = - this.calculatePointSize(sceneIs3D); - this.points.material = this.pickingMaterial; - - let colors = (this.points.geometry as THREE.BufferGeometry) - .getAttribute('color') as THREE.BufferAttribute; - colors.array = this.pickingColors; - colors.needsUpdate = true; - - let scaleFactors = - (this.points.geometry as THREE.BufferGeometry) - .getAttribute('scaleFactor') as THREE.BufferAttribute; - scaleFactors.array = rc.pointScaleFactors; - scaleFactors.needsUpdate = true; - } - - onRender(rc: RenderContext) { - if (!this.points) { - return; - } - const sceneIs3D: boolean = (rc.camera instanceof THREE.PerspectiveCamera); - - this.setFogDistances( - sceneIs3D, rc.nearestCameraSpacePointZ, rc.farthestCameraSpacePointZ); - - this.scene.fog = this.fog; - this.scene.fog.color = new THREE.Color(rc.backgroundColor); - - this.renderMaterial.uniforms.fogColor.value = this.scene.fog.color; - this.renderMaterial.uniforms.fogNear.value = this.fog.near; - this.renderMaterial.uniforms.fogFar.value = this.fog.far; - this.renderMaterial.uniforms.spritesPerRow.value = this.spritesPerRow; - this.renderMaterial.uniforms.spritesPerColumn.value = this.spritesPerColumn; - this.renderMaterial.uniforms.isImage.value = (this.texture != null); - this.renderMaterial.uniforms.texture.value = - (this.texture != null) ? this.texture : this.standinTextureForPoints; - this.renderMaterial.uniforms.sizeAttenuation.value = sceneIs3D; - this.renderMaterial.uniforms.pointSize.value = - this.calculatePointSize(sceneIs3D); - this.points.material = this.renderMaterial; - - let colors = (this.points.geometry as THREE.BufferGeometry) - .getAttribute('color') as THREE.BufferAttribute; - this.renderColors = rc.pointColors; - colors.array = this.renderColors; - colors.needsUpdate = true; - - let scaleFactors = - (this.points.geometry as THREE.BufferGeometry) - .getAttribute('scaleFactor') as THREE.BufferAttribute; - scaleFactors.array = rc.pointScaleFactors; - scaleFactors.needsUpdate = true; - } - - onResize(newWidth: number, newHeight: number) {} -} diff --git a/tensorflow/tensorboard/components/vz_projector/sptree.ts b/tensorflow/tensorboard/components/vz_projector/sptree.ts deleted file mode 100644 index 991369a3352..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/sptree.ts +++ /dev/null @@ -1,175 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/** N-dimensional point. Usually 2D or 3D. */ -export type Point = number[]; - -export interface BBox { - center: Point; - halfDim: number; -} - -/** A node in a space-partitioning tree. */ -export interface SPNode { - /** The children of this node. */ - children?: SPNode[]; - /** The bounding box of the region this node occupies. */ - box: BBox; - /** One or more points this node has. */ - point: Point; -} - -/** - * A Space-partitioning tree (https://en.wikipedia.org/wiki/Space_partitioning) - * that recursively divides the space into regions of equal sizes. This data - * structure can act both as a Quad tree and an Octree when the data is 2 or - * 3 dimensional respectively. One usage is in t-SNE in order to do Barnes-Hut - * approximation. - */ -export class SPTree { - root: SPNode; - - private masks: number[]; - private dim: number; - - /** - * Constructs a new tree with the provided data. - * - * @param data List of n-dimensional data points. - * @param capacity Number of data points to store in a single node. - */ - constructor(data: Point[]) { - if (data.length < 1) { - throw new Error('There should be at least 1 data point'); - } - // Make a bounding box based on the extent of the data. - this.dim = data[0].length; - // Each node has 2^d children, where d is the dimension of the space. - // Binary masks (e.g. 000, 001, ... 111 in 3D) are used to determine in - // which child (e.g. quadron in 2D) the new point is going to be assigned. - // For more details, see the insert() method and its comments. - this.masks = new Array(Math.pow(2, this.dim)); - for (let d = 0; d < this.masks.length; ++d) { - this.masks[d] = (1 << d); - } - let min: Point = new Array(this.dim); - fillArray(min, Number.POSITIVE_INFINITY); - let max: Point = new Array(this.dim); - fillArray(max, Number.NEGATIVE_INFINITY); - - for (let i = 0; i < data.length; ++i) { - // For each dim get the min and max. - // E.g. For 2-D, get the x_min, x_max, y_min, y_max. - for (let d = 0; d < this.dim; ++d) { - min[d] = Math.min(min[d], data[i][d]); - max[d] = Math.max(max[d], data[i][d]); - } - } - // Create a bounding box with the center of the largest span. - let center: Point = new Array(this.dim); - let halfDim = 0; - for (let d = 0; d < this.dim; ++d) { - let span = max[d] - min[d]; - center[d] = min[d] + span / 2; - halfDim = Math.max(halfDim, span / 2); - } - this.root = {box: {center: center, halfDim: halfDim}, point: data[0]}; - for (let i = 1; i < data.length; ++i) { - this.insert(this.root, data[i]); - } - } - - /** - * Visits every node in the tree. Each node can store 1 or more points, - * depending on the node capacity provided in the constructor. - * - * @param accessor Method that takes the currently visited node, and the - * low and high point of the region that this node occupies. E.g. in 2D, - * the low and high points will be the lower-left corner and the upper-right - * corner. - */ - visit( - accessor: (node: SPNode, lowPoint: Point, highPoint: Point) => boolean, - noBox = false) { - this.visitNode(this.root, accessor, noBox); - } - - private visitNode( - node: SPNode, - accessor: (node: SPNode, lowPoint?: Point, highPoint?: Point) => boolean, - noBox: boolean) { - let skipChildren: boolean; - if (noBox) { - skipChildren = accessor(node); - } else { - let lowPoint = new Array(this.dim); - let highPoint = new Array(this.dim); - for (let d = 0; d < this.dim; ++d) { - lowPoint[d] = node.box.center[d] - node.box.halfDim; - highPoint[d] = node.box.center[d] + node.box.halfDim; - } - skipChildren = accessor(node, lowPoint, highPoint); - } - if (!node.children || skipChildren) { - return; - } - for (let i = 0; i < node.children.length; ++i) { - let child = node.children[i]; - if (child) { - this.visitNode(child, accessor, noBox); - } - } - } - - private insert(node: SPNode, p: Point) { - // Subdivide and then add the point to whichever node will accept it. - if (node.children == null) { - node.children = new Array(this.masks.length); - } - - // Decide which child will get the new point by constructing a D-bits binary - // signature (D=3 for 3D) where the k-th bit is 1 if the point's k-th - // coordinate is greater than the node's k-th coordinate, 0 otherwise. - // Then the binary signature in decimal system gives us the index of the - // child where the new point should be. - let index = 0; - for (let d = 0; d < this.dim; ++d) { - if (p[d] > node.box.center[d]) { - index |= this.masks[d]; - } - } - if (node.children[index] == null) { - this.makeChild(node, index, p); - } else { - this.insert(node.children[index], p); - } - } - - private makeChild(node: SPNode, index: number, p: Point): void { - let oldC = node.box.center; - let h = node.box.halfDim / 2; - let newC: Point = new Array(this.dim); - for (let d = 0; d < this.dim; ++d) { - newC[d] = (index & (1 << d)) ? oldC[d] + h : oldC[d] - h; - } - node.children[index] = {box: {center: newC, halfDim: h}, point: p}; - } -} - -function fillArray(arr: T[], value: T): void { - for (let i = 0; i < arr.length; ++i) { - arr[i] = value; - } -} diff --git a/tensorflow/tensorboard/components/vz_projector/styles.html b/tensorflow/tensorboard/components/vz_projector/styles.html deleted file mode 100644 index 32dc984b5d6..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/styles.html +++ /dev/null @@ -1,185 +0,0 @@ - - - - - diff --git a/tensorflow/tensorboard/components/vz_projector/test/BUILD b/tensorflow/tensorboard/components/vz_projector/test/BUILD deleted file mode 100644 index fc8659f06a3..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/test/BUILD +++ /dev/null @@ -1,36 +0,0 @@ -package( - default_testonly = True, - default_visibility = ["//tensorflow/tensorboard:internal"], -) - -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "test", - srcs = [ - "assert.ts", - "data-provider_test.ts", - "data_test.ts", - "sptree_test.ts", - "tests.html", - "util_test.ts", - # "scatterPlotRectangleSelector_test.ts", - # "vz-projector-projections-panel_test.ts", - ], - path = "/vz-projector/test", - deps = [ - "//tensorflow/tensorboard/components/tf_imports:polymer", - "//tensorflow/tensorboard/components/tf_imports:web_component_tester", - "//tensorflow/tensorboard/components/tf_imports:webcomponentsjs", - "//tensorflow/tensorboard/components/vz_projector", - ], -) - -filegroup( - name = "all_files", - testonly = 0, - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/vz_projector/test/data-provider_test.ts b/tensorflow/tensorboard/components/vz_projector/test/data-provider_test.ts deleted file mode 100644 index 59a42ffbfd8..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/test/data-provider_test.ts +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {DataPoint, SpriteAndMetadataInfo} from '../data'; -import * as data_provider from '../data-provider'; - -/** - * Converts a string to an ArrayBuffer. - */ -function stringToArrayBuffer(str: string): Promise { - return new Promise((resolve, reject) => { - let blob = new Blob([str]); - let file = new FileReader(); - file.onload = (e: any) => { - resolve(e.target.result); - }; - file.readAsArrayBuffer(blob); - }); -} - -/** - * Converts an data array to TSV format. - */ -function dataToTsv(data: string[][]|number[][]) { - let lines = []; - for (let i = 0; i < data.length; i++) { - lines.push(data[i].join('\t')); - } - return lines.join('\n'); -} - -describe('parse tensors', () => { - it('parseTensors', (doneFn) => { - let tensors = [[1.0, 2.0], [2.0, 3.0]]; - stringToArrayBuffer(dataToTsv(tensors)) - .then((tensorsArrayBuffer: ArrayBuffer) => { - data_provider.parseTensors(tensorsArrayBuffer) - .then((data: DataPoint[]) => { - assert.equal(2, data.length); - - assert.deepEqual(new Float32Array(tensors[0]), data[0].vector); - assert.equal(0, data[0].index); - assert.isNull(data[0].projections); - - assert.deepEqual(new Float32Array(tensors[1]), data[1].vector); - assert.equal(1, data[1].index); - assert.isNull(data[1].projections); - doneFn(); - }); - }); - }); - it('parseMetadata', (doneFn) => { - let metadata = [['label', 'fakecol'], ['Г', '0'], ['label1', '1']]; - - stringToArrayBuffer(dataToTsv(metadata)) - .then((metadataArrayBuffer: ArrayBuffer) => { - data_provider.parseMetadata(metadataArrayBuffer) - .then((spriteAndMetadataInfo: SpriteAndMetadataInfo) => { - assert.equal(2, spriteAndMetadataInfo.stats.length); - assert.equal(metadata[0][0], - spriteAndMetadataInfo.stats[0].name); - assert.isFalse(spriteAndMetadataInfo.stats[0].isNumeric); - assert.isFalse( - spriteAndMetadataInfo.stats[0].tooManyUniqueValues); - assert.equal(metadata[0][1], - spriteAndMetadataInfo.stats[1].name); - assert.isTrue(spriteAndMetadataInfo.stats[1].isNumeric); - assert.isFalse( - spriteAndMetadataInfo.stats[1].tooManyUniqueValues); - - assert.equal(2, spriteAndMetadataInfo.pointsInfo.length); - assert.equal(metadata[1][0], - spriteAndMetadataInfo.pointsInfo[0]['label']); - assert.equal(+metadata[1][1], - spriteAndMetadataInfo.pointsInfo[0]['fakecol']); - assert.equal(metadata[2][0], - spriteAndMetadataInfo.pointsInfo[1]['label']); - assert.equal(+metadata[2][1], - spriteAndMetadataInfo.pointsInfo[1]['fakecol']); - doneFn(); - }); - }); - }); -}); diff --git a/tensorflow/tensorboard/components/vz_projector/test/data_test.ts b/tensorflow/tensorboard/components/vz_projector/test/data_test.ts deleted file mode 100644 index 5e47c091c5b..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/test/data_test.ts +++ /dev/null @@ -1,104 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {DataPoint, DataSet, State, stateGetAccessorDimensions} from '../data'; - -/** - * Helper method that makes a list of points given an array of - * sequence indexes. - * - * @param sequences The i-th entry holds the 'next' attribute for the i-th - * point. - */ -function makePointsWithSequences( - sequences: number[], nextAttr = '__seq_next__') { - let points: DataPoint[] = []; - sequences.forEach((t, i) => { - let metadata: {[key: string]: any} = {}; - metadata[nextAttr] = t >= 0 ? t : null; - points.push({ - vector: new Float32Array(0), - metadata: metadata, - projections: {}, - index: i - }); - }); - return points; -} - -describe('constructor_with_sequences', () => { - it('Simple forward pointing sequences, __seq_next__ metadata format', () => { - // The input is: 0->2, 1->None, 2->3, 3->None. This should return - // one sequence 0->2->3. - const points = makePointsWithSequences([2, -1, 3, -1]); - let dataset = new DataSet(points); - assert.equal(1, dataset.sequences.length); - assert.deepEqual([0, 2, 3], dataset.sequences[0].pointIndices); - }); - - it('Simple forward pointing sequences, __next__ metadata format', () => { - // The input is: 0->2, 1->None, 2->3, 3->None. This should return - // one sequence 0->2->3. - const points = makePointsWithSequences([2, -1, 3, -1], '__next__'); - let dataset = new DataSet(points); - assert.equal(1, dataset.sequences.length); - assert.deepEqual([0, 2, 3], dataset.sequences[0].pointIndices); - }); - - it('No sequences', () => { - let points = makePointsWithSequences([-1, -1, -1, -1]); - let dataset = new DataSet(points); - assert.equal(0, dataset.sequences.length); - }); - - it('A sequence that goes backwards and forward in the array', () => { - // The input is: 0->2, 1->0, 2->nothing, 3->1. This should return - // one sequence 3->1->0->2. - let points = makePointsWithSequences([2, 0, -1, 1]); - let dataset = new DataSet(points); - assert.equal(1, dataset.sequences.length); - assert.deepEqual([3, 1, 0, 2], dataset.sequences[0].pointIndices); - }); -}); - -describe('stateGetAccessorDimensions', () => { - it('returns [0, 1] for 2d t-SNE', () => { - const state = new State(); - state.selectedProjection = 'tsne'; - state.tSNEis3d = false; - assert.deepEqual([0, 1], stateGetAccessorDimensions(state)); - }); - - it('returns [0, 1, 2] for 3d t-SNE', () => { - const state = new State(); - state.selectedProjection = 'tsne'; - state.tSNEis3d = true; - assert.deepEqual([0, 1, 2], stateGetAccessorDimensions(state)); - }); - - it('returns pca component dimensions array for pca', () => { - const state = new State(); - state.selectedProjection = 'pca'; - state.pcaComponentDimensions = [13, 12, 11, 10]; - assert.deepEqual(state.pcaComponentDimensions, - stateGetAccessorDimensions(state)); - }); - - it('returns ["x", "y"] for custom projections', () => { - const state = new State(); - state.selectedProjection = 'custom'; - assert.deepEqual(['x', 'y'], stateGetAccessorDimensions(state)); - }); -}); diff --git a/tensorflow/tensorboard/components/vz_projector/test/scatterPlotRectangleSelector_test.ts b/tensorflow/tensorboard/components/vz_projector/test/scatterPlotRectangleSelector_test.ts deleted file mode 100644 index 0ee6cf620df..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/test/scatterPlotRectangleSelector_test.ts +++ /dev/null @@ -1,69 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {BoundingBox, ScatterPlotRectangleSelector} from '../scatterPlotRectangleSelector'; - -describe('selector callbacks make bounding box start bottom left', () => { - let containerElement: HTMLElement; - let selectionCallback: (boundingBox: BoundingBox) => void; - let selection: ScatterPlotRectangleSelector; - - beforeEach(() => { - containerElement = document.createElement('div'); - const selector = document.createElement('svg'); - selector.id = 'selector'; - containerElement.appendChild(selector); - - selectionCallback = jasmine.createSpy('selectionCallback'); - selection = - new ScatterPlotRectangleSelector(containerElement, selectionCallback); - }); - - it('Simple mouse event starting top left', () => { - selection.onMouseDown(0, 0); - selection.onMouseMove(10, 10); - selection.onMouseUp(); - - expect(selectionCallback) - .toHaveBeenCalledWith({x: 0, y: 10, width: 10, height: 10}); - }); - - it('Simple mouse event starting bottom left', () => { - selection.onMouseDown(0, 10); - selection.onMouseMove(10, 0); - selection.onMouseUp(); - - expect(selectionCallback) - .toHaveBeenCalledWith({x: 0, y: 10, width: 10, height: 10}); - }); - - it('Simple mouse event starting top right', () => { - selection.onMouseDown(10, 0); - selection.onMouseMove(0, 10); - selection.onMouseUp(); - - expect(selectionCallback) - .toHaveBeenCalledWith({x: 0, y: 10, width: 10, height: 10}); - }); - - it('Simple mouse event starting bottom right', () => { - selection.onMouseDown(10, 10); - selection.onMouseMove(0, 0); - selection.onMouseUp(); - - expect(selectionCallback) - .toHaveBeenCalledWith({x: 0, y: 10, width: 10, height: 10}); - }); -}); diff --git a/tensorflow/tensorboard/components/vz_projector/test/sptree_test.ts b/tensorflow/tensorboard/components/vz_projector/test/sptree_test.ts deleted file mode 100644 index 7e340ea62f5..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/test/sptree_test.ts +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {SPTree} from '../sptree'; - -it('simple 2D data', () => { - let data = [ - [0, 1], - [1, 0], - [1, 1], - [0, 0], - ]; - let tree = new SPTree(data); - // Check that each point is within the bound. - tree.visit((node, low, high) => { - assert.equal(low.length, 2); - assert.equal(high.length, 2); - let point = node.point; - assert.equal(point.length, 2); - // Each point should be in the node's bounding box. - assert.equal( - point[0] >= low[0] && point[0] <= high[0] && point[1] >= low[1] && - point[1] <= high[1], - true); - return false; - }); -}); - -it('simple 3D data', () => { - let data = [ - [0, 1, 0], - [1, 0.4, 2], - [1, 1, 3], - [0, 0, 5], - ]; - let tree = new SPTree(data); - // Check that each point is within the bound. - tree.visit((node, low, high) => { - assert.equal(low.length, 3); - assert.equal(high.length, 3); - let point = node.point; - assert.equal(point.length, 3); - // Each point should be in the node's bounding box. - assert.equal( - point[0] >= low[0] && point[0] <= high[0] && point[1] >= low[1] && - point[1] <= high[1] && point[2] >= low[2] && point[2] <= high[2], - true); - return false; - }); -}); - -it('Only visit root', () => { - let data = [ - [0, 1, 0], - [1, 0.4, 2], - [1, 1, 3], - [0, 0, 5], - ]; - let tree = new SPTree(data); - let numVisits = 0; - tree.visit((node, low, high) => { - numVisits++; - return true; - }); - assert.equal(numVisits, 1); -}); - -it('Search in random data', () => { - let N = 10000; - let data = new Array(N); - for (let i = 0; i < N; i++) { - data[i] = [Math.random(), Math.random()]; - } - let tree = new SPTree(data); - let numVisits = 0; - let query = data[Math.floor(Math.random() * N)]; - let found = false; - tree.visit((node, low, high) => { - numVisits++; - if (node.point === query) { - found = true; - return true; - } - let outOfBounds = query[0] < low[0] || query[0] > high[0] || - query[1] < low[1] || query[1] > high[1]; - return outOfBounds; - }); - assert.equal(found, true); - assert.isBelow(numVisits, N / 4); -}); diff --git a/tensorflow/tensorboard/components/vz_projector/test/tests.html b/tensorflow/tensorboard/components/vz_projector/test/tests.html deleted file mode 100644 index a6843d0d6b8..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/test/tests.html +++ /dev/null @@ -1,31 +0,0 @@ - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_projector/test/util_test.ts b/tensorflow/tensorboard/components/vz_projector/test/util_test.ts deleted file mode 100644 index c18db95eed7..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/test/util_test.ts +++ /dev/null @@ -1,42 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -import * as util from '../util'; - -describe('getURLParams', () => { - it('search query with valid param returns correct object', () => { - let urlParams = util.getURLParams('?config=http://google.com/'); - assert.deepEqual({'config': 'http://google.com/'}, urlParams); - }); - - it('search query with multiple valid params returns correct object', () => { - let urlParams = util.getURLParams('?config=http://google.com/&foo=bar'); - assert.deepEqual({'config': 'http://google.com/', 'foo': 'bar'}, urlParams); - }); - - it('search query with valid param with URL encoded characters', () => { - let urlParams = util.getURLParams('?config=http://google.com/%20search'); - assert.deepEqual({'config': 'http://google.com/ search'}, urlParams); - }); - - it('search query with pound sign', () => { - let urlParams = util.getURLParams('?config=http://google.com/#foo'); - assert.deepEqual({'config': 'http://google.com/'}, urlParams); - }); - - it('no search query returns empty object', () => { - let urlParams = util.getURLParams(''); - assert.deepEqual({}, urlParams); - }); -}); diff --git a/tensorflow/tensorboard/components/vz_projector/test/vz-projector-projections-panel_test.ts b/tensorflow/tensorboard/components/vz_projector/test/vz-projector-projections-panel_test.ts deleted file mode 100644 index 2bf0c6eb48f..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/test/vz-projector-projections-panel_test.ts +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -import {State} from '../data'; -import {ProjectionsPanel} from '../vz-projector-projections-panel'; - -describe('restoreUIFromBookmark', () => { - let projectionsPanel: ProjectionsPanel; - beforeEach(() => { - projectionsPanel = document.createElement(ProjectionsPanel.prototype.is) as - ProjectionsPanel; - - // Set up some of the UI so the elements are found in the production code. - const tsnePerplexityContainer = document.createElement('div'); - tsnePerplexityContainer.className = 'tsne-perplexity'; - const tsnePerplexity = document.createElement('span'); - tsnePerplexityContainer.appendChild(tsnePerplexity); - projectionsPanel.appendChild(tsnePerplexityContainer); - - const tsneLearningRateContainer = document.createElement('div'); - tsneLearningRateContainer.className = 'tsne-learning-rate'; - const tsneLearningRate = document.createElement('span'); - tsneLearningRateContainer.appendChild(tsneLearningRate); - projectionsPanel.appendChild(tsneLearningRateContainer); - }); - - it('sets the pcaX/Y properties when setting 2D component values', () => { - spyOn(projectionsPanel, 'setZDropdownEnabled'); - - const s = new State(); - s.pcaComponentDimensions = [0, 1]; - projectionsPanel.restoreUIFromBookmark(s); - - assert.equal(0, projectionsPanel.pcaX); - assert.equal(1, projectionsPanel.pcaY); - - expect(projectionsPanel.setZDropdownEnabled).toHaveBeenCalledWith(false); - }); - - it('sets the pcaX/Y properties when setting 3D component values', () => { - spyOn(projectionsPanel, 'setZDropdownEnabled'); - - const s = new State(); - s.pcaComponentDimensions = [0, 1, 2]; - projectionsPanel.restoreUIFromBookmark(s); - - assert.equal(0, projectionsPanel.pcaX); - assert.equal(1, projectionsPanel.pcaY); - assert.equal(2, projectionsPanel.pcaZ); - - expect(projectionsPanel.setZDropdownEnabled).toHaveBeenCalledWith(true); - }); -}); - -describe('populateBookmarkFromUI', () => { - let projectionsPanel: ProjectionsPanel; - - beforeEach(() => { - projectionsPanel = document.createElement(ProjectionsPanel.prototype.is) as - ProjectionsPanel; - - // Set up some of the UI so the elements are found in the production code. - const tsnePerplexityContainer = document.createElement('div'); - tsnePerplexityContainer.className = 'tsne-perplexity'; - const tsnePerplexity = document.createElement('span'); - tsnePerplexityContainer.appendChild(tsnePerplexity); - projectionsPanel.appendChild(tsnePerplexityContainer); - - const tsneLearningRateContainer = document.createElement('div'); - tsneLearningRateContainer.className = 'tsne-learning-rate'; - const tsneLearningRate = document.createElement('span'); - tsneLearningRateContainer.appendChild(tsneLearningRate); - projectionsPanel.appendChild(tsneLearningRateContainer); - }); - - it('gets the PCA component UI values from a 2D PCA projection', () => { - projectionsPanel.pcaX = 0; - projectionsPanel.pcaY = 1; - projectionsPanel.pcaIs3d = false; - - const s = new State(); - projectionsPanel.populateBookmarkFromUI(s); - assert.deepEqual([0, 1], s.pcaComponentDimensions); - }); - - it('gets the PCA component UI values from a 3D PCA projection', () => { - projectionsPanel.pcaX = 0; - projectionsPanel.pcaY = 1; - projectionsPanel.pcaZ = 2; - projectionsPanel.pcaIs3d = true; - - const s = new State(); - projectionsPanel.populateBookmarkFromUI(s); - assert.deepEqual([0, 1, 2], s.pcaComponentDimensions); - }); -}); diff --git a/tensorflow/tensorboard/components/vz_projector/util.ts b/tensorflow/tensorboard/components/vz_projector/util.ts deleted file mode 100644 index bd6df68b1a5..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/util.ts +++ /dev/null @@ -1,252 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {DataPoint} from './data'; -import * as logging from './logging'; -import {Point2D} from './vector'; - -/** - * Delay for running expensive tasks, in milliseconds. - * The duration was empirically found so that it leaves enough time for the - * browser to update its UI state before starting an expensive UI-blocking task. - */ -const TASK_DELAY_MS = 200; - -/** Shuffles the array in-place in O(n) time using Fisher-Yates algorithm. */ -export function shuffle(array: T[]): T[] { - let m = array.length; - let t: T; - let i: number; - - // While there remain elements to shuffle. - while (m) { - // Pick a remaining element - i = Math.floor(Math.random() * m--); - // And swap it with the current element. - t = array[m]; - array[m] = array[i]; - array[i] = t; - } - return array; -} - -export function range(count: number): number[] { - const rangeOutput: number[] = []; - for (let i = 0; i < count; i++) { - rangeOutput.push(i); - } - return rangeOutput; -} - -export function classed( - element: HTMLElement, className: string, enabled: boolean) { - const classNames = element.className.split(' '); - if (enabled) { - if (className in classNames) { - return; - } else { - classNames.push(className); - } - } else { - const index = classNames.indexOf(className); - if (index === -1) { - return; - } - classNames.splice(index, 1); - } - element.className = classNames.join(' '); -} - -/** Projects a 3d point into screen space */ -export function vector3DToScreenCoords( - cam: THREE.Camera, w: number, h: number, v: THREE.Vector3): Point2D { - let dpr = window.devicePixelRatio; - let pv = new THREE.Vector3().copy(v).project(cam); - - // The screen-space origin is at the middle of the screen, with +y up. - let coords: Point2D = - [((pv.x + 1) / 2 * w) * dpr, -((pv.y - 1) / 2 * h) * dpr]; - return coords; -} - -/** Loads 3 contiguous elements from a packed xyz array into a Vector3. */ -export function vector3FromPackedArray( - a: Float32Array, pointIndex: number): THREE.Vector3 { - const offset = pointIndex * 3; - return new THREE.Vector3(a[offset], a[offset + 1], a[offset + 2]); -} - -/** - * Gets the camera-space z coordinates of the nearest and farthest points. - * Ignores points that are behind the camera. - */ -export function getNearFarPoints( - worldSpacePoints: Float32Array, cameraPos: THREE.Vector3, - cameraTarget: THREE.Vector3): [number, number] { - let shortestDist: number = Infinity; - let furthestDist: number = 0; - const camToTarget = new THREE.Vector3().copy(cameraTarget).sub(cameraPos); - const camPlaneNormal = new THREE.Vector3().copy(camToTarget).normalize(); - const n = worldSpacePoints.length / 3; - let src = 0; - let p = new THREE.Vector3(); - let camToPoint = new THREE.Vector3(); - for (let i = 0; i < n; i++) { - p.x = worldSpacePoints[src]; - p.y = worldSpacePoints[src + 1]; - p.z = worldSpacePoints[src + 2]; - src += 3; - - camToPoint.copy(p).sub(cameraPos); - const dist = camPlaneNormal.dot(camToPoint); - if (dist < 0) { - continue; - } - furthestDist = (dist > furthestDist) ? dist : furthestDist; - shortestDist = (dist < shortestDist) ? dist : shortestDist; - } - return [shortestDist, furthestDist]; -} - -/** - * Generate a texture for the points/images and sets some initial params - */ -export function createTexture(image: HTMLImageElement| - HTMLCanvasElement): THREE.Texture { - let tex = new THREE.Texture(image); - tex.needsUpdate = true; - // Used if the texture isn't a power of 2. - tex.minFilter = THREE.LinearFilter; - tex.generateMipmaps = false; - tex.flipY = false; - return tex; -} - -/** - * Assert that the condition is satisfied; if not, log user-specified message - * to the console. - */ -export function assert(condition: boolean, message?: string) { - if (!condition) { - message = message || 'Assertion failed'; - throw new Error(message); - } -} - -export type SearchPredicate = (p: DataPoint) => boolean; - -export function getSearchPredicate( - query: string, inRegexMode: boolean, fieldName: string): SearchPredicate { - let predicate: SearchPredicate; - if (inRegexMode) { - let regExp = new RegExp(query, 'i'); - predicate = p => regExp.test(p.metadata[fieldName].toString()); - } else { - // Doing a case insensitive substring match. - query = query.toLowerCase(); - predicate = p => { - let label = p.metadata[fieldName].toString().toLowerCase(); - return label.indexOf(query) >= 0; - }; - } - return predicate; -} - -/** - * Runs an expensive task asynchronously with some delay - * so that it doesn't block the UI thread immediately. - * - * @param message The message to display to the user. - * @param task The expensive task to run. - * @param msgId Optional. ID of an existing message. If provided, will overwrite - * an existing message and won't automatically clear the message when the - * task is done. - * @return The value returned by the task. - */ -export function runAsyncTask( - message: string, task: () => T, msgId: string = null): Promise { - let autoClear = (msgId == null); - msgId = logging.setModalMessage(message, msgId); - return new Promise((resolve, reject) => { - setTimeout(() => { - try { - let result = task(); - // Clearing the old message. - if (autoClear) { - logging.setModalMessage(null, msgId); - } - resolve(result); - } catch (ex) { - reject(ex); - } - return true; - }, TASK_DELAY_MS); - }); -} - - -/** - * Parses the URL for query parameters, e.g. ?foo=1&bar=2 will return - * {'foo': '1', 'bar': '2'}. - * @param url The URL to parse. - * @return A map of queryParam key to its value. - */ -export function getURLParams(url: string): {[key: string]: string} { - if (!url) { - return {}; - } - - let queryString = url.indexOf('?') !== -1 ? url.split('?')[1] : url; - if (queryString.indexOf('#')) { - queryString = queryString.split('#')[0]; - } - - const queryEntries = queryString.split('&'); - let queryParams: {[key: string]: string} = {}; - for (let i = 0; i < queryEntries.length; i++) { - let queryEntryComponents = queryEntries[i].split('='); - queryParams[queryEntryComponents[0].toLowerCase()] = - decodeURIComponent(queryEntryComponents[1]); - } - return queryParams; -} - -/** List of substrings that auto generated tensors have in their name. */ -const SUBSTR_GEN_TENSORS = ['/Adagrad']; - -/** Returns true if the tensor was automatically generated by TF API calls. */ -export function tensorIsGenerated(tensorName: string): boolean { - for (let i = 0; i < SUBSTR_GEN_TENSORS.length; i++) { - if (tensorName.indexOf(SUBSTR_GEN_TENSORS[i]) >= 0) { - return true; - } - } - return false; -} - -export function xor(cond1: boolean, cond2: boolean): boolean { - return (cond1 || cond2) && !(cond1 && cond2); -} - -/** Checks to see if the browser supports webgl. */ -export function hasWebGLSupport(): boolean { - try { - let c = document.createElement('canvas'); - let gl = c.getContext('webgl') || c.getContext('experimental-webgl'); - return gl != null && typeof weblas !== 'undefined'; - } catch (e) { - return false; - } -} diff --git a/tensorflow/tensorboard/components/vz_projector/vector.ts b/tensorflow/tensorboard/components/vz_projector/vector.ts deleted file mode 100644 index cab30483138..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vector.ts +++ /dev/null @@ -1,265 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {assert} from './util'; - -/** - * @fileoverview Useful vector utilities. - */ - -export type Vector = Float32Array | number[]; -export type Point2D = [number, number]; -export type Point3D = [number, number, number]; - -/** Returns the dot product of two vectors. */ -export function dot(a: Vector, b: Vector): number { - assert(a.length === b.length, 'Vectors a and b must be of same length'); - let result = 0; - for (let i = 0; i < a.length; ++i) { - result += a[i] * b[i]; - } - return result; -} - -/** Sums all the elements in the vector */ -export function sum(a: Vector): number { - let result = 0; - for (let i = 0; i < a.length; ++i) { - result += a[i]; - } - return result; -} - -/** Returns the sum of two vectors, i.e. a + b */ -export function add(a: Vector, b: Vector): Float32Array { - assert(a.length === b.length, 'Vectors a and b must be of same length'); - let result = new Float32Array(a.length); - for (let i = 0; i < a.length; ++i) { - result[i] = a[i] + b[i]; - } - return result; -} - -/** Subtracts vector b from vector a, i.e. returns a - b */ -export function sub(a: Vector, b: Vector): Float32Array { - assert(a.length === b.length, 'Vectors a and b must be of same length'); - let result = new Float32Array(a.length); - for (let i = 0; i < a.length; ++i) { - result[i] = a[i] - b[i]; - } - return result; -} - -/** Returns the square norm of the vector */ -export function norm2(a: Vector): number { - let result = 0; - for (let i = 0; i < a.length; ++i) { - result += a[i] * a[i]; - } - return result; -} - -/** Returns the euclidean distance between two vectors. */ -export function dist(a: Vector, b: Vector): number { - return Math.sqrt(dist2(a, b)); -} - -/** Returns the square euclidean distance between two vectors. */ -export function dist2(a: Vector, b: Vector): number { - assert(a.length === b.length, 'Vectors a and b must be of same length'); - let result = 0; - for (let i = 0; i < a.length; ++i) { - let diff = a[i] - b[i]; - result += diff * diff; - } - return result; -} - -/** Returns the square euclidean distance between two 2D points. */ -export function dist2_2D(a: Vector, b: Vector): number { - let dX = a[0] - b[0]; - let dY = a[1] - b[1]; - return dX * dX + dY * dY; -} - -/** Returns the square euclidean distance between two 3D points. */ -export function dist2_3D(a: Vector, b: Vector): number { - let dX = a[0] - b[0]; - let dY = a[1] - b[1]; - let dZ = a[2] - b[2]; - return dX * dX + dY * dY + dZ * dZ; -} - -/** Returns the euclidean distance between 2 3D points. */ -export function dist_3D(a: Vector, b: Vector): number { - return Math.sqrt(dist2_3D(a, b)); -} - -/** - * Returns the square euclidean distance between two vectors, with an early - * exit (returns -1) if the distance is >= to the provided limit. - */ -export function dist2WithLimit(a: Vector, b: Vector, limit: number): number { - assert(a.length === b.length, 'Vectors a and b must be of same length'); - let result = 0; - for (let i = 0; i < a.length; ++i) { - let diff = a[i] - b[i]; - result += diff * diff; - if (result >= limit) { - return -1; - } - } - return result; -} - -/** Returns the square euclidean distance between two 2D points. */ -export function dist22D(a: Point2D, b: Point2D): number { - let dX = a[0] - b[0]; - let dY = a[1] - b[1]; - return dX * dX + dY * dY; -} - -/** Modifies the vector in-place to have unit norm. */ -export function unit(a: Vector): void { - let norm = Math.sqrt(norm2(a)); - assert(norm >= 0, 'Norm of the vector must be > 0'); - for (let i = 0; i < a.length; ++i) { - a[i] /= norm; - } -} - -/** - * Projects the vectors to a lower dimension - * - * @param vectors Array of vectors to be projected. - * @param newDim The resulting dimension of the vectors. - */ -export function projectRandom(vectors: Float32Array[], newDim: number): - Float32Array[] { - let dim = vectors[0].length; - let N = vectors.length; - let newVectors: Float32Array[] = new Array(N); - for (let i = 0; i < N; ++i) { - newVectors[i] = new Float32Array(newDim); - } - // Make nDim projections. - for (let k = 0; k < newDim; ++k) { - let randomVector = rn(dim); - for (let i = 0; i < N; ++i) { - newVectors[i][k] = dot(vectors[i], randomVector); - } - } - return newVectors; -} - -/** - * Projects a vector onto a 2D plane specified by the two direction vectors. - */ -export function project2d(a: Vector, dir1: Vector, dir2: Vector): Point2D { - return [dot(a, dir1), dot(a, dir2)]; -} - -/** - * Computes the centroid of the data points. If the provided data points are not - * vectors, an accessor function needs to be provided. - */ -export function centroid(dataPoints: T[], accessor?: (a: T) => Vector): - Vector { - if (dataPoints.length === 0) { - return null; - } - if (accessor == null) { - accessor = (a: T) => a; - } - assert(dataPoints.length >= 0, '`vectors` must be of length >= 1'); - let centroid = new Float32Array(accessor(dataPoints[0]).length); - for (let i = 0; i < dataPoints.length; ++i) { - let dataPoint = dataPoints[i]; - let vector = accessor(dataPoint); - for (let j = 0; j < centroid.length; ++j) { - centroid[j] += vector[j]; - } - } - for (let j = 0; j < centroid.length; ++j) { - centroid[j] /= dataPoints.length; - } - return centroid; -} - -/** - * Generates a vector of the specified size where each component is drawn from - * a random (0, 1) gaussian distribution. - */ -export function rn(size: number): Float32Array { - const normal = d3.randomNormal(); - let result = new Float32Array(size); - for (let i = 0; i < size; ++i) { - result[i] = normal(); - } - return result; -} - -/** - * Returns the cosine distance ([0, 2]) between two vectors - * that have been normalized to unit norm. - */ -export function cosDistNorm(a: Vector, b: Vector): number { - return 1 - dot(a, b); -} - -/** - * Returns the cosine distance ([0, 2]) between two vectors. - */ -export function cosDist(a: Vector, b: Vector): number { - return 1 - cosSim(a, b); -} - -/** Returns the cosine similarity ([-1, 1]) between two vectors. */ -export function cosSim(a: Vector, b: Vector): number { - return dot(a, b) / Math.sqrt(norm2(a) * norm2(b)); -} - -/** - * Converts list of vectors (matrix) into a 1-dimensional - * typed array with row-first order. - */ -export function toTypedArray( - dataPoints: T[], accessor: (dataPoint: T) => Float32Array): Float32Array { - let N = dataPoints.length; - let dim = accessor(dataPoints[0]).length; - let result = new Float32Array(N * dim); - for (let i = 0; i < N; ++i) { - let vector = accessor(dataPoints[i]); - for (let d = 0; d < dim; ++d) { - result[i * dim + d] = vector[d]; - } - } - return result; -} - -/** - * Transposes an RxC matrix represented as a flat typed array - * into a CxR matrix, again represented as a flat typed array. - */ -export function transposeTypedArray( - r: number, c: number, typedArray: Float32Array) { - let result = new Float32Array(r * c); - for (let i = 0; i < r; ++i) { - for (let j = 0; j < c; ++j) { - result[j * r + i] = typedArray[i * c + j]; - } - } - return result; -} diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-app.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-app.html deleted file mode 100644 index e19f0364c44..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-app.html +++ /dev/null @@ -1,105 +0,0 @@ - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-bookmark-panel.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-bookmark-panel.html deleted file mode 100644 index f3f3f59a948..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-bookmark-panel.html +++ /dev/null @@ -1,207 +0,0 @@ - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-bookmark-panel.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-bookmark-panel.ts deleted file mode 100644 index 53195fa47c0..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-bookmark-panel.ts +++ /dev/null @@ -1,283 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -import {State} from './data'; -import {DataProvider, EmbeddingInfo} from './data-provider'; -import * as logging from './logging'; -import {ProjectorEventContext} from './projectorEventContext'; -import {Projector} from './vz-projector'; -// tslint:disable-next-line:no-unused-variable -import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; - -// tslint:disable-next-line -export let BookmarkPanelPolymer = PolymerElement({ - is: 'vz-projector-bookmark-panel', - properties: { - savedStates: Object, - // Keep a separate polymer property because the savedStates doesn't change - // when adding and removing states. - hasStates: {type: Boolean, value: false}, - selectedState: Number - } -}); - -export class BookmarkPanel extends BookmarkPanelPolymer { - private projector: Projector; - - // A list containing all of the saved states. - private savedStates: State[]; - private hasStates = false; - private selectedState: number; - private ignoreNextProjectionEvent: boolean; - - private expandLessButton: HTMLButtonElement; - private expandMoreButton: HTMLButtonElement; - - ready() { - this.savedStates = []; - this.setupUploadButton(); - this.ignoreNextProjectionEvent = false; - this.expandLessButton = - this.querySelector('#expand-less') as HTMLButtonElement; - this.expandMoreButton = - this.querySelector('#expand-more') as HTMLButtonElement; - } - - initialize( - projector: Projector, projectorEventContext: ProjectorEventContext) { - this.projector = projector; - projectorEventContext.registerProjectionChangedListener(() => { - if (this.ignoreNextProjectionEvent) { - this.ignoreNextProjectionEvent = false; - } else { - this.clearStateSelection(); - } - }); - } - - setSelectedTensor( - run: string, tensorInfo: EmbeddingInfo, dataProvider: DataProvider) { - // Clear any existing bookmarks. - this.addStates(null); - if (tensorInfo && tensorInfo.bookmarksPath) { - // Get any bookmarks that may come when the projector starts up. - dataProvider.getBookmarks(run, tensorInfo.tensorName, bookmarks => { - this.addStates(bookmarks); - this._expandMore(); - }); - } else { - this._expandLess(); - } - } - - /** Handles a click on show bookmarks tray button. */ - _expandMore() { - this.$.panel.show(); - this.expandMoreButton.style.display = 'none'; - this.expandLessButton.style.display = ''; - } - - /** Handles a click on hide bookmarks tray button. */ - _expandLess() { - this.$.panel.hide(); - this.expandMoreButton.style.display = ''; - this.expandLessButton.style.display = 'none'; - } - - /** Handles a click on the add bookmark button. */ - _addBookmark() { - let currentState = this.projector.getCurrentState(); - currentState.label = 'State ' + this.savedStates.length; - currentState.isSelected = true; - - this.selectedState = this.savedStates.length; - - for (let i = 0; i < this.savedStates.length; i++) { - this.savedStates[i].isSelected = false; - // We have to call notifyPath so that polymer knows this element was - // updated. - this.notifyPath('savedStates.' + i + '.isSelected', false, false); - } - - this.push('savedStates', currentState as any); - this.updateHasStates(); - } - - /** Handles a click on the download bookmarks button. */ - _downloadFile() { - let serializedState = this.serializeAllSavedStates(); - let blob = new Blob([serializedState], {type: 'text/plain'}); - let textFile = window.URL.createObjectURL(blob); - - // Force a download. - let a = document.createElement('a'); - document.body.appendChild(a); - a.style.display = 'none'; - a.href = textFile; - (a as any).download = 'state'; - a.click(); - - document.body.removeChild(a); - window.URL.revokeObjectURL(textFile); - } - - /** Handles a click on the upload bookmarks button. */ - _uploadFile() { - let fileInput = this.dom.select('#state-file'); - (fileInput.node() as HTMLInputElement).click(); - } - - private setupUploadButton() { - // Show and setup the load view button. - const fileInput = this.querySelector('#state-file') as HTMLInputElement; - fileInput.onchange = () => { - const file: File = fileInput.files[0]; - // Clear out the value of the file chooser. This ensures that if the user - // selects the same file, we'll re-read it. - fileInput.value = ''; - const fileReader = new FileReader(); - fileReader.onload = (evt) => { - const str: string = fileReader.result; - const savedStates = JSON.parse(str); - - // Verify the bookmarks match. - if (this.savedStatesValid(savedStates)) { - this.addStates(savedStates); - this.loadSavedState(0); - } else { - logging.setWarningMessage( - `Unable to load bookmarks: wrong dataset, expected dataset ` + - `with shape (${savedStates[0].dataSetDimensions}).`); - } - }; - fileReader.readAsText(file); - }; - } - - addStates(savedStates?: State[]) { - if (savedStates == null) { - this.savedStates = []; - } else { - for (let i = 0; i < savedStates.length; i++) { - savedStates[i].isSelected = false; - this.push('savedStates', savedStates[i] as any); - } - } - this.updateHasStates(); - } - - /** Deselects any selected state selection. */ - clearStateSelection() { - for (let i = 0; i < this.savedStates.length; i++) { - this.setSelectionState(i, false); - } - } - - /** Handles a radio button click on a saved state. */ - _radioButtonHandler(evt: Event) { - const index = this.getParentDataIndex(evt); - this.loadSavedState(index); - this.setSelectionState(index, true); - } - - loadSavedState(index: number) { - for (let i = 0; i < this.savedStates.length; i++) { - if (this.savedStates[i].isSelected) { - this.setSelectionState(i, false); - } else if (index === i) { - this.setSelectionState(i, true); - this.ignoreNextProjectionEvent = true; - this.projector.loadState(this.savedStates[i]); - } - } - } - - private setSelectionState(stateIndex: number, selected: boolean) { - this.savedStates[stateIndex].isSelected = selected; - const path = 'savedStates.' + stateIndex + '.isSelected'; - this.notifyPath(path, selected, false); - } - - /** - * Crawls up the DOM to find an ancestor with a data-index attribute. This is - * used to match events to their bookmark index. - */ - private getParentDataIndex(evt: Event) { - for (let i = 0; i < (evt as any).path.length; i++) { - let dataIndex = (evt as any).path[i].getAttribute('data-index'); - if (dataIndex != null) { - return +dataIndex; - } - } - return -1; - } - - /** Handles a clear button click on a bookmark. */ - _clearButtonHandler(evt: Event) { - let index = this.getParentDataIndex(evt); - this.splice('savedStates', index, 1); - this.updateHasStates(); - } - - /** Handles a label change event on a bookmark. */ - _labelChange(evt: Event) { - let index = this.getParentDataIndex(evt); - this.savedStates[index].label = (evt.target as any).value; - } - - /** - * Used to determine whether to select the radio button for a given bookmark. - */ - _isSelectedState(index: number) { - return index === this.selectedState; - } - _isNotSelectedState(index: number) { - return index !== this.selectedState; - } - - /** - * Gets all of the saved states as a serialized string. - */ - serializeAllSavedStates(): string { - return JSON.stringify(this.savedStates); - } - - /** - * Loads all of the serialized states and shows them in the list of - * viewable states. - */ - loadSavedStates(serializedStates: string) { - this.savedStates = JSON.parse(serializedStates); - this.updateHasStates(); - } - - /** - * Updates the hasState polymer property. - */ - private updateHasStates() { - this.hasStates = (this.savedStates.length !== 0); - } - - /** Sanity checks a State array to ensure it matches the current dataset. */ - private savedStatesValid(states: State[]): boolean { - for (let i = 0; i < states.length; i++) { - if (states[i].dataSetDimensions[0] !== this.projector.dataSet.dim[0] || - states[i].dataSetDimensions[1] !== this.projector.dataSet.dim[1]) { - return false; - } - } - return true; - } -} -document.registerElement(BookmarkPanel.prototype.is, BookmarkPanel); diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-colab.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-colab.html deleted file mode 100644 index 2acb570b3c1..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-colab.html +++ /dev/null @@ -1,32 +0,0 @@ - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-dashboard.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-dashboard.html deleted file mode 100644 index 8223c503ecd..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-dashboard.html +++ /dev/null @@ -1,79 +0,0 @@ - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.html deleted file mode 100644 index d8dfd6e978c..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.html +++ /dev/null @@ -1,402 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.ts deleted file mode 100644 index a9b6f6c5a06..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-data-panel.ts +++ /dev/null @@ -1,496 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {ColorOption, ColumnStats, SpriteAndMetadataInfo} from './data'; -import {DataProvider, EmbeddingInfo, parseRawMetadata, parseRawTensors, ProjectorConfig} from './data-provider'; -import * as util from './util'; -import {Projector} from './vz-projector'; -import {ColorLegendRenderInfo, ColorLegendThreshold} from './vz-projector-legend'; -// tslint:disable-next-line:no-unused-variable -import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; - -export let DataPanelPolymer = PolymerElement({ - is: 'vz-projector-data-panel', - properties: { - selectedTensor: {type: String, observer: '_selectedTensorChanged'}, - selectedRun: {type: String, observer: '_selectedRunChanged'}, - selectedColorOptionName: { - type: String, - notify: true, - observer: '_selectedColorOptionNameChanged' - }, - selectedLabelOption: - {type: String, notify: true, observer: '_selectedLabelOptionChanged'}, - normalizeData: Boolean, - showForceCategoricalColorsCheckbox: Boolean - } -}); - -export class DataPanel extends DataPanelPolymer { - selectedLabelOption: string; - selectedColorOptionName: string; - showForceCategoricalColorsCheckbox: boolean; - - private normalizeData: boolean; - private labelOptions: string[]; - private colorOptions: ColorOption[]; - forceCategoricalColoring: boolean = false; - - private selectedTensor: string; - private selectedRun: string; - private dataProvider: DataProvider; - private tensorNames: {name: string, shape: number[]}[]; - private runNames: string[]; - private projector: Projector; - private projectorConfig: ProjectorConfig; - private colorLegendRenderInfo: ColorLegendRenderInfo; - private spriteAndMetadata: SpriteAndMetadataInfo; - private metadataFile: string; - - ready() { - this.normalizeData = true; - } - - initialize(projector: Projector, dp: DataProvider) { - this.projector = projector; - this.dataProvider = dp; - this.setupUploadButtons(); - - // Tell the projector whenever the data normalization changes. - // Unknown why, but the polymer checkbox button stops working as soon as - // you do d3.select() on it. - this.querySelector('#normalize-data-checkbox') - .addEventListener('change', () => { - this.projector.setNormalizeData(this.normalizeData); - }); - - let forceCategoricalColoringCheckbox = - this.querySelector('#force-categorical-checkbox'); - forceCategoricalColoringCheckbox.addEventListener('change', () => { - this.setForceCategoricalColoring( - (forceCategoricalColoringCheckbox as HTMLInputElement).checked); - }); - - // Get all the runs. - this.dataProvider.retrieveRuns(runs => { - this.runNames = runs; - // Choose the first run by default. - if (this.runNames.length > 0) { - this.selectedRun = runs[0]; - } - }); - } - - setForceCategoricalColoring(forceCategoricalColoring: boolean) { - this.forceCategoricalColoring = forceCategoricalColoring; - (this.querySelector('#force-categorical-checkbox') as HTMLInputElement) - .checked = this.forceCategoricalColoring; - - this.updateMetadataUI(this.spriteAndMetadata.stats, this.metadataFile); - - // The selected color option name doesn't change when we switch to using - // categorical coloring for stats with too many unique values, so we - // manually call this polymer observer so that we update the UI. - this._selectedColorOptionNameChanged(); - } - - getSeparatorClass(isSeparator: boolean): string { - return isSeparator ? 'separator' : null; - } - - metadataChanged( - spriteAndMetadata: SpriteAndMetadataInfo, metadataFile: string) { - this.spriteAndMetadata = spriteAndMetadata; - this.metadataFile = metadataFile; - - this.updateMetadataUI(this.spriteAndMetadata.stats, this.metadataFile); - this.selectedColorOptionName = this.colorOptions[0].name; - } - - private addWordBreaks(longString: string): string { - if (longString == null) { - return ''; - } - return longString.replace(/([\/=-_,])/g, '$1'); - } - - private updateMetadataUI(columnStats: ColumnStats[], metadataFile: string) { - const metadataFileElement = - this.querySelector('#metadata-file') as HTMLSpanElement; - metadataFileElement.innerHTML = this.addWordBreaks(metadataFile); - metadataFileElement.title = metadataFile; - - // Label by options. - let labelIndex = -1; - this.labelOptions = columnStats.map((stats, i) => { - // Make the default label by the first non-numeric column. - if (!stats.isNumeric && labelIndex === -1) { - labelIndex = i; - } - return stats.name; - }); - this.selectedLabelOption = this.labelOptions[Math.max(0, labelIndex)]; - - // Color by options. - const standardColorOption: ColorOption[] = [ - {name: 'No color map'}, - // TODO(smilkov): Implement this. - // {name: 'Distance of neighbors', - // desc: 'How far is each point from its neighbors'} - ]; - const metadataColorOption: ColorOption[] = - columnStats - .filter(stats => { - return !stats.tooManyUniqueValues || stats.isNumeric; - }) - .map(stats => { - let map; - let items: {label: string, count: number}[]; - let thresholds: ColorLegendThreshold[]; - let isCategorical = - this.forceCategoricalColoring || !stats.tooManyUniqueValues; - if (isCategorical) { - const scale = d3.scaleOrdinal(d3.schemeCategory20); - let range = scale.range(); - // Re-order the range. - let newRange = range.map((color, i) => { - let index = (i * 3) % range.length; - return range[index]; - }); - items = stats.uniqueEntries; - scale.range(newRange).domain(items.map(x => x.label)); - map = scale; - } else { - thresholds = [ - {color: '#ffffdd', value: stats.min}, - {color: '#1f2d86', value: stats.max} - ]; - map = d3.scaleLinear() - .domain(thresholds.map(t => t.value)) - .range(thresholds.map(t => t.color)); - } - let desc = !isCategorical ? 'gradient' : - stats.uniqueEntries.length + - ((stats.uniqueEntries.length > 20) ? ' non-unique' : '') + - ' colors'; - return { - name: stats.name, - desc: desc, - map: map, - items: items, - thresholds: thresholds, - tooManyUniqueValues: stats.tooManyUniqueValues - }; - }); - - if (metadataColorOption.length > 0) { - // Add a separator line between built-in color maps - // and those based on metadata columns. - standardColorOption.push({name: 'Metadata', isSeparator: true}); - } - this.colorOptions = standardColorOption.concat(metadataColorOption); - } - - setNormalizeData(normalizeData: boolean) { - this.normalizeData = normalizeData; - } - - _selectedTensorChanged() { - this.projector.updateDataSet(null, null, null); - if (this.selectedTensor == null) { - return; - } - this.dataProvider.retrieveTensor( - this.selectedRun, this.selectedTensor, ds => { - let metadataFile = - this.getEmbeddingInfoByName(this.selectedTensor).metadataPath; - this.dataProvider.retrieveSpriteAndMetadata( - this.selectedRun, this.selectedTensor, metadata => { - this.projector.updateDataSet(ds, metadata, metadataFile); - }); - }); - this.projector.setSelectedTensor( - this.selectedRun, this.getEmbeddingInfoByName(this.selectedTensor)); - } - - _selectedRunChanged() { - this.dataProvider.retrieveProjectorConfig(this.selectedRun, info => { - this.projectorConfig = info; - let names = - this.projectorConfig.embeddings.map(e => e.tensorName) - .filter(name => { - let shape = this.getEmbeddingInfoByName(name).tensorShape; - return shape.length === 2 && shape[0] > 1 && shape[1] > 1; - }) - .sort((a, b) => { - let embA = this.getEmbeddingInfoByName(a); - let embB = this.getEmbeddingInfoByName(b); - - // Prefer tensors with metadata. - if (util.xor(!!embA.metadataPath, !!embB.metadataPath)) { - return embA.metadataPath ? -1 : 1; - } - - // Prefer non-generated tensors. - let isGenA = util.tensorIsGenerated(a); - let isGenB = util.tensorIsGenerated(b); - if (util.xor(isGenA, isGenB)) { - return isGenB ? -1 : 1; - } - - // Prefer bigger tensors. - let sizeA = embA.tensorShape[0]; - let sizeB = embB.tensorShape[0]; - if (sizeA !== sizeB) { - return sizeB - sizeA; - } - - // Sort alphabetically by tensor name. - return a <= b ? -1 : 1; - }); - this.tensorNames = names.map(name => { - return {name, shape: this.getEmbeddingInfoByName(name).tensorShape}; - }); - const wordBreakablePath = - this.addWordBreaks(this.projectorConfig.modelCheckpointPath); - const checkpointFile = - this.querySelector('#checkpoint-file') as HTMLSpanElement; - checkpointFile.innerHTML = wordBreakablePath; - checkpointFile.title = this.projectorConfig.modelCheckpointPath; - - // If in demo mode, let the order decide which tensor to load by default. - const defaultTensor = this.projector.servingMode === 'demo' ? - this.projectorConfig.embeddings[0].tensorName : - names[0]; - if (this.selectedTensor === defaultTensor) { - // Explicitly call the observer. Polymer won't call it if the previous - // string matches the current string. - this._selectedTensorChanged(); - } else { - this.selectedTensor = defaultTensor; - } - }); - } - - _selectedLabelOptionChanged() { - this.projector.setSelectedLabelOption(this.selectedLabelOption); - } - - _selectedColorOptionNameChanged() { - let colorOption: ColorOption; - for (let i = 0; i < this.colorOptions.length; i++) { - if (this.colorOptions[i].name === this.selectedColorOptionName) { - colorOption = this.colorOptions[i]; - break; - } - } - if (!colorOption) { - return; - } - - this.showForceCategoricalColorsCheckbox = !!colorOption.tooManyUniqueValues; - - if (colorOption.map == null) { - this.colorLegendRenderInfo = null; - } else if (colorOption.items) { - let items = colorOption.items.map(item => { - return { - color: colorOption.map(item.label), - label: item.label, - count: item.count - }; - }); - this.colorLegendRenderInfo = {items, thresholds: null}; - } else { - this.colorLegendRenderInfo = { - items: null, - thresholds: colorOption.thresholds - }; - } - this.projector.setSelectedColorOption(colorOption); - } - - private tensorWasReadFromFile(rawContents: ArrayBuffer, fileName: string) { - parseRawTensors(rawContents, ds => { - const checkpointFile = - this.querySelector('#checkpoint-file') as HTMLSpanElement; - checkpointFile.innerText = fileName; - checkpointFile.title = fileName; - this.projector.updateDataSet(ds); - }); - } - - private metadataWasReadFromFile(rawContents: ArrayBuffer, fileName: string) { - parseRawMetadata(rawContents, metadata => { - this.projector.updateDataSet(this.projector.dataSet, metadata, fileName); - }); - } - - private getEmbeddingInfoByName(tensorName: string): EmbeddingInfo { - for (let i = 0; i < this.projectorConfig.embeddings.length; i++) { - const e = this.projectorConfig.embeddings[i]; - if (e.tensorName === tensorName) { - return e; - } - } - } - - private setupUploadButtons() { - // Show and setup the upload button. - const fileInput = this.querySelector('#file') as HTMLInputElement; - fileInput.onchange = () => { - const file: File = fileInput.files[0]; - // Clear out the value of the file chooser. This ensures that if the user - // selects the same file, we'll re-read it. - fileInput.value = ''; - const fileReader = new FileReader(); - fileReader.onload = evt => { - const content: ArrayBuffer = fileReader.result; - this.tensorWasReadFromFile(content, file.name); - }; - fileReader.readAsArrayBuffer(file); - }; - - const uploadButton = - this.querySelector('#upload-tensors') as HTMLButtonElement; - uploadButton.onclick = () => { - fileInput.click(); - }; - - // Show and setup the upload metadata button. - const fileMetadataInput = - this.querySelector('#file-metadata') as HTMLInputElement; - fileMetadataInput.onchange = () => { - const file: File = fileMetadataInput.files[0]; - // Clear out the value of the file chooser. This ensures that if the user - // selects the same file, we'll re-read it. - fileMetadataInput.value = ''; - const fileReader = new FileReader(); - fileReader.onload = evt => { - const contents: ArrayBuffer = fileReader.result; - this.metadataWasReadFromFile(contents, file.name); - }; - fileReader.readAsArrayBuffer(file); - }; - - const uploadMetadataButton = - this.querySelector('#upload-metadata') as HTMLButtonElement; - uploadMetadataButton.onclick = () => { - fileMetadataInput.click(); - }; - - if (this.projector.servingMode !== 'demo') { - (this.$$('#publish-container') as HTMLElement).style.display = 'none'; - (this.$$('#upload-tensors-step-container') as HTMLElement).style.display = - 'none'; - (this.$$('#upload-metadata-label') as HTMLElement).style.display = 'none'; - } - - (this.$$('#demo-data-buttons-container') as HTMLElement).style.display = - 'block'; - - // Fill out the projector config. - const projectorConfigTemplate = - this.$$('#projector-config-template') as HTMLTextAreaElement; - const projectorConfigTemplateJson: ProjectorConfig = { - embeddings: [{ - tensorName: 'My tensor', - tensorShape: [1000, 50], - tensorPath: 'https://raw.githubusercontent.com/.../tensors.tsv', - metadataPath: - 'https://raw.githubusercontent.com/.../optional.metadata.tsv', - }], - }; - this.setProjectorConfigTemplateJson( - projectorConfigTemplate, projectorConfigTemplateJson); - - // Set up optional field checkboxes. - const spriteFieldCheckbox = - this.$$('#config-sprite-checkbox') as HTMLInputElement; - spriteFieldCheckbox.onchange = () => { - if ((spriteFieldCheckbox as any).checked) { - projectorConfigTemplateJson.embeddings[0].sprite = { - imagePath: 'https://github.com/.../optional.sprite.png', - singleImageDim: [32, 32] - }; - } else { - delete projectorConfigTemplateJson.embeddings[0].sprite; - } - this.setProjectorConfigTemplateJson( - projectorConfigTemplate, projectorConfigTemplateJson); - }; - const bookmarksFieldCheckbox = - this.$$('#config-bookmarks-checkbox') as HTMLInputElement; - bookmarksFieldCheckbox.onchange = () => { - if ((bookmarksFieldCheckbox as any).checked) { - projectorConfigTemplateJson.embeddings[0].bookmarksPath = - 'https://raw.githubusercontent.com/.../bookmarks.txt'; - } else { - delete projectorConfigTemplateJson.embeddings[0].bookmarksPath; - } - this.setProjectorConfigTemplateJson( - projectorConfigTemplate, projectorConfigTemplateJson); - }; - const metadataFieldCheckbox = - this.$$('#config-metadata-checkbox') as HTMLInputElement; - metadataFieldCheckbox.onchange = () => { - if ((metadataFieldCheckbox as HTMLInputElement).checked) { - projectorConfigTemplateJson.embeddings[0].metadataPath = - 'https://raw.githubusercontent.com/.../optional.metadata.tsv'; - } else { - delete projectorConfigTemplateJson.embeddings[0].metadataPath; - } - this.setProjectorConfigTemplateJson( - projectorConfigTemplate, projectorConfigTemplateJson); - }; - - // Update the link and the readonly shareable URL. - const projectorConfigUrlInput = - this.$$('#projector-config-url') as HTMLInputElement; - const projectorConfigDemoUrlInput = this.$$('#projector-share-url'); - const projectorConfigDemoUrlLink = this.$$('#projector-share-url-link'); - projectorConfigUrlInput.onchange = () => { - let projectorDemoUrl = location.protocol + '//' + location.host + - location.pathname + - '?config=' + (projectorConfigUrlInput as HTMLInputElement).value; - - (projectorConfigDemoUrlInput as HTMLInputElement).value = - projectorDemoUrl; - (projectorConfigDemoUrlLink as HTMLLinkElement).href = projectorDemoUrl; - }; - } - - private setProjectorConfigTemplateJson( - projectorConfigTemplate: HTMLTextAreaElement, config: ProjectorConfig) { - projectorConfigTemplate.value = - JSON.stringify(config, null, /** replacer */ 2 /** white space */); - } - - _getNumTensorsLabel(): string { - return this.tensorNames.length === 1 ? '1 tensor' : - this.tensorNames.length + ' tensors'; - } - - _getNumRunsLabel(): string { - return this.runNames.length === 1 ? '1 run' : - this.runNames.length + ' runs'; - } - - _hasChoices(choices: any[]): boolean { - return choices.length > 1; - } -} - -document.registerElement(DataPanel.prototype.is, DataPanel); diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-input.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-input.html deleted file mode 100644 index 0d7bf7cdda6..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-input.html +++ /dev/null @@ -1,66 +0,0 @@ - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-input.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-input.ts deleted file mode 100644 index e11346d327f..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-input.ts +++ /dev/null @@ -1,113 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// tslint:disable-next-line:no-unused-variable -import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; - -// tslint:disable-next-line -export let PolymerClass = PolymerElement( - {is: 'vz-projector-input', properties: {label: String, message: String}}); - -export interface InputChangedListener { - (value: string, inRegexMode: boolean): void; -} - -/** Input control with custom capabilities (e.g. regex). */ -export class ProjectorInput extends PolymerClass { - private textChangedListeners: InputChangedListener[]; - private paperInput: HTMLInputElement; - private inRegexModeButton: HTMLButtonElement; - private inRegexMode: boolean; - - /** Message that will be displayed at the bottom of the input control. */ - message: string; - - /** Subscribe to be called everytime the input changes. */ - registerInputChangedListener(listener: InputChangedListener) { - this.textChangedListeners.push(listener); - } - - ready() { - this.inRegexMode = false; - this.textChangedListeners = []; - this.paperInput = this.querySelector('paper-input') as HTMLInputElement; - this.inRegexModeButton = - this.querySelector('paper-button') as HTMLButtonElement; - this.paperInput.setAttribute('error-message', 'Invalid regex'); - - this.paperInput.addEventListener('input', () => { - this.onTextChanged(); - }); - - this.paperInput.addEventListener('keydown', event => { - event.stopPropagation(); - }); - - this.inRegexModeButton.addEventListener( - 'click', () => this.onClickRegexModeButton()); - this.updateRegexModeDisplaySlashes(); - this.onTextChanged(); - } - - private onClickRegexModeButton() { - this.inRegexMode = (this.inRegexModeButton as any).active; - this.updateRegexModeDisplaySlashes(); - this.onTextChanged(); - } - - private notifyInputChanged(value: string, inRegexMode: boolean) { - this.textChangedListeners.forEach(l => l(value, inRegexMode)); - } - - private onTextChanged() { - try { - if (this.inRegexMode) { - new RegExp(this.paperInput.value); - } - } catch (invalidRegexException) { - this.paperInput.setAttribute('invalid', 'true'); - this.message = ''; - this.notifyInputChanged(null, true); - return; - } - this.paperInput.removeAttribute('invalid'); - this.notifyInputChanged(this.paperInput.value, this.inRegexMode); - } - - private updateRegexModeDisplaySlashes() { - const slashes = this.paperInput.querySelectorAll('.slash'); - const display = this.inRegexMode ? '' : 'none'; - - for (let i = 0; i < slashes.length; i++) { - (slashes[i] as HTMLDivElement).style.display = display; - } - } - - getValue(): string { - return this.paperInput.value; - } - - getInRegexMode(): boolean { - return this.inRegexMode; - } - - set(value: string, inRegexMode: boolean) { - (this.inRegexModeButton as any).active = inRegexMode; - this.paperInput.value = value; - this.onClickRegexModeButton(); - } -} - -document.registerElement(ProjectorInput.prototype.is, ProjectorInput); diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-inspector-panel.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-inspector-panel.html deleted file mode 100644 index 1b81094776f..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-inspector-panel.html +++ /dev/null @@ -1,241 +0,0 @@ - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-inspector-panel.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-inspector-panel.ts deleted file mode 100644 index 3ee2c2165f2..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-inspector-panel.ts +++ /dev/null @@ -1,337 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {DistanceFunction, SpriteAndMetadataInfo, State} from './data'; -import * as knn from './knn'; -import {ProjectorEventContext} from './projectorEventContext'; -import * as adapter from './projectorScatterPlotAdapter'; -import * as util from './util'; -import * as vector from './vector'; -import {Projector} from './vz-projector'; -import {ProjectorInput} from './vz-projector-input'; -// tslint:disable-next-line:no-unused-variable -import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; - -/** Limit the number of search results we show to the user. */ -const LIMIT_RESULTS = 100; - -// tslint:disable-next-line -export let PolymerClass = PolymerElement({ - is: 'vz-projector-inspector-panel', - properties: {selectedMetadataField: String, metadataFields: Array} -}); - -export class InspectorPanel extends PolymerClass { - distFunc: DistanceFunction; - numNN: number; - - private projectorEventContext: ProjectorEventContext; - - private selectedMetadataField: string; - private metadataFields: string[]; - private projector: Projector; - private selectedPointIndices: number[]; - private neighborsOfFirstPoint: knn.NearestEntry[]; - private searchBox: ProjectorInput; - - private resetFilterButton: HTMLButtonElement; - private setFilterButton: HTMLButtonElement; - private clearSelectionButton: HTMLButtonElement; - private limitMessage: HTMLDivElement; - - ready() { - this.resetFilterButton = - this.querySelector('.reset-filter') as HTMLButtonElement; - this.setFilterButton = - this.querySelector('.set-filter') as HTMLButtonElement; - this.clearSelectionButton = - this.querySelector('.clear-selection') as HTMLButtonElement; - this.limitMessage = this.querySelector('.limit-msg') as HTMLDivElement; - this.searchBox = this.querySelector('#search-box') as ProjectorInput; - // https://www.polymer-project.org/1.0/docs/devguide/styling#scope-subtree - this.scopeSubtree(this, true); - } - - initialize( - projector: Projector, projectorEventContext: ProjectorEventContext) { - this.projector = projector; - this.projectorEventContext = projectorEventContext; - this.setupUI(projector); - projectorEventContext.registerSelectionChangedListener( - (selection, neighbors) => - this.updateInspectorPane(selection, neighbors)); - } - - /** Updates the nearest neighbors list in the inspector. */ - private updateInspectorPane( - indices: number[], neighbors: knn.NearestEntry[]) { - this.neighborsOfFirstPoint = neighbors; - this.selectedPointIndices = indices; - - this.updateFilterButtons(indices.length + neighbors.length); - this.updateNeighborsList(neighbors); - if (neighbors.length === 0) { - this.updateSearchResults(indices); - } else { - this.updateSearchResults([]); - } - } - - private enableResetFilterButton(enabled: boolean) { - this.resetFilterButton.disabled = !enabled; - } - - restoreUIFromBookmark(bookmark: State) { - this.enableResetFilterButton(bookmark.filteredPoints != null); - } - - metadataChanged(spriteAndMetadata: SpriteAndMetadataInfo) { - let labelIndex = -1; - this.metadataFields = spriteAndMetadata.stats.map((stats, i) => { - if (!stats.isNumeric && labelIndex === -1) { - labelIndex = i; - } - return stats.name; - }); - labelIndex = Math.max(0, labelIndex); - // Make the default label the first non-numeric column. - this.selectedMetadataField = spriteAndMetadata.stats[labelIndex].name; - } - - datasetChanged() { - this.enableResetFilterButton(false); - } - - private updateSearchResults(indices: number[]) { - const container = this.querySelector('.matches-list') as HTMLDivElement; - container.style.display = indices.length ? null : 'none'; - const list = container.querySelector('.list') as HTMLDivElement; - list.innerHTML = ''; - if (indices.length === 0) { - return; - } - - this.limitMessage.style.display = - indices.length <= LIMIT_RESULTS ? 'none' : null; - indices = indices.slice(0, LIMIT_RESULTS); - - for (let i = 0; i < indices.length; i++) { - const index = indices[i]; - - const row = document.createElement('div'); - row.className = 'row'; - - const label = this.getLabelFromIndex(index); - const rowLink = document.createElement('a'); - rowLink.className = 'label'; - rowLink.title = label; - rowLink.innerText = label; - - rowLink.onmouseenter = () => { - this.projectorEventContext.notifyHoverOverPoint(index); - }; - rowLink.onmouseleave = () => { - this.projectorEventContext.notifyHoverOverPoint(null); - }; - rowLink.onclick = () => { - this.projectorEventContext.notifySelectionChanged([index]); - }; - - row.appendChild(rowLink); - list.appendChild(row); - } - } - - private getLabelFromIndex(pointIndex: number): string { - const point = this.projector.dataSet.points[pointIndex]; - return point.metadata[this.selectedMetadataField].toString(); - } - - private updateNeighborsList(neighbors: knn.NearestEntry[]) { - const nnlist = this.querySelector('.nn-list') as HTMLDivElement; - nnlist.innerHTML = ''; - - (this.querySelector('.nn') as HTMLDivElement).style.display = - neighbors.length ? null : 'none'; - - if (neighbors.length === 0) { - return; - } - - this.searchBox.message = ''; - const minDist = neighbors.length > 0 ? neighbors[0].dist : 0; - - for (let i = 0; i < neighbors.length; i++) { - const neighbor = neighbors[i]; - - const neighborElement = document.createElement('div'); - neighborElement.className = 'neighbor'; - - const neighborElementLink = document.createElement('a'); - neighborElementLink.className = 'neighbor-link'; - neighborElementLink.title = this.getLabelFromIndex(neighbor.index); - - const labelValueElement = document.createElement('div'); - labelValueElement.className = 'label-and-value'; - - const labelElement = document.createElement('div'); - labelElement.className = 'label'; - labelElement.style.color = - adapter.dist2color(this.distFunc, neighbor.dist, minDist); - labelElement.innerText = this.getLabelFromIndex(neighbor.index); - - const valueElement = document.createElement('div'); - valueElement.className = 'value'; - valueElement.innerText = neighbor.dist.toFixed(3); - - labelValueElement.appendChild(labelElement); - labelValueElement.appendChild(valueElement); - - const barElement = document.createElement('div'); - barElement.className = 'bar'; - - const barFillElement = document.createElement('div'); - barFillElement.className = 'fill'; - barFillElement.style.borderTopColor = - adapter.dist2color(this.distFunc, neighbor.dist, minDist); - barFillElement.style.width = - adapter.normalizeDist(this.distFunc, neighbor.dist, minDist) * 100 + - '%'; - barElement.appendChild(barFillElement); - - for (let j = 1; j < 4; j++) { - const tickElement = document.createElement('div'); - tickElement.className = 'tick'; - tickElement.style.left = j * 100 / 4 + '%'; - barElement.appendChild(tickElement); - } - - neighborElementLink.appendChild(labelValueElement); - neighborElementLink.appendChild(barElement); - neighborElement.appendChild(neighborElementLink); - nnlist.appendChild(neighborElement); - - neighborElementLink.onmouseenter = () => { - this.projectorEventContext.notifyHoverOverPoint(neighbor.index); - }; - neighborElementLink.onmouseleave = () => { - this.projectorEventContext.notifyHoverOverPoint(null); - }; - neighborElementLink.onclick = () => { - this.projectorEventContext.notifySelectionChanged([neighbor.index]); - }; - } - } - - private updateFilterButtons(numPoints: number) { - if (numPoints > 1) { - this.setFilterButton.innerText = `Isolate ${numPoints} points`; - this.setFilterButton.disabled = null; - this.clearSelectionButton.disabled = null; - } else { - this.setFilterButton.disabled = true; - this.clearSelectionButton.disabled = true; - } - } - - private setupUI(projector: Projector) { - this.distFunc = vector.cosDist; - const eucDist = - this.querySelector('.distance a.euclidean') as HTMLLinkElement; - eucDist.onclick = () => { - const links = this.querySelectorAll('.distance a'); - for (let i = 0; i < links.length; i++) { - util.classed(links[i] as HTMLElement, 'selected', false); - } - util.classed(eucDist as HTMLElement, 'selected', true); - - this.distFunc = vector.dist; - this.projectorEventContext.notifyDistanceMetricChanged(this.distFunc); - const neighbors = projector.dataSet.findNeighbors( - this.selectedPointIndices[0], this.distFunc, this.numNN); - this.updateNeighborsList(neighbors); - }; - - const cosDist = this.querySelector('.distance a.cosine') as HTMLLinkElement; - cosDist.onclick = () => { - const links = this.querySelectorAll('.distance a'); - for (let i = 0; i < links.length; i++) { - util.classed(links[i] as HTMLElement, 'selected', false); - } - util.classed(cosDist, 'selected', true); - - this.distFunc = vector.cosDist; - this.projectorEventContext.notifyDistanceMetricChanged(this.distFunc); - const neighbors = projector.dataSet.findNeighbors( - this.selectedPointIndices[0], this.distFunc, this.numNN); - this.updateNeighborsList(neighbors); - }; - - // Called whenever the search text input changes. - const updateInput = (value: string, inRegexMode: boolean) => { - if (value == null || value.trim() === '') { - this.searchBox.message = ''; - this.projectorEventContext.notifySelectionChanged([]); - return; - } - const indices = projector.dataSet.query( - value, inRegexMode, this.selectedMetadataField); - if (indices.length === 0) { - this.searchBox.message = '0 matches.'; - } else { - this.searchBox.message = `${indices.length} matches.`; - } - this.projectorEventContext.notifySelectionChanged(indices); - }; - this.searchBox.registerInputChangedListener((value, inRegexMode) => { - updateInput(value, inRegexMode); - }); - - // Nearest neighbors controls. - const numNNInput = this.$$('#nn-slider') as HTMLInputElement; - const updateNumNN = () => { - this.numNN = +numNNInput.value; - (this.querySelector('.num-nn .nn-count') as HTMLSpanElement).innerText = - '' + this.numNN; - if (this.selectedPointIndices != null) { - this.projectorEventContext.notifySelectionChanged( - [this.selectedPointIndices[0]]); - } - }; - numNNInput.addEventListener('change', updateNumNN); - updateNumNN(); - - // Filtering dataset. - this.setFilterButton.onclick = () => { - const indices = this.selectedPointIndices.concat( - this.neighborsOfFirstPoint.map(n => n.index)); - projector.filterDataset(indices); - this.enableResetFilterButton(true); - this.updateFilterButtons(0); - }; - - this.resetFilterButton.onclick = () => { - projector.resetFilterDataset(); - this.enableResetFilterButton(false); - }; - - this.clearSelectionButton.onclick = () => { - projector.adjustSelectionAndHover([]); - }; - this.enableResetFilterButton(false); - } -} - -document.registerElement(InspectorPanel.prototype.is, InspectorPanel); diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.html deleted file mode 100644 index 4b98d8bded8..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.html +++ /dev/null @@ -1,78 +0,0 @@ - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.ts deleted file mode 100644 index 1c4ddf940dc..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-legend.ts +++ /dev/null @@ -1,98 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// tslint:disable-next-line:no-unused-variable -import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; - -// tslint:disable-next-line -export let LegendPolymer = PolymerElement({ - is: 'vz-projector-legend', - properties: {renderInfo: {type: Object, observer: '_renderInfoChanged'}} -}); - -export interface ColorLegendRenderInfo { - // To be used for categorical map. - items: ColorLegendItem[]; - // To be used for gradient map. - thresholds: ColorLegendThreshold[]; -} - -/** An item in the categorical color legend. */ -export interface ColorLegendItem { - color: string; - label: string; - count: number; -} - -/** An item in the gradient color legend. */ -export interface ColorLegendThreshold { - color: string; - value: number; -} - -export class Legend extends LegendPolymer { - renderInfo: ColorLegendRenderInfo; - - _renderInfoChanged() { - if (this.renderInfo == null) { - return; - } - if (this.renderInfo.thresholds) { - // is under dom-if so we should wait for it to be - // inserted in the dom tree using async(). - this.async(() => this.setupLinearGradient()); - } - } - - _getLastThreshold(): number { - if (this.renderInfo == null || this.renderInfo.thresholds == null) { - return; - } - return this.renderInfo.thresholds[this.renderInfo.thresholds.length - 1] - .value; - } - - private getOffset(value: number): string { - const min = this.renderInfo.thresholds[0].value; - const max = - this.renderInfo.thresholds[this.renderInfo.thresholds.length - 1].value; - return (100 * (value - min) / (max - min)).toFixed(2) + '%'; - } - - private setupLinearGradient() { - const linearGradient = - this.querySelector('#gradient') as SVGLinearGradientElement; - - const width = - (this.querySelector('svg.gradient') as SVGElement).clientWidth; - - // Set the svg to be the width of its parent. - (this.querySelector('svg.gradient rect') as SVGRectElement).style.width = - width + 'px'; - - // Remove all children from before. - linearGradient.innerHTML = ''; - - // Add a child in for each gradient threshold. - this.renderInfo.thresholds.forEach(t => { - const stopElement = - document.createElementNS('http://www.w3.org/2000/svg', 'stop'); - stopElement.setAttribute('offset', this.getOffset(t.value)); - stopElement.setAttribute('stop-color', t.color); - }); - } -} - -document.registerElement(Legend.prototype.is, Legend); diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.html deleted file mode 100644 index 4231a61ff30..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.html +++ /dev/null @@ -1,99 +0,0 @@ - - - - - - - - -
-
- - -
- - -
- -
-
-
- - - - diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.ts deleted file mode 100644 index 939300f3878..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-metadata-card.ts +++ /dev/null @@ -1,88 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {PointMetadata} from './data'; -// tslint:disable-next-line:no-unused-variable -import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; - -// tslint:disable-next-line -export let MetadataCardPolymer = PolymerElement({ - is: 'vz-projector-metadata-card', - properties: { - hasMetadata: {type: Boolean, value: false}, - metadata: {type: Array}, - label: String - } -}); - -export class MetadataCard extends MetadataCardPolymer { - hasMetadata: boolean; - metadata: Array<{key: string, value: string}>; - label: string; - - private labelOption: string; - private pointMetadata: PointMetadata; - - private expandLessButton: HTMLButtonElement; - private expandMoreButton: HTMLButtonElement; - - ready() { - this.expandLessButton = - this.querySelector('#expand-less') as HTMLButtonElement; - this.expandMoreButton = - this.querySelector('#expand-more') as HTMLButtonElement; - } - /** Handles a click on the expand more icon. */ - _expandMore() { - (this.$$('#metadata-container') as any).toggle(); - - this.expandMoreButton.style.display = 'none'; - this.expandLessButton.style.display = ''; - } - - /** Handles a click on the expand less icon. */ - _expandLess() { - (this.$$('#metadata-container') as any).toggle(); - this.expandMoreButton.style.display = ''; - this.expandLessButton.style.display = 'none'; - } - - updateMetadata(pointMetadata?: PointMetadata) { - this.pointMetadata = pointMetadata; - this.hasMetadata = (pointMetadata != null); - - if (pointMetadata) { - let metadata = []; - for (let metadataKey in pointMetadata) { - if (!pointMetadata.hasOwnProperty(metadataKey)) { - continue; - } - metadata.push({key: metadataKey, value: pointMetadata[metadataKey]}); - } - - this.metadata = metadata; - this.label = '' + this.pointMetadata[this.labelOption]; - } - } - - setLabelOption(labelOption: string) { - this.labelOption = labelOption; - if (this.pointMetadata) { - this.label = '' + this.pointMetadata[this.labelOption]; - } - } -} - -document.registerElement(MetadataCard.prototype.is, MetadataCard); diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html deleted file mode 100644 index b82f3f520b5..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.html +++ /dev/null @@ -1,316 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts deleted file mode 100644 index 377c6c11ad5..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-projections-panel.ts +++ /dev/null @@ -1,589 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import * as data from './data'; -import {DataSet, Projection, ProjectionType, SpriteAndMetadataInfo, State} from './data'; -import * as util from './util'; -import * as vector from './vector'; -import {Vector} from './vector'; -import {Projector} from './vz-projector'; -import {ProjectorInput} from './vz-projector-input'; -// tslint:disable-next-line:no-unused-variable -import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; - -const NUM_PCA_COMPONENTS = 10; - -// tslint:disable-next-line -export let ProjectionsPanelPolymer = PolymerElement({ - is: 'vz-projector-projections-panel', - properties: { - pcaIs3d: - {type: Boolean, value: true, observer: '_pcaDimensionToggleObserver'}, - tSNEis3d: - {type: Boolean, value: true, observer: '_tsneDimensionToggleObserver'}, - // PCA projection. - pcaComponents: Array, - pcaX: {type: Number, value: 0, observer: 'showPCAIfEnabled'}, - pcaY: {type: Number, value: 1, observer: 'showPCAIfEnabled'}, - pcaZ: {type: Number, value: 2, observer: 'showPCAIfEnabled'}, - // Custom projection. - customSelectedSearchByMetadataOption: { - type: String, - observer: '_customSelectedSearchByMetadataOptionChanged' - }, - } -}); - -type InputControlName = 'xLeft'|'xRight'|'yUp'|'yDown'; - -type CentroidResult = { - centroid?: Vector; numMatches?: number; -}; - -type Centroids = { - [key: string]: Vector; xLeft: Vector; xRight: Vector; yUp: Vector; - yDown: Vector; -}; - -/** - * A polymer component which handles the projection tabs in the projector. - */ -export class ProjectionsPanel extends ProjectionsPanelPolymer { - private projector: Projector; - private pcaComponents: - Array<{id: number, componentNumber: number, percVariance: string}>; - private currentProjection: ProjectionType; - private polymerChangesTriggerReprojection: boolean; - private dataSet: DataSet; - private originalDataSet: DataSet; - private dim: number; - - /** T-SNE perplexity. Roughly how many neighbors each point influences. */ - private perplexity: number; - /** T-SNE learning rate. */ - private learningRate: number; - - private searchByMetadataOptions: string[]; - - /** Centroids for custom projections. */ - private centroidValues: any; - private centroids: Centroids; - /** The centroid across all points. */ - private allCentroid: number[]; - - /** Polymer properties. */ - // TODO(nsthorat): Move these to a separate view controller. - public tSNEis3d: boolean; - public pcaIs3d: boolean; - public pcaX: number; - public pcaY: number; - public pcaZ: number; - public customSelectedSearchByMetadataOption: string; - - /** Polymer elements. */ - private runTsneButton: HTMLButtonElement; - private stopTsneButton: HTMLButtonElement; - private perplexitySlider: HTMLInputElement; - private learningRateInput: HTMLInputElement; - private zDropdown: HTMLElement; - private iterationLabel: HTMLElement; - - private customProjectionXLeftInput: ProjectorInput; - private customProjectionXRightInput: ProjectorInput; - private customProjectionYUpInput: ProjectorInput; - private customProjectionYDownInput: ProjectorInput; - - initialize(projector: Projector) { - this.polymerChangesTriggerReprojection = true; - this.projector = projector; - - // Set up TSNE projections. - this.perplexity = 30; - this.learningRate = 10; - - // Setup Custom projections. - this.centroidValues = {xLeft: null, xRight: null, yUp: null, yDown: null}; - this.clearCentroids(); - - this.setupUIControls(); - } - - ready() { - this.zDropdown = this.querySelector('#z-dropdown') as HTMLElement; - this.runTsneButton = this.querySelector('.run-tsne') as HTMLButtonElement; - this.stopTsneButton = this.querySelector('.stop-tsne') as HTMLButtonElement; - this.perplexitySlider = - this.querySelector('#perplexity-slider') as HTMLInputElement; - this.learningRateInput = - this.querySelector('#learning-rate-slider') as HTMLInputElement; - this.iterationLabel = this.querySelector('.run-tsne-iter') as HTMLElement; - } - - disablePolymerChangesTriggerReprojection() { - this.polymerChangesTriggerReprojection = false; - } - - enablePolymerChangesTriggerReprojection() { - this.polymerChangesTriggerReprojection = true; - } - - private updateTSNEPerplexityFromSliderChange() { - if (this.perplexitySlider) { - this.perplexity = +this.perplexitySlider.value; - } - (this.querySelector('.tsne-perplexity span') as HTMLSpanElement).innerText = - '' + this.perplexity; - } - - private updateTSNELearningRateFromUIChange() { - if (this.learningRateInput) { - this.learningRate = Math.pow(10, +this.learningRateInput.value); - } - (this.querySelector('.tsne-learning-rate span') as HTMLSpanElement) - .innerText = '' + this.learningRate; - } - - private setupUIControls() { - { - const self = this; - const inkTabs = this.querySelectorAll('.ink-tab'); - for (let i = 0; i < inkTabs.length; i++) { - inkTabs[i].addEventListener('click', function() { - let id = this.getAttribute('data-tab'); - self.showTab(id); - }); - } - } - - this.runTsneButton.addEventListener('click', () => this.runTSNE()); - this.stopTsneButton.addEventListener( - 'click', () => this.dataSet.stopTSNE()); - - this.perplexitySlider.value = this.perplexity.toString(); - this.perplexitySlider.addEventListener( - 'change', () => this.updateTSNEPerplexityFromSliderChange()); - this.updateTSNEPerplexityFromSliderChange(); - - this.learningRateInput.addEventListener( - 'change', () => this.updateTSNELearningRateFromUIChange()); - this.updateTSNELearningRateFromUIChange(); - - this.setupCustomProjectionInputFields(); - // TODO: figure out why `--paper-input-container-input` css mixin didn't - // work. - const inputs = - this.querySelectorAll('paper-dropdown-menu paper-input input'); - for (let i = 0; i < inputs.length; i++) { - (inputs[i] as HTMLElement).style.fontSize = '14px'; - } - } - - restoreUIFromBookmark(bookmark: State) { - this.disablePolymerChangesTriggerReprojection(); - - // PCA - this.pcaX = bookmark.pcaComponentDimensions[0]; - this.pcaY = bookmark.pcaComponentDimensions[1]; - if (bookmark.pcaComponentDimensions.length === 3) { - this.pcaZ = bookmark.pcaComponentDimensions[2]; - } - this.pcaIs3d = (bookmark.pcaComponentDimensions.length === 3); - - // t-SNE - if (this.perplexitySlider) { - this.perplexitySlider.value = bookmark.tSNEPerplexity.toString(); - } - if (this.learningRateInput) { - this.learningRateInput.value = bookmark.tSNELearningRate.toString(); - } - this.tSNEis3d = bookmark.tSNEis3d; - - // custom - this.customSelectedSearchByMetadataOption = - bookmark.customSelectedSearchByMetadataOption; - if (this.customProjectionXLeftInput) { - this.customProjectionXLeftInput.set( - bookmark.customXLeftText, bookmark.customXLeftRegex); - } - if (this.customProjectionXRightInput) { - this.customProjectionXRightInput.set( - bookmark.customXRightText, bookmark.customXRightRegex); - } - if (this.customProjectionYUpInput) { - this.customProjectionYUpInput.set( - bookmark.customYUpText, bookmark.customYUpRegex); - } - if (this.customProjectionYDownInput) { - this.customProjectionYDownInput.set( - bookmark.customYDownText, bookmark.customYDownRegex); - } - this.computeAllCentroids(); - - this.setZDropdownEnabled(this.pcaIs3d); - this.updateTSNEPerplexityFromSliderChange(); - this.updateTSNELearningRateFromUIChange(); - if (this.iterationLabel) { - this.iterationLabel.innerText = bookmark.tSNEIteration.toString(); - } - if (bookmark.selectedProjection != null) { - this.showTab(bookmark.selectedProjection); - } - this.enablePolymerChangesTriggerReprojection(); - } - - populateBookmarkFromUI(bookmark: State) { - this.disablePolymerChangesTriggerReprojection(); - - // PCA - bookmark.pcaComponentDimensions = [this.pcaX, this.pcaY]; - if (this.pcaIs3d) { - bookmark.pcaComponentDimensions.push(this.pcaZ); - } - - // t-SNE - if (this.perplexitySlider != null) { - bookmark.tSNEPerplexity = +this.perplexitySlider.value; - } - if (this.learningRateInput != null) { - bookmark.tSNELearningRate = +this.learningRateInput.value; - } - bookmark.tSNEis3d = this.tSNEis3d; - - // custom - bookmark.customSelectedSearchByMetadataOption = - this.customSelectedSearchByMetadataOption; - if (this.customProjectionXLeftInput != null) { - bookmark.customXLeftText = this.customProjectionXLeftInput.getValue(); - bookmark.customXLeftRegex = - this.customProjectionXLeftInput.getInRegexMode(); - } - if (this.customProjectionXRightInput != null) { - bookmark.customXRightText = this.customProjectionXRightInput.getValue(); - bookmark.customXRightRegex = - this.customProjectionXRightInput.getInRegexMode(); - } - if (this.customProjectionYUpInput != null) { - bookmark.customYUpText = this.customProjectionYUpInput.getValue(); - bookmark.customYUpRegex = this.customProjectionYUpInput.getInRegexMode(); - } - if (this.customProjectionYDownInput != null) { - bookmark.customYDownText = this.customProjectionYDownInput.getValue(); - bookmark.customYDownRegex = - this.customProjectionYDownInput.getInRegexMode(); - } - - this.enablePolymerChangesTriggerReprojection(); - } - - // This method is marked as public as it is used as the view method that - // abstracts DOM manipulation so we can stub it in a test. - // TODO(nsthorat): Move this to its own class as the glue between this class - // and the DOM. - setZDropdownEnabled(enabled: boolean) { - if (this.zDropdown) { - if (enabled) { - this.zDropdown.removeAttribute('disabled'); - } else { - this.zDropdown.setAttribute('disabled', 'true'); - } - } - } - - dataSetUpdated(dataSet: DataSet, originalDataSet: DataSet, dim: number) { - this.dataSet = dataSet; - this.originalDataSet = originalDataSet; - this.dim = dim; - const pointCount = (dataSet == null) ? 0 : dataSet.points.length; - const perplexity = Math.max(5, Math.ceil(Math.sqrt(pointCount) / 4)); - this.perplexitySlider.value = perplexity.toString(); - this.updateTSNEPerplexityFromSliderChange(); - this.clearCentroids(); - - (this.querySelector('#tsne-sampling') as HTMLElement).style.display = - pointCount > data.TSNE_SAMPLE_SIZE ? null : 'none'; - const wasSampled = - (dataSet == null) ? false : (dataSet.dim[0] > data.PCA_SAMPLE_DIM || - dataSet.dim[1] > data.PCA_SAMPLE_DIM); - (this.querySelector('#pca-sampling') as HTMLElement).style.display = - wasSampled ? null : 'none'; - this.showTab('pca'); - } - - _pcaDimensionToggleObserver() { - this.setZDropdownEnabled(this.pcaIs3d); - this.beginProjection(this.currentProjection); - } - - _tsneDimensionToggleObserver() { - this.beginProjection(this.currentProjection); - } - - metadataChanged(spriteAndMetadata: SpriteAndMetadataInfo) { - // Project by options for custom projections. - let searchByMetadataIndex = -1; - this.searchByMetadataOptions = spriteAndMetadata.stats.map((stats, i) => { - // Make the default label by the first non-numeric column. - if (!stats.isNumeric && searchByMetadataIndex === -1) { - searchByMetadataIndex = i; - } - return stats.name; - }); - this.customSelectedSearchByMetadataOption = - this.searchByMetadataOptions[Math.max(0, searchByMetadataIndex)]; - } - - public showTab(id: ProjectionType) { - this.currentProjection = id; - - const tab = - this.querySelector('.ink-tab[data-tab="' + id + '"]') as HTMLElement; - const allTabs = this.querySelectorAll('.ink-tab'); - for (let i = 0; i < allTabs.length; i++) { - util.classed(allTabs[i] as HTMLElement, 'active', false); - } - - util.classed(tab, 'active', true); - - const allTabContent = this.querySelectorAll('.ink-panel-content'); - for (let i = 0; i < allTabContent.length; i++) { - util.classed(allTabContent[i] as HTMLElement, 'active', false); - } - - util.classed( - this.querySelector('.ink-panel-content[data-panel="' + id + '"]') as - HTMLElement, - 'active', true); - - // guard for unit tests, where polymer isn't attached and $ doesn't exist. - if (this.$ != null) { - const main = this.$['main']; - // In order for the projections panel to animate its height, we need to - // set it explicitly. - requestAnimationFrame(() => { - this.style.height = main.clientHeight + 'px'; - }); - } - - this.beginProjection(id); - } - - private beginProjection(projection: ProjectionType) { - if (this.polymerChangesTriggerReprojection === false) { - return; - } - if (projection === 'pca') { - if (this.dataSet != null) { - this.dataSet.stopTSNE(); - } - this.showPCA(); - } else if (projection === 'tsne') { - this.showTSNE(); - } else if (projection === 'custom') { - if (this.dataSet != null) { - this.dataSet.stopTSNE(); - } - this.computeAllCentroids(); - this.reprojectCustom(); - } - } - - private showTSNE() { - const dataSet = this.dataSet; - if (dataSet == null) { - return; - } - const accessors = - data.getProjectionComponents('tsne', [0, 1, this.tSNEis3d ? 2 : null]); - const dimensionality = this.tSNEis3d ? 3 : 2; - const projection = - new Projection('tsne', accessors, dimensionality, dataSet); - this.projector.setProjection(projection); - - if (!this.dataSet.hasTSNERun) { - this.runTSNE(); - } else { - this.projector.notifyProjectionPositionsUpdated(); - } - } - - private runTSNE() { - this.runTsneButton.disabled = true; - this.stopTsneButton.disabled = null; - this.dataSet.projectTSNE( - this.perplexity, this.learningRate, this.tSNEis3d ? 3 : 2, - (iteration: number) => { - if (iteration != null) { - this.iterationLabel.innerText = '' + iteration; - this.projector.notifyProjectionPositionsUpdated(); - } else { - this.runTsneButton.disabled = null; - this.stopTsneButton.disabled = true; - } - }); - } - - // tslint:disable-next-line:no-unused-variable - private showPCAIfEnabled() { - if (this.polymerChangesTriggerReprojection) { - this.showPCA(); - } - } - - private updateTotalVarianceMessage() { - let variances = this.dataSet.fracVariancesExplained; - let totalVariance = variances[this.pcaX] + variances[this.pcaY]; - let msg = 'Total variance described: '; - if (this.pcaIs3d) { - totalVariance += variances[this.pcaZ]; - } - msg += (totalVariance * 100).toFixed(1) + '%.'; - (this.querySelector('#total-variance') as HTMLElement).innerHTML = msg; - } - - private showPCA() { - if (this.dataSet == null) { - return; - } - this.dataSet.projectPCA().then(() => { - // Polymer properties are 1-based. - const accessors = data.getProjectionComponents( - 'pca', [this.pcaX, this.pcaY, this.pcaZ]); - - const dimensionality = this.pcaIs3d ? 3 : 2; - const projection = - new Projection('pca', accessors, dimensionality, this.dataSet); - this.projector.setProjection(projection); - let numComponents = Math.min(NUM_PCA_COMPONENTS, this.dataSet.dim[1]); - this.updateTotalVarianceMessage(); - this.pcaComponents = util.range(numComponents).map(i => { - let fracVariance = this.dataSet.fracVariancesExplained[i]; - return { - id: i, - componentNumber: i + 1, - percVariance: (fracVariance * 100).toFixed(1) - }; - }); - }); - } - - private reprojectCustom() { - if (this.centroids == null || this.centroids.xLeft == null || - this.centroids.xRight == null || this.centroids.yUp == null || - this.centroids.yDown == null) { - return; - } - const xDir = vector.sub(this.centroids.xRight, this.centroids.xLeft); - this.dataSet.projectLinear(xDir, 'linear-x'); - - const yDir = vector.sub(this.centroids.yUp, this.centroids.yDown); - this.dataSet.projectLinear(yDir, 'linear-y'); - - const accessors = data.getProjectionComponents('custom', ['x', 'y']); - const projection = new Projection('custom', accessors, 2, this.dataSet); - this.projector.setProjection(projection); - } - - clearCentroids(): void { - this.centroids = {xLeft: null, xRight: null, yUp: null, yDown: null}; - this.allCentroid = null; - } - - _customSelectedSearchByMetadataOptionChanged(newVal: string, oldVal: string) { - if (this.polymerChangesTriggerReprojection === false) { - return; - } - if (this.currentProjection === 'custom') { - this.computeAllCentroids(); - this.reprojectCustom(); - } - } - - private setupCustomProjectionInputFields() { - this.customProjectionXLeftInput = - this.setupCustomProjectionInputField('xLeft'); - this.customProjectionXRightInput = - this.setupCustomProjectionInputField('xRight'); - this.customProjectionYUpInput = this.setupCustomProjectionInputField('yUp'); - this.customProjectionYDownInput = - this.setupCustomProjectionInputField('yDown'); - } - - private computeAllCentroids() { - this.computeCentroid('xLeft'); - this.computeCentroid('xRight'); - this.computeCentroid('yUp'); - this.computeCentroid('yDown'); - } - - private computeCentroid(name: InputControlName) { - const input = this.querySelector('#' + name) as ProjectorInput; - if (input == null) { - return; - } - const value = input.getValue(); - if (value == null) { - return; - } - let inRegexMode = input.getInRegexMode(); - let result = this.getCentroid(value, inRegexMode); - if (result.numMatches === 0) { - input.message = '0 matches. Using a random vector.'; - result.centroid = vector.rn(this.dim); - } else { - input.message = `${result.numMatches} matches.`; - } - this.centroids[name] = result.centroid; - this.centroidValues[name] = value; - } - - private setupCustomProjectionInputField(name: InputControlName): - ProjectorInput { - let input = this.querySelector('#' + name) as ProjectorInput; - input.registerInputChangedListener((input, inRegexMode) => { - if (this.polymerChangesTriggerReprojection) { - this.computeCentroid(name); - this.reprojectCustom(); - } - }); - return input; - } - - private getCentroid(pattern: string, inRegexMode: boolean): CentroidResult { - if (pattern == null || pattern === '') { - return {numMatches: 0}; - } - // Search by the original dataset since we often want to filter and project - // only the nearest neighbors of A onto B-C where B and C are not nearest - // neighbors of A. - let accessor = (i: number) => this.originalDataSet.points[i].vector; - let r = this.originalDataSet.query( - pattern, inRegexMode, this.customSelectedSearchByMetadataOption); - return {centroid: vector.centroid(r, accessor), numMatches: r.length}; - } - - getPcaSampledDimText() { - return data.PCA_SAMPLE_DIM.toLocaleString(); - } - - getPcaSampleSizeText() { - return data.PCA_SAMPLE_SIZE.toLocaleString(); - } - - getTsneSampleSizeText() { - return data.TSNE_SAMPLE_SIZE.toLocaleString(); - } -} - -document.registerElement(ProjectionsPanel.prototype.is, ProjectionsPanel); diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector-util.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector-util.ts deleted file mode 100644 index 44062062a36..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector-util.ts +++ /dev/null @@ -1,34 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -export type Spec = { - is: string; properties?: { - [key: string]: - (Function | - { - type: Function, value?: any; - readonly?: boolean; - notify?: boolean; - observer?: string; - }) - }; - observers?: string[]; -}; - -export function PolymerElement(spec: Spec) { - return Polymer.Class(spec as any) as{new (): PolymerHTMLElement}; -} - -export interface PolymerHTMLElement extends HTMLElement, polymer.Base {} diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector.html b/tensorflow/tensorboard/components/vz_projector/vz-projector.html deleted file mode 100644 index 438ea9f4e97..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector.html +++ /dev/null @@ -1,346 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_projector/vz-projector.ts b/tensorflow/tensorboard/components/vz_projector/vz-projector.ts deleted file mode 100644 index bf98a4d4785..00000000000 --- a/tensorflow/tensorboard/components/vz_projector/vz-projector.ts +++ /dev/null @@ -1,570 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {AnalyticsLogger} from './analyticsLogger'; -import * as data from './data'; -import {ColorOption, ColumnStats, DataPoint, DataProto, DataSet, DistanceFunction, PointMetadata, Projection, SpriteAndMetadataInfo, State, stateGetAccessorDimensions} from './data'; -import {DataProvider, EmbeddingInfo, ServingMode} from './data-provider'; -import {DemoDataProvider} from './data-provider-demo'; -import {ProtoDataProvider} from './data-provider-proto'; -import {ServerDataProvider} from './data-provider-server'; -import * as knn from './knn'; -import * as logging from './logging'; -import {DistanceMetricChangedListener, HoverListener, ProjectionChangedListener, ProjectorEventContext, SelectionChangedListener} from './projectorEventContext'; -import {ProjectorScatterPlotAdapter} from './projectorScatterPlotAdapter'; -import {MouseMode} from './scatterPlot'; -import * as util from './util'; -import {BookmarkPanel} from './vz-projector-bookmark-panel'; -import {DataPanel} from './vz-projector-data-panel'; -import {InspectorPanel} from './vz-projector-inspector-panel'; -import {MetadataCard} from './vz-projector-metadata-card'; -import {ProjectionsPanel} from './vz-projector-projections-panel'; -// tslint:disable-next-line:no-unused-variable -import {PolymerElement, PolymerHTMLElement} from './vz-projector-util'; - -/** - * The minimum number of dimensions the data should have to automatically - * decide to normalize the data. - */ -const THRESHOLD_DIM_NORMALIZE = 50; -const POINT_COLOR_MISSING = 'black'; - -export let ProjectorPolymer = PolymerElement({ - is: 'vz-projector', - properties: { - routePrefix: String, - dataProto: {type: String, observer: '_dataProtoChanged'}, - servingMode: String, - projectorConfigJsonPath: String, - pageViewLogging: Boolean, - eventLogging: Boolean - } -}); - -const INDEX_METADATA_FIELD = '__index__'; - -export class Projector extends ProjectorPolymer implements - ProjectorEventContext { - // The working subset of the data source's original data set. - dataSet: DataSet; - servingMode: ServingMode; - // The path to the projector config JSON file for demo mode. - projectorConfigJsonPath: string; - - private selectionChangedListeners: SelectionChangedListener[]; - private hoverListeners: HoverListener[]; - private projectionChangedListeners: ProjectionChangedListener[]; - private distanceMetricChangedListeners: DistanceMetricChangedListener[]; - - private originalDataSet: DataSet; - private dataSetBeforeFilter: DataSet; - private projectorScatterPlotAdapter: ProjectorScatterPlotAdapter; - private dim: number; - - private dataSetFilterIndices: number[]; - private selectedPointIndices: number[]; - private neighborsOfFirstPoint: knn.NearestEntry[]; - private hoverPointIndex: number; - - private dataProvider: DataProvider; - private inspectorPanel: InspectorPanel; - - private selectedColorOption: ColorOption; - private selectedLabelOption: string; - private routePrefix: string; - private normalizeData: boolean; - private projection: Projection; - - /** Polymer component panels */ - private dataPanel: DataPanel; - private bookmarkPanel: BookmarkPanel; - private projectionsPanel: ProjectionsPanel; - private metadataCard: MetadataCard; - - private statusBar: HTMLDivElement; - private analyticsLogger: AnalyticsLogger; - private eventLogging: boolean; - private pageViewLogging: boolean; - - ready() { - logging.setDomContainer(this); - - this.analyticsLogger = - new AnalyticsLogger(this.pageViewLogging, this.eventLogging); - this.analyticsLogger.logPageView('embeddings'); - - if (!util.hasWebGLSupport()) { - this.analyticsLogger.logWebGLDisabled(); - logging.setErrorMessage( - 'Your browser or device does not have WebGL enabled. Please enable ' + - 'hardware acceleration, or use a browser that supports WebGL.'); - return; - } - - this.selectionChangedListeners = []; - this.hoverListeners = []; - this.projectionChangedListeners = []; - this.distanceMetricChangedListeners = []; - this.selectedPointIndices = []; - this.neighborsOfFirstPoint = []; - - this.dataPanel = this.$['data-panel'] as DataPanel; - this.inspectorPanel = this.$['inspector-panel'] as InspectorPanel; - this.inspectorPanel.initialize(this, this as ProjectorEventContext); - this.projectionsPanel = this.$['projections-panel'] as ProjectionsPanel; - this.projectionsPanel.initialize(this); - this.bookmarkPanel = this.$['bookmark-panel'] as BookmarkPanel; - this.bookmarkPanel.initialize(this, this as ProjectorEventContext); - this.metadataCard = this.$['metadata-card'] as MetadataCard; - this.statusBar = this.querySelector('#status-bar') as HTMLDivElement; - this.scopeSubtree(this.$$('#notification-dialog'), true); - this.setupUIControls(); - this.initializeDataProvider(); - } - - setSelectedLabelOption(labelOption: string) { - this.selectedLabelOption = labelOption; - this.metadataCard.setLabelOption(this.selectedLabelOption); - this.projectorScatterPlotAdapter.setLabelPointAccessor(labelOption); - this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); - this.projectorScatterPlotAdapter.render(); - } - - setSelectedColorOption(colorOption: ColorOption) { - this.selectedColorOption = colorOption; - this.projectorScatterPlotAdapter.setLegendPointColorer( - this.getLegendPointColorer(colorOption)); - this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); - this.projectorScatterPlotAdapter.render(); - } - - setNormalizeData(normalizeData: boolean) { - this.normalizeData = normalizeData; - this.setCurrentDataSet(this.originalDataSet.getSubset()); - } - - updateDataSet( - ds: DataSet, spriteAndMetadata?: SpriteAndMetadataInfo, - metadataFile?: string) { - this.dataSetFilterIndices = null; - this.originalDataSet = ds; - if (ds != null) { - this.normalizeData = - this.originalDataSet.dim[1] >= THRESHOLD_DIM_NORMALIZE; - spriteAndMetadata = spriteAndMetadata || {}; - if (spriteAndMetadata.pointsInfo == null) { - let [pointsInfo, stats] = this.makeDefaultPointsInfoAndStats(ds.points); - spriteAndMetadata.pointsInfo = pointsInfo; - spriteAndMetadata.stats = stats; - } - let metadataMergeSucceeded = ds.mergeMetadata(spriteAndMetadata); - if (!metadataMergeSucceeded) { - return; - } - } - if (this.projectorScatterPlotAdapter != null) { - if (ds == null) { - this.projectorScatterPlotAdapter.setLabelPointAccessor(null); - this.setProjection(null); - } else { - this.projectorScatterPlotAdapter.updateScatterPlotPositions(); - this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); - this.projectorScatterPlotAdapter.resize(); - this.projectorScatterPlotAdapter.render(); - } - } - if (ds != null) { - this.dataPanel.setNormalizeData(this.normalizeData); - this.setCurrentDataSet(ds.getSubset()); - this.projectorScatterPlotAdapter.setLabelPointAccessor( - this.selectedLabelOption); - this.inspectorPanel.datasetChanged(); - - this.inspectorPanel.metadataChanged(spriteAndMetadata); - this.projectionsPanel.metadataChanged(spriteAndMetadata); - this.dataPanel.metadataChanged(spriteAndMetadata, metadataFile); - // Set the container to a fixed height, otherwise in Colab the - // height can grow indefinitely. - const container = this.querySelector('#container') as HTMLDivElement; - container.style.height = container.clientHeight + 'px'; - } else { - this.setCurrentDataSet(null); - } - } - - setSelectedTensor(run: string, tensorInfo: EmbeddingInfo) { - this.bookmarkPanel.setSelectedTensor(run, tensorInfo, this.dataProvider); - } - - /** - * Registers a listener to be called any time the selected point set changes. - */ - registerSelectionChangedListener(listener: SelectionChangedListener) { - this.selectionChangedListeners.push(listener); - } - - filterDataset(pointIndices: number[]) { - const selectionSize = this.selectedPointIndices.length; - if (this.dataSetBeforeFilter == null) { - this.dataSetBeforeFilter = this.dataSet; - } - this.setCurrentDataSet(this.dataSet.getSubset(pointIndices)); - this.dataSetFilterIndices = pointIndices; - this.projectorScatterPlotAdapter.updateScatterPlotPositions(); - this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); - this.adjustSelectionAndHover(util.range(selectionSize)); - } - - resetFilterDataset() { - const originalPointIndices = this.selectedPointIndices.map( - filteredIndex => this.dataSet.points[filteredIndex].index); - this.setCurrentDataSet(this.dataSetBeforeFilter); - if (this.projection != null) { - this.projection.dataSet = this.dataSetBeforeFilter; - } - this.dataSetBeforeFilter = null; - this.projectorScatterPlotAdapter.updateScatterPlotPositions(); - this.projectorScatterPlotAdapter.updateScatterPlotAttributes(); - this.dataSetFilterIndices = []; - this.adjustSelectionAndHover(originalPointIndices); - } - - /** - * Used by clients to indicate that a selection has occurred. - */ - notifySelectionChanged(newSelectedPointIndices: number[]) { - this.selectedPointIndices = newSelectedPointIndices; - let neighbors: knn.NearestEntry[] = []; - - if (newSelectedPointIndices.length === 1) { - neighbors = this.dataSet.findNeighbors( - newSelectedPointIndices[0], this.inspectorPanel.distFunc, - this.inspectorPanel.numNN); - this.metadataCard.updateMetadata( - this.dataSet.points[newSelectedPointIndices[0]].metadata); - } else { - this.metadataCard.updateMetadata(null); - } - - this.selectionChangedListeners.forEach( - l => l(this.selectedPointIndices, neighbors)); - } - - /** - * Registers a listener to be called any time the mouse hovers over a point. - */ - registerHoverListener(listener: HoverListener) { - this.hoverListeners.push(listener); - } - - /** - * Used by clients to indicate that a hover is occurring. - */ - notifyHoverOverPoint(pointIndex: number) { - this.hoverListeners.forEach(l => l(pointIndex)); - } - - registerProjectionChangedListener(listener: ProjectionChangedListener) { - this.projectionChangedListeners.push(listener); - } - - notifyProjectionChanged(projection: Projection) { - this.projectionChangedListeners.forEach(l => l(projection)); - } - - registerDistanceMetricChangedListener(l: DistanceMetricChangedListener) { - this.distanceMetricChangedListeners.push(l); - } - - notifyDistanceMetricChanged(distMetric: DistanceFunction) { - this.distanceMetricChangedListeners.forEach(l => l(distMetric)); - } - - _dataProtoChanged(dataProtoString: string) { - let dataProto = - dataProtoString ? JSON.parse(dataProtoString) as DataProto : null; - this.initializeDataProvider(dataProto); - } - - private makeDefaultPointsInfoAndStats(points: DataPoint[]): - [PointMetadata[], ColumnStats[]] { - let pointsInfo: PointMetadata[] = []; - points.forEach(p => { - let pointInfo: PointMetadata = {}; - pointInfo[INDEX_METADATA_FIELD] = p.index; - pointsInfo.push(pointInfo); - }); - let stats: ColumnStats[] = [{ - name: INDEX_METADATA_FIELD, - isNumeric: false, - tooManyUniqueValues: true, - min: 0, - max: pointsInfo.length - 1 - }]; - return [pointsInfo, stats]; - } - - private initializeDataProvider(dataProto?: DataProto) { - if (this.servingMode === 'demo') { - let projectorConfigUrl: string; - - // Only in demo mode do we allow the config being passed via URL. - let urlParams = util.getURLParams(window.location.search); - if ('config' in urlParams) { - projectorConfigUrl = urlParams['config']; - } else { - projectorConfigUrl = this.projectorConfigJsonPath; - } - this.dataProvider = new DemoDataProvider(projectorConfigUrl); - } else if (this.servingMode === 'server') { - if (!this.routePrefix) { - throw 'route-prefix is a required parameter'; - } - this.dataProvider = new ServerDataProvider(this.routePrefix); - } else if (this.servingMode === 'proto' && dataProto != null) { - this.dataProvider = new ProtoDataProvider(dataProto); - } - - this.dataPanel.initialize(this, this.dataProvider); - } - - private getLegendPointColorer(colorOption: ColorOption): - (ds: DataSet, index: number) => string { - if ((colorOption == null) || (colorOption.map == null)) { - return null; - } - const colorer = (ds: DataSet, i: number) => { - let value = ds.points[i].metadata[this.selectedColorOption.name]; - if (value == null) { - return POINT_COLOR_MISSING; - } - return colorOption.map(value); - }; - return colorer; - } - - private get3DLabelModeButton(): any { - return this.querySelector('#labels3DMode'); - } - - private get3DLabelMode(): boolean { - const label3DModeButton = this.get3DLabelModeButton(); - return (label3DModeButton as any).active; - } - - adjustSelectionAndHover(selectedPointIndices: number[], hoverIndex?: number) { - this.notifySelectionChanged(selectedPointIndices); - this.notifyHoverOverPoint(hoverIndex); - this.setMouseMode(MouseMode.CAMERA_AND_CLICK_SELECT); - } - - private setMouseMode(mouseMode: MouseMode) { - let selectModeButton = this.querySelector('#selectMode'); - (selectModeButton as any).active = (mouseMode === MouseMode.AREA_SELECT); - this.projectorScatterPlotAdapter.scatterPlot.setMouseMode(mouseMode); - } - - private setCurrentDataSet(ds: DataSet) { - this.adjustSelectionAndHover([]); - if (this.dataSet != null) { - this.dataSet.stopTSNE(); - } - if ((ds != null) && this.normalizeData) { - ds.normalize(); - } - this.dim = (ds == null) ? 0 : ds.dim[1]; - (this.querySelector('span.numDataPoints') as HTMLSpanElement).innerText = - (ds == null) ? '0' : '' + ds.dim[0]; - (this.querySelector('span.dim') as HTMLSpanElement).innerText = - (ds == null) ? '0' : '' + ds.dim[1]; - - this.dataSet = ds; - - this.projectionsPanel.dataSetUpdated( - this.dataSet, this.originalDataSet, this.dim); - - this.projectorScatterPlotAdapter.setDataSet(this.dataSet); - this.projectorScatterPlotAdapter.scatterPlot - .setCameraParametersForNextCameraCreation(null, true); - } - - private setupUIControls() { - // View controls - this.querySelector('#reset-zoom').addEventListener('click', () => { - this.projectorScatterPlotAdapter.scatterPlot.resetZoom(); - this.projectorScatterPlotAdapter.scatterPlot.startOrbitAnimation(); - }); - - let selectModeButton = this.querySelector('#selectMode'); - selectModeButton.addEventListener('click', (event) => { - this.setMouseMode( - (selectModeButton as any).active ? MouseMode.AREA_SELECT : - MouseMode.CAMERA_AND_CLICK_SELECT); - }); - let nightModeButton = this.querySelector('#nightDayMode'); - nightModeButton.addEventListener('click', () => { - this.projectorScatterPlotAdapter.scatterPlot.setDayNightMode( - (nightModeButton as any).active); - }); - - const labels3DModeButton = this.get3DLabelModeButton(); - labels3DModeButton.addEventListener('click', () => { - this.projectorScatterPlotAdapter.set3DLabelMode(this.get3DLabelMode()); - }); - - window.addEventListener('resize', () => { - const container = this.querySelector('#container') as HTMLDivElement; - const parentHeight = (container.parentNode as HTMLElement).clientHeight; - container.style.height = parentHeight + 'px'; - this.projectorScatterPlotAdapter.resize(); - }); - - { - this.projectorScatterPlotAdapter = new ProjectorScatterPlotAdapter( - this.getScatterContainer(), this as ProjectorEventContext); - this.projectorScatterPlotAdapter.setLabelPointAccessor( - this.selectedLabelOption); - } - - this.projectorScatterPlotAdapter.scatterPlot.onCameraMove( - (cameraPosition: THREE.Vector3, cameraTarget: THREE.Vector3) => - this.bookmarkPanel.clearStateSelection()); - - this.registerHoverListener( - (hoverIndex: number) => this.onHover(hoverIndex)); - - this.registerSelectionChangedListener( - (selectedPointIndices: number[], - neighborsOfFirstPoint: knn.NearestEntry[]) => - this.onSelectionChanged( - selectedPointIndices, neighborsOfFirstPoint)); - } - - private onHover(hoverIndex: number) { - this.hoverPointIndex = hoverIndex; - let hoverText = null; - if (hoverIndex != null) { - const point = this.dataSet.points[hoverIndex]; - if (point.metadata[this.selectedLabelOption]) { - hoverText = point.metadata[this.selectedLabelOption].toString(); - } - } - if (this.selectedPointIndices.length === 0) { - this.statusBar.style.display = hoverText ? null : 'none'; - this.statusBar.innerText = hoverText; - } - } - - private getScatterContainer(): HTMLDivElement { - return this.querySelector('#scatter') as HTMLDivElement; - } - - private onSelectionChanged( - selectedPointIndices: number[], - neighborsOfFirstPoint: knn.NearestEntry[]) { - this.selectedPointIndices = selectedPointIndices; - this.neighborsOfFirstPoint = neighborsOfFirstPoint; - let totalNumPoints = - this.selectedPointIndices.length + neighborsOfFirstPoint.length; - this.statusBar.innerText = `Selected ${totalNumPoints} points`; - this.statusBar.style.display = totalNumPoints > 0 ? null : 'none'; - } - - setProjection(projection: Projection) { - this.projection = projection; - if (projection != null) { - this.analyticsLogger.logProjectionChanged(projection.projectionType); - } - this.notifyProjectionChanged(projection); - } - - notifyProjectionPositionsUpdated() { - this.projectorScatterPlotAdapter.notifyProjectionPositionsUpdated(); - } - - /** - * Gets the current view of the embedding and saves it as a State object. - */ - getCurrentState(): State { - const state = new State(); - - // Save the individual datapoint projections. - state.projections = []; - for (let i = 0; i < this.dataSet.points.length; i++) { - const point = this.dataSet.points[i]; - const projections: {[key: string]: number} = {}; - const keys = Object.keys(point.projections); - for (let j = 0; j < keys.length; ++j) { - projections[keys[j]] = point.projections[keys[j]]; - } - state.projections.push(projections); - } - state.selectedProjection = this.projection.projectionType; - state.dataSetDimensions = this.dataSet.dim; - state.tSNEIteration = this.dataSet.tSNEIteration; - state.selectedPoints = this.selectedPointIndices; - state.filteredPoints = this.dataSetFilterIndices; - this.projectorScatterPlotAdapter.populateBookmarkFromUI(state); - state.selectedColorOptionName = this.dataPanel.selectedColorOptionName; - state.forceCategoricalColoring = this.dataPanel.forceCategoricalColoring; - state.selectedLabelOption = this.selectedLabelOption; - this.projectionsPanel.populateBookmarkFromUI(state); - return state; - } - - /** Loads a State object into the world. */ - loadState(state: State) { - this.setProjection(null); - { - this.projectionsPanel.disablePolymerChangesTriggerReprojection(); - if (this.dataSetBeforeFilter != null) { - this.resetFilterDataset(); - } - if (state.filteredPoints != null) { - this.filterDataset(state.filteredPoints); - } - this.projectionsPanel.enablePolymerChangesTriggerReprojection(); - } - for (let i = 0; i < state.projections.length; i++) { - const point = this.dataSet.points[i]; - const projection = state.projections[i]; - const keys = Object.keys(projection); - for (let j = 0; j < keys.length; ++j) { - point.projections[keys[j]] = projection[keys[j]]; - } - } - this.dataSet.hasTSNERun = (state.selectedProjection === 'tsne'); - this.dataSet.tSNEIteration = state.tSNEIteration; - this.projectionsPanel.restoreUIFromBookmark(state); - this.inspectorPanel.restoreUIFromBookmark(state); - this.dataPanel.selectedColorOptionName = state.selectedColorOptionName; - this.dataPanel.setForceCategoricalColoring( - !!state.forceCategoricalColoring); - this.selectedLabelOption = state.selectedLabelOption; - this.projectorScatterPlotAdapter.restoreUIFromBookmark(state); - { - const dimensions = stateGetAccessorDimensions(state); - const components = - data.getProjectionComponents(state.selectedProjection, dimensions); - const projection = new Projection( - state.selectedProjection, components, dimensions.length, - this.dataSet); - this.setProjection(projection); - } - this.notifySelectionChanged(state.selectedPoints); - } -} - -document.registerElement(Projector.prototype.is, Projector); diff --git a/tensorflow/tensorboard/components/vz_sorting/BUILD b/tensorflow/tensorboard/components/vz_sorting/BUILD deleted file mode 100644 index e06b8ae1979..00000000000 --- a/tensorflow/tensorboard/components/vz_sorting/BUILD +++ /dev/null @@ -1,28 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("//tensorflow/tensorboard/defs:defs.bzl", "tensorboard_webcomponent_library") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "vz_sorting", - srcs = [ - "sorting.ts", - "vz-sorting.html", - ], - path = "/vz-sorting", - visibility = ["//visibility:public"], -) - -tensorboard_webcomponent_library( - name = "legacy", - srcs = [":vz_sorting"], - destdir = "vz-sorting", -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/vz_sorting/sorting.ts b/tensorflow/tensorboard/components/vz_sorting/sorting.ts deleted file mode 100644 index 061184d24bf..00000000000 --- a/tensorflow/tensorboard/components/vz_sorting/sorting.ts +++ /dev/null @@ -1,107 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -/** - * Compares tag names asciinumerically broken into components. - * - *

This is the comparison function used for sorting most string values in - * TensorBoard. Unlike the standard asciibetical comparator, this function - * knows that 'a10b' > 'a2b'. Fixed point and engineering notation are - * supported. This function also splits the input by slash and underscore to - * perform array comparison. Therefore it knows that 'a/a' < 'a+/a' even - * though '+' < '/' in the ASCII table. - */ -export function compareTagNames(a, b: string): number { - let ai = 0; - let bi = 0; - while (true) { - if (ai === a.length) { - return bi === b.length ? 0 : -1; - } - if (bi === b.length) { - return 1; - } - if (isDigit(a[ai]) && isDigit(b[bi])) { - const ais = ai; - const bis = bi; - ai = consumeNumber(a, ai + 1); - bi = consumeNumber(b, bi + 1); - const an = parseFloat(a.slice(ais, ai)); - const bn = parseFloat(b.slice(bis, bi)); - if (an < bn) { - return -1; - } - if (an > bn) { - return 1; - } - continue; - } - if (isBreak(a[ai])) { - if (!isBreak(b[bi])) { - return -1; - } - } else if (isBreak(b[bi])) { - return 1; - } else if (a[ai] < b[bi]) { - return -1; - } else if (a[ai] > b[bi]) { - return 1; - } - ai++; - bi++; - } -} - -function consumeNumber(s: string, i: number): number { - enum State { NATURAL, REAL, EXPONENT_SIGN, EXPONENT } - let state = State.NATURAL; - for (; i < s.length; i++) { - if (state === State.NATURAL) { - if (s[i] === '.') { - state = State.REAL; - } else if (s[i] === 'e' || s[i] === 'E') { - state = State.EXPONENT_SIGN; - } else if (!isDigit(s[i])) { - break; - } - } else if (state === State.REAL) { - if (s[i] === 'e' || s[i] === 'E') { - state = State.EXPONENT_SIGN; - } else if (!isDigit(s[i])) { - break; - } - } else if (state === State.EXPONENT_SIGN) { - if (isDigit(s[i]) || s[i] === '+' || s[i] === '-') { - state = State.EXPONENT; - } else { - break; - } - } else if (state === State.EXPONENT) { - if (!isDigit(s[i])) { - break; - } - } - } - return i; -} - -function isDigit(c: string): boolean { - return '0' <= c && c <= '9'; -} - -function isBreak(c: string): boolean { - // TODO(jart): Remove underscore when people stop using it like a slash. - return c === '/' || c === '_' || isDigit(c); -} diff --git a/tensorflow/tensorboard/components/vz_sorting/test/BUILD b/tensorflow/tensorboard/components/vz_sorting/test/BUILD deleted file mode 100644 index 929e80d3728..00000000000 --- a/tensorflow/tensorboard/components/vz_sorting/test/BUILD +++ /dev/null @@ -1,37 +0,0 @@ -package( - default_testonly = True, - default_visibility = ["//tensorflow/tensorboard:internal"], -) - -load("//tensorflow/tensorboard/defs:vulcanize.bzl", "tensorboard_html_binary") -load("//tensorflow/tensorboard/defs:web.bzl", "ts_web_library") - -licenses(["notice"]) # Apache 2.0 - -ts_web_library( - name = "test", - srcs = [ - "sortingTests.ts", - "tests.html", - ], - path = "/vz-sorting/test", - deps = [ - "//tensorflow/tensorboard/components/tf_imports:web_component_tester", - "//tensorflow/tensorboard/components/vz_sorting", - ], -) - -tensorboard_html_binary( - name = "devserver", - compilation_level = "WHITESPACE_ONLY", - input_path = "/vz-sorting/test/tests.html", - output_path = "/vz-sorting/test/tests.html", - deps = [":test"], -) - -filegroup( - name = "all_files", - testonly = 0, - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/components/vz_sorting/test/sortingTests.ts b/tensorflow/tensorboard/components/vz_sorting/test/sortingTests.ts deleted file mode 100644 index 510685cb4b5..00000000000 --- a/tensorflow/tensorboard/components/vz_sorting/test/sortingTests.ts +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -import {compareTagNames} from '../sorting'; - -describe('compareTagNames', () => { - - const assert = chai.assert; - const sortTagNames = (a) => a.sort(compareTagNames); - - it('is asciibetical', () => { - assert.deepEqual(sortTagNames(['a', 'b']), ['a', 'b']); - assert.deepEqual(sortTagNames(['a', 'B']), ['B', 'a']); - }); - - it('sorts integer portions', () => { - assert.deepEqual(['03', '1'].sort(), ['03', '1']); - assert.deepEqual(sortTagNames(['03', '1']), ['1', '03']); - assert.deepEqual(sortTagNames(['a03', 'a1']), ['a1', 'a03']); - assert.deepEqual(sortTagNames(['a03', 'b1']), ['a03', 'b1']); - assert.deepEqual(sortTagNames(['x0a03', 'x0a1']), ['x0a1', 'x0a03']); - assert.deepEqual(sortTagNames(['a/b/03', 'a/b/1']), ['a/b/1', 'a/b/03']); - }); - - it('sorts fixed point numbers', () => { - assert.deepEqual(sortTagNames(['a0.1', 'a0.01']), ['a0.01', 'a0.1']); - }); - - it('sorts engineering notation', () => { - assert.deepEqual(sortTagNames(['a1e9', 'a9e8']), ['a9e8', 'a1e9']); - assert.deepEqual(sortTagNames(['a1e+9', 'a9e+8']), ['a9e+8', 'a1e+9']); - assert.deepEqual(sortTagNames(['a1e+5', 'a9e-6']), ['a9e-6', 'a1e+5']); - assert.deepEqual(sortTagNames(['a1.0e9', 'a9.0e8']), ['a9.0e8', 'a1.0e9']); - assert.deepEqual( - sortTagNames(['a1.0e+9', 'a9.0e+8']), ['a9.0e+8', 'a1.0e+9']); - }); - - it('is componentized by slash', () => { - assert.deepEqual(['a+/a', 'a/a', 'ab/a'].sort(), ['a+/a', 'a/a', 'ab/a']); - assert.deepEqual( - sortTagNames(['a+/a', 'a/a', 'ab/a']), ['a/a', 'a+/a', 'ab/a']); - }); - - it('is componentized by underscore', () => { - assert.deepEqual( - sortTagNames(['a+_a', 'a_a', 'ab_a']), ['a_a', 'a+_a', 'ab_a']); - assert.deepEqual( - sortTagNames(['a+/a', 'a_a', 'ab_a']), ['a_a', 'a+/a', 'ab_a']); - }); - - it('is componentized by number boundaries', () => { - assert.deepEqual( - sortTagNames(['a+0a', 'a0a', 'ab0a']), ['a0a', 'a+0a', 'ab0a']); - }); - - it('empty comes first', () => { - assert.deepEqual(sortTagNames(['a', '//', '/', '']), ['', '/', '//', 'a']); - }); - - it('decimal parsed correctly', () => { - assert.deepEqual(sortTagNames(['0.2', '0.03']), ['0.03', '0.2']); - assert.deepEqual(sortTagNames(['0..2', '0..03']), ['0..2', '0..03']); - assert.deepEqual(sortTagNames(['.2', '.03']), ['.2', '.03']); - }); -}); diff --git a/tensorflow/tensorboard/components/vz_sorting/test/tests.html b/tensorflow/tensorboard/components/vz_sorting/test/tests.html deleted file mode 100644 index f92c608cdb1..00000000000 --- a/tensorflow/tensorboard/components/vz_sorting/test/tests.html +++ /dev/null @@ -1,23 +0,0 @@ - - - - - - - - diff --git a/tensorflow/tensorboard/components/vz_sorting/vz-sorting.html b/tensorflow/tensorboard/components/vz_sorting/vz-sorting.html deleted file mode 100644 index 5ff6f311589..00000000000 --- a/tensorflow/tensorboard/components/vz_sorting/vz-sorting.html +++ /dev/null @@ -1,18 +0,0 @@ - - - diff --git a/tensorflow/tensorboard/defs/BUILD b/tensorflow/tensorboard/defs/BUILD deleted file mode 100644 index 92a2af34048..00000000000 --- a/tensorflow/tensorboard/defs/BUILD +++ /dev/null @@ -1,14 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -licenses(["notice"]) # Apache 2.0 - -filegroup( - name = "ts_web_library_default_typings", - srcs = [ - # Ordering probably matters. - "@com_microsoft_typescript//:lib.es6.d.ts", - "@io_angular_clutz//:src/resources/closure.lib.d.ts", - "clutz.d.ts", - ], - visibility = ["//visibility:public"], -) diff --git a/tensorflow/tensorboard/defs/clutz.d.ts b/tensorflow/tensorboard/defs/clutz.d.ts deleted file mode 100644 index 47cf307d261..00000000000 --- a/tensorflow/tensorboard/defs/clutz.d.ts +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// tslint:disable -declare namespace ಠ_ಠ.clutz { - interface IteratorIterable extends Iterator, Iterable {} - interface IIterableResult extends IteratorResult {} -} diff --git a/tensorflow/tensorboard/defs/hacks.bzl b/tensorflow/tensorboard/defs/hacks.bzl deleted file mode 100644 index f1d4be79061..00000000000 --- a/tensorflow/tensorboard/defs/hacks.bzl +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TODO(jart): Merge this file into defs.bzl once that file is sync unified. - -def tensorboard_typescript_bundle( - name, - out, - namespace_srcs, - namespace_symbol_aliases={}, - namespace_symbol_aliases_public={}, - **kwargs): - """Rolls TypeScript ES6 modules into one vanilla source file without imports. - - This is a genrule wrapper that concatenates TypeScripts sources inside - namespace blocks while removing ^import lines. Because the sources themselves - are not parsed, the structure of the modules must be passed to this macro as - a Skylark data structure. - - Args: - name: Name of this build rule target. - out: Path of outputted TypeScript source file. - namespace_srcs: Multimap of namespace strings to build file targets. The - ordering of the dictionary and nested lists does not matter when - generating a typings file, but *does* matter when generating a source - file. - namespace_symbol_aliases: Map of namespace strings where each value is a - map of symbol names to fully qualified symbol names. - namespace_symbol_aliases_public: Same as namespace_symbol_aliases but the - symbol will be visible to other namespaces. - """ - cmd = ["(", "echo // GENERATED BY TENSORBOARD_TYPESCRIPT_BUNDLE"] - inputs = set() - for namespace, srcs in namespace_srcs.items(): - cmd.append("echo") - if out[-5:] == ".d.ts": - cmd.append("echo 'declare namespace %s {'" % namespace) - elif out[-3:] == ".ts": - cmd.append("echo 'module %s {'" % namespace) - else: - fail("'out' must end with .ts or .d.ts: " + out) - for symbol, canon in namespace_symbol_aliases.get(namespace, {}).items(): - cmd.append("echo 'import %s = %s;'" % (symbol, canon)) - for symbol, canon in namespace_symbol_aliases_public.get(namespace, - {}).items(): - cmd.append("echo 'export import %s = %s;'" % (symbol, canon)) - inputs += srcs - for src in srcs: - cmd.append("for f in $(locations %s); do" % src) - cmd.append(" echo") - cmd.append(" echo /////////////////////////////////////////////////////") - cmd.append(" echo // " + namespace) - cmd.append(" echo // $$f") - cmd.append(" echo /////////////////////////////////////////////////////") - cmd.append(" echo") - cmd.append(" sed 's!^import !// import !' $$f \\") - cmd.append(" | sed 's!^export declare !export !' \\") - cmd.append(" | sed '/^export .* from /d' \\") - cmd.append(" | sed '/^export {.*};$$/d'") - cmd.append("done") - cmd.append("echo '}'") - cmd.append(") >$@") - native.genrule( - name = name, - srcs = list(inputs), - outs = [out], - cmd = "\n".join(cmd), - **kwargs - ) diff --git a/tensorflow/tensorboard/defs/protos.bzl b/tensorflow/tensorboard/defs/protos.bzl deleted file mode 100644 index 6d1982e098d..00000000000 --- a/tensorflow/tensorboard/defs/protos.bzl +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("@protobuf//:protobuf.bzl", "py_proto_library") - -def tb_proto_library(name, srcs = [], visibility = []): - py_proto_library( - name = name + "_py", - srcs = srcs, - srcs_version = "PY2AND3", - deps = ["@protobuf//:protobuf_python"], - protoc = "@protobuf//:protoc", - visibility = visibility, - default_runtime = "@protobuf//:protobuf_python", - testonly = 0, - ) \ No newline at end of file diff --git a/tensorflow/tensorboard/defs/vulcanize.bzl b/tensorflow/tensorboard/defs/vulcanize.bzl deleted file mode 100644 index 6ff49a35ed7..00000000000 --- a/tensorflow/tensorboard/defs/vulcanize.bzl +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("//tensorflow/tensorboard/defs:defs.bzl", "legacy_js") -load("@io_bazel_rules_closure//closure/private:defs.bzl", "collect_js", "unfurl", "long_path") -load("//tensorflow/tensorboard/defs:web.bzl", "web_aspect") - -def _tensorboard_html_binary(ctx): - deps = unfurl(ctx.attr.deps, provider="webfiles") - manifests = set(order="topological") - files = set() - webpaths = set() - for dep in deps: - manifests += dep.webfiles.manifests - webpaths += dep.webfiles.webpaths - files += dep.data_runfiles.files - webpaths += [ctx.attr.output_path] - closure_js_library=collect_js( - ctx, unfurl(ctx.attr.deps, provider="closure_js_library")) - - # vulcanize - jslibs = depset(ctx.files._jslibs) + closure_js_library.srcs - ctx.action( - inputs=list(manifests | files | jslibs), - outputs=[ctx.outputs.html], - executable=ctx.executable._Vulcanize, - arguments=([ctx.attr.compilation_level, - "true" if ctx.attr.testonly else "false", - ctx.attr.input_path, - ctx.attr.output_path, - ctx.outputs.html.path] + - [f.path for f in jslibs] + - [f.path for f in manifests]), - progress_message="Vulcanizing %s" % ctx.attr.input_path) - - # webfiles manifest - manifest_srcs = [struct(path=ctx.outputs.html.path, - longpath=long_path(ctx, ctx.outputs.html), - webpath=ctx.attr.output_path)] - manifest = ctx.new_file(ctx.configuration.bin_dir, - "%s.pbtxt" % ctx.label.name) - ctx.file_action( - output=manifest, - content=struct( - label=str(ctx.label), - src=manifest_srcs).to_proto()) - manifests += [manifest] - - # webfiles server - params = struct( - label=str(ctx.label), - bind="[::]:6006", - manifest=[long_path(ctx, man) for man in manifests], - external_asset=[struct(webpath=k, path=v) - for k, v in ctx.attr.external_assets.items()]) - params_file = ctx.new_file(ctx.configuration.bin_dir, - "%s_server_params.pbtxt" % ctx.label.name) - ctx.file_action(output=params_file, content=params.to_proto()) - ctx.file_action( - executable=True, - output=ctx.outputs.executable, - content="#!/bin/sh\nexec %s %s" % ( - ctx.executable._WebfilesServer.short_path, - long_path(ctx, params_file))) - - transitive_runfiles = depset() - transitive_runfiles += ctx.attr._WebfilesServer.data_runfiles.files - for dep in deps: - transitive_runfiles += dep.data_runfiles.files - return struct( - files=depset([ctx.outputs.html]), - webfiles=struct( - manifest=manifest, - manifests=manifests, - webpaths=webpaths, - dummy=ctx.outputs.html), - runfiles=ctx.runfiles( - files=ctx.files.data + [manifest, - params_file, - ctx.outputs.html, - ctx.outputs.executable], - transitive_files=transitive_runfiles)) - -tensorboard_html_binary = rule( - implementation=_tensorboard_html_binary, - executable=True, - attrs={ - "compilation_level": attr.string(default="ADVANCED"), - "input_path": attr.string(mandatory=True), - "output_path": attr.string(mandatory=True), - "data": attr.label_list(cfg="data", allow_files=True), - "deps": attr.label_list( - aspects=[ - web_aspect, - legacy_js, - ], - mandatory=True), - "external_assets": attr.string_dict(default={"/_/runfiles": "."}), - "_jslibs": attr.label( - default=Label("//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:jslibs"), - allow_files=True), - "_Vulcanize": attr.label( - default=Label("//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:Vulcanize"), - executable=True, - cfg="host"), - "_WebfilesServer": attr.label( - default=Label( - "@io_bazel_rules_closure//java/io/bazel/rules/closure/webfiles/server:WebfilesServer"), - executable=True, - cfg="host"), - }, - outputs={ - "html": "%{name}.html", - }) diff --git a/tensorflow/tensorboard/defs/web.bzl b/tensorflow/tensorboard/defs/web.bzl deleted file mode 100644 index 103942b0a25..00000000000 --- a/tensorflow/tensorboard/defs/web.bzl +++ /dev/null @@ -1,419 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Same as web_library but supports TypeScript.""" - -load("//tensorflow/tensorboard/defs:defs.bzl", "legacy_js") - -load("//third_party:clutz.bzl", - "CLUTZ_ATTRIBUTES", - "CLUTZ_OUTPUTS", - "clutz_aspect", - "extract_dts_from_closure_libraries") - -load("@io_bazel_rules_closure//closure/private:defs.bzl", - "CLOSURE_LIBRARY_BASE_ATTR", - "CLOSURE_LIBRARY_DEPS_ATTR", - "collect_js", - "collect_runfiles", - "convert_path_to_es6_module_name", - "create_argfile", - "difference", - "long_path", - "unfurl") - -_ASPECT_SLURP_FILE_TYPE = FileType([ - ".html", ".js", ".css", ".gss", ".png", ".jpg", ".gif", ".ico", ".svg"]) - -_CLOSURE_WORKER = attr.label( - default=Label("@io_bazel_rules_closure//java/io/bazel/rules/closure:ClosureWorker"), - executable=True, - cfg="host") - -def _ts_web_library(ctx): - if not ctx.attr.srcs: - if ctx.attr.deps: - fail("deps can not be set when srcs is not") - if not ctx.attr.exports: - fail("exports must be set if srcs is not") - if ctx.attr.path: - if not ctx.attr.path.startswith("/"): - fail("webpath must start with /") - if ctx.attr.path != "/" and ctx.attr.path.endswith("/"): - fail("webpath must not end with / unless it is /") - if "//" in ctx.attr.path: - fail("webpath must not have //") - elif ctx.attr.srcs: - fail("path must be set when srcs is set") - if "*" in ctx.attr.suppress and len(ctx.attr.suppress) != 1: - fail("when \"*\" is suppressed no other items should be present") - - # process what came before - deps = unfurl(ctx.attr.deps, provider="webfiles") - webpaths = depset() - ts_typings = depset(ctx.files._default_typings) - ts_typings_paths = depset( - [long_path(ctx, f) for f in ctx.files._default_typings]) - ts_typings_execroots = depset() - aspect_runfiles = depset() - for dep in deps: - webpaths += dep.webfiles.webpaths - if hasattr(dep.webfiles, "ts_typings"): - ts_typings += dep.webfiles.ts_typings - if hasattr(dep.webfiles, "ts_typings_paths"): - ts_typings_paths += dep.webfiles.ts_typings_paths - if hasattr(dep.webfiles, "ts_typings_execroots"): - ts_typings_execroots += dep.webfiles.ts_typings_execroots - if hasattr(dep.webfiles, "aspect_runfiles"): - aspect_runfiles += dep.webfiles.aspect_runfiles - - # process what comes now - manifest_srcs = [] - new_webpaths = [] - ts_inputs = depset() - ts_outputs = [] - ts_files = list(ts_typings_paths) - new_typings = [] - new_typings_paths = [] - new_typings_execroot = struct(inputs=[]) - execroot = struct( - inputs=[(long_path(ctx, f), f.path) for f in ctx.files._default_typings], - outputs=[], - program=[ctx.executable._tsc.path, "-p"]) - web_srcs = [] - path = ctx.attr.path - strip = _get_strip(ctx) - for src in ctx.files.srcs: - suffix = _get_path_relative_to_package(src) - if strip: - if not suffix.startswith(strip): - fail("Relative src path not start with '%s': %s" % (strip, suffix)) - suffix = suffix[len(strip):] - webpath = "%s/%s" % ("" if path == "/" else path, suffix) - _add_webpath(ctx, src, webpath, webpaths, new_webpaths, manifest_srcs) - if suffix.endswith(".d.ts"): - web_srcs.append(src) - entry = (webpath[1:], src.path) - new_typings.append(src) - new_typings_paths.append(entry[0]) - new_typings_execroot.inputs.append(entry) - ts_inputs += [src] - ts_files.append(entry[0]) - execroot.inputs.append(entry) - elif suffix.endswith(".ts"): - noext = suffix[:-3] - js = ctx.new_file(ctx.bin_dir, "%s.js" % noext) - dts = ctx.new_file(ctx.bin_dir, "%s.d.ts" % noext) - webpath_js = webpath[:-3] + ".js" - webpath_dts = webpath[:-3] + ".d.ts" - _add_webpath(ctx, js, webpath_js, webpaths, new_webpaths, manifest_srcs) - _add_webpath(ctx, dts, webpath_dts, webpaths, new_webpaths, manifest_srcs) - ts_inputs += [src] - ts_outputs.append(js) - ts_outputs.append(dts) - web_srcs.append(dts) - web_srcs.append(js) - ts_files.append(webpath[1:]) - execroot.inputs.append((webpath[1:], src.path)) - execroot.outputs.append((webpath_js[1:], js.path)) - execroot.outputs.append((webpath_dts[1:], dts.path)) - new_typings.append(dts) - new_typings_paths.append(webpath_dts[1:]) - new_typings_execroot.inputs.append((webpath_dts[1:], dts.path)) - else: - web_srcs.append(src) - - # get typings for closure code - clutz_dts = extract_dts_from_closure_libraries(ctx) - if clutz_dts: - entry = (long_path(ctx, clutz_dts), clutz_dts.path) - ts_inputs += [clutz_dts] - ts_files.append(entry[0]) - execroot.inputs.append(entry) - - # compile typescript - workspace = "" - if ctx.label.workspace_root: - workspace = "/" + ctx.label.workspace_root - if execroot.outputs: - ts_config = _new_file(ctx, "-tsc.json") - execroot.inputs.append(("tsconfig.json", ts_config.path)) - ctx.file_action( - output=ts_config, - content=struct( - compilerOptions=struct( - baseUrl=".", - declaration=True, - inlineSourceMap=True, - inlineSources=True, - module="es6", - moduleResolution="node", - noResolve=True, - target="es5", - ), - files=ts_files, - ).to_json()) - er_config = _new_file(ctx, "-tsc-execroot.json") - ctx.file_action(output=er_config, content=execroot.to_json()) - ts_inputs += collect_runfiles([ctx.attr._tsc]) - ts_inputs += ctx.files._tsc - ts_inputs += ts_typings - ts_inputs += ts_typings_execroots - ts_inputs += [ts_config, er_config] - ctx.action( - inputs=list(ts_inputs), - outputs=ts_outputs, - executable=ctx.executable._execrooter, - arguments=[er_config.path] + [f.path for f in ts_typings_execroots], - progress_message="Compiling %d TypeScript files %s" % ( - len(ts_files), ctx.label)) - - # perform strict dependency checking - manifest = _make_manifest(ctx, manifest_srcs) - webpaths += new_webpaths - dummy, manifests = _run_webfiles_validator(ctx, web_srcs, deps, manifest) - web_srcs.append(dummy) - - # define development web server that only applies to this transitive closure - params = struct( - label=str(ctx.label), - bind="[::]:6006", - manifest=[long_path(ctx, man) for man in manifests], - external_asset=[struct(webpath=k, path=v) - for k, v in ctx.attr.external_assets.items()]) - params_file = _new_file(ctx, "-params.pbtxt") - ctx.file_action(output=params_file, content=params.to_proto()) - ctx.file_action( - executable=True, - output=ctx.outputs.executable, - content="#!/bin/sh\nexec %s %s" % ( - ctx.executable._WebfilesServer.short_path, - long_path(ctx, params_file))) - - if new_typings: - er_config = _new_file(ctx, "-typings-execroot.json") - ctx.file_action(output=er_config, content=new_typings_execroot.to_json()) - ts_typings += new_typings - ts_typings_paths += new_typings_paths - ts_typings_execroots += [er_config] - else: - ts_typings = depset() - ts_typings_paths = depset() - ts_typings_execroots = depset() - - # export data to parent rules - return struct( - files=depset(web_srcs + [dummy]), - exports=unfurl(ctx.attr.exports), - webfiles=struct( - manifest=manifest, - manifests=manifests, - webpaths=webpaths, - dummy=dummy, - ts_typings=ts_typings, - ts_typings_paths=ts_typings_paths, - ts_typings_execroots=ts_typings_execroots), - closure_js_library=collect_js( - ctx, unfurl(ctx.attr.deps, provider="closure_js_library")), - runfiles=ctx.runfiles( - files=ctx.files.srcs + ctx.files.data + ts_outputs + [ - manifest, - params_file, - ctx.outputs.executable, - dummy], - transitive_files=(collect_runfiles([ctx.attr._WebfilesServer]) | - collect_runfiles(deps) | - collect_runfiles(ctx.attr.data) | - aspect_runfiles))) - -def _web_aspect_impl(target, ctx): - if hasattr(target, "webfiles"): - return struct() - srcs = [] - deps = [] - if hasattr(ctx.rule.files, "srcs"): - srcs.extend(_ASPECT_SLURP_FILE_TYPE.filter(ctx.rule.files.srcs)) - for attr in ("deps", "sticky_deps", "module_deps"): - value = getattr(ctx.rule.attr, attr, None) - if value: - deps.extend(value) - deps = unfurl(deps, provider="webfiles") - webpaths = depset() - aspect_runfiles = depset(srcs) - for dep in deps: - webpaths += dep.webfiles.webpaths - if hasattr(dep.webfiles, "aspect_runfiles"): - aspect_runfiles += dep.webfiles.aspect_runfiles - manifest_srcs = [] - new_webpaths = [] - for src in srcs: - webpath = "/" + long_path(ctx, src) - _add_webpath(ctx, src, webpath, webpaths, new_webpaths, manifest_srcs) - webpaths += new_webpaths - manifest = _make_manifest(ctx, manifest_srcs) - dummy, manifests = _run_webfiles_validator(ctx, srcs, deps, manifest) - aspect_runfiles += [dummy, manifest] - return struct( - webfiles=struct( - manifest=manifest, - manifests=manifests, - webpaths=webpaths, - dummy=dummy, - aspect_runfiles=aspect_runfiles)) - -def _make_manifest(ctx, src_list): - manifest = _new_file(ctx, "-webfiles.pbtxt") - ctx.file_action( - output=manifest, - content=struct( - label=str(ctx.label), - src=src_list).to_proto()) - return manifest - -def _run_webfiles_validator(ctx, srcs, deps, manifest): - dummy = _new_file(ctx, "-webfiles.ignoreme") - manifests = depset(order="topological") - for dep in deps: - manifests += dep.webfiles.manifests - if srcs: - args = ["WebfilesValidator", - "--dummy", dummy.path, - "--target", manifest.path] - if hasattr(ctx, "attr") and hasattr(ctx.attr, "suppress"): - for category in ctx.attr.suppress: - args.append("--suppress") - args.append(category) - inputs = [manifest] - inputs.extend(srcs) - direct_manifests = depset() - for dep in deps: - inputs.append(dep.webfiles.dummy) - for f in dep.files: - inputs.append(f) - direct_manifests += [dep.webfiles.manifest] - inputs.append(dep.webfiles.manifest) - args.append("--direct_dep") - args.append(dep.webfiles.manifest.path) - for man in difference(manifests, direct_manifests): - inputs.append(man) - args.append("--transitive_dep") - args.append(man.path) - argfile = _new_file(ctx, "-webfiles-checker-args.txt") - ctx.file_action(output=argfile, content="\n".join(args)) - inputs.append(argfile) - ctx.action( - inputs=inputs, - outputs=[dummy], - executable=(getattr(ctx.executable, "_ClosureWorker", None) or - getattr(ctx.executable, "_ClosureWorkerAspect", None)), - arguments=["@@" + argfile.path], - mnemonic="Closure", - execution_requirements={"supports-workers": "1"}, - progress_message="Checking webfiles %s" % ctx.label) - else: - ctx.file_action(output=dummy, content="BOO!") - manifests += [manifest] - return dummy, manifests - -def _new_file(ctx, suffix): - return ctx.new_file(ctx.bin_dir, "%s%s" % (ctx.label.name, suffix)) - -def _add_webpath(ctx, src, webpath, webpaths, new_webpaths, manifest_srcs): - if webpath in new_webpaths: - _fail(ctx, "multiple srcs within %s define the webpath %s " % ( - ctx.label, webpath)) - if webpath in webpaths: - _fail(ctx, "webpath %s was defined by %s when already defined by deps" % ( - webpath, ctx.label)) - new_webpaths.append(webpath) - manifest_srcs.append(struct( - path=src.path, - longpath=long_path(ctx, src), - webpath=webpath)) - -def _fail(ctx, message): - if ctx.attr.suppress == ["*"]: - print(message) - else: - fail(message) - -def _get_path_relative_to_package(artifact): - """Returns file path relative to the package that declared it.""" - path = artifact.path - for prefix in (artifact.root.path, - artifact.owner.workspace_root if artifact.owner else '', - artifact.owner.package if artifact.owner else ''): - if prefix: - prefix = prefix + "/" - if not path.startswith(prefix): - fail("Path %s doesn't start with %s" % (path, prefix)) - path = path[len(prefix):] - return path - -def _get_strip(ctx): - strip = ctx.attr.strip_prefix - if strip: - if strip.startswith("/"): - _fail(ctx, "strip_prefix should not end with /") - strip = strip[1:] - if strip.endswith("/"): - _fail(ctx, "strip_prefix should not end with /") - else: - strip += "/" - return strip - -web_aspect = aspect( - implementation=_web_aspect_impl, - attr_aspects=["deps", "sticky_deps", "module_deps"], - attrs={"_ClosureWorkerAspect": _CLOSURE_WORKER}) - -ts_web_library = rule( - implementation=_ts_web_library, - executable=True, - attrs=CLUTZ_ATTRIBUTES + { - "path": attr.string(), - "srcs": attr.label_list(allow_files=True), - "deps": attr.label_list( - aspects=[ - web_aspect, - clutz_aspect, - legacy_js, - ]), - "exports": attr.label_list(), - "data": attr.label_list(cfg="data", allow_files=True), - "suppress": attr.string_list(), - "strip_prefix": attr.string(), - "external_assets": attr.string_dict(default={"/_/runfiles": "."}), - "clutz_entry_points": attr.string_list(), - "_execrooter": attr.label( - default=Label("//tensorflow/tensorboard/scripts:execrooter"), - executable=True, - cfg="host"), - "_tsc": attr.label( - default=Label("@com_microsoft_typescript//:tsc"), - allow_files=True, - executable=True, - cfg="host"), - "_default_typings": attr.label( - default=Label("//tensorflow/tensorboard:ts_web_library_default_typings"), - allow_files=True), - "_WebfilesServer": attr.label( - default=Label("@io_bazel_rules_closure//java/io/bazel/rules/closure/webfiles/server:WebfilesServer"), - executable=True, - cfg="host"), - "_ClosureWorker": _CLOSURE_WORKER, - "_closure_library_base": CLOSURE_LIBRARY_BASE_ATTR, - "_closure_library_deps": CLOSURE_LIBRARY_DEPS_ATTR, - }, - outputs=CLUTZ_OUTPUTS) diff --git a/tensorflow/tensorboard/defs/zipper.bzl b/tensorflow/tensorboard/defs/zipper.bzl deleted file mode 100644 index e98309ec9a5..00000000000 --- a/tensorflow/tensorboard/defs/zipper.bzl +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -load("@io_bazel_rules_closure//closure/private:defs.bzl", "unfurl", "long_path") - -def _tensorboard_zip_file(ctx): - deps = unfurl(ctx.attr.deps, provider="webfiles") - manifests = set(order="link") - files = set() - webpaths = set() - for dep in deps: - manifests += dep.webfiles.manifests - webpaths += dep.webfiles.webpaths - files += dep.data_runfiles.files - ctx.action( - inputs=list(manifests + files), - outputs=[ctx.outputs.zip], - executable=ctx.executable._Zipper, - arguments=([ctx.outputs.zip.path] + - [m.path for m in manifests]), - progress_message="Zipping %d files" % len(webpaths)) - transitive_runfiles = set() - for dep in deps: - transitive_runfiles += dep.data_runfiles.files - return struct( - files=set([ctx.outputs.zip]), - runfiles=ctx.runfiles( - files=ctx.files.data + [ctx.outputs.zip], - transitive_files=transitive_runfiles)) - -tensorboard_zip_file = rule( - implementation=_tensorboard_zip_file, - attrs={ - "data": attr.label_list(cfg="data", allow_files=True), - "deps": attr.label_list(providers=["webfiles"], mandatory=True), - "_Zipper": attr.label( - default=Label("//tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize:Zipper"), - executable=True, - cfg="host"), - }, - outputs={ - "zip": "%{name}.zip", - }) diff --git a/tensorflow/tensorboard/demo/BUILD b/tensorflow/tensorboard/demo/BUILD deleted file mode 100644 index b253572ec55..00000000000 --- a/tensorflow/tensorboard/demo/BUILD +++ /dev/null @@ -1,20 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library") - -licenses(["notice"]) # Apache 2.0 - -# THIS PACKAGE HAS MOVED -# See tensorflow/tensorboard/components/tf_tensorboard:demo - -web_library( - name = "demo_data", - srcs = glob(["data/**"]), - path = "/", -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/demo/data/audio_run_run1_tag_au1_2Faudio_2F0.json b/tensorflow/tensorboard/demo/data/audio_run_run1_tag_au1_2Faudio_2F0.json deleted file mode 100644 index 7dfe32c7112..00000000000 --- a/tensorflow/tensorboard/demo/data/audio_run_run1_tag_au1_2Faudio_2F0.json +++ /dev/null @@ -1 +0,0 @@ -[{"query": "index=0&tag=au1%2Faudio%2F0&run=run1", "step": 0, "wall_time": 1461795049.203407, "content_type": "audio/wav"}] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/audio_run_run2_tag_au2_2Faudio_2F0.json b/tensorflow/tensorboard/demo/data/audio_run_run2_tag_au2_2Faudio_2F0.json deleted file mode 100644 index 13f9c2de426..00000000000 --- a/tensorflow/tensorboard/demo/data/audio_run_run2_tag_au2_2Faudio_2F0.json +++ /dev/null @@ -1 +0,0 @@ -[{"query": "index=0&tag=au2%2Faudio%2F0&run=run2", "step": 0, "wall_time": 1461795049.212815, "content_type": "audio/wav"}] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/compressedHistograms_run_run1_tag_histo1.json b/tensorflow/tensorboard/demo/data/compressedHistograms_run_run1_tag_histo1.json deleted file mode 100644 index 6ae6fbf880e..00000000000 --- a/tensorflow/tensorboard/demo/data/compressedHistograms_run_run1_tag_histo1.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, [[0, -2.3150592308536755], [668, -2.0967547155036605], [1587, -1.4326244423655616], [3085, -0.8871306575801902], [5000, -0.09312398815580714], [6915, 0.2584093405812282], [8413, 0.8895470642005087], [9332, 1.3198979614453679], [10000, 1.6793308878855118]]], [100.0, 10, [[0, -1.3417572789138936], [668, -1.183563374619141], [1587, -0.48920418783271574], [3085, 0.29326906896076954], [5000, 0.56953784145381], [6915, 0.8684655583499333], [8413, 1.4133127368907181], [9332, 1.906140650457873], [10000, 2.135771998171255]]], [200.0, 20, [[0, -1.5066917525035333], [668, -1.3910909571770793], [1587, -0.902737218885874], [3085, -0.3807791904765027], [5000, 0.38900200905253046], [6915, 0.8209734209339482], [8413, 1.302385856695965], [9332, 1.9324626053521639], [10000, 2.957505317875451]]], [300.0, 30, [[0, -0.5430457051469562], [668, -0.4626161834245273], [1587, 0.21573949543027715], [3085, 0.37353741100174215], [5000, 0.6891407881591103], [6915, 1.0927156232630852], [8413, 1.2745337159550916], [9332, 1.4321116832891605], [10000, 2.1913774993059034]]], [400.0, 40, [[0, -0.3584790755077172], [668, -0.33301611509753215], [1587, -0.1089466072951948], [3085, 0.5792199847585249], [5000, 1.220854943811942], [6915, 1.759829438421432], [8413, 2.3072559906741614], [9332, 2.753036118353921], [10000, 3.0267252195784047]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/compressedHistograms_run_run2_tag_histo1.json b/tensorflow/tensorboard/demo/data/compressedHistograms_run_run2_tag_histo1.json deleted file mode 100644 index 3ad520c5687..00000000000 --- a/tensorflow/tensorboard/demo/data/compressedHistograms_run_run2_tag_histo1.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, [[0, -3.6801669545044846], [668, -3.192188140974744], [1587, -2.3414678549368806], [3085, -0.9632173471995873], [5000, -0.3214892636797772], [6915, 0.11870794142185205], [8413, 0.8895470642005087], [9332, 1.183563374619141], [10000, 2.665663810418372]]], [100.0, 10, [[0, -3.564793583751807], [668, -3.376844436865802], [1587, -1.0366615731293798], [3085, -0.27318696312672563], [5000, 0.9718642422053263], [6915, 2.5765662807928194], [8413, 3.1415385101545126], [9332, 4.085981768607621], [10000, 4.623079406808927]]], [200.0, 20, [[0, -2.235172510433281], [668, -2.004569042815611], [1587, -1.2015432383370985], [3085, 0.11835464933202625], [5000, 0.56953784145381], [6915, 1.202844810963146], [8413, 2.689066032283515], [9332, 2.8494015726499944], [10000, 3.481377676013788]]], [300.0, 30, [[0, -3.360113978269659], [668, -2.8293185004961043], [1587, -1.5992540502266783], [3085, 0.14393860259807117], [5000, 1.47723448201245], [6915, 1.9510057389110733], [8413, 2.833176104473626], [9332, 4.142405216576347], [10000, 4.706937777668589]]], [400.0, 40, [[0, -2.599286228987632], [668, -2.240365897443259], [1587, -1.5992540502266783], [3085, -0.9101893288861387], [5000, 0.7580548669750213], [6915, 1.6009864433919474], [8413, 2.3504002974280036], [9332, 2.7907805263353733], [10000, 3.5098048900144323]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/compressedHistograms_run_run2_tag_histo2.json b/tensorflow/tensorboard/demo/data/compressedHistograms_run_run2_tag_histo2.json deleted file mode 100644 index a3802ba2365..00000000000 --- a/tensorflow/tensorboard/demo/data/compressedHistograms_run_run2_tag_histo2.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, [[0, -1.9291158122759586], [668, -1.5970765333488954], [1587, -1.0923120348519078], [3085, -0.6688082872192093], [5000, 0.09312398815580714], [6915, 0.44532789251701854], [8413, 0.8238009655877649], [9332, 1.0357232383581656], [10000, 1.2741043689144438]]], [100.0, 10, [[0, -0.7780725642449806], [668, -0.7138496178727424], [1587, -0.5448932415735014], [3085, -0.24370397454796228], [5000, 0.42790220995778355], [6915, 0.6191730643365096], [8413, 0.752059342118037], [9332, 1.0451472255274825], [10000, 2.5559479569222825]]], [200.0, 20, [[0, -1.3876904425996377], [668, -1.1464188862638496], [1587, -0.4049955219067526], [3085, 0.04721394862139682], [5000, 0.56953784145381], [6915, 1.3221859041483333], [8413, 1.6188495656305735], [9332, 1.7613953069723651], [10000, 2.3257482385477384]]], [300.0, 30, [[0, -1.600772629982185], [668, -1.1548516185367033], [1587, -0.260387173785447], [3085, 0.17416570914366614], [5000, 0.47069243095356195], [6915, 1.1559276581637614], [8413, 2.0474031182051404], [9332, 2.18821711651116], [10000, 2.2393193406467518]]], [400.0, 40, [[0, -0.8286852465281818], [668, -0.7815041529866706], [1587, -0.3334896444053469], [3085, 0.21085213041026643], [5000, 0.5177616740489182], [6915, 1.077122434649409], [8413, 1.5898009703967424], [9332, 1.8859097291499742], [10000, 2.0954239138728523]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/graph_run_run1.pbtxt b/tensorflow/tensorboard/demo/data/graph_run_run1.pbtxt deleted file mode 100644 index 2a6af328408..00000000000 --- a/tensorflow/tensorboard/demo/data/graph_run_run1.pbtxt +++ /dev/null @@ -1,9 +0,0 @@ -node { - name: "a" - op: "matmul" -} -node { - name: "b" - op: "matmul" - input: "a:0" -} diff --git a/tensorflow/tensorboard/demo/data/graph_run_run2.pbtxt b/tensorflow/tensorboard/demo/data/graph_run_run2.pbtxt deleted file mode 100644 index a5a4d65d5c6..00000000000 --- a/tensorflow/tensorboard/demo/data/graph_run_run2.pbtxt +++ /dev/null @@ -1,15 +0,0 @@ -node { - name: "a" - op: "matmul" -} -node { - name: "b" - op: "matmul" - input: "a:0" -} -node { - name: "c" - op: "matmul" - input: "a:0" - input: "b:0" -} diff --git a/tensorflow/tensorboard/demo/data/histograms_run_run1_tag_histo1.json b/tensorflow/tensorboard/demo/data/histograms_run_run1_tag_histo1.json deleted file mode 100644 index a5600a356e8..00000000000 --- a/tensorflow/tensorboard/demo/data/histograms_run_run1_tag_histo1.json +++ /dev/null @@ -1 +0,0 @@ -[[400.0, 40, [-0.3584790755077172, 3.0267252195784047, 20.0, 24.012225532303315, 48.29045006426564, [-0.35363819004775493, -0.29226296698161564, -0.19961953895336082, 0.3214892636797772, 0.5177616740489182, 0.56953784145381, 0.6264916255991911, 0.7580548669750213, 0.8338603536725235, 1.220854943811942, 1.3429404381931362, 1.47723448201245, 1.624957930213695, 1.7874537232350647, 1.9661990955585713, 2.379100905625872, 2.6170109961884593, 3.1665833053880363], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 2.0]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/histograms_run_run2_tag_histo1.json b/tensorflow/tensorboard/demo/data/histograms_run_run2_tag_histo1.json deleted file mode 100644 index 407c375d2fc..00000000000 --- a/tensorflow/tensorboard/demo/data/histograms_run_run2_tag_histo1.json +++ /dev/null @@ -1 +0,0 @@ -[[400.0, 40, [-2.599286228987632, 3.5098048900144323, 20.0, 10.792285491200078, 66.66796979177158, [-2.379100905625872, -1.9661990955585713, -1.624957930213695, -1.47723448201245, -1.109868130738129, -1.0089710279437536, -0.42790220995778355, -0.2195814928486969, 0.47069243095356195, 0.7580548669750213, 0.917246389039776, 1.3429404381931362, 1.624957930213695, 1.7874537232350647, 2.1628190051144287, 2.6170109961884593, 2.8787120958073054, 3.8315657995195243], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 1.0, 1.0, 2.0, 1.0, 1.0]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/histograms_run_run2_tag_histo2.json b/tensorflow/tensorboard/demo/data/histograms_run_run2_tag_histo2.json deleted file mode 100644 index 752b621ab03..00000000000 --- a/tensorflow/tensorboard/demo/data/histograms_run_run2_tag_histo2.json +++ /dev/null @@ -1 +0,0 @@ -[[400.0, 40, [-0.8286852465281818, 2.0954239138728523, 20.0, 13.546880465642861, 24.14836803774091, [-0.7580548669750213, -0.38900200905253046, -0.06996543062044111, 0.07696197368248522, 0.19961953895336082, 0.2656936063469233, 0.29226296698161564, 0.5177616740489182, 0.7580548669750213, 0.917246389039776, 1.109868130738129, 1.220854943811942, 1.624957930213695, 2.1628190051144287], [2.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 1.0, 1.0, 1.0, 2.0, 3.0]]]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/images_run_run1_tag_im1_2Fimage_2F0.json b/tensorflow/tensorboard/demo/data/images_run_run1_tag_im1_2Fimage_2F0.json deleted file mode 100644 index 814b4193c63..00000000000 --- a/tensorflow/tensorboard/demo/data/images_run_run1_tag_im1_2Fimage_2F0.json +++ /dev/null @@ -1 +0,0 @@ -[{"wall_time": 1459200389.088045, "width": 4, "height": 4, "step": 0, "query": "tag=im1%2Fimage%2F0&index=0&run=run1"}] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/images_run_run1_tag_im2_2Fimage_2F0.json b/tensorflow/tensorboard/demo/data/images_run_run1_tag_im2_2Fimage_2F0.json deleted file mode 100644 index 0c2bdcfc79c..00000000000 --- a/tensorflow/tensorboard/demo/data/images_run_run1_tag_im2_2Fimage_2F0.json +++ /dev/null @@ -1 +0,0 @@ -[{"wall_time": 1459200389.093653, "width": 4, "height": 4, "step": 0, "query": "tag=im2%2Fimage%2F0&index=0&run=run1"}] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/images_run_run2_tag_im1_2Fimage_2F0.json b/tensorflow/tensorboard/demo/data/images_run_run2_tag_im1_2Fimage_2F0.json deleted file mode 100644 index 3160aae366d..00000000000 --- a/tensorflow/tensorboard/demo/data/images_run_run2_tag_im1_2Fimage_2F0.json +++ /dev/null @@ -1 +0,0 @@ -[{"wall_time": 1459200389.117463, "width": 4, "height": 4, "step": 0, "query": "tag=im1%2Fimage%2F0&index=0&run=run2"}] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/individualAudio_index_0_tag_au1_2Faudio_2F0_run_run1.wav b/tensorflow/tensorboard/demo/data/individualAudio_index_0_tag_au1_2Faudio_2F0_run_run1.wav deleted file mode 100644 index f1d24adc0ce..00000000000 Binary files a/tensorflow/tensorboard/demo/data/individualAudio_index_0_tag_au1_2Faudio_2F0_run_run1.wav and /dev/null differ diff --git a/tensorflow/tensorboard/demo/data/individualAudio_index_0_tag_au2_2Faudio_2F0_run_run2.wav b/tensorflow/tensorboard/demo/data/individualAudio_index_0_tag_au2_2Faudio_2F0_run_run2.wav deleted file mode 100644 index 006c84338f7..00000000000 Binary files a/tensorflow/tensorboard/demo/data/individualAudio_index_0_tag_au2_2Faudio_2F0_run_run2.wav and /dev/null differ diff --git a/tensorflow/tensorboard/demo/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run1.png b/tensorflow/tensorboard/demo/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run1.png deleted file mode 100644 index 346fd0076be..00000000000 Binary files a/tensorflow/tensorboard/demo/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run1.png and /dev/null differ diff --git a/tensorflow/tensorboard/demo/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run2.png b/tensorflow/tensorboard/demo/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run2.png deleted file mode 100644 index 26d2d10acaf..00000000000 Binary files a/tensorflow/tensorboard/demo/data/individualImage_tag_im1_2Fimage_2F0_index_0_run_run2.png and /dev/null differ diff --git a/tensorflow/tensorboard/demo/data/individualImage_tag_im2_2Fimage_2F0_index_0_run_run1.png b/tensorflow/tensorboard/demo/data/individualImage_tag_im2_2Fimage_2F0_index_0_run_run1.png deleted file mode 100644 index 6c419062942..00000000000 Binary files a/tensorflow/tensorboard/demo/data/individualImage_tag_im2_2Fimage_2F0_index_0_run_run1.png and /dev/null differ diff --git a/tensorflow/tensorboard/demo/data/logdir b/tensorflow/tensorboard/demo/data/logdir deleted file mode 100644 index b6362b45d77..00000000000 --- a/tensorflow/tensorboard/demo/data/logdir +++ /dev/null @@ -1 +0,0 @@ -{"logdir": "/foo/some/fake/logdir"} \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/runs.json b/tensorflow/tensorboard/demo/data/runs.json deleted file mode 100644 index e0903905429..00000000000 --- a/tensorflow/tensorboard/demo/data/runs.json +++ /dev/null @@ -1 +0,0 @@ -{"run1": {"scalars": ["foo/sin", "foo/cos", "foo/square", "bar/square"], "run_metadata": [], "compressedHistograms": ["histo1"], "images": ["im1/image/0", "im2/image/0"], "histograms": ["histo1"], "graph": true, "audio": ["au1/audio/0"]}, "run2": {"scalars": ["foo/cos", "foo/square", "bar/square"], "run_metadata": [], "compressedHistograms": ["histo2", "histo1"], "images": ["im1/image/0"], "histograms": ["histo2", "histo1"], "graph": true, "audio": ["au2/audio/0"]}} \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars.json b/tensorflow/tensorboard/demo/data/scalars.json deleted file mode 100644 index bc269395b68..00000000000 --- a/tensorflow/tensorboard/demo/data/scalars.json +++ /dev/null @@ -1 +0,0 @@ -{"run2": {"foo/cos": [[0.0, 0, 2.0], [10.0, 1, 1.0806045532226562], [20.0, 2, -0.832293689250946], [30.0, 3, -1.979984998703003], [40.0, 4, -1.3072872161865234]], "bar/square": [[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]], "foo/square": [[0.0, 0, 0.0], [10.0, 1, 2.0], [20.0, 2, 8.0], [30.0, 3, 18.0], [40.0, 4, 32.0]]}, "run1": {"foo/sin": [[0.0, 0, 0.0], [10.0, 1, 0.8414709568023682], [20.0, 2, 0.9092974066734314], [30.0, 3, 0.14112000167369843], [40.0, 4, -0.756802499294281]], "foo/cos": [[0.0, 0, 1.0], [10.0, 1, 0.5403022766113281], [20.0, 2, -0.416146844625473], [30.0, 3, -0.9899924993515015], [40.0, 4, -0.6536436080932617]], "bar/square": [[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]], "foo/square": [[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]]}} \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_bar_2Fsquare.json b/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_bar_2Fsquare.json deleted file mode 100644 index 6d584fb4a9e..00000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_bar_2Fsquare.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fcos.json b/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fcos.json deleted file mode 100644 index 025eaa16e93..00000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fcos.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 1.0], [10.0, 1, 0.5403022766113281], [20.0, 2, -0.416146844625473], [30.0, 3, -0.9899924993515015], [40.0, 4, -0.6536436080932617]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fsin.json b/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fsin.json deleted file mode 100644 index eae69dd78f3..00000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fsin.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 0.0], [10.0, 1, 0.8414709568023682], [20.0, 2, 0.9092974066734314], [30.0, 3, 0.14112000167369843], [40.0, 4, -0.756802499294281]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fsquare.json b/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fsquare.json deleted file mode 100644 index 6d584fb4a9e..00000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run1_tag_foo_2Fsquare.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_bar_2Fsquare.json b/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_bar_2Fsquare.json deleted file mode 100644 index 6d584fb4a9e..00000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_bar_2Fsquare.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 0.0], [10.0, 1, 1.0], [20.0, 2, 4.0], [30.0, 3, 9.0], [40.0, 4, 16.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_foo_2Fcos.json b/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_foo_2Fcos.json deleted file mode 100644 index dd3593f9d10..00000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_foo_2Fcos.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 2.0], [10.0, 1, 1.0806045532226562], [20.0, 2, -0.832293689250946], [30.0, 3, -1.979984998703003], [40.0, 4, -1.3072872161865234]] \ No newline at end of file diff --git a/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_foo_2Fsquare.json b/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_foo_2Fsquare.json deleted file mode 100644 index 0ff9ef0551d..00000000000 --- a/tensorflow/tensorboard/demo/data/scalars_run_run2_tag_foo_2Fsquare.json +++ /dev/null @@ -1 +0,0 @@ -[[0.0, 0, 0.0], [10.0, 1, 2.0], [20.0, 2, 8.0], [30.0, 3, 18.0], [40.0, 4, 32.0]] \ No newline at end of file diff --git a/tensorflow/tensorboard/http_api.md b/tensorflow/tensorboard/http_api.md deleted file mode 100644 index c2885daf93c..00000000000 --- a/tensorflow/tensorboard/http_api.md +++ /dev/null @@ -1,402 +0,0 @@ -# Tensorboard client-server HTTP API - -## Runs, Tags, and Tag Types - -TensorBoard data is organized around the concept of a `run`, which represents -all the related data thrown off by a single execution of TensorFlow, a `tag`, -which groups values of data that come from the same source within a TensorFlow -run, and `tag types`, which are our way of distinguishing different types of -data that have fundamentally different representations and should be processed -on different code paths. For example, a "train" run may have a `scalars` -tag that represents the learning rate, another `scalars` tag that -represents the value of the objective function, a `histograms` tag that reveals -information on weights in a particular layer over time, and an `images` tag that -shows input images flowing into the system. The "eval" run might have an -entirely different set of tag names, or some duplicated tag names. - -The currently supported tag types are `scalars`, `images`, `audio`, -`histograms`, `graph` and `run_metadata`. Each tag type corresponds to a route -(documented below) for retrieving tag data of that type. - -All of the data provided comes from TensorFlow events files ('\*.tfevents\*'), -which are written using the SummaryWriter class -(tensorflow/python/training/summary_writer.py), and the data is generated by -summary ops (tensorflow/python/ops/summary_ops.py). The `scalars` come from the -`ScalarSummary` op, the `histograms` from the `HistogramSummary`, the `audio` -from the `AudioSummary`, and the `images` from `ImageSummary`. The tag type -`graph` is special in that it is not a collection of tags of that type, but a -boolean denoting if there is a graph definition associated with the run. The tag -is provided to the summary op (usually as a constant). - -## `data/logdir` - -Returns a JSON object with a key "logdir" that maps to the `logdir` argument -(string) with which Tensorboard started up. Example: -`{logdir: '/foo/logdir/argument'}` - -The `logdir` argument is the path of the directory that contains events files. - -## `data/plugins_listing` - -Returns a dict mapping from plugin name to a boolean indicating whether the -plugin is active. A plugin might be inactive, for instance, if it lacks relevant -data. Every plugin has a key. This route helps the frontend avoid issuing -requests to an inactive plugin - the routes of an inactive plugin do not work. - -## `data/runs` - -Returns an array containing the names of all the runs known to the -TensorBoard backend at this time. Each entry is a string corresponding -to a single run. - -We guarantee that as new runs are created in the log directory, they -will always appear at the end of the list returned by this route. That -is, the order of runs is persistent, and the result of this route is an -“append-only” list. - -Example response: - - ["train_run", "eval"] - -## `/data/plugin/scalars/tags` - -Returns a dictionary mapping from `run_name` (quoted string) to arrays of -`tag_name` (quoted string), where each array contains the names of all -scalar tags present in the corresponding run. Here is an example: - - { - "train_run": ["xent", "loss", "learning_rate"], - "eval": ["precision", "recall"] - } - -Note that runs without any scalar tags are included as keys with value the -empty array. - -## `/data/plugin/scalars/scalars?run=foo&tag=bar` - -Returns an array of event_accumulator.SimpleValueEvents ([wall_time, step, -value]) for the given run and tag. wall_time is seconds since epoch. - -Example: - - [ - [1443856985.705543, 1448, 0.7461960315704346], # wall_time, step, value - [1443857105.704628, 3438, 0.5427092909812927], - [1443857225.705133, 5417, 0.5457325577735901], - ... - ] - -If the format parameter is set to 'csv', the response will instead be in CSV -format: - - Wall time,step,value - 1443856985.705543,1448,0.7461960315704346 - 1443857105.704628,3438,0.5427092909812927 - 1443857225.705133,5417,0.5457325577735901 - -## `/data/plugin/histograms/tags` - -Returns a dictionary mapping from `run_name` (quoted string) to arrays of -`tag_name` (quoted string), where each array contains the names of all -histogram tags present in the corresponding run. Here is an example: - - { - "train_run": ["foo_histogram", "bar_histogram"], - "eval": ["foo_histogram", "bar_histogram"] - } - -Note that runs without any histogram tags are included as keys with -value the empty array. - -## `/data/plugin/histograms/histograms?run=foo&tag=bar` - -Returns an array of event_accumulator.HistogramEvents ([wall_time, step, -HistogramValue]) for the given run and tag. A HistogramValue is [min, max, num, -sum, sum_squares, bucket_limit, bucket]. wall_time is seconds since epoch. - -Annotated Example: (note - real data is higher precision) - - [ - [ - 1443871386.185149, # wall_time - 235166, # step - [ - -0.66, # minimum value - 0.44, # maximum value - 8.0, # number of items in the histogram - -0.80, # sum of items in the histogram - 0.73, # sum of squares of items in the histogram - [-0.68, -0.62, -0.292, -0.26, -0.11, -0.10, -0.08, -0.07, -0.05, - -0.0525, -0.0434, -0.039, -0.029, -0.026, 0.42, 0.47, 1.8e+308], - # the right edge of each bucket - [0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0, 0.0, - 1.0, 0.0] # the number of elements within each bucket - ] - ] - ] - -## `/data/plugin/distributions/tags` - -Returns a dictionary mapping from `run_name` (quoted string) to arrays of -`tag_name` (quoted string), where each array contains the names of all -distribution tags present in the corresponding run. Here is an example: - - { - "train_run": ["foo_histogram", "bar_histogram"], - "eval": ["foo_histogram", "bar_histogram"] - } - -Note that runs without any distribution tags are included as keys with -value the empty array. - -## `/data/plugin/distributions/distributions?run=foo&tag=bar` - -Returns an array of event_accumulator.CompressedHistogramEvents ([wall_time, -step, CompressedHistogramValues]) for the given run and tag. - -CompressedHistogramValues is a list of namedtuples with each tuple specifying -a basis point (bps) as well as an interpolated value of the histogram value -at that basis point. A basis point is 1/100 of a percent. - -The current compression strategy is to choose basis points that correspond to -the median and bands of 1SD, 2SD, and 3SDs around the median. Note that the -current compression strategy does not work well for representing multimodal -data -- this is something that will be improved in a later iteration. - -Annotated Example: (note - real data is higher precision) - - [ - [ - 1441154832.580509, # wall_time - 5, # step - [ [0, -3.67], # CompressedHistogramValue for 0th percentile - [2500, -4.19], # CompressedHistogramValue for 25th percentile - [5000, 6.29], - [7500, 1.64], - [10000, 3.67] - ] - ], - ... - ] - -## `/data/plugin/images/images?run=foo&tag=bar` - -Gets a sample of ImageMetadatas for the given run and tag. - -Returns an array of objects containing information about available images, -crucially including the query parameter that may be used to retrieve that image. -(See /data/plugin/images/individualImage for details.) - -For example: - - { - "width": 28, # width in pixels - "height": 28, # height in pixels - "wall_time": 1440210599.246, # time in seconds since epoch - "step": 63702821, # number of steps that have passed - "query": "index=0&tagname=input%2Fimage%2F2&run=train" - # param for /individualImage - } - -## `/data/plugin/images/individualImage?{{query}}` - -Retrieves an individual image. The image query should not be generated by the -frontend, but instead acquired from calling the /images route (the image -metadata objects contain the query to use). The response is the image itself -with mime-type 'image/png'. - -Note that the query is not guaranteed to always refer to the same image even -within a single run, as images may be removed from the sampling reservoir and -replaced with other images. (See Notes for details on the reservoir sampling.) - -An example call to this route would look like this: -/data/plugin/images/individualImage?index=0&tagname=input%2Fimage%2F2&run=train - -## `/data/plugin/images/tags` - -Returns a dictionary mapping from `run_name` (quoted string) to arrays of -`tag_name` (quoted string), where each array contains the names of all image -tags present in the corresponding run. Here is an example: - - { - "train": ["foo_image", "bar_image"], - "eval": ["foo_image", "bar_image"] - } - -Note that runs without any image tags are included as keys with value the empty -array. - -## `/data/plugin/audio/audio?run=foo&tag=bar` - -Gets a sample of AudioMetadatas for the given run and tag. - -Returns an array of objects containing information about available audio, -crucially including the query parameter that may be used to retrieve that audio. -(See /data/plugin/audio/individualAudio for details.) - -For example: - - { - "wall_time": 1440210599.246, # time in seconds since epoch - "step": 63702821, # number of steps that have passed - "content_type": "audio/wav" # the MIME-type of the audio - "query": "index=0&tagname=input%2Faudio%2F2&run=train" - # param for /individualAudio - } - -## `/data/plugin/audio/individualAudio?{{query}}` - -Retrieves an individual audio clip. The audio query should not be generated by -the frontend, but instead acquired from calling the /audio route (the audio -metadata objects contain the query to use). The response is the audio itself -with an appropriate Content-Type header set. - -Note that the query is not guaranteed to always refer to the same clip even -within a single run, as audio may be removed from the sampling reservoir and -replaced with other clips. (See Notes for details on the reservoir sampling.) - -An example call to this route would look like this: -/individualAudio?index=0&tagname=input%2Faudio%2F2&run=train - -## `/data/plugin/audio/tags` - -Returns a dictionary mapping from `run_name` (quoted string) to arrays of -`tag_name` (quoted string), where each array contains the names of all audio -tags present in the corresponding run. Here is an example: - - { - "train": ["foo_audio", "bar_audio"], - "eval": ["foo_audio", "bar_audio"], - } - -Note that runs without any audio tags are included as keys with value the empty -array. - -## `/data/plugin/graphs/runs` - -Returns a list of runs that have associated graphs. - -For example: - - ["train"] - -## `/data/plugin/graphs/graph?run=foo&limit_attr_size=1024&large_attrs_key=key` - -Returns the graph definition for the given run in pbtxt format. The -graph is composed of a list of nodes, where each node is a specific -TensorFlow operation which takes as inputs other nodes (operations). - -The query parameters `limit_attr_size` and `large_attrs_key` are optional. - -`limit_attr_size` specifies the maximum allowed size in bytes, before the -attribute is considered large and filtered out of the graph. If specified, -it must be an int and > 0. If not specified, no filtering is applied. - -`large_attrs_key` is the attribute key that will be used for storing -attributes that are too large. The value of this key (list of strings) -should be used by the client in order to determine which attributes -have been filtered. Must be specified if `limit_attr_size` is specified. - -For the query - - /data/plugin/graphs/graph?run=foo&limit_attr_size=1024&large_attrs_key=_too_large, - -here is an example pbtxt response of a graph with 3 nodes, where the second -node had two large attributes "a" and "b" that were filtered out (size > 1024): - - node { - op: "Input" - name: "A" - } - node { - op: "Input" - name: "B" - attr { - key: "small_attr" - value: { - s: "some string" - } - } - attr { - key: "_too_large" - value { - list { - s: "a" - s: "b" - } - } - } - } - node { - op: "MatMul" - name: "C" - input: "A" - input: "B" - } - -Prior to filtering, the original node "B" had the following content: - - node { - op: "Input" - name: "B" - attr { - key: "small_attr" - value: { - s: "some string" - } - } - attr { - key: "a" - value { Very large object... } - } - attr { - key: "b" - value { Very large object... } - } - } - -## `/data/run_metadata?run=foo&tag=bar` - -Given a run and tag, returns the metadata of a particular -`session.run()` as a gzipped, pbtxt serialized [`RunMetadata`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/protobuf/config.proto) -proto. For example: - - step_stats { - dev_stats { - device: "/job:localhost/replica:0/task:0/cpu:0" - node_stats { - node_name: "_SOURCE" - all_start_micros: 1458337695775395 - op_start_rel_micros: 11 - op_end_rel_micros: 12 - all_end_rel_micros: 38 - memory { - allocator_name: "cpu" - } - timeline_label: "_SOURCE = NoOp()" - scheduled_micros: 1458337695775363 - } - } - } - -## Notes - -All returned values, histograms, audio, and images are returned in the order -they were written by TensorFlow (which should correspond to increasing -`wall_time` order, but may not necessarily correspond to increasing step count -if the process had to restart from a previous checkpoint). - -The returned values may be downsampled using reservoir sampling, which is -configurable by the TensorBoard server. When downsampling occurs, the server -guarantees that different tags will all sample at the same sequence of indices, -so that if there are two tags `A` and `B` which are related so that `A[i] ~ -B[i]` for all `i`, then `D(A)[i] ~ D(B)[i]` for all `i`, where `D` represents -the downsampling operation. - -The reservoir sampling puts an upper bound on the number of items that will be -returned for a given run-tag combination, and guarantees that all items are -equally likely to be in the final sample (ie it is a uniform distribution over -the values), with the proviso that the most recent individual item is always -included in the sample. - -The reservoir sizes are configurable on a per-tag type basis. diff --git a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/BUILD b/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/BUILD deleted file mode 100644 index f1f7746ff84..00000000000 --- a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/BUILD +++ /dev/null @@ -1,56 +0,0 @@ -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -licenses(["notice"]) # Apache 2.0 - -java_binary( - name = "Vulcanize", - srcs = ["Vulcanize.java"], - jvm_flags = [ - "-Xss20m", # JSCompiler needs big stacks for recursive parsing - "-XX:+UseParallelGC", # Best GC when app isn't latency sensitive - "-Djava.util.logging.SimpleFormatter.format='%1$$tY-%1$$tm-%1$$td %1$$tH:%1$$tM:%1$$tS.%1$$tL %4$$-6s %5$$s%6$$s%n'", # Less log spam - ], - visibility = ["//visibility:public"], - deps = [ - "@com_google_guava", - "@com_google_protobuf_java", - "@io_bazel_rules_closure//closure/compiler", - "@io_bazel_rules_closure//java/io/bazel/rules/closure:webpath", - "@io_bazel_rules_closure//java/io/bazel/rules/closure/webfiles:build_info_java_proto", - "@io_bazel_rules_closure//java/org/jsoup/nodes", - "@org_jsoup", - ], -) - -java_binary( - name = "Zipper", - srcs = ["Zipper.java"], - visibility = ["//visibility:public"], - deps = [ - "@com_google_guava", - "@com_google_protobuf_java", - "@io_bazel_rules_closure//java/io/bazel/rules/closure/webfiles", - "@io_bazel_rules_closure//java/io/bazel/rules/closure/webfiles:build_info_java_proto", - ], -) - -# These JS files are always taken into consideration by the Closure Compiler -# when vulcanizing, per vulcanize.bzl. -filegroup( - name = "jslibs", - srcs = [ - # Ordering probably matters - "@com_google_javascript_closure_compiler_externs", - "@com_google_javascript_closure_compiler_externs_polymer", - "externs.js", - "@com_google_javascript_closure_library//:closure/goog/base.js", - "@com_google_javascript_closure_library//:closure/goog/deps.js", - ], - visibility = ["//visibility:public"], -) - -filegroup( - name = "all_files", - srcs = glob(["**"]), - tags = ["notsan"], -) diff --git a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java b/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java deleted file mode 100644 index 533907dd64d..00000000000 --- a/tensorflow/tensorboard/java/org/tensorflow/tensorboard/vulcanize/Vulcanize.java +++ /dev/null @@ -1,546 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package org.tensorflow.tensorboard.vulcanize; - -import static com.google.common.base.Preconditions.checkNotNull; -import static com.google.common.base.Verify.verify; -import static com.google.common.base.Verify.verifyNotNull; -import static java.nio.charset.StandardCharsets.UTF_8; - -import com.google.common.base.CharMatcher; -import com.google.common.base.Joiner; -import com.google.common.base.Optional; -import com.google.common.base.Splitter; -import com.google.common.collect.HashMultimap; -import com.google.common.collect.ImmutableMultimap; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; -import com.google.common.collect.Lists; -import com.google.common.collect.Multimap; -import com.google.javascript.jscomp.CheckLevel; -import com.google.javascript.jscomp.CompilationLevel; -import com.google.javascript.jscomp.Compiler; -import com.google.javascript.jscomp.CompilerOptions; -import com.google.javascript.jscomp.DiagnosticGroup; -import com.google.javascript.jscomp.DiagnosticGroups; -import com.google.javascript.jscomp.DiagnosticType; -import com.google.javascript.jscomp.JSError; -import com.google.javascript.jscomp.ModuleIdentifier; -import com.google.javascript.jscomp.PropertyRenamingPolicy; -import com.google.javascript.jscomp.Result; -import com.google.javascript.jscomp.SourceFile; -import com.google.javascript.jscomp.WarningsGuard; -import com.google.protobuf.TextFormat; -import io.bazel.rules.closure.Webpath; -import io.bazel.rules.closure.webfiles.BuildInfo.Webfiles; -import io.bazel.rules.closure.webfiles.BuildInfo.WebfilesSource; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.nio.file.StandardOpenOption; -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Deque; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import java.util.stream.Collectors; -import org.jsoup.Jsoup; -import org.jsoup.nodes.Attribute; -import org.jsoup.nodes.Comment; -import org.jsoup.nodes.DataNode; -import org.jsoup.nodes.Document; -import org.jsoup.nodes.Element; -import org.jsoup.nodes.Html5Printer; -import org.jsoup.nodes.Node; -import org.jsoup.nodes.TextNode; -import org.jsoup.parser.Parser; -import org.jsoup.parser.Tag; - -/** Simple one-off solution for TensorBoard vulcanization. */ -public final class Vulcanize { - - private static final Pattern IGNORE_PATHS_PATTERN = - Pattern.compile("/(?:polymer|marked-element)/.*"); - - private static final ImmutableSet EXTRA_JSDOC_TAGS = - ImmutableSet.of("attribute", "hero", "group", "required"); - - private static final Pattern WEBPATH_PATTERN = Pattern.compile("//~~WEBPATH~~([^\n]+)"); - - private static final Parser parser = Parser.htmlParser(); - private static final Map webfiles = new HashMap<>(); - private static final Set alreadyInlined = new HashSet<>(); - private static final Set legalese = new HashSet<>(); - private static final List licenses = new ArrayList<>(); - private static final List stack = new ArrayList<>(); - private static final List externs = new ArrayList<>(); - private static final List sourcesFromJsLibraries = new ArrayList<>(); - private static final Map sourcesFromScriptTags = new LinkedHashMap<>(); - private static final Map sourceTags = new LinkedHashMap<>(); - private static final Multimap suppressions = HashMultimap.create(); - private static CompilationLevel compilationLevel; - private static Webpath outputPath; - private static Node firstCompiledScript; - private static Node licenseComment; - private static int insideDemoSnippet; - private static boolean testOnly; - - public static void main(String[] args) throws IOException { - compilationLevel = CompilationLevel.fromString(args[0]); - testOnly = args[1].equals("true"); - Webpath inputPath = Webpath.get(args[2]); - outputPath = Webpath.get(args[3]); - Path output = Paths.get(args[4]); - for (int i = 5; i < args.length; i++) { - if (args[i].endsWith(".js")) { - String code = new String(Files.readAllBytes(Paths.get(args[i])), UTF_8); - SourceFile sourceFile = SourceFile.fromCode(args[i], code); - if (code.contains("@externs")) { - externs.add(sourceFile); - } else { - sourcesFromJsLibraries.add(sourceFile); - } - continue; - } - if (!args[i].endsWith(".pbtxt")) { - continue; - } - Webfiles manifest = loadWebfilesPbtxt(Paths.get(args[i])); - for (WebfilesSource src : manifest.getSrcList()) { - webfiles.put(Webpath.get(src.getWebpath()), Paths.get(src.getPath())); - } - } - stack.add(inputPath); - Document document = parse(Files.readAllBytes(webfiles.get(inputPath))); - transform(document); - compile(); - if (licenseComment != null) { - licenseComment.attr("comment", String.format("\n%s\n", Joiner.on("\n\n").join(licenses))); - } - Files.write( - output, - Html5Printer.stringify(document).getBytes(UTF_8), - StandardOpenOption.WRITE, - StandardOpenOption.CREATE, - StandardOpenOption.TRUNCATE_EXISTING); - } - - private static void transform(Node root) throws IOException { - Node node = checkNotNull(root); - Node newNode; - while (true) { - newNode = enterNode(node); - if (node.equals(root)) { - root = newNode; - } - node = newNode; - if (node.childNodeSize() > 0) { - node = node.childNode(0); - } else { - while (true) { - newNode = leaveNode(node); - if (node.equals(root)) { - root = newNode; - } - node = newNode; - if (node.equals(root)) { - return; - } - Node next = node.nextSibling(); - if (next == null) { - if (node.parentNode() == null) { - return; - } - node = verifyNotNull(node.parentNode(), "unexpected root: %s", node); - } else { - node = next; - break; - } - } - } - } - } - - private static Node enterNode(Node node) throws IOException { - if (node.nodeName().equals("demo-snippet")) { - insideDemoSnippet++; - } - if (insideDemoSnippet > 0) { - return node; - } - if (node instanceof Element) { - if (!getAttrTransitive(node, "vulcanize-noinline").isPresent()) { - if (node.nodeName().equals("link") && node.attr("rel").equals("import")) { - // Inline HTML. - node = visitHtmlImport(node); - } else if (node.nodeName().equals("script") - && !shouldIgnoreUri(node.attr("src")) - && !node.hasAttr("jscomp-ignore")) { - node = visitScript(node); - } else if (node.nodeName().equals("link") - && node.attr("rel").equals("stylesheet") - && !node.attr("href").isEmpty() - && !shouldIgnoreUri(node.attr("href"))) { - node = visitStylesheet(node); - } - } - rootifyAttribute(node, "href"); - rootifyAttribute(node, "src"); - rootifyAttribute(node, "action"); - rootifyAttribute(node, "assetpath"); - } else if (node instanceof Comment) { - String text = ((Comment) node).getData(); - if (text.contains("@license")) { - handleLicense(text); - if (licenseComment == null) { - licenseComment = node; - } else { - node = replaceNode(node, new TextNode("", node.baseUri())); - } - } else { - node = replaceNode(node, new TextNode("", node.baseUri())); - } - } - return node; - } - - private static Node leaveNode(Node node) { - if (node instanceof Document) { - stack.remove(stack.size() - 1); - } else if (node.nodeName().equals("demo-snippet")) { - insideDemoSnippet--; - } - return node; - } - - private static Node visitHtmlImport(Node node) throws IOException { - Webpath href = me().lookup(Webpath.get(node.attr("href"))); - if (alreadyInlined.add(href)) { - stack.add(href); - Document subdocument = parse(Files.readAllBytes(getWebfile(href))); - for (Attribute attr : node.attributes()) { - subdocument.attr(attr.getKey(), attr.getValue()); - } - return replaceNode(node, subdocument); - } else { - return replaceNode(node, new TextNode("", node.baseUri())); - } - } - - private static Node visitScript(Node node) throws IOException { - Webpath path; - String script; - if (node.attr("src").isEmpty()) { - path = makeSyntheticName(".js"); - script = getInlineScriptFromNode(node); - } else { - path = me().lookup(Webpath.get(node.attr("src"))); - script = new String(Files.readAllBytes(getWebfile(path)), UTF_8); - } - if (node.attr("src").endsWith(".min.js") - || getAttrTransitive(node, "jscomp-nocompile").isPresent()) { - Node newScript = - new Element(Tag.valueOf("script"), node.baseUri(), node.attributes()) - .appendChild(new DataNode(script, node.baseUri())) - .removeAttr("src") - .removeAttr("jscomp-nocompile"); - if (firstCompiledScript != null) { - firstCompiledScript.before(newScript); - return replaceNode(node, new TextNode("", node.baseUri())); - } else { - return replaceNode(node, newScript); - } - } else { - if (firstCompiledScript == null) { - firstCompiledScript = node; - } - sourcesFromScriptTags.put(path, script); - sourceTags.put(path, node); - Optional suppress = getAttrTransitive(node, "jscomp-suppress"); - if (suppress.isPresent()) { - if (suppress.get().isEmpty()) { - suppressions.put(path, "*"); - } else { - suppressions.putAll(path, Splitter.on(' ').split(suppress.get())); - } - } - return node; - } - } - - private static Node visitStylesheet(Node node) throws IOException { - Webpath href = me().lookup(Webpath.get(node.attr("href"))); - return replaceNode( - node, - new Element(Tag.valueOf("style"), node.baseUri(), node.attributes()) - .appendChild( - new DataNode( - new String(Files.readAllBytes(getWebfile(href)), UTF_8), node.baseUri())) - .removeAttr("rel") - .removeAttr("href")); - } - - private static Optional getAttrTransitive(Node node, String attr) { - while (node != null) { - if (node.hasAttr(attr)) { - return Optional.of(node.attr(attr)); - } - node = node.parent(); - } - return Optional.absent(); - } - - private static Node replaceNode(Node oldNode, Node newNode) { - oldNode.replaceWith(newNode); - return newNode; - } - - private static Path getWebfile(Webpath path) { - return verifyNotNull(webfiles.get(path), "Bad ref: %s -> %s", me(), path); - } - - private static void compile() { - if (sourcesFromScriptTags.isEmpty()) { - return; - } - - CompilerOptions options = new CompilerOptions(); - compilationLevel.setOptionsForCompilationLevel(options); - - // Nice options. - options.setColorizeErrorOutput(true); - options.setContinueAfterErrors(true); - options.setLanguageIn(CompilerOptions.LanguageMode.ECMASCRIPT_2016); - options.setLanguageOut(CompilerOptions.LanguageMode.ECMASCRIPT5); - options.setGenerateExports(true); - options.setStrictModeInput(false); - options.setExtraAnnotationNames(EXTRA_JSDOC_TAGS); - - // So we can chop JS binary back up into the original script tags. - options.setPrintInputDelimiter(true); - options.setInputDelimiter("//~~WEBPATH~~%name%"); - - // Optimizations that are too advanced for us right now. - options.setPropertyRenaming(PropertyRenamingPolicy.OFF); - options.setCheckGlobalThisLevel(CheckLevel.OFF); - options.setRemoveUnusedPrototypeProperties(false); - options.setRemoveUnusedPrototypePropertiesInExterns(false); - options.setRemoveUnusedClassProperties(false); - - // Dependency management. - options.setClosurePass(true); - options.setManageClosureDependencies(true); - options.getDependencyOptions().setDependencyPruning(true); - options.getDependencyOptions().setDependencySorting(true); - options.getDependencyOptions().setMoocherDropping(false); - options.getDependencyOptions() - .setEntryPoints( - sourceTags - .keySet() - .stream() - .map(Webpath::toString) - .map(ModuleIdentifier::forFile) - .collect(Collectors.toList())); - - // Polymer pass. - options.setPolymerVersion(1); - - // Debug flags. - if (testOnly) { - options.setPrettyPrint(true); - options.setGeneratePseudoNames(true); - options.setExportTestFunctions(true); - } - - // Don't print warnings from " - sanitized = "<script>alert('xss')</script>" - self.assertEqual(text_plugin.markdown_and_sanitize(dangerous), sanitized) - - dangerous = textwrap.dedent("""\ - hello *you*""") - sanitized = '

hello you

' - self.assertEqual(text_plugin.markdown_and_sanitize(dangerous), sanitized) - - def testTableGeneration(self): - array2d = np.array([['one', 'two'], ['three', 'four']]) - expected_table = textwrap.dedent("""\ - - - - - - - - - - - -
onetwo
threefour
""") - self.assertEqual(text_plugin.make_table(array2d), expected_table) - - expected_table_with_headers = textwrap.dedent("""\ - - - - - - - - - - - - - - - - - -
c1c2
onetwo
threefour
""") - - actual_with_headers = text_plugin.make_table(array2d, headers=['c1', 'c2']) - self.assertEqual(actual_with_headers, expected_table_with_headers) - - array_1d = np.array(['one', 'two', 'three', 'four', 'five']) - expected_1d = textwrap.dedent("""\ - - - - - - - - - - - - - - - - - - -
one
two
three
four
five
""") - self.assertEqual(text_plugin.make_table(array_1d), expected_1d) - - expected_1d_with_headers = textwrap.dedent("""\ - - - - - - - - - - - - - - - - - - - - - - - -
X
one
two
three
four
five
""") - actual_1d_with_headers = text_plugin.make_table(array_1d, headers=['X']) - self.assertEqual(actual_1d_with_headers, expected_1d_with_headers) - - def testMakeTableExceptions(self): - # Verify that contents is being type-checked and shape-checked. - with self.assertRaises(ValueError): - text_plugin.make_table([]) - - with self.assertRaises(ValueError): - text_plugin.make_table('foo') - - with self.assertRaises(ValueError): - invalid_shape = np.full((3, 3, 3), 'nope', dtype=np.dtype('S3')) - text_plugin.make_table(invalid_shape) - - # Test headers exceptions in 2d array case. - test_array = np.full((3, 3), 'foo', dtype=np.dtype('S3')) - with self.assertRaises(ValueError): - # Headers is wrong type. - text_plugin.make_table(test_array, headers='foo') - with self.assertRaises(ValueError): - # Too many headers. - text_plugin.make_table(test_array, headers=['foo', 'bar', 'zod', 'zoink']) - with self.assertRaises(ValueError): - # headers is 2d - text_plugin.make_table(test_array, headers=test_array) - - # Also make sure the column counting logic works in the 1d array case. - test_array = np.array(['foo', 'bar', 'zod']) - with self.assertRaises(ValueError): - # Too many headers. - text_plugin.make_table(test_array, headers=test_array) - - def test_reduce_to_2d(self): - - def make_range_array(dim): - """Produce an incrementally increasing multidimensional array. - - Args: - dim: the number of dimensions for the array - - Returns: - An array of increasing integer elements, with dim dimensions and size - two in each dimension. - - Example: rangeArray(2) results in [[0,1],[2,3]]. - """ - return np.array(range(2**dim)).reshape([2] * dim) - - for i in range(2, 5): - actual = text_plugin.reduce_to_2d(make_range_array(i)) - expected = make_range_array(2) - np.testing.assert_array_equal(actual, expected) - - def test_text_array_to_html(self): - - convert = text_plugin.text_array_to_html - scalar = np.array('foo') - scalar_expected = '

foo

' - self.assertEqual(convert(scalar), scalar_expected) - - vector = np.array(['foo', 'bar']) - vector_expected = textwrap.dedent("""\ - - - - - - - - - -

foo

bar

""") - self.assertEqual(convert(vector), vector_expected) - - d2 = np.array([['foo', 'bar'], ['zoink', 'zod']]) - d2_expected = textwrap.dedent("""\ - - - - - - - - - - - -

foo

bar

zoink

zod

""") - self.assertEqual(convert(d2), d2_expected) - - d3 = np.array([[['foo', 'bar'], ['zoink', 'zod']], [['FOO', 'BAR'], - ['ZOINK', 'ZOD']]]) - - warning = text_plugin.markdown_and_sanitize(text_plugin.WARNING_TEMPLATE % - 3) - d3_expected = warning + textwrap.dedent("""\ - - - - - - - - - - - -

foo

bar

zoink

zod

""") - self.assertEqual(convert(d3), d3_expected) - - def testPluginIsActive(self): - plugin = text_plugin.TextPlugin() - multiplexer = event_multiplexer.EventMultiplexer() - plugin.get_plugin_apps(event_multiplexer.EventMultiplexer(), None) - - # The plugin is inactive because text summaries are not available. - self.assertFalse(plugin.is_active()) - - multiplexer.AddRunsFromDirectory(self.logdir) - multiplexer.Reload() - - # The plugin is active because text summaries are available. - self.assertTrue(self.plugin.is_active()) - - def testUnicode(self): - self.assertConverted(u'

Iñtërnâtiônàlizætiøn⚡💩

', - 'Iñtërnâtiônàlizætiøn⚡💩') - - -if __name__ == '__main__': - tf.test.main() diff --git a/tensorflow/tensorboard/scripts/BUILD b/tensorflow/tensorboard/scripts/BUILD deleted file mode 100644 index 05425ee61d0..00000000000 --- a/tensorflow/tensorboard/scripts/BUILD +++ /dev/null @@ -1,32 +0,0 @@ -# Description: -# Some useful scripts that are bundled with TensorBoard. - -package(default_visibility = ["//tensorflow/tensorboard:internal"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -py_binary( - name = "generate_testdata", - srcs = ["generate_testdata.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow:tensorflow_py", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - -py_binary( - name = "execrooter", - srcs = ["execrooter.py"], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], -) - -filegroup( - name = "all_files", - srcs = glob(["*"]), - visibility = ["//tensorflow:__subpackages__"], -) diff --git a/tensorflow/tensorboard/scripts/__init__.py b/tensorflow/tensorboard/scripts/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/tensorflow/tensorboard/scripts/execrooter.py b/tensorflow/tensorboard/scripts/execrooter.py deleted file mode 100644 index 65569b91512..00000000000 --- a/tensorflow/tensorboard/scripts/execrooter.py +++ /dev/null @@ -1,85 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Utility for running programs in a symlinked execroot.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import json -import os -import shutil -import subprocess -import sys -import tempfile - - -def run(inputs, program, outputs): - """Creates temp symlink tree, runs program, and copies back outputs. - - Args: - inputs: List of fake paths to real paths, which are used for symlink tree. - program: List containing real path of program and its arguments. The - execroot directory will be appended as the last argument. - outputs: List of fake outputted paths to copy back to real paths. - Returns: - 0 if succeeded or nonzero if failed. - """ - root = tempfile.mkdtemp() - try: - cwd = os.getcwd() - for fake, real in inputs: - parent = os.path.join(root, os.path.dirname(fake)) - if not os.path.exists(parent): - os.makedirs(parent) - os.symlink(os.path.join(cwd, real), os.path.join(root, fake)) - if subprocess.call(program + [root]) != 0: - return 1 - for fake, real in outputs: - shutil.copyfile(os.path.join(root, fake), real) - return 0 - finally: - shutil.rmtree(root) - - -def main(args): - """Invokes run function using a JSON file config. - - Args: - args: CLI args, which can be a JSON file containing an object whose - attributes are the parameters to the run function. If multiple JSON - files are passed, their contents are concatenated. - Returns: - 0 if succeeded or nonzero if failed. - Raises: - Exception: If input data is missing. - """ - if not args: - raise Exception('Please specify at least one JSON config path') - inputs = [] - program = [] - outputs = [] - for arg in args: - with open(arg) as fd: - config = json.load(fd) - inputs.extend(config.get('inputs', [])) - program.extend(config.get('program', [])) - outputs.extend(config.get('outputs', [])) - if not program: - raise Exception('Please specify a program') - return run(inputs, program, outputs) - - -if __name__ == '__main__': - sys.exit(main(sys.argv[1:])) diff --git a/tensorflow/tensorboard/scripts/generate_testdata.py b/tensorflow/tensorboard/scripts/generate_testdata.py deleted file mode 100644 index f191d16a82d..00000000000 --- a/tensorflow/tensorboard/scripts/generate_testdata.py +++ /dev/null @@ -1,225 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Generate some standard test data for debugging TensorBoard. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import bisect -import math -import os -import os.path -import random -import shutil - -import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin -import tensorflow as tf - - -tf.flags.DEFINE_string("target", None, """The directory where serialized data -will be written""") - -tf.flags.DEFINE_boolean("overwrite", False, """Whether to remove and overwrite -TARGET if it already exists.""") - -FLAGS = tf.flags.FLAGS - -# Hardcode a start time and reseed so script always generates the same data. -_start_time = 0 -random.seed(0) - - -def _MakeHistogramBuckets(): - v = 1E-12 - buckets = [] - neg_buckets = [] - while v < 1E20: - buckets.append(v) - neg_buckets.append(-v) - v *= 1.1 - # Should include DBL_MAX, but won't bother for test data. - return neg_buckets[::-1] + [0] + buckets - - -def _MakeHistogram(values): - """Convert values into a histogram proto using logic from histogram.cc.""" - limits = _MakeHistogramBuckets() - counts = [0] * len(limits) - for v in values: - idx = bisect.bisect_left(limits, v) - counts[idx] += 1 - - limit_counts = [(limits[i], counts[i]) for i in xrange(len(limits)) - if counts[i]] - bucket_limit = [lc[0] for lc in limit_counts] - bucket = [lc[1] for lc in limit_counts] - sum_sq = sum(v * v for v in values) - return tf.HistogramProto( - min=min(values), - max=max(values), - num=len(values), - sum=sum(values), - sum_squares=sum_sq, - bucket_limit=bucket_limit, - bucket=bucket) - - -def WriteScalarSeries(writer, tag, f, n=5): - """Write a series of scalar events to writer, using f to create values.""" - step = 0 - wall_time = _start_time - for i in xrange(n): - v = f(i) - value = tf.Summary.Value(tag=tag, simple_value=v) - summary = tf.Summary(value=[value]) - event = tf.Event(wall_time=wall_time, step=step, summary=summary) - writer.add_event(event) - step += 1 - wall_time += 10 - - -def WriteHistogramSeries(writer, tag, mu_sigma_tuples, n=20): - """Write a sequence of normally distributed histograms to writer.""" - step = 0 - wall_time = _start_time - for [mean, stddev] in mu_sigma_tuples: - data = [random.normalvariate(mean, stddev) for _ in xrange(n)] - histo = _MakeHistogram(data) - summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=histo)]) - event = tf.Event(wall_time=wall_time, step=step, summary=summary) - writer.add_event(event) - step += 10 - wall_time += 100 - - -def WriteImageSeries(writer, tag, n_images=1): - """Write a few dummy images to writer.""" - step = 0 - session = tf.Session() - p = tf.placeholder("uint8", (1, 4, 4, 3)) - s = tf.summary.image(tag, p) - for _ in xrange(n_images): - im = np.random.random_integers(0, 255, (1, 4, 4, 3)) - summ = session.run(s, feed_dict={p: im}) - writer.add_summary(summ, step) - step += 20 - session.close() - - -def WriteAudioSeries(writer, tag, n_audio=1): - """Write a few dummy audio clips to writer.""" - step = 0 - session = tf.Session() - - min_frequency_hz = 440 - max_frequency_hz = 880 - sample_rate = 4000 - duration_frames = sample_rate // 2 # 0.5 seconds. - frequencies_per_run = 1 - num_channels = 2 - - p = tf.placeholder("float32", (frequencies_per_run, duration_frames, - num_channels)) - s = tf.summary.audio(tag, p, sample_rate) - - for _ in xrange(n_audio): - # Generate a different frequency for each channel to show stereo works. - frequencies = np.random.random_integers( - min_frequency_hz, - max_frequency_hz, - size=(frequencies_per_run, num_channels)) - tiled_frequencies = np.tile(frequencies, (1, duration_frames)) - tiled_increments = np.tile( - np.arange(0, duration_frames), - (num_channels, 1)).T.reshape(1, duration_frames * num_channels) - tones = np.sin(2.0 * np.pi * tiled_frequencies * tiled_increments / - sample_rate) - tones = tones.reshape(frequencies_per_run, duration_frames, num_channels) - - summ = session.run(s, feed_dict={p: tones}) - writer.add_summary(summ, step) - step += 20 - session.close() - - -def GenerateTestData(path): - """Generates the test data directory.""" - run1_path = os.path.join(path, "run1") - os.makedirs(run1_path) - writer1 = tf.summary.FileWriter(run1_path) - WriteScalarSeries(writer1, "foo/square", lambda x: x * x) - WriteScalarSeries(writer1, "bar/square", lambda x: x * x) - WriteScalarSeries(writer1, "foo/sin", math.sin) - WriteScalarSeries(writer1, "foo/cos", math.cos) - WriteHistogramSeries(writer1, "histo1", [[0, 1], [0.3, 1], [0.5, 1], [0.7, 1], - [1, 1]]) - WriteImageSeries(writer1, "im1") - WriteImageSeries(writer1, "im2") - WriteAudioSeries(writer1, "au1") - - run2_path = os.path.join(path, "run2") - os.makedirs(run2_path) - writer2 = tf.summary.FileWriter(run2_path) - WriteScalarSeries(writer2, "foo/square", lambda x: x * x * 2) - WriteScalarSeries(writer2, "bar/square", lambda x: x * x * 3) - WriteScalarSeries(writer2, "foo/cos", lambda x: math.cos(x) * 2) - WriteHistogramSeries(writer2, "histo1", [[0, 2], [0.3, 2], [0.5, 2], [0.7, 2], - [1, 2]]) - WriteHistogramSeries(writer2, "histo2", [[0, 1], [0.3, 1], [0.5, 1], [0.7, 1], - [1, 1]]) - WriteImageSeries(writer2, "im1") - WriteAudioSeries(writer2, "au2") - - graph_def = tf.GraphDef() - node1 = graph_def.node.add() - node1.name = "a" - node1.op = "matmul" - node2 = graph_def.node.add() - node2.name = "b" - node2.op = "matmul" - node2.input.extend(["a:0"]) - - writer1.add_graph(graph_def) - node3 = graph_def.node.add() - node3.name = "c" - node3.op = "matmul" - node3.input.extend(["a:0", "b:0"]) - writer2.add_graph(graph_def) - writer1.close() - writer2.close() - - -def main(unused_argv=None): - target = FLAGS.target - if not target: - print("The --target flag is required.") - return -1 - if os.path.exists(target): - if FLAGS.overwrite: - if os.path.isdir(target): - shutil.rmtree(target) - else: - os.remove(target) - else: - print("Refusing to overwrite target %s without --overwrite" % target) - return -2 - GenerateTestData(target) - - -if __name__ == "__main__": - tf.app.run() diff --git a/tensorflow/tools/api/golden/tensorflow.-l-m-d-b-reader.pbtxt b/tensorflow/tools/api/golden/tensorflow.-l-m-d-b-reader.pbtxt new file mode 100644 index 00000000000..f9b7e9bbca8 --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.-l-m-d-b-reader.pbtxt @@ -0,0 +1,46 @@ +path: "tensorflow.LMDBReader" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member { + name: "reader_ref" + mtype: "" + } + member { + name: "supports_serialize" + mtype: "" + } + member_method { + name: "__init__" + argspec: "args=[\'self\', \'name\', \'options\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " + } + member_method { + name: "num_records_produced" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "num_work_units_completed" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read" + argspec: "args=[\'self\', \'queue\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "read_up_to" + argspec: "args=[\'self\', \'queue\', \'num_records\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "reset" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "restore_state" + argspec: "args=[\'self\', \'state\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "serialize_state" + argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.bitwise.pbtxt b/tensorflow/tools/api/golden/tensorflow.bitwise.pbtxt new file mode 100644 index 00000000000..1e4d333cc0b --- /dev/null +++ b/tensorflow/tools/api/golden/tensorflow.bitwise.pbtxt @@ -0,0 +1,19 @@ +path: "tensorflow.bitwise" +tf_module { + member_method { + name: "bitwise_and" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bitwise_or" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "bitwise_xor" + argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } + member_method { + name: "invert" + argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } +} diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index 342ee95f74d..88c171b7921 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -108,6 +108,10 @@ tf_module { name: "InteractiveSession" mtype: "" } + member { + name: "LMDBReader" + mtype: "" + } member { name: "LogMessage" mtype: "" @@ -252,6 +256,10 @@ tf_module { name: "bfloat16" mtype: "" } + member { + name: "bitwise" + mtype: "" + } member { name: "bool" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.pbtxt index 58fd5760c11..c2955379651 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.pbtxt @@ -304,6 +304,10 @@ tf_module { name: "import_meta_graph" argspec: "args=[\'meta_graph_or_file\', \'clear_devices\', \'import_scope\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'None\'], " } + member_method { + name: "init_from_checkpoint" + argspec: "args=[\'ckpt_dir_or_file\', \'assignment_map\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "input_producer" argspec: "args=[\'input_tensor\', \'element_shape\', \'num_epochs\', \'shuffle\', \'seed\', \'capacity\', \'shared_name\', \'summary_name\', \'name\', \'cancel_op\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\', \'None\', \'32\', \'None\', \'None\', \'None\', \'None\'], " @@ -320,6 +324,18 @@ tf_module { name: "limit_epochs" argspec: "args=[\'tensor\', \'num_epochs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } + member_method { + name: "list_variables" + argspec: "args=[\'ckpt_dir_or_file\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_checkpoint" + argspec: "args=[\'ckpt_dir_or_file\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "load_variable" + argspec: "args=[\'ckpt_dir_or_file\', \'name\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "match_filenames_once" argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/ci_build/Dockerfile.tensorboard b/tensorflow/tools/ci_build/Dockerfile.tensorboard deleted file mode 100644 index 9795872e2c4..00000000000 --- a/tensorflow/tools/ci_build/Dockerfile.tensorboard +++ /dev/null @@ -1,11 +0,0 @@ -FROM ubuntu:14.04 - -MAINTAINER Jan Prach - -# Copy and run the install scripts. -COPY install/*.sh /install/ -RUN /install/install_bootstrap_deb_packages.sh -RUN add-apt-repository -y ppa:openjdk-r/ppa && \ - add-apt-repository -y ppa:george-edison55/cmake-3.x -RUN /install/install_deb_packages.sh -RUN /install/install_tensorboard_packages.sh diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh index c9867796f3a..7fcd235e625 100755 --- a/tensorflow/tools/ci_build/install/install_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh @@ -85,6 +85,3 @@ pip2 install mock pip2 install portpicker pip3 install portpicker - -pip2 install backports.weakref==1.0rc1 -pip3 install backports.weakref==1.0rc1 diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh index 33b3bc104bd..084ac49496c 100755 --- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh +++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh @@ -89,6 +89,3 @@ pip3.5 install wheel==0.29.0 pip3.5 install portpicker pip3.5 install werkzeug - -pip3.5 install backports.weakref==1.0rc1 - diff --git a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh index e5f4a22f7ad..6574fd144a2 100755 --- a/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh +++ b/tensorflow/tools/ci_build/osx/cpu/run_py2_cc_core.sh @@ -36,5 +36,4 @@ bazel test --test_tag_filters=-gpu,-benchmark-test,-nomac \ --test_timeout 300,450,1200,3600 \ --test_size_filters=small,medium \ --jobs=${N_JOBS} --build_tests_only --test_output=errors -k -- \ - //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... \ - -//tensorflow/tensorboard/... + //tensorflow/... -//tensorflow/compiler/... -//tensorflow/contrib/... diff --git a/tensorflow/tools/ci_build/update_version.sh b/tensorflow/tools/ci_build/update_version.sh index 682f5329f58..b707ee338a2 100755 --- a/tensorflow/tools/ci_build/update_version.sh +++ b/tensorflow/tools/ci_build/update_version.sh @@ -130,12 +130,6 @@ if [[ ${OLD_MAJOR} != ${MAJOR} ]] || [[ ${OLD_MINOR} != ${MINOR} ]]; then echo "Detected Major.Minor change. "\ "Updating pattern ${OLD_R_MAJOR_MINOR} to ${R_MAJOR_MINOR} in additional files" - # Update tensorflow/tensorboard/README.md - TENSORBOARD_README_MD="${TF_SRC_DIR}/tensorboard/README.md" - check_existence file "${TENSORBOARD_README_MD}" - sed -i -r -e "s/${OLD_R_MAJOR_MINOR}/${R_MAJOR_MINOR}/g" \ - "${TENSORBOARD_README_MD}" - # Update dockerfiles DEVEL_DOCKERFILE="${TF_SRC_DIR}/tools/docker/Dockerfile.devel" check_existence file "${DEVEL_DOCKERFILE}" diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index 8e27b133c2f..45722ec9ebd 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -37,6 +37,7 @@ py_library( srcs = ["parser.py"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], + deps = ["@com_github_andreif_codegen"], ) py_test( @@ -44,7 +45,6 @@ py_test( size = "small", srcs = ["parser_test.py"], srcs_version = "PY2AND3", - tags = ["manual"], deps = [ ":parser", "//tensorflow/python:platform_test", @@ -78,13 +78,10 @@ py_test( size = "small", srcs = ["generate_lib_test.py"], srcs_version = "PY2AND3", - tags = ["manual"], deps = [ ":generate_lib", ":parser", - "//tensorflow:tensorflow_py", "//tensorflow/python:platform_test", - "//tensorflow/python/debug:debug_py", ], ) @@ -105,7 +102,6 @@ py_test( srcs = ["build_docs_test.py"], data = ["//tensorflow:docs_src"], srcs_version = "PY2AND3", - tags = ["manual"], deps = [ ":generate_lib", "//tensorflow:tensorflow_py", diff --git a/tensorflow/tools/docs/build_docs_test.py b/tensorflow/tools/docs/build_docs_test.py index d28dd93b9a8..ae293f65764 100644 --- a/tensorflow/tools/docs/build_docs_test.py +++ b/tensorflow/tools/docs/build_docs_test.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function import os +import sys +import textwrap import tensorflow as tf from tensorflow.python import debug as tf_debug @@ -29,19 +31,40 @@ from tensorflow.tools.docs import generate_lib class Flags(object): resource_root = resource_loader.get_root_dir_with_all_resources() - src_dir = os.path.join(resource_root, 'third_party/tensorflow/docs_src') - base_dir = os.path.join(resource_root, 'third_party/tensorflow/') + src_dir = os.path.join(resource_root, 'tensorflow/docs_src') + base_dir = os.path.join(resource_root, 'tensorflow/') output_dir = googletest.GetTempDir() class BuildDocsTest(googletest.TestCase): def testBuildDocs(self): + if sys.version_info >= (3, 0): + print('Warning: Doc generation is not supported from python3.') + return + doc_generator = generate_lib.DocGenerator() doc_generator.set_py_modules([('tf', tf), ('tfdbg', tf_debug)]) - status = doc_generator.build(Flags()) + try: + status = doc_generator.build(Flags()) + except RuntimeError as e: + if not e.args[0].startswith('Modules nested too deep'): + raise + + msg = textwrap.dedent("""\ + %s + + **************************************************************** + If this test fails here, you have most likely introduced an + unsealed module. Make sure to use `remove_undocumented` or similar + utilities to avoid leaking symbols. See above for more information + on the exact point of failure. + **************************************************************** + """ % e.args[0]) + + raise RuntimeError(msg) if status: self.fail('Found %s Errors!' % status) diff --git a/tensorflow/tools/docs/generate_lib.py b/tensorflow/tools/docs/generate_lib.py index 99872e1d844..67a4ad0ec92 100644 --- a/tensorflow/tools/docs/generate_lib.py +++ b/tensorflow/tools/docs/generate_lib.py @@ -20,6 +20,7 @@ from __future__ import print_function import argparse import os +import sys import six @@ -415,6 +416,8 @@ class DocGenerator(object): """Main entry point for generating docs.""" def __init__(self): + if sys.version_info >= (3, 0): + print('Warning: Doc generation is not supported from python3.') self.argument_parser = argparse.ArgumentParser() self._py_modules = None self._private_map = _get_default_private_map() diff --git a/tensorflow/tools/docs/generate_lib_test.py b/tensorflow/tools/docs/generate_lib_test.py index 6e5deb6a36e..ea6d28a02b1 100644 --- a/tensorflow/tools/docs/generate_lib_test.py +++ b/tensorflow/tools/docs/generate_lib_test.py @@ -21,9 +21,6 @@ from __future__ import print_function import os import sys -import tensorflow as tf - -from tensorflow.python import debug as tf_debug from tensorflow.python.platform import googletest from tensorflow.tools.docs import generate_lib from tensorflow.tools.docs import parser @@ -54,22 +51,6 @@ class DummyVisitor(object): class GenerateTest(googletest.TestCase): - def test_extraction(self): - py_modules = [('tf', tf), ('tfdbg', tf_debug)] - - try: - generate_lib.extract(py_modules, - generate_lib._get_default_private_map(), - generate_lib._get_default_do_not_descend_map()) - except RuntimeError: - print('*****************************************************************') - print('If this test fails, you have most likely introduced an unsealed') - print('module. Make sure to use remove_undocumented or similar utilities') - print('to avoid leaking symbols. See below for more information on the') - print('failure.') - print('*****************************************************************') - raise - def test_write(self): module = sys.modules[__name__] diff --git a/tensorflow/tools/docs/parser_test.py b/tensorflow/tools/docs/parser_test.py index 3e02160130f..862f0acfa90 100644 --- a/tensorflow/tools/docs/parser_test.py +++ b/tensorflow/tools/docs/parser_test.py @@ -491,13 +491,13 @@ Returns: class TestParseFunctionDetails(googletest.TestCase): - def testParseFunctionDetails(self): + def test_parse_function_details(self): docstring, function_details = parser._parse_function_details(RELU_DOC) self.assertEqual(len(function_details), 2) args = function_details[0] self.assertEqual(args.keyword, 'Args') - self.assertEmpty(args.header) + self.assertEqual(len(args.header), 0) self.assertEqual(len(args.items), 2) self.assertEqual(args.items[0][0], 'features') self.assertEqual(args.items[1][0], 'name') @@ -515,5 +515,60 @@ class TestParseFunctionDetails(googletest.TestCase): docstring + ''.join(str(detail) for detail in function_details)) +class TestGenerateSignature(googletest.TestCase): + + def test_known_object(self): + if sys.version_info >= (3, 0): + print('Warning: Doc generation is not supported from python3.') + return + + known_object = object() + reverse_index = {id(known_object): 'location.of.object.in.api'} + + def example_fun(arg=known_object): # pylint: disable=unused-argument + pass + + sig = parser._generate_signature(example_fun, reverse_index) + self.assertEqual(sig, ['arg=location.of.object.in.api']) + + def test_literals(self): + if sys.version_info >= (3, 0): + print('Warning: Doc generation is not supported from python3.') + return + + def example_fun(a=5, b=5.0, c=None, d=True, e='hello', f=(1, (2, 3))): # pylint: disable=g-bad-name, unused-argument + pass + + sig = parser._generate_signature(example_fun, reverse_index={}) + self.assertEqual( + sig, ['a=5', 'b=5.0', 'c=None', 'd=True', "e='hello'", 'f=(1, (2, 3))']) + + def test_dotted_name(self): + if sys.version_info >= (3, 0): + print('Warning: Doc generation is not supported from python3.') + return + + # pylint: disable=g-bad-name + class a(object): + + class b(object): + + class c(object): + + class d(object): + + def __init__(self, *args): + pass + # pylint: enable=g-bad-name + + e = {'f': 1} + + def example_fun(arg1=a.b.c.d, arg2=a.b.c.d(1, 2), arg3=e['f']): # pylint: disable=unused-argument + pass + + sig = parser._generate_signature(example_fun, reverse_index={}) + self.assertEqual(sig, ['arg1=a.b.c.d', 'arg2=a.b.c.d(1, 2)', "arg3=e['f']"]) + + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/tools/graph_transforms/quantize_nodes.cc b/tensorflow/tools/graph_transforms/quantize_nodes.cc index da064377ac3..2b85e7e83c6 100644 --- a/tensorflow/tools/graph_transforms/quantize_nodes.cc +++ b/tensorflow/tools/graph_transforms/quantize_nodes.cc @@ -119,6 +119,13 @@ const std::vector& GetQuantizedOpList() { DT_QUINT8, {}, QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, + {"ResizeBilinear", + {"align_corners"}, + {{"T", DT_QUINT8}}, + DT_QUINT8, + DT_QUINT8, + {1}, + QuantizedOpInfo::CONTIGUOUS_MIN_MAX}, {"Relu6", {}, {{"Tinput", DT_QUINT8}}, diff --git a/tensorflow/tools/graph_transforms/quantize_nodes_test.cc b/tensorflow/tools/graph_transforms/quantize_nodes_test.cc index d02655f3f9c..eca263a1ae0 100644 --- a/tensorflow/tools/graph_transforms/quantize_nodes_test.cc +++ b/tensorflow/tools/graph_transforms/quantize_nodes_test.cc @@ -106,8 +106,8 @@ class QuantizeNodesTest : public ::testing::Test { // Reshape is not included here because it can be added as part of the // quantization process. const std::set quantizable_ops = { - "Add", "BiasAdd", "Concat", "Conv2D", "MatMul", - "Relu", "Relu6", "AvgPool", "MaxPool", "Mul"}; + "Add", "BiasAdd", "Concat", "Conv2D", "MatMul", "Relu", + "Relu6", "ResizeBilinear", "AvgPool", "MaxPool", "Mul"}; for (const NodeDef& node : quantized_graph_def.node()) { EXPECT_EQ(0, quantizable_ops.count(node.op())) << "Found quantizable node " << node.op() << " for node named " @@ -652,6 +652,33 @@ class QuantizeNodesTest : public ::testing::Test { EXPECT_EQ("requantize_op", node_map.at("final_dequantize")->input(0)); } + void TestQuantizeResizeBilinear() { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor size_tensor(DT_INT32, TensorShape({2})); + test::FillValues(&size_tensor, {256, 256}); + + Output constant_op = Const(root.WithOpName("size_tensor_op"), + Input::Initializer(size_tensor)); + + Output placeholder_op = + Placeholder(root.WithOpName("placeholder_op"), DT_FLOAT); + + Output resize_bilinear_op = ResizeBilinear( + root.WithOpName("resize_bilinear_op"), placeholder_op, constant_op); + + GraphDef float_graph_def; + TF_ASSERT_OK(root.ToGraphDef(&float_graph_def)); + + Tensor input_tensor(DT_FLOAT, {1, 128, 128, 3}); + test::FillFn(&input_tensor, [](int) { return 100.0f; }); + + TestQuantizedVersusFloatGraph(float_graph_def, + {{"placeholder_op", input_tensor}}, + {"resize_bilinear_op"}); + } + void TestRemoveRedundantQuantizationWithMultipleOutputs() { auto root = tensorflow::Scope::NewRootScope(); using namespace ::tensorflow::ops; // NOLINT(build/namespaces) @@ -1446,6 +1473,10 @@ TEST_F(QuantizeNodesTest, TestQuantizeAvgPool) { TestQuantizeAvgPool(); } TEST_F(QuantizeNodesTest, TestQuantizeReshape) { TestQuantizeReshape(); } +TEST_F(QuantizeNodesTest, TestQuantizeResizeBilinear) { + TestQuantizeResizeBilinear(); +} + TEST_F(QuantizeNodesTest, TestRemoveRedundantQuantization) { TestRemoveRedundantQuantization(); } diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 798338d7875..d3952b5cec5 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -75,10 +75,6 @@ py_binary( "//tensorflow/python/saved_model", "//tensorflow/python/tools:tools_pip", # These targets don't build on Windows yet. Exclude them for now. - # rules_closure currently doesn't build on Windows due to - # https://github.com/bazelbuild/rules_closure/pull/206 - # Since tensorboard dependes on rules_closure, exclude tensorboard until it's fixed. - # "//tensorflow/tensorboard", # "//tensorflow/contrib/ndlstm", # "//tensorflow/contrib/slim", # "//tensorflow/contrib/slim/python/slim/nets:nets_pip", @@ -113,15 +109,12 @@ filegroup( "@lmdb//:LICENSE", "@local_config_sycl//sycl:LICENSE.text", "@nanopb_git//:LICENSE.txt", - "@org_html5lib//:LICENSE", - "@org_mozilla_bleach//:LICENSE", - "@org_pocoo_werkzeug//:LICENSE", - "@org_pythonhosted_markdown//:LICENSE.md", "@png_archive//:LICENSE", "@protobuf//:LICENSE", "@six_archive//:LICENSE", "@snappy//:COPYING", "@zlib_archive//:zlib.h", + "@org_python_pypi_backports_weakref//:LICENSE", ] + if_not_windows([ "@nccl_archive//:LICENSE.txt", ]) + tf_additional_license_deps(), @@ -151,9 +144,13 @@ sh_binary( "//tensorflow/contrib/slim:slim", "//tensorflow/contrib/slim/python/slim/data:data_pip", "//tensorflow/contrib/slim/python/slim/nets:nets_pip", + "//tensorflow/contrib/tpu:tpu_estimator", + "//tensorflow/contrib/tpu:tpu_helper_library", + "//tensorflow/contrib/tpu:tpu_py", "//tensorflow/contrib/specs:specs", "//tensorflow/contrib/tensor_forest:init_py", "//tensorflow/contrib/tensor_forest/hybrid:hybrid_pip", + "//tensorflow/contrib/predictor:predictor_pip", "//tensorflow/examples/tutorials/mnist:package", "//tensorflow/python:distributed_framework_test_lib", "//tensorflow/python:meta_graph_testdata", @@ -161,7 +158,6 @@ sh_binary( "//tensorflow/python/debug:debug_pip", "//tensorflow/python/saved_model:saved_model", "//tensorflow/python/tools:tools_pip", - "//tensorflow/tensorboard", ], }) + if_mkl(["//third_party/mkl:intel_binary_blob"]), ) diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index dd47b44001a..54ba8064e8a 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -34,12 +34,9 @@ _VERSION = '1.2.0' REQUIRED_PACKAGES = [ 'numpy >= 1.11.0', 'six >= 1.10.0', - 'protobuf >= 3.3.0', - 'werkzeug >= 0.11.10', - 'html5lib == 0.9999999', # identical to 1.0b8 - 'markdown == 2.2.0', - 'bleach == 1.5.0', + 'protobuf >= 3.2.0', 'backports.weakref == 1.0rc1', + 'tensorflow-tensorboard', ] project_name = 'tensorflow' @@ -59,7 +56,6 @@ else: # pylint: disable=line-too-long CONSOLE_SCRIPTS = [ - 'tensorboard = tensorflow.tensorboard.tensorboard:main', 'saved_model_cli = tensorflow.python.tools.saved_model_cli:main', ] # pylint: enable=line-too-long @@ -191,8 +187,6 @@ setup( package_data={ 'tensorflow': [ EXTENSION_NAME, - 'tensorboard/components/index.html', - 'tensorboard/TAG', ] + matches, }, zip_safe=False, diff --git a/tensorflow/tools/test/upload_test_benchmarks_index.yaml b/tensorflow/tools/test/upload_test_benchmarks_index.yaml index 8cd33a1da60..ec7f76f6663 100644 --- a/tensorflow/tools/test/upload_test_benchmarks_index.yaml +++ b/tensorflow/tools/test/upload_test_benchmarks_index.yaml @@ -27,7 +27,7 @@ indexes: properties: - name: test - name: start - direction: asc + direction: desc # Index to access a specific (test, entry, start) Entity, and also to be able to # fetch a range of (start, timing) graph values for a given (test, entry) pair diff --git a/tensorflow/tools/tfprof/g3doc/options.md b/tensorflow/tools/tfprof/g3doc/options.md index 78c72bf5edd..57f67c66fa0 100644 --- a/tensorflow/tools/tfprof/g3doc/options.md +++ b/tensorflow/tools/tfprof/g3doc/options.md @@ -49,7 +49,7 @@ In graph view, in means the number of hops in the graph. `-step`: Show the stats of the this step when multiple steps of RunMetadata were added. By default, show the average of all steps." -`-order_by`: Order the results by [name|depth|bytes|micros|params|float_ops|occurrence] +`-order_by`: Order the results by [name|depth|bytes|micros|accelerator_micros|cpu_micros|params|float_ops|occurrence] `-account_type_regexes`: Account and display the ops whose types match one of the type regexes specified. tfprof allow user to define extra op types for ops through tensorflow.tfprof.OpLog proto. regexes are comma-sperated. @@ -73,7 +73,7 @@ as long as they match the `-account_xxx` options. `-account_displayed_op_only`: If True, only account the statistics of ops eventually displayed. If False, account all op statistics matching -account_type_regexes recursively. `-select`: Comma-separated list of metrics to show: -[bytes|micros|params|float_ops|occurrence|tensor_value|device|op_types|input_shapes]. +[bytes|micros|accelerator_micros|cpu_micros|params|float_ops|occurrence|tensor_value|device|op_types|input_shapes]. `-output`: Output results as stdout, file or timeline. The format is ```output_type:key=value,key=value```. diff --git a/tensorflow/tools/tfprof/internal/BUILD b/tensorflow/tools/tfprof/internal/BUILD index 9b77b0fb3f2..c47008b0694 100644 --- a/tensorflow/tools/tfprof/internal/BUILD +++ b/tensorflow/tools/tfprof/internal/BUILD @@ -185,6 +185,7 @@ cc_library( ":tfprof_node_show", ":tfprof_options", ":tfprof_scope", + ":tfprof_show", ":tfprof_tensor", ":tfprof_timeline", ":tfprof_utils", diff --git a/tensorflow/tools/tfprof/internal/advisor/accelerator_utilization_checker.h b/tensorflow/tools/tfprof/internal/advisor/accelerator_utilization_checker.h index 8f256584f7b..fb7f65d7dc9 100644 --- a/tensorflow/tools/tfprof/internal/advisor/accelerator_utilization_checker.h +++ b/tensorflow/tools/tfprof/internal/advisor/accelerator_utilization_checker.h @@ -76,7 +76,7 @@ class AcceleratorUtilizationChecker : public Checker { if (execs.empty()) { return; } - if (!IsAcceleratorDevice(node->canonical_device())) { + if (!IsPlacedOnAccelerator(node->canonical_device())) { return; } diff --git a/tensorflow/tools/tfprof/internal/advisor/operation_checker.h b/tensorflow/tools/tfprof/internal/advisor/operation_checker.h index 78132e3a460..2a05f9bfd02 100644 --- a/tensorflow/tools/tfprof/internal/advisor/operation_checker.h +++ b/tensorflow/tools/tfprof/internal/advisor/operation_checker.h @@ -47,7 +47,7 @@ class OperationChecker : public Checker { if (node->op_attrs().find("data_format") != node->op_attrs().end()) { const AttrValue* attr_val = node->op_attrs().at("data_format"); if (attr_val->s() == "NHWC" && - IsAcceleratorDevice(node->canonical_device())) { + IsPlacedOnAccelerator(node->canonical_device())) { recommend_nchw = true; } } diff --git a/tensorflow/tools/tfprof/internal/tfprof_code.cc b/tensorflow/tools/tfprof/internal/tfprof_code.cc index f328e3b0cd3..70e341e8d4b 100644 --- a/tensorflow/tools/tfprof/internal/tfprof_code.cc +++ b/tensorflow/tools/tfprof/internal/tfprof_code.cc @@ -264,15 +264,9 @@ string TFCode::FormatNode(CodeNode* node, const Options& opts, int64 indent) { } attrs.push_back(memory); } - if (opts.select.find(kShown[1]) != opts.select.end()) { - string time = FormatTime(node->proto().total_exec_micros()); - if (node->account) { - time = FormatTime(node->proto().exec_micros()) + "/" + time; - } else { - time = "--/" + time; - } - attrs.push_back(time); - } + std::vector time_attrs = FormatTimes(node, opts); + attrs.insert(attrs.end(), time_attrs.begin(), time_attrs.end()); + if (opts.select.find(kShown[5]) != opts.select.end() && !node->node->devices().empty()) { attrs.push_back(str_util::Join(node->node->devices(), "|")); diff --git a/tensorflow/tools/tfprof/internal/tfprof_node.cc b/tensorflow/tools/tfprof/internal/tfprof_node.cc index 6353813a26c..72f74902b5a 100644 --- a/tensorflow/tools/tfprof/internal/tfprof_node.cc +++ b/tensorflow/tools/tfprof/internal/tfprof_node.cc @@ -19,6 +19,18 @@ limitations under the License. namespace tensorflow { namespace tfprof { +namespace { +bool CountAsAcceleratorTime(const string& device) { + return device.find("stream:all") != device.npos; +} + +bool CountAsCPUTime(const string& device) { + return RE2::FullMatch(device, ".*/(gpu|cpu):\\d+"); +} + +bool IsCanonicalDevice(const string& device) { return CountAsCPUTime(device); } + +} // namespace // Notes about start and end time from the NodeExecStats proto: // For GPU, there is no difference between op_end_rel_micros and // all_end_rel_micros. All are kernel times. @@ -52,13 +64,16 @@ void ExecStep::AddTimeStats(const string& dev, const NodeExecStats& step_stat) { latest_end_micros_ = std::max( latest_end_micros_, step_stat.all_start_micros() + op_end_rel_micros); - op_execs_[dev].push_back( - std::make_pair(step_stat.all_start_micros(), op_end_rel_micros)); - - // TODO(xpan): Can a stream only in stream:all or doesn't in stream at all? - if (dev.find("stream") != dev.npos && dev.find("stream:all") == dev.npos) { - gpu_kernel_execs_[dev].push_back( - std::make_pair(step_stat.all_start_micros(), op_end_rel_micros)); + const std::pair pair = + std::make_pair(step_stat.all_start_micros(), op_end_rel_micros); + if (CountAsAcceleratorTime(dev)) { + accelerator_execs_[dev].push_back(pair); + op_execs_[dev].push_back(pair); + } else if (CountAsCPUTime(dev)) { + // TODO(xpan): A while-loop can has multiple nodes sharing the + // same name. They shouldn't be counted in one node. + cpu_execs_[dev].push_back(pair); + op_execs_[dev].push_back(pair); } } } @@ -113,8 +128,11 @@ void TFGraphNode::AddStepStat(int64 step, const string& device, const NodeExecStats& step_stat) { string dev = str_util::Lowercase(device); - // TODO(xpan): Test it. - if (RE2::FullMatch(dev, "/job:.*/replica:\\d+/task:\\d+/[a-z]+:\\d+")) { + // TODO(xpan): Make this more robust? + // See run_metadata_test.py + // It can be /job:0/replica:0/xxxx/gpu:0, or simply /gpu:0. + // It can has some ad-hoc suffix, such as /stream:xx or /memcpy:xx. + if (IsCanonicalDevice(device)) { if (!canonical_device_.empty()) { if (canonical_device_ != dev) { fprintf(stderr, "Unexpected: graph node changed device: %s->%s.\n", @@ -143,26 +161,15 @@ void TFGraphNode::AddStepStat(int64 step, const string& device, } int64 ExecStep::exec_micros() const { - int64 total = accelerator_exec_micros(); - if (total > 0) return total; - - // If there is no gpu kernel time, fall back to assume it runs on cpu. - // TODO(xpan): No way to track CPU async op timing accurately? - if (op_execs_.size() > 1) { - fprintf(stderr, "Op: %s has over 1 no-gpu assignment\n", - node->name().c_str()); - } - for (const auto& execs : op_execs_) { - for (const auto& exec : execs.second) { - total += exec.second; - } - } - return total; + return accelerator_exec_micros() + cpu_exec_micros(); } int64 ExecStep::accelerator_exec_micros() const { int64 total = 0; - for (const auto& execs : gpu_kernel_execs_) { + // Normally, an op should only be scheduled on 1 accelerator device. + // Hence there should generally be 1 element in accelerator_execs_. + for (const auto& execs : accelerator_execs_) { + // A op can fire multiple kernel runs hence multiple elements here. for (const auto& exec : execs.second) { total += exec.second; } @@ -170,12 +177,20 @@ int64 ExecStep::accelerator_exec_micros() const { return total; } -bool IsCombinedGPUStream(const string& device) { - return device.find("stream:all") != device.npos; -} - -bool IsCPUDevice(const string& device) { - return device.find("cpu:0") != device.npos; +int64 ExecStep::cpu_exec_micros() const { + int64 total = 0; + // Here we use for loop just for consistent appearence with + // accelerator_execs. + // We only expect cpu_execs_ to have 1 element because an + // op can only be scheduled on 1 device. + for (const auto& execs : cpu_execs_) { + // We only expect exec to have 1 element because an op + // can only be schedule once. + for (const auto& exec : execs.second) { + total += exec.second; + } + } + return total; } std::vector ShapeProtoToVec(const TensorShapeProto& shape_pb) { @@ -203,7 +218,7 @@ TensorShapeProto VecToShapeProto(const std::vector shape_vec) { return shape_pb; } -bool IsAcceleratorDevice(const string& device) { +bool IsPlacedOnAccelerator(const string& device) { return device.find("gpu") != device.npos; } } // namespace tfprof diff --git a/tensorflow/tools/tfprof/internal/tfprof_node.h b/tensorflow/tools/tfprof/internal/tfprof_node.h index d788f2acf4d..6fd65c1ee7e 100644 --- a/tensorflow/tools/tfprof/internal/tfprof_node.h +++ b/tensorflow/tools/tfprof/internal/tfprof_node.h @@ -63,9 +63,10 @@ class ExecStep { // The execution time of an op. If it runs on accelerator, then it's // accelerator_exec_micros(). Otherwise, it's CPU time. int64 exec_micros() const; - - // The execution time of an op. 0 if it runs on cpu. + // The accelerator execution time of an op. 0 if not run on accelerator. int64 accelerator_exec_micros() const; + // The cpu execution time of an op. + int64 cpu_exec_micros() const; const std::map>>& op_execs() const { @@ -92,10 +93,13 @@ class ExecStep { int64 all_start_micros_; int64 latest_end_micros_; // device -> vector of {op_start_micros, op_exec_micros} pairs. - // For accelerator op, op_start_micros and op_exec_micros are kernel time. - // For cpu op, op_start_micros and op_exec_micros are scheduling time. ( - // might include compute time if it's sync op). - std::map>> gpu_kernel_execs_; + // accelerator_execs: gpu:id/stream:all -> {op_start_micros, op_exec_micros} + // For accelerator, vector size can be larger than 1, multiple kernel fires. + std::map>> accelerator_execs_; + // cpu_execs: cpu/gpu:id -> {op_start_micros, op_exec_micros} + // For cpu, normally vector size is 1, that is only one run. + std::map>> cpu_execs_; + // combines accelerator_execs_ and cpu_execs_. std::map>> op_execs_; // All devices the op is associated with (e.g. gpu:0 (scheduling), // gpu:0:stream:xx (kernel exec), cpu:0 host) @@ -184,8 +188,10 @@ class TFGraphNode { return src_output_idx_; } - // This is time spent in kernel execution. - int64 kernel_exec_micros(int64 step) const { + // This is overall computation time, including both cpu and accelerator. + // Note, cpu and accelerator might or might not run in parallel. + int64 exec_micros(int64 step) const { + // Empty when no RunMetadata is provided. if (execs_.empty()) { return 0; } @@ -202,6 +208,46 @@ class TFGraphNode { return total_micros / execs_.size(); } + // This is accelerator computation time of a step, or average of + // multiple step, when step < 0. + int64 accelerator_exec_micros(int64 step) const { + // Empty when no RunMetadata is provided. + if (execs_.empty()) { + return 0; + } + if (step >= 0) { + auto exec = execs_.find(step); + CHECK(exec != execs_.end()); + return exec->second.accelerator_exec_micros(); + } + + int64 total_micros = 0; + for (const auto& exec : execs_) { + total_micros += exec.second.accelerator_exec_micros(); + } + return total_micros / execs_.size(); + } + + // This is cpu computation time of a step, or average of + // multiple step, when step < 0. + int64 cpu_exec_micros(int64 step) const { + // Empty when no RunMetadata is provided. + if (execs_.empty()) { + return 0; + } + if (step >= 0) { + auto exec = execs_.find(step); + CHECK(exec != execs_.end()); + return exec->second.cpu_exec_micros(); + } + + int64 total_micros = 0; + for (const auto& exec : execs_) { + total_micros += exec.second.cpu_exec_micros(); + } + return total_micros / execs_.size(); + } + int64 requested_bytes(int64 step) const { if (execs_.empty()) { return 0; @@ -325,12 +371,17 @@ class TFMultiGraphNode { public: TFMultiGraphNode(const string& name) : name_(name), - kernel_exec_micros_(0), + exec_micros_(0), + accelerator_exec_micros_(0), + cpu_exec_micros_(0), requested_bytes_(0), float_ops_(0) {} bool SnapshotNodes(int64 step, const std::vector& type_regexes) { - kernel_exec_micros_ = 0; + exec_micros_ = 0; + accelerator_exec_micros_ = 0; + cpu_exec_micros_ = 0; + requested_bytes_ = 0; float_ops_ = 0; op_types_.clear(); @@ -347,7 +398,10 @@ class TFMultiGraphNode { for (const TFGraphNode* node : nodes) { op_types_.insert(node->op_types().begin(), node->op_types().end()); - kernel_exec_micros_ += node->kernel_exec_micros(step); + exec_micros_ += node->exec_micros(step); + accelerator_exec_micros_ += node->accelerator_exec_micros(step); + cpu_exec_micros_ += node->cpu_exec_micros(step); + requested_bytes_ += node->requested_bytes(step); float_ops_ += node->float_ops(); if (node->shape().size() > 0) { @@ -382,7 +436,9 @@ class TFMultiGraphNode { const string& name() const { return name_; } - int64 kernel_exec_micros() const { return kernel_exec_micros_; } + int64 exec_micros() const { return exec_micros_; } + int64 accelerator_exec_micros() const { return accelerator_exec_micros_; } + int64 cpu_exec_micros() const { return cpu_exec_micros_; } int64 requested_bytes() const { return requested_bytes_; } @@ -424,7 +480,10 @@ class TFMultiGraphNode { const string name_; // Snapshot based on type_regexes std::set op_types_; - int64 kernel_exec_micros_; + int64 exec_micros_; + int64 accelerator_exec_micros_; + int64 cpu_exec_micros_; + int64 requested_bytes_; int64 float_ops_; std::set devices_; @@ -436,9 +495,7 @@ class TFMultiGraphNode { std::map> children_; }; -bool IsCombinedGPUStream(const string& device); -bool IsCPUDevice(const string& device); -bool IsAcceleratorDevice(const string& device); +bool IsPlacedOnAccelerator(const string& device); } // namespace tfprof } // namespace tensorflow diff --git a/tensorflow/tools/tfprof/internal/tfprof_node_show.cc b/tensorflow/tools/tfprof/internal/tfprof_node_show.cc index 7b604e091a7..7b22e88e079 100644 --- a/tensorflow/tools/tfprof/internal/tfprof_node_show.cc +++ b/tensorflow/tools/tfprof/internal/tfprof_node_show.cc @@ -31,7 +31,11 @@ void ShowNode::ReInit(int64 step) { if (!node->canonical_device().empty()) { mutable_proto()->add_devices(node->canonical_device()); } - mutable_proto()->set_exec_micros(node->kernel_exec_micros(step)); + mutable_proto()->set_exec_micros(node->exec_micros(step)); + mutable_proto()->set_accelerator_exec_micros( + node->accelerator_exec_micros(step)); + mutable_proto()->set_cpu_exec_micros(node->cpu_exec_micros(step)); + mutable_proto()->set_requested_bytes(node->requested_bytes(step)); mutable_proto()->set_float_ops(node->float_ops()); @@ -69,6 +73,12 @@ void ShowNode::AggregateTotalStats(ShowNode* node) { TFGraphNodeProto* node_pb = node->mutable_proto(); mutable_proto()->set_total_exec_micros(proto().total_exec_micros() + node_pb->total_exec_micros()); + mutable_proto()->set_total_accelerator_exec_micros( + proto().total_accelerator_exec_micros() + + node_pb->total_accelerator_exec_micros()); + mutable_proto()->set_total_cpu_exec_micros(proto().total_cpu_exec_micros() + + node_pb->total_cpu_exec_micros()); + mutable_proto()->set_total_requested_bytes(proto().total_requested_bytes() + node_pb->total_requested_bytes()); mutable_proto()->set_total_parameters(proto().total_parameters() + @@ -80,6 +90,12 @@ void ShowNode::AggregateTotalStats(ShowNode* node) { void ShowNode::AddSelfToTotalStats() { mutable_proto()->set_total_exec_micros(proto().total_exec_micros() + proto().exec_micros()); + mutable_proto()->set_total_accelerator_exec_micros( + proto().total_accelerator_exec_micros() + + proto().accelerator_exec_micros()); + mutable_proto()->set_total_cpu_exec_micros(proto().total_cpu_exec_micros() + + proto().cpu_exec_micros()); + mutable_proto()->set_total_requested_bytes(proto().total_requested_bytes() + proto().requested_bytes()); mutable_proto()->set_total_parameters(proto().total_parameters() + @@ -90,6 +106,9 @@ void ShowNode::AddSelfToTotalStats() { void ShowNode::ResetTotalStats() { mutable_proto()->set_total_exec_micros(0); + mutable_proto()->set_total_accelerator_exec_micros(0); + mutable_proto()->set_total_cpu_exec_micros(0); + mutable_proto()->set_total_requested_bytes(0); mutable_proto()->set_total_parameters(0); mutable_proto()->set_total_float_ops(0); @@ -116,7 +135,10 @@ bool ShowMultiNode::ReInit(int64 step, } mutable_proto()->set_name(name()); - mutable_proto()->set_exec_micros(node->kernel_exec_micros()); + mutable_proto()->set_exec_micros(node->exec_micros()); + mutable_proto()->set_accelerator_exec_micros(node->accelerator_exec_micros()); + mutable_proto()->set_cpu_exec_micros(node->cpu_exec_micros()); + mutable_proto()->set_requested_bytes(node->requested_bytes()); mutable_proto()->set_float_ops(node->float_ops()); @@ -151,6 +173,12 @@ void ShowMultiNode::AggregateTotalStats(ShowMultiNode* node) { TFMultiGraphNodeProto* node_pb = node->mutable_proto(); mutable_proto()->set_total_exec_micros(proto().total_exec_micros() + node_pb->total_exec_micros()); + mutable_proto()->set_total_accelerator_exec_micros( + proto().total_accelerator_exec_micros() + + node_pb->total_accelerator_exec_micros()); + mutable_proto()->set_total_cpu_exec_micros(proto().total_cpu_exec_micros() + + node_pb->total_cpu_exec_micros()); + mutable_proto()->set_total_requested_bytes(proto().total_requested_bytes() + node_pb->total_requested_bytes()); mutable_proto()->set_total_parameters(proto().total_parameters() + @@ -162,6 +190,12 @@ void ShowMultiNode::AggregateTotalStats(ShowMultiNode* node) { void ShowMultiNode::AddSelfToTotalStats() { mutable_proto()->set_total_exec_micros(proto().total_exec_micros() + proto().exec_micros()); + mutable_proto()->set_total_accelerator_exec_micros( + proto().total_accelerator_exec_micros() + + proto().accelerator_exec_micros()); + mutable_proto()->set_total_cpu_exec_micros(proto().total_cpu_exec_micros() + + proto().cpu_exec_micros()); + mutable_proto()->set_total_requested_bytes(proto().total_requested_bytes() + proto().requested_bytes()); mutable_proto()->set_total_parameters(proto().total_parameters() + @@ -172,6 +206,9 @@ void ShowMultiNode::AddSelfToTotalStats() { void ShowMultiNode::ResetTotalStats() { mutable_proto()->set_total_exec_micros(0); + mutable_proto()->set_total_accelerator_exec_micros(0); + mutable_proto()->set_total_cpu_exec_micros(0); + mutable_proto()->set_total_requested_bytes(0); mutable_proto()->set_total_parameters(0); mutable_proto()->set_total_float_ops(0); diff --git a/tensorflow/tools/tfprof/internal/tfprof_op.cc b/tensorflow/tools/tfprof/internal/tfprof_op.cc index 655569f1a28..ac702320b37 100644 --- a/tensorflow/tools/tfprof/internal/tfprof_op.cc +++ b/tensorflow/tools/tfprof/internal/tfprof_op.cc @@ -26,6 +26,60 @@ limitations under the License. namespace tensorflow { namespace tfprof { +namespace { +string FormatToalExecTime(const ShowMultiNode* node, + const ShowMultiNode* root) { + double accu_pct = 0.0; + double pct = 0.0; + if (node->proto().total_exec_micros() > 0) { + accu_pct = 100.0 * node->proto().total_exec_micros() / + root->proto().total_exec_micros(); + pct = + 100.0 * node->proto().exec_micros() / root->proto().total_exec_micros(); + } + + return strings::Printf( + "%30s", strings::Printf("%s (%.2f%%, %.2f%%)", + FormatTime(node->proto().exec_micros()).c_str(), + accu_pct, pct) + .c_str()); +} +string FormatCPUExecTime(const ShowMultiNode* node, const ShowMultiNode* root) { + double accu_pct = 0.0; + double pct = 0.0; + if (node->proto().total_cpu_exec_micros() > 0) { + accu_pct = 100.0 * node->proto().total_cpu_exec_micros() / + root->proto().total_cpu_exec_micros(); + pct = 100.0 * node->proto().cpu_exec_micros() / + root->proto().total_cpu_exec_micros(); + } + + return strings::Printf( + "%30s", + strings::Printf("%s (%.2f%%, %.2f%%)", + FormatTime(node->proto().cpu_exec_micros()).c_str(), + accu_pct, pct) + .c_str()); +} +string FormatAcceleratorExecTime(const ShowMultiNode* node, + const ShowMultiNode* root) { + double accu_pct = 0.0; + double pct = 0.0; + if (node->proto().total_accelerator_exec_micros() > 0) { + accu_pct = 100.0 * node->proto().total_accelerator_exec_micros() / + root->proto().total_accelerator_exec_micros(); + pct = 100.0 * node->proto().accelerator_exec_micros() / + root->proto().total_accelerator_exec_micros(); + } + + return strings::Printf( + "%30s", strings::Printf( + "%s (%.2f%%, %.2f%%)", + FormatTime(node->proto().accelerator_exec_micros()).c_str(), + accu_pct, pct) + .c_str()); +} +} // namespace void TFOp::AddNode(TFGraphNode* node) { const string& op = node->op(); @@ -168,22 +222,18 @@ string TFOp::FormatNode(OpNode* node, OpNode* root, const Options& opts) { } if (opts.select.find(kShown[1]) != opts.select.end()) { - double accu_pct = 0.0; - double pct = 0.0; - if (node->proto().total_exec_micros() > 0) { - accu_pct = 100.0 * node->proto().total_exec_micros() / - root->proto().total_exec_micros(); - pct = 100.0 * node->proto().exec_micros() / - root->proto().total_exec_micros(); - } - - attrs.push_back(strings::Printf( - "%30s", strings::Printf("%s (%.2f%%, %.2f%%)", - FormatTime(node->proto().exec_micros()).c_str(), - accu_pct, pct) - .c_str())); + attrs.push_back(FormatToalExecTime(node, root)); + attrs.push_back(FormatAcceleratorExecTime(node, root)); + attrs.push_back(FormatCPUExecTime(node, root)); + } + if (opts.select.find(kShown[9]) != opts.select.end() && + opts.select.find(kShown[1]) == opts.select.end()) { + attrs.push_back(FormatAcceleratorExecTime(node, root)); + } + if (opts.select.find(kShown[10]) != opts.select.end() && + opts.select.find(kShown[1]) == opts.select.end()) { + attrs.push_back(FormatCPUExecTime(node, root)); } - if (opts.select.find(kShown[2]) != opts.select.end()) { double accu_pct = 0.0; double pct = 0.0; diff --git a/tensorflow/tools/tfprof/internal/tfprof_options.h b/tensorflow/tools/tfprof/internal/tfprof_options.h index d8c172e0a2c..6c9db243422 100644 --- a/tensorflow/tools/tfprof/internal/tfprof_options.h +++ b/tensorflow/tools/tfprof/internal/tfprof_options.h @@ -46,15 +46,17 @@ static const char* const kOptions[] = { }; static const char* const kOrderBy[] = { - "name", "bytes", "micros", "params", "float_ops", "occurrence", + "name", "bytes", "micros", "accelerator_micros", + "cpu_micros", "params", "float_ops", "occurrence", }; // Append Only. // TODO(xpan): As we are adding more fields to be selected, we // need to have a way to tell users what fields are available in which view. static const char* const kShown[] = { - "bytes", "micros", "params", "float_ops", "tensor_value", - "device", "op_types", "occurrence", "input_shapes"}; + "bytes", "micros", "params", "float_ops", "tensor_value", + "device", "op_types", "occurrence", "input_shapes", "accelerator_micros", + "cpu_micros"}; static const char* const kCmds[] = { "scope", "graph", "code", "op", "set", "help", diff --git a/tensorflow/tools/tfprof/internal/tfprof_show.cc b/tensorflow/tools/tfprof/internal/tfprof_show.cc index 517a09f0c74..eaab5b96087 100644 --- a/tensorflow/tools/tfprof/internal/tfprof_show.cc +++ b/tensorflow/tools/tfprof/internal/tfprof_show.cc @@ -156,13 +156,17 @@ string TFShow::FormatNode(ShowNode* node, const Options& opts) { info.push_back(memory); } if (opts.select.find(kShown[1]) != opts.select.end()) { - string time = FormatTime(node->proto().total_exec_micros()); - if (node->account) { - time = FormatTime(node->proto().exec_micros()) + "/" + time; - } else { - time = "--/" + time; - } - info.push_back(time); + info.push_back(FormatTotalExecTime(node, opts)); + info.push_back(FormatAcceleratorExecTime(node, opts)); + info.push_back(FormatCPUExecTime(node, opts)); + } + if (opts.select.find(kShown[9]) != opts.select.end() && + opts.select.find(kShown[1]) == opts.select.end()) { + info.push_back(FormatAcceleratorExecTime(node, opts)); + } + if (opts.select.find(kShown[10]) != opts.select.end() && + opts.select.find(kShown[1]) == opts.select.end()) { + info.push_back(FormatCPUExecTime(node, opts)); } if (opts.select.find(kShown[5]) != opts.select.end()) { if (node->proto().devices_size() > 0) { @@ -202,7 +206,17 @@ string TFShow::FormatLegend(const Options& opts) { legends.push_back("output bytes"); } if (opts.select.find(kShown[1]) != opts.select.end()) { - legends.push_back("execution time"); + legends.push_back("total execution time"); + legends.push_back("accelerator execution time"); + legends.push_back("cpu execution time"); + } + if (opts.select.find(kShown[9]) != opts.select.end() && + opts.select.find(kShown[1]) == opts.select.end()) { + legends.push_back("accelerator execution time"); + } + if (opts.select.find(kShown[10]) != opts.select.end() && + opts.select.find(kShown[1]) == opts.select.end()) { + legends.push_back("cpu execution time"); } if (opts.select.find(kShown[5]) != opts.select.end()) { legends.push_back("assigned devices"); diff --git a/tensorflow/tools/tfprof/internal/tfprof_show.h b/tensorflow/tools/tfprof/internal/tfprof_show.h index a337a584f7a..2c61b4fd732 100644 --- a/tensorflow/tools/tfprof/internal/tfprof_show.h +++ b/tensorflow/tools/tfprof/internal/tfprof_show.h @@ -101,6 +101,36 @@ class TFShow { checkpoint::CheckpointReader* ckpt_reader_; }; +template +string FormatTotalExecTime(const T* node, const Options& opts) { + string time = FormatTime(node->proto().total_exec_micros()); + if (node->account) { + time = FormatTime(node->proto().exec_micros()) + "/" + time; + } else { + time = "--/" + time; + } + return time; +} +template +string FormatCPUExecTime(const T* node, const Options& opts) { + string time = FormatTime(node->proto().total_cpu_exec_micros()); + if (node->account) { + time = FormatTime(node->proto().cpu_exec_micros()) + "/" + time; + } else { + time = "--/" + time; + } + return time; +} +template +string FormatAcceleratorExecTime(const T* node, const Options& opts) { + string time = FormatTime(node->proto().total_accelerator_exec_micros()); + if (node->account) { + time = FormatTime(node->proto().accelerator_exec_micros()) + "/" + time; + } else { + time = "--/" + time; + } + return time; +} } // namespace tfprof } // namespace tensorflow diff --git a/tensorflow/tools/tfprof/internal/tfprof_show_multi.cc b/tensorflow/tools/tfprof/internal/tfprof_show_multi.cc index 4714ffc33a6..97f204d25bf 100644 --- a/tensorflow/tools/tfprof/internal/tfprof_show_multi.cc +++ b/tensorflow/tools/tfprof/internal/tfprof_show_multi.cc @@ -108,7 +108,17 @@ string TFMultiShow::FormatLegend(const Options& opts) { legends.push_back("output bytes"); } if (opts.select.find(kShown[1]) != opts.select.end()) { - legends.push_back("execution time"); + legends.push_back("total execution time"); + legends.push_back("accelerator execution time"); + legends.push_back("cpu execution time"); + } + if (opts.select.find(kShown[9]) != opts.select.end() && + opts.select.find(kShown[1]) == opts.select.end()) { + legends.push_back("accelerator execution time"); + } + if (opts.select.find(kShown[10]) != opts.select.end() && + opts.select.find(kShown[1]) == opts.select.end()) { + legends.push_back("cpu execution time"); } if (opts.select.find(kShown[2]) != opts.select.end()) { legends.push_back("# parameters"); @@ -179,5 +189,24 @@ string TFMultiShow::FormatInputShapes(const TFMultiGraphNodeProto& proto) { return str_util::Join(input_types, "\n"); } +std::vector TFMultiShow::FormatTimes(const ShowMultiNode* node, + const Options& opts) { + std::vector attrs; + if (opts.select.find(kShown[1]) != opts.select.end()) { + attrs.push_back(FormatTotalExecTime(node, opts)); + attrs.push_back(FormatAcceleratorExecTime(node, opts)); + attrs.push_back(FormatCPUExecTime(node, opts)); + } + if (opts.select.find(kShown[9]) != opts.select.end() && + opts.select.find(kShown[1]) == opts.select.end()) { + attrs.push_back(FormatAcceleratorExecTime(node, opts)); + } + if (opts.select.find(kShown[10]) != opts.select.end() && + opts.select.find(kShown[1]) == opts.select.end()) { + attrs.push_back(FormatCPUExecTime(node, opts)); + } + return attrs; +} + } // namespace tfprof } // namespace tensorflow diff --git a/tensorflow/tools/tfprof/internal/tfprof_show_multi.h b/tensorflow/tools/tfprof/internal/tfprof_show_multi.h index 1181e45ee18..e6faf1231dc 100644 --- a/tensorflow/tools/tfprof/internal/tfprof_show_multi.h +++ b/tensorflow/tools/tfprof/internal/tfprof_show_multi.h @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/tools/tfprof/internal/tfprof_node.h" #include "tensorflow/tools/tfprof/internal/tfprof_node_show.h" #include "tensorflow/tools/tfprof/internal/tfprof_options.h" +#include "tensorflow/tools/tfprof/internal/tfprof_show.h" #include "tensorflow/tools/tfprof/internal/tfprof_tensor.h" #include "tensorflow/tools/tfprof/internal/tfprof_timeline.h" #include "tensorflow/tools/tfprof/internal/tfprof_utils.h" @@ -67,6 +68,8 @@ class TFMultiShow { string FormatLegend(const Options& opts); string FormatInputShapes(const TFMultiGraphNodeProto& proto); + std::vector FormatTimes(const ShowMultiNode* node, + const Options& opts); template std::vector SortNodes(const std::vector& nodes, const Options& opts) { @@ -88,12 +91,18 @@ class TFMultiShow { return n1->proto().total_exec_micros() > n2->proto().total_exec_micros(); } else if (opts.order_by == kOrderBy[3]) { + return n1->proto().total_accelerator_exec_micros() > + n2->proto().total_accelerator_exec_micros(); + } else if (opts.order_by == kOrderBy[4]) { + return n1->proto().total_cpu_exec_micros() > + n2->proto().total_cpu_exec_micros(); + } else if (opts.order_by == kOrderBy[5]) { return n1->proto().total_parameters() > n2->proto().total_parameters(); - } else if (opts.order_by == kOrderBy[4]) { + } else if (opts.order_by == kOrderBy[6]) { return n1->proto().total_float_ops() > n2->proto().total_float_ops(); - } else if (opts.order_by == kOrderBy[5]) { + } else if (opts.order_by == kOrderBy[7]) { return n1->node->graph_nodes().size() > n2->node->graph_nodes().size(); } diff --git a/tensorflow/tools/tfprof/internal/tfprof_show_test.cc b/tensorflow/tools/tfprof/internal/tfprof_show_test.cc index 498477de0a0..478e269f878 100644 --- a/tensorflow/tools/tfprof/internal/tfprof_show_test.cc +++ b/tensorflow/tools/tfprof/internal/tfprof_show_test.cc @@ -38,7 +38,8 @@ class TFProfShowTest : public ::testing::Test { io::JoinPath(testing::TensorFlowSrcRoot(), "tools/tfprof/internal/testdata/graph.pbtxt"); std::unique_ptr graph_pb(new tensorflow::GraphDef()); - TF_CHECK_OK(ReadGraphDef(Env::Default(), graph_path, graph_pb.get())); + TF_CHECK_OK( + ReadProtoFile(Env::Default(), graph_path, graph_pb.get(), false)); std::unique_ptr run_meta_pb( new tensorflow::RunMetadata()); @@ -46,7 +47,7 @@ class TFProfShowTest : public ::testing::Test { io::JoinPath(testing::TensorFlowSrcRoot(), "tools/tfprof/internal/testdata/run_meta"); TF_CHECK_OK( - ReadBinaryProto(Env::Default(), run_meta_path, run_meta_pb.get())); + ReadProtoFile(Env::Default(), run_meta_path, run_meta_pb.get(), true)); std::unique_ptr op_log_pb(new OpLog()); string op_log_path = @@ -81,14 +82,73 @@ TEST_F(TFProfShowTest, DumpScopeMode) { string dump_str; TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file, &dump_str)); EXPECT_EQ( - "node name | # parameters | # float_ops | output bytes | execution " - "time\n_TFProfRoot (--/370 params, --/0 flops, --/1.48KB, --/5us)\n " - "conv2d (--/140 params, --/0 flops, --/560B, --/2us)\n conv2d/bias " - "(5, 5/5 params, 0/0 flops, 20B/20B, 1us/1us)\n conv2d/kernel " - "(3x3x3x5, 135/135 params, 0/0 flops, 540B/540B, 1us/1us)\n conv2d_1 " - "(--/230 params, --/0 flops, --/920B, --/3us)\n conv2d_1/bias (5, 5/5 " - "params, 0/0 flops, 20B/20B, 1us/1us)\n conv2d_1/kernel (3x3x5x5, " - "225/225 params, 0/0 flops, 900B/900B, 2us/2us)\n", + "node name | # parameters | # float_ops | output bytes | total execution " + "time | accelerator execution time | cpu execution time\n_TFProfRoot " + "(--/370 params, --/0 flops, --/1.48KB, --/5us, --/0us, --/5us)\n " + "conv2d (--/140 params, --/0 flops, --/560B, --/2us, --/0us, --/2us)\n " + " conv2d/bias (5, 5/5 params, 0/0 flops, 20B/20B, 1us/1us, 0us/0us, " + "1us/1us)\n conv2d/kernel (3x3x3x5, 135/135 params, 0/0 flops, " + "540B/540B, 1us/1us, 0us/0us, 1us/1us)\n conv2d_1 (--/230 params, --/0 " + "flops, --/920B, --/3us, --/0us, --/3us)\n conv2d_1/bias (5, 5/5 " + "params, 0/0 flops, 20B/20B, 1us/1us, 0us/0us, 1us/1us)\n " + "conv2d_1/kernel (3x3x5x5, 225/225 params, 0/0 flops, 900B/900B, " + "2us/2us, 0us/0us, 2us/2us)\n", + dump_str); +} + +TEST_F(TFProfShowTest, DumpAcceleratorAndCPUMicros) { + string dump_file = io::JoinPath(testing::TmpDir(), "dump"); + Options opts( + 5, 0, 0, 0, 0, 0, -1, "cpu_micros", {".*"}, // accout_type_regexes + {".*"}, {""}, {".*"}, {""}, false, {"accelerator_micros", "cpu_micros"}, + "file", {{"outfile", dump_file}}); + tf_stats_->ShowGraphNode("scope", opts); + + string dump_str; + TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file, &dump_str)); + EXPECT_EQ( + "node name | accelerator execution time | cpu execution " + "time\n_TFProfRoot (--/0us, --/97us)\n conv2d (0us/0us, 0us/76us)\n " + "conv2d/convolution (0us/0us, 60us/60us)\n conv2d/convolution/Shape " + "(0us/0us, 0us/0us)\n conv2d/convolution/dilation_rate (0us/0us, " + "0us/0us)\n conv2d/BiasAdd (0us/0us, 12us/12us)\n conv2d/bias " + "(0us/0us, 1us/2us)\n conv2d/bias/Assign (0us/0us, 0us/0us)\n " + "conv2d/bias/Initializer (0us/0us, 0us/0us)\n " + "conv2d/bias/Initializer/Const (0us/0us, 0us/0us)\n " + "conv2d/bias/read (0us/0us, 1us/1us)\n conv2d/kernel (0us/0us, " + "1us/2us)\n conv2d/kernel/Assign (0us/0us, 0us/0us)\n " + "conv2d/kernel/Initializer (0us/0us, 0us/0us)\n " + "conv2d/kernel/Initializer/random_uniform (0us/0us, 0us/0us)\n " + "conv2d/kernel/read (0us/0us, 1us/1us)\n conv2d_2 (0us/0us, 0us/15us)\n " + " conv2d_2/convolution (0us/0us, 13us/13us)\n " + "conv2d_2/convolution/Shape (0us/0us, 0us/0us)\n " + "conv2d_2/convolution/dilation_rate (0us/0us, 0us/0us)\n " + "conv2d_2/BiasAdd (0us/0us, 2us/2us)\n conv2d_1 (0us/0us, 0us/5us)\n " + "conv2d_1/bias (0us/0us, 1us/2us)\n conv2d_1/bias/Assign (0us/0us, " + "0us/0us)\n conv2d_1/bias/Initializer (0us/0us, 0us/0us)\n " + "conv2d_1/bias/Initializer/Const (0us/0us, 0us/0us)\n " + "conv2d_1/bias/read (0us/0us, 1us/1us)\n conv2d_1/kernel (0us/0us, " + "2us/3us)\n conv2d_1/kernel/Assign (0us/0us, 0us/0us)\n " + "conv2d_1/kernel/Initializer (0us/0us, 0us/0us)\n " + "conv2d_1/kernel/Initializer/random_uniform (0us/0us, 0us/0us)\n " + "conv2d_1/kernel/read (0us/0us, 1us/1us)\n init (0us/0us, 0us/0us)\n " + "save (0us/0us, 0us/0us)\n save/Assign (0us/0us, 0us/0us)\n " + "save/Assign_1 (0us/0us, 0us/0us)\n save/Assign_2 (0us/0us, " + "0us/0us)\n save/Assign_3 (0us/0us, 0us/0us)\n save/Const " + "(0us/0us, 0us/0us)\n save/RestoreV2 (0us/0us, 0us/0us)\n " + "save/RestoreV2/shape_and_slices (0us/0us, 0us/0us)\n " + "save/RestoreV2/tensor_names (0us/0us, 0us/0us)\n save/RestoreV2_1 " + "(0us/0us, 0us/0us)\n save/RestoreV2_1/shape_and_slices (0us/0us, " + "0us/0us)\n save/RestoreV2_1/tensor_names (0us/0us, 0us/0us)\n " + "save/RestoreV2_2 (0us/0us, 0us/0us)\n " + "save/RestoreV2_2/shape_and_slices (0us/0us, 0us/0us)\n " + "save/RestoreV2_2/tensor_names (0us/0us, 0us/0us)\n save/RestoreV2_3 " + "(0us/0us, 0us/0us)\n save/RestoreV2_3/shape_and_slices (0us/0us, " + "0us/0us)\n save/RestoreV2_3/tensor_names (0us/0us, 0us/0us)\n " + "save/SaveV2 (0us/0us, 0us/0us)\n save/SaveV2/shape_and_slices " + "(0us/0us, 0us/0us)\n save/SaveV2/tensor_names (0us/0us, 0us/0us)\n " + " save/control_dependency (0us/0us, 0us/0us)\n save/restore_all " + "(0us/0us, 0us/0us)\n zeros (0us/0us, 1us/1us)\n", dump_str); } @@ -104,14 +164,16 @@ TEST_F(TFProfShowTest, DumpOpMode) { string dump_str; TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file, &dump_str)); EXPECT_EQ( - "nodename|outputbytes|executiontime|#parameters|#float_ops|opoccurrence|" - "inputshapes\nVariableV21.48KB(100.00%,17.10%),5us(100.00%,5.15%)," - "370params(100.00%,100.00%),0float_ops(100.00%,0.00%),4\n\ninput_type:\t(" - "*4)\texec_time:5us\n\nAssign0B(0.00%,0.00%),0us(94.85%,0.00%),0params(0." - "00%,0.00%),0float_ops(100.00%,0.00%),8\n\ninput_type:0:unknown,\t1:" - "unknown\t(*8)\texec_time:0us\n\nConst1.54KB(58.87%,17.74%),1us(80.41%,1." - "03%),0params(0.00%,0.00%),0float_ops(98.49%,0.00%),24\n\ninput_type:\t(*" - "24)\texec_time:1us\n\n", + "nodename|outputbytes|totalexecutiontime|acceleratorexecutiontime|" + "cpuexecutiontime|#parameters|#float_ops|opoccurrence|" + "inputshapes\nVariableV21.48KB(100.00%,17.10%),5us(100.00%,5.15%),0us(0." + "00%,0.00%),5us(100.00%,5.15%),370params(100.00%,100.00%),0float_ops(100." + "00%,0.00%),4\n\ninput_type:\t(*4)\texec_time:5us\n\nAssign0B(0.00%,0.00%" + "),0us(94.85%,0.00%),0us(0.00%,0.00%),0us(94.85%,0.00%),0params(0.00%,0." + "00%),0float_ops(100.00%,0.00%),8\n\ninput_type:0:unknown,\t1:unknown\t(*" + "8)\texec_time:0us\n\nConst1.54KB(58.87%,17.74%),1us(80.41%,1.03%),0us(0." + "00%,0.00%),1us(80.41%,1.03%),0params(0.00%,0.00%),0float_ops(98.49%,0." + "00%),24\n\ninput_type:\t(*24)\texec_time:1us\n\n", StringReplace(dump_str, " ", "")); } } // namespace tfprof diff --git a/tensorflow/tools/tfprof/internal/tfprof_stats_test.cc b/tensorflow/tools/tfprof/internal/tfprof_stats_test.cc index a1e500f9492..948ce49df4d 100644 --- a/tensorflow/tools/tfprof/internal/tfprof_stats_test.cc +++ b/tensorflow/tools/tfprof/internal/tfprof_stats_test.cc @@ -39,7 +39,8 @@ class TFProfStatsTest : public ::testing::Test { io::JoinPath(testing::TensorFlowSrcRoot(), "tools/tfprof/internal/testdata/graph.pbtxt"); std::unique_ptr graph_pb(new tensorflow::GraphDef()); - TF_CHECK_OK(ReadGraphDef(Env::Default(), graph_path, graph_pb.get())); + TF_CHECK_OK( + ReadProtoFile(Env::Default(), graph_path, graph_pb.get(), false)); std::unique_ptr run_meta_pb( new tensorflow::RunMetadata()); @@ -47,7 +48,7 @@ class TFProfStatsTest : public ::testing::Test { io::JoinPath(testing::TensorFlowSrcRoot(), "tools/tfprof/internal/testdata/run_meta"); TF_CHECK_OK( - ReadBinaryProto(Env::Default(), run_meta_path, run_meta_pb.get())); + ReadProtoFile(Env::Default(), run_meta_path, run_meta_pb.get(), true)); std::unique_ptr op_log_pb(new OpLog()); string op_log_path = @@ -88,26 +89,40 @@ TEST_F(TFProfStatsTest, CustomOpType) { "total_exec_micros: 1\n total_requested_bytes: 20\n " "total_parameters: 5\n devices: " "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 0\n " - "total_float_ops: 0\n }\n children {\n name: \"conv2d/kernel\"\n " - "exec_micros: 1\n requested_bytes: 540\n parameters: 135\n " - "total_exec_micros: 1\n total_requested_bytes: 540\n " - "total_parameters: 135\n devices: " + "total_float_ops: 0\n accelerator_exec_micros: 0\n " + "cpu_exec_micros: 1\n total_accelerator_exec_micros: 0\n " + "total_cpu_exec_micros: 1\n }\n children {\n name: " + "\"conv2d/kernel\"\n exec_micros: 1\n requested_bytes: 540\n " + "parameters: 135\n total_exec_micros: 1\n total_requested_bytes: " + "540\n total_parameters: 135\n devices: " "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 0\n " - "total_float_ops: 0\n }\n float_ops: 0\n total_float_ops: " - "0\n}\nchildren {\n name: \"conv2d_1\"\n exec_micros: 0\n " + "total_float_ops: 0\n accelerator_exec_micros: 0\n " + "cpu_exec_micros: 1\n total_accelerator_exec_micros: 0\n " + "total_cpu_exec_micros: 1\n }\n float_ops: 0\n total_float_ops: 0\n " + "accelerator_exec_micros: 0\n cpu_exec_micros: 0\n " + "total_accelerator_exec_micros: 0\n total_cpu_exec_micros: " + "2\n}\nchildren {\n name: \"conv2d_1\"\n exec_micros: 0\n " "requested_bytes: 0\n total_exec_micros: 3\n total_requested_bytes: " "920\n total_parameters: 230\n children {\n name: " "\"conv2d_1/bias\"\n exec_micros: 1\n requested_bytes: 20\n " "parameters: 5\n total_exec_micros: 1\n total_requested_bytes: " "20\n total_parameters: 5\n devices: " "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 0\n " - "total_float_ops: 0\n }\n children {\n name: \"conv2d_1/kernel\"\n " - " exec_micros: 2\n requested_bytes: 900\n parameters: 225\n " - "total_exec_micros: 2\n total_requested_bytes: 900\n " - "total_parameters: 225\n devices: " + "total_float_ops: 0\n accelerator_exec_micros: 0\n " + "cpu_exec_micros: 1\n total_accelerator_exec_micros: 0\n " + "total_cpu_exec_micros: 1\n }\n children {\n name: " + "\"conv2d_1/kernel\"\n exec_micros: 2\n requested_bytes: 900\n " + "parameters: 225\n total_exec_micros: 2\n total_requested_bytes: " + "900\n total_parameters: 225\n devices: " "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 0\n " - "total_float_ops: 0\n }\n float_ops: 0\n total_float_ops: " - "0\n}\nfloat_ops: 0\ntotal_float_ops: 0\n", + "total_float_ops: 0\n accelerator_exec_micros: 0\n " + "cpu_exec_micros: 2\n total_accelerator_exec_micros: 0\n " + "total_cpu_exec_micros: 2\n }\n float_ops: 0\n total_float_ops: 0\n " + "accelerator_exec_micros: 0\n cpu_exec_micros: 0\n " + "total_accelerator_exec_micros: 0\n total_cpu_exec_micros: " + "3\n}\nfloat_ops: 0\ntotal_float_ops: 0\naccelerator_exec_micros: " + "0\ncpu_exec_micros: 0\ntotal_accelerator_exec_micros: " + "0\ntotal_cpu_exec_micros: 5\n", &expected)); EXPECT_EQ(expected.DebugString(), root.DebugString()); } @@ -130,26 +145,40 @@ TEST_F(TFProfStatsTest, CheckPointOpType) { "total_exec_micros: 1\n total_requested_bytes: 20\n " "total_parameters: 5\n devices: " "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 0\n " - "total_float_ops: 0\n }\n children {\n name: \"conv2d/kernel\"\n " - "exec_micros: 1\n requested_bytes: 540\n parameters: 135\n " - "total_exec_micros: 1\n total_requested_bytes: 540\n " - "total_parameters: 135\n devices: " + "total_float_ops: 0\n accelerator_exec_micros: 0\n " + "cpu_exec_micros: 1\n total_accelerator_exec_micros: 0\n " + "total_cpu_exec_micros: 1\n }\n children {\n name: " + "\"conv2d/kernel\"\n exec_micros: 1\n requested_bytes: 540\n " + "parameters: 135\n total_exec_micros: 1\n total_requested_bytes: " + "540\n total_parameters: 135\n devices: " "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 0\n " - "total_float_ops: 0\n }\n float_ops: 0\n total_float_ops: " - "0\n}\nchildren {\n name: \"conv2d_1\"\n exec_micros: 0\n " + "total_float_ops: 0\n accelerator_exec_micros: 0\n " + "cpu_exec_micros: 1\n total_accelerator_exec_micros: 0\n " + "total_cpu_exec_micros: 1\n }\n float_ops: 0\n total_float_ops: 0\n " + "accelerator_exec_micros: 0\n cpu_exec_micros: 0\n " + "total_accelerator_exec_micros: 0\n total_cpu_exec_micros: " + "2\n}\nchildren {\n name: \"conv2d_1\"\n exec_micros: 0\n " "requested_bytes: 0\n total_exec_micros: 3\n total_requested_bytes: " "920\n total_parameters: 230\n children {\n name: " "\"conv2d_1/bias\"\n exec_micros: 1\n requested_bytes: 20\n " "parameters: 5\n total_exec_micros: 1\n total_requested_bytes: " "20\n total_parameters: 5\n devices: " "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 0\n " - "total_float_ops: 0\n }\n children {\n name: \"conv2d_1/kernel\"\n " - " exec_micros: 2\n requested_bytes: 900\n parameters: 225\n " - "total_exec_micros: 2\n total_requested_bytes: 900\n " - "total_parameters: 225\n devices: " + "total_float_ops: 0\n accelerator_exec_micros: 0\n " + "cpu_exec_micros: 1\n total_accelerator_exec_micros: 0\n " + "total_cpu_exec_micros: 1\n }\n children {\n name: " + "\"conv2d_1/kernel\"\n exec_micros: 2\n requested_bytes: 900\n " + "parameters: 225\n total_exec_micros: 2\n total_requested_bytes: " + "900\n total_parameters: 225\n devices: " "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 0\n " - "total_float_ops: 0\n }\n float_ops: 0\n total_float_ops: " - "0\n}\nfloat_ops: 0\ntotal_float_ops: 0\n", + "total_float_ops: 0\n accelerator_exec_micros: 0\n " + "cpu_exec_micros: 2\n total_accelerator_exec_micros: 0\n " + "total_cpu_exec_micros: 2\n }\n float_ops: 0\n total_float_ops: 0\n " + "accelerator_exec_micros: 0\n cpu_exec_micros: 0\n " + "total_accelerator_exec_micros: 0\n total_cpu_exec_micros: " + "3\n}\nfloat_ops: 0\ntotal_float_ops: 0\naccelerator_exec_micros: " + "0\ncpu_exec_micros: 0\ntotal_accelerator_exec_micros: " + "0\ntotal_cpu_exec_micros: 5\n", &expected)); EXPECT_EQ(expected.DebugString(), root.DebugString()); } @@ -163,10 +192,11 @@ TEST_F(TFProfStatsTest, TestGraph) { TFGraphNodeProto expected; CHECK(protobuf::TextFormat::ParseFromString( - "name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: 0\n" - "total_exec_micros: 97\ntotal_requested_bytes: " - "8656\ntotal_parameters: 370\nfloat_ops: " - "0\ntotal_float_ops: 34360\n", + "name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: " + "0\ntotal_exec_micros: 97\ntotal_requested_bytes: " + "8656\ntotal_parameters: 370\nfloat_ops: 0\ntotal_float_ops: " + "34360\naccelerator_exec_micros: 0\ncpu_exec_micros: " + "0\ntotal_accelerator_exec_micros: 0\ntotal_cpu_exec_micros: 97\n", &expected)); EXPECT_EQ(expected.DebugString(), root.DebugString()); } @@ -186,28 +216,39 @@ TEST_F(TFProfStatsTest, TestFloatOps) { "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 360\n " "total_float_ops: 360\n input_shapes {\n key: 0\n value {\n " "unknown_rank: true\n }\n }\n input_shapes {\n key: 1\n value " - "{\n unknown_rank: true\n }\n }\n}\nchildren {\n name: " + "{\n unknown_rank: true\n }\n }\n accelerator_exec_micros: 0\n " + " cpu_exec_micros: 12\n total_accelerator_exec_micros: 0\n " + "total_cpu_exec_micros: 12\n}\nchildren {\n name: " "\"conv2d/convolution\"\n exec_micros: 60\n requested_bytes: 1440\n " "total_exec_micros: 60\n total_requested_bytes: 1440\n " "total_parameters: 0\n devices: " "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 19440\n " "total_float_ops: 19440\n input_shapes {\n key: 0\n value {\n " " unknown_rank: true\n }\n }\n input_shapes {\n key: 1\n " - "value {\n unknown_rank: true\n }\n }\n}\nchildren {\n name: " - "\"conv2d_2/BiasAdd\"\n exec_micros: 2\n requested_bytes: 640\n " - "total_exec_micros: 2\n total_requested_bytes: 640\n total_parameters: " - "0\n devices: \"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: " - "160\n total_float_ops: 160\n input_shapes {\n key: 0\n value " - "{\n unknown_rank: true\n }\n }\n input_shapes {\n key: 1\n " - " value {\n unknown_rank: true\n }\n }\n}\nchildren {\n " - "name: \"conv2d_2/convolution\"\n exec_micros: 13\n requested_bytes: " - "640\n total_exec_micros: 13\n total_requested_bytes: 640\n " + "value {\n unknown_rank: true\n }\n }\n " + "accelerator_exec_micros: 0\n cpu_exec_micros: 60\n " + "total_accelerator_exec_micros: 0\n total_cpu_exec_micros: " + "60\n}\nchildren {\n name: \"conv2d_2/BiasAdd\"\n exec_micros: 2\n " + "requested_bytes: 640\n total_exec_micros: 2\n total_requested_bytes: " + "640\n total_parameters: 0\n devices: " + "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 160\n " + "total_float_ops: 160\n input_shapes {\n key: 0\n value {\n " + "unknown_rank: true\n }\n }\n input_shapes {\n key: 1\n value " + "{\n unknown_rank: true\n }\n }\n accelerator_exec_micros: 0\n " + " cpu_exec_micros: 2\n total_accelerator_exec_micros: 0\n " + "total_cpu_exec_micros: 2\n}\nchildren {\n name: " + "\"conv2d_2/convolution\"\n exec_micros: 13\n requested_bytes: 640\n " + "total_exec_micros: 13\n total_requested_bytes: 640\n " "total_parameters: 0\n devices: " "\"/job:localhost/replica:0/task:0/cpu:0\"\n float_ops: 14400\n " "total_float_ops: 14400\n input_shapes {\n key: 0\n value {\n " " unknown_rank: true\n }\n }\n input_shapes {\n key: 1\n " - "value {\n unknown_rank: true\n }\n }\n}\nfloat_ops: " - "0\ntotal_float_ops: 34360\n", + "value {\n unknown_rank: true\n }\n }\n " + "accelerator_exec_micros: 0\n cpu_exec_micros: 13\n " + "total_accelerator_exec_micros: 0\n total_cpu_exec_micros: " + "13\n}\nfloat_ops: 0\ntotal_float_ops: 34360\naccelerator_exec_micros: " + "0\ncpu_exec_micros: 0\ntotal_accelerator_exec_micros: " + "0\ntotal_cpu_exec_micros: 97\n", &expected)); EXPECT_EQ(expected.DebugString(), root.DebugString()); } @@ -223,7 +264,9 @@ TEST_F(TFProfStatsTest, TestAccountShownNameOnly) { CHECK(protobuf::TextFormat::ParseFromString( "name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: " "0\ntotal_exec_micros: 0\ntotal_requested_bytes: 0\ntotal_parameters: " - "0\nfloat_ops: 0\ntotal_float_ops: 0\n", + "0\nfloat_ops: 0\ntotal_float_ops: 0\naccelerator_exec_micros: " + "0\ncpu_exec_micros: 0\ntotal_accelerator_exec_micros: " + "0\ntotal_cpu_exec_micros: 0\n", &expected)); EXPECT_EQ(expected.DebugString(), root.DebugString()); } @@ -238,7 +281,9 @@ TEST_F(TFProfStatsTest, TestShowTensorValue) { CHECK(protobuf::TextFormat::ParseFromString( "name: \"_TFProfRoot\"\nexec_micros: 0\nrequested_bytes: " "0\ntotal_exec_micros: 97\ntotal_requested_bytes: " - "8656\ntotal_parameters: 370\nfloat_ops: 0\ntotal_float_ops: 34360\n", + "8656\ntotal_parameters: 370\nfloat_ops: 0\ntotal_float_ops: " + "34360\naccelerator_exec_micros: 0\ncpu_exec_micros: " + "0\ntotal_accelerator_exec_micros: 0\ntotal_cpu_exec_micros: 97\n", &expected)); EXPECT_EQ(expected.DebugString(), root.DebugString()); } diff --git a/tensorflow/tools/tfprof/internal/tfprof_tensor_test.cc b/tensorflow/tools/tfprof/internal/tfprof_tensor_test.cc index 3dd721cbcc8..698738c23cc 100644 --- a/tensorflow/tools/tfprof/internal/tfprof_tensor_test.cc +++ b/tensorflow/tools/tfprof/internal/tfprof_tensor_test.cc @@ -34,7 +34,8 @@ class TFProfTensorTest : public ::testing::Test { io::JoinPath(testing::TensorFlowSrcRoot(), "tools/tfprof/internal/testdata/graph.pbtxt"); std::unique_ptr graph_pb(new tensorflow::GraphDef()); - TF_CHECK_OK(ReadGraphDef(Env::Default(), graph_path, graph_pb.get())); + TF_CHECK_OK( + ReadProtoFile(Env::Default(), graph_path, graph_pb.get(), false)); std::unique_ptr run_meta_pb; std::unique_ptr op_log_pb; @@ -72,7 +73,9 @@ TEST_F(TFProfTensorTest, Basics) { "total_parameters: 5\n float_ops: 0\n total_float_ops: 0\n " "tensor_value {\n dtype: DT_FLOAT\n value_double: 0\n " "value_double: 0\n value_double: 0\n value_double: 0\n " - "value_double: 0\n }\n }\n children {\n name: " + "value_double: 0\n }\n accelerator_exec_micros: 0\n " + "cpu_exec_micros: 0\n total_accelerator_exec_micros: 0\n " + "total_cpu_exec_micros: 0\n }\n children {\n name: " "\"conv2d/kernel\"\n exec_micros: 0\n requested_bytes: 0\n " "parameters: 135\n total_exec_micros: 0\n total_requested_bytes: " "0\n total_parameters: 135\n float_ops: 0\n total_float_ops: " @@ -143,7 +146,11 @@ TEST_F(TFProfTensorTest, Basics) { "value_double: 0.19068\n value_double: 0.220352\n " "value_double: -0.255741\n value_double: 0.110853\n " "value_double: 0.146625\n value_double: 0.167754\n " - "value_double: 0.249554\n }\n }\n float_ops: 0\n total_float_ops: " + "value_double: 0.249554\n }\n accelerator_exec_micros: 0\n " + "cpu_exec_micros: 0\n total_accelerator_exec_micros: 0\n " + "total_cpu_exec_micros: 0\n }\n float_ops: 0\n total_float_ops: 0\n " + "accelerator_exec_micros: 0\n cpu_exec_micros: 0\n " + "total_accelerator_exec_micros: 0\n total_cpu_exec_micros: " "0\n}\nchildren {\n name: \"conv2d_1\"\n exec_micros: 0\n " "requested_bytes: 0\n total_exec_micros: 0\n total_requested_bytes: " "0\n total_parameters: 230\n children {\n name: \"conv2d_1/bias\"\n " @@ -152,7 +159,9 @@ TEST_F(TFProfTensorTest, Basics) { "total_parameters: 5\n float_ops: 0\n total_float_ops: 0\n " "tensor_value {\n dtype: DT_FLOAT\n value_double: 0\n " "value_double: 0\n value_double: 0\n value_double: 0\n " - "value_double: 0\n }\n }\n children {\n name: " + "value_double: 0\n }\n accelerator_exec_micros: 0\n " + "cpu_exec_micros: 0\n total_accelerator_exec_micros: 0\n " + "total_cpu_exec_micros: 0\n }\n children {\n name: " "\"conv2d_1/kernel\"\n exec_micros: 0\n requested_bytes: 0\n " "parameters: 225\n total_exec_micros: 0\n total_requested_bytes: " "0\n total_parameters: 225\n float_ops: 0\n total_float_ops: " @@ -268,9 +277,14 @@ TEST_F(TFProfTensorTest, Basics) { "value_double: 0.237298\n value_double: -0.0896481\n " "value_double: -0.0605349\n value_double: 0.231679\n " "value_double: -0.123842\n value_double: 0.0858642\n " - "value_double: 0.23111\n value_double: 0.0491742\n }\n }\n " - "float_ops: 0\n total_float_ops: 0\n}\nfloat_ops: 0\ntotal_float_ops: " - "0\n", + "value_double: 0.23111\n value_double: 0.0491742\n }\n " + "accelerator_exec_micros: 0\n cpu_exec_micros: 0\n " + "total_accelerator_exec_micros: 0\n total_cpu_exec_micros: 0\n }\n " + "float_ops: 0\n total_float_ops: 0\n accelerator_exec_micros: 0\n " + "cpu_exec_micros: 0\n total_accelerator_exec_micros: 0\n " + "total_cpu_exec_micros: 0\n}\nfloat_ops: 0\ntotal_float_ops: " + "0\naccelerator_exec_micros: 0\ncpu_exec_micros: " + "0\ntotal_accelerator_exec_micros: 0\ntotal_cpu_exec_micros: 0\n", &expected)); EXPECT_EQ(expected.DebugString(), root.DebugString()); } diff --git a/tensorflow/tools/tfprof/internal/tfprof_timeline.cc b/tensorflow/tools/tfprof/internal/tfprof_timeline.cc index c98aa940c8c..c835da81624 100644 --- a/tensorflow/tools/tfprof/internal/tfprof_timeline.cc +++ b/tensorflow/tools/tfprof/internal/tfprof_timeline.cc @@ -192,9 +192,6 @@ void Timeline::AllocateTimeNodes(GraphNode* gnode) { const TFGraphNode* node = gnode->node; for (const auto& kernel_execs : node->op_execs(step_)) { const string& device = kernel_execs.first; - if (!IsCombinedGPUStream(device) && !IsCPUDevice(device)) { - continue; - } if (process_.find(device) == process_.end()) { int64 pid = AllocatePID(); diff --git a/tensorflow/tools/tfprof/internal/tfprof_timeline_test.cc b/tensorflow/tools/tfprof/internal/tfprof_timeline_test.cc index bcf2bf05946..0e9bb9658c8 100644 --- a/tensorflow/tools/tfprof/internal/tfprof_timeline_test.cc +++ b/tensorflow/tools/tfprof/internal/tfprof_timeline_test.cc @@ -39,7 +39,8 @@ class TFProfTimelineTest : public ::testing::Test { io::JoinPath(testing::TensorFlowSrcRoot(), "tools/tfprof/internal/testdata/graph.pbtxt"); std::unique_ptr graph_pb(new tensorflow::GraphDef()); - TF_CHECK_OK(ReadGraphDef(Env::Default(), graph_path, graph_pb.get())); + TF_CHECK_OK( + ReadProtoFile(Env::Default(), graph_path, graph_pb.get(), false)); std::unique_ptr run_meta_pb( new tensorflow::RunMetadata()); @@ -47,7 +48,7 @@ class TFProfTimelineTest : public ::testing::Test { io::JoinPath(testing::TensorFlowSrcRoot(), "tools/tfprof/internal/testdata/run_meta"); TF_CHECK_OK( - ReadBinaryProto(Env::Default(), run_meta_path, run_meta_pb.get())); + ReadProtoFile(Env::Default(), run_meta_path, run_meta_pb.get(), true)); tf_stats_.reset(new TFStats(std::move(graph_pb), std::move(run_meta_pb), nullptr, nullptr)); diff --git a/tensorflow/tools/tfprof/internal/tfprof_utils.cc b/tensorflow/tools/tfprof/internal/tfprof_utils.cc index 0bc12170125..8f208b75d35 100644 --- a/tensorflow/tools/tfprof/internal/tfprof_utils.cc +++ b/tensorflow/tools/tfprof/internal/tfprof_utils.cc @@ -72,19 +72,6 @@ string StringReplace(const string& str, const string& oldsub, return out; } -Status ReadGraphDef(Env* env, const string& fname, GraphDef* graph_def) { - string out; - Status s = ReadFileToString(env, fname, &out); - if (!s.ok()) return s; - if (protobuf::TextFormat::ParseFromString(out, graph_def)) { - return Status(); - } else if (ReadBinaryProto(tensorflow::Env::Default(), fname, graph_def) - .ok()) { - return Status(); - } - return errors::InvalidArgument("Cannot parse proto string."); -} - namespace { string StripQuote(const string& s) { int start = s.find_first_not_of("\"\'"); @@ -301,8 +288,8 @@ void PrintHelp() { "of times. Only available in op view.\n\n" " -step: Show the stats of a step when multiple steps of " "RunMetadata were added. By default (-1), show the average of all steps." - " -order_by: Order the results by [name|depth|bytes|micros|params|" - "float_ops]\n\n" + " -order_by: Order the results by [name|depth|bytes|micros" + "|accelerator_micros|cpu_micros|params|float_ops]\n\n" " -account_type_regexes: Account and display the ops whose types match " "one of the type regexes specified. tfprof " "allow user to define extra op types for ops " @@ -333,7 +320,8 @@ void PrintHelp() { "ops eventually displayed. If False, account all " "op statistics matching -account_type_regexes recursively.\n\n" " -select: Comma-separated list of metrics to show: [bytes|micros|" - "params|float_ops|tensor_value|device|op_types]." + "accelerator_micros|cpu_micros|params|float_ops|tensor_value|device|" + "op_types|input_shapes]." "\n\n" " -dump_to_file: Dump the output to a file, instead of terminal.\n\n" "" diff --git a/tensorflow/tools/tfprof/internal/tfprof_utils.h b/tensorflow/tools/tfprof/internal/tfprof_utils.h index afa7a58acd3..6130325babb 100644 --- a/tensorflow/tools/tfprof/internal/tfprof_utils.h +++ b/tensorflow/tools/tfprof/internal/tfprof_utils.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/tools/tfprof/internal/tfprof_options.h" namespace tensorflow { @@ -40,7 +41,28 @@ tensorflow::Status ParseCmdLine(const string& line, string* cmd, string StringReplace(const string& str, const string& oldsub, const string& newsub); -Status ReadGraphDef(Env* env, const string& fname, GraphDef* graph_def); +template +Status ReadProtoFile(Env* env, const string& fname, T* proto, + bool binary_first) { + string out; + Status s = ReadFileToString(env, fname, &out); + if (!s.ok()) return s; + + if (binary_first) { + if (ReadBinaryProto(tensorflow::Env::Default(), fname, proto).ok()) { + return Status(); + } else if (protobuf::TextFormat::ParseFromString(out, proto)) { + return Status(); + } + } else { + if (protobuf::TextFormat::ParseFromString(out, proto)) { + return Status(); + } else if (ReadBinaryProto(tensorflow::Env::Default(), fname, proto).ok()) { + return Status(); + } + } + return errors::InvalidArgument("Cannot parse proto file."); +} void PrintHelp(); diff --git a/tensorflow/tools/tfprof/tfprof_main.cc b/tensorflow/tools/tfprof/tfprof_main.cc index ae02b526347..7a4e7e85ffa 100644 --- a/tensorflow/tools/tfprof/tfprof_main.cc +++ b/tensorflow/tools/tfprof/tfprof_main.cc @@ -38,121 +38,114 @@ limitations under the License. #include "tensorflow/tools/tfprof/internal/tfprof_utils.h" #include "tensorflow/tools/tfprof/tfprof_log.pb.h" -using tensorflow::str_util::Split; - +namespace tensorflow { +namespace tfprof { void completion(const char* buf, linenoiseCompletions* lc) { - tensorflow::string buf_str = buf; + string buf_str = buf; if (buf_str.find(" ") == buf_str.npos) { - for (const char* opt : tensorflow::tfprof::kCmds) { - if (tensorflow::string(opt).find(buf_str) == 0) { + for (const char* opt : kCmds) { + if (string(opt).find(buf_str) == 0) { linenoiseAddCompletion(lc, opt); } } return; } - tensorflow::string prefix; + string prefix; int last_dash = buf_str.find_last_of(' '); - if (last_dash != tensorflow::string::npos) { + if (last_dash != string::npos) { prefix = buf_str.substr(0, last_dash + 1); - buf_str = buf_str.substr(last_dash + 1, tensorflow::kint32max); + buf_str = buf_str.substr(last_dash + 1, kint32max); } - for (const char* opt : tensorflow::tfprof::kOptions) { - if (tensorflow::string(opt).find(buf_str) == 0) { + for (const char* opt : kOptions) { + if (string(opt).find(buf_str) == 0) { linenoiseAddCompletion(lc, (prefix + opt).c_str()); } } } -int main(int argc, char** argv) { - tensorflow::string FLAGS_graph_path = ""; - tensorflow::string FLAGS_run_meta_path = ""; - tensorflow::string FLAGS_op_log_path = ""; - tensorflow::string FLAGS_checkpoint_path = ""; - tensorflow::int32 FLAGS_max_depth = 10; - tensorflow::int64 FLAGS_min_bytes = 0; - tensorflow::int64 FLAGS_min_micros = 0; - tensorflow::int64 FLAGS_min_params = 0; - tensorflow::int64 FLAGS_min_float_ops = 0; - tensorflow::int64 FLAGS_min_occurrence = 0; - tensorflow::int64 FLAGS_step = -1; - tensorflow::string FLAGS_order_by = "name"; - tensorflow::string FLAGS_account_type_regexes = ".*"; - tensorflow::string FLAGS_start_name_regexes = ".*"; - tensorflow::string FLAGS_trim_name_regexes = ""; - tensorflow::string FLAGS_show_name_regexes = ".*"; - tensorflow::string FLAGS_hide_name_regexes; +int Run(int argc, char** argv) { + string FLAGS_graph_path = ""; + string FLAGS_run_meta_path = ""; + string FLAGS_op_log_path = ""; + string FLAGS_checkpoint_path = ""; + int32 FLAGS_max_depth = 10; + int64 FLAGS_min_bytes = 0; + int64 FLAGS_min_micros = 0; + int64 FLAGS_min_params = 0; + int64 FLAGS_min_float_ops = 0; + int64 FLAGS_min_occurrence = 0; + int64 FLAGS_step = -1; + string FLAGS_order_by = "name"; + string FLAGS_account_type_regexes = ".*"; + string FLAGS_start_name_regexes = ".*"; + string FLAGS_trim_name_regexes = ""; + string FLAGS_show_name_regexes = ".*"; + string FLAGS_hide_name_regexes; bool FLAGS_account_displayed_op_only = false; - tensorflow::string FLAGS_select = "params"; - tensorflow::string FLAGS_output = ""; + string FLAGS_select = "micros"; + string FLAGS_output = ""; for (int i = 0; i < argc; i++) { fprintf(stderr, "%s\n", argv[i]); } - std::vector flag_list = { - tensorflow::Flag("graph_path", &FLAGS_graph_path, - "GraphDef proto text file name"), - tensorflow::Flag("run_meta_path", &FLAGS_run_meta_path, - "Comma-separated list of RunMetadata proto binary " - "files. Each file is given step number 0,1,2,etc"), - tensorflow::Flag("op_log_path", &FLAGS_op_log_path, - "tensorflow::tfprof::OpLog proto binary file name"), - tensorflow::Flag("checkpoint_path", &FLAGS_checkpoint_path, - "TensorFlow Checkpoint file name"), - tensorflow::Flag("max_depth", &FLAGS_max_depth, "max depth"), - tensorflow::Flag("min_bytes", &FLAGS_min_bytes, "min_bytes"), - tensorflow::Flag("min_micros", &FLAGS_min_micros, "min micros"), - tensorflow::Flag("min_params", &FLAGS_min_params, "min params"), - tensorflow::Flag("min_float_ops", &FLAGS_min_float_ops, "min float ops"), - tensorflow::Flag("min_occurrence", &FLAGS_min_occurrence, - "min occurrence"), - tensorflow::Flag("step", &FLAGS_step, - "The stats of which step to use. By default average"), - tensorflow::Flag("order_by", &FLAGS_order_by, "order by"), - tensorflow::Flag("account_type_regexes", &FLAGS_start_name_regexes, - "start name regexes"), - tensorflow::Flag("trim_name_regexes", &FLAGS_trim_name_regexes, - "trim name regexes"), - tensorflow::Flag("show_name_regexes", &FLAGS_show_name_regexes, - "show name regexes"), - tensorflow::Flag("hide_name_regexes", &FLAGS_hide_name_regexes, - "hide name regexes"), - tensorflow::Flag("account_displayed_op_only", - &FLAGS_account_displayed_op_only, - "account displayed op only"), - tensorflow::Flag("select", &FLAGS_select, "select"), - tensorflow::Flag("output", &FLAGS_output, "output"), + std::vector flag_list = { + Flag("graph_path", &FLAGS_graph_path, "GraphDef proto text file name"), + Flag("run_meta_path", &FLAGS_run_meta_path, + "Comma-separated list of RunMetadata proto binary " + "files. Each file is given step number 0,1,2,etc"), + Flag("op_log_path", &FLAGS_op_log_path, + "tensorflow::tfprof::OpLog proto binary file name"), + Flag("checkpoint_path", &FLAGS_checkpoint_path, + "TensorFlow Checkpoint file name"), + Flag("max_depth", &FLAGS_max_depth, "max depth"), + Flag("min_bytes", &FLAGS_min_bytes, "min_bytes"), + Flag("min_micros", &FLAGS_min_micros, "min micros"), + Flag("min_params", &FLAGS_min_params, "min params"), + Flag("min_float_ops", &FLAGS_min_float_ops, "min float ops"), + Flag("min_occurrence", &FLAGS_min_occurrence, "min occurrence"), + Flag("step", &FLAGS_step, + "The stats of which step to use. By default average"), + Flag("order_by", &FLAGS_order_by, "order by"), + Flag("account_type_regexes", &FLAGS_start_name_regexes, + "start name regexes"), + Flag("trim_name_regexes", &FLAGS_trim_name_regexes, "trim name regexes"), + Flag("show_name_regexes", &FLAGS_show_name_regexes, "show name regexes"), + Flag("hide_name_regexes", &FLAGS_hide_name_regexes, "hide name regexes"), + Flag("account_displayed_op_only", &FLAGS_account_displayed_op_only, + "account displayed op only"), + Flag("select", &FLAGS_select, "select"), + Flag("output", &FLAGS_output, "output"), }; - tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); - bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + string usage = Flags::Usage(argv[0], flag_list); + bool parse_ok = Flags::Parse(&argc, argv, flag_list); if (!parse_ok) { printf("%s", usage.c_str()); return (2); } - tensorflow::port::InitMain(argv[0], &argc, &argv); + port::InitMain(argv[0], &argc, &argv); fprintf(stderr, "%s\n", FLAGS_graph_path.c_str()); - std::vector account_type_regexes = - Split(FLAGS_account_type_regexes, ',', tensorflow::str_util::SkipEmpty()); - std::vector start_name_regexes = - Split(FLAGS_start_name_regexes, ',', tensorflow::str_util::SkipEmpty()); - std::vector trim_name_regexes = - Split(FLAGS_trim_name_regexes, ',', tensorflow::str_util::SkipEmpty()); - std::vector show_name_regexes = - Split(FLAGS_show_name_regexes, ',', tensorflow::str_util::SkipEmpty()); - std::vector hide_name_regexes = - Split(FLAGS_hide_name_regexes, ',', tensorflow::str_util::SkipEmpty()); - std::vector select = - Split(FLAGS_select, ',', tensorflow::str_util::SkipEmpty()); + std::vector account_type_regexes = + str_util::Split(FLAGS_account_type_regexes, ',', str_util::SkipEmpty()); + std::vector start_name_regexes = + str_util::Split(FLAGS_start_name_regexes, ',', str_util::SkipEmpty()); + std::vector trim_name_regexes = + str_util::Split(FLAGS_trim_name_regexes, ',', str_util::SkipEmpty()); + std::vector show_name_regexes = + str_util::Split(FLAGS_show_name_regexes, ',', str_util::SkipEmpty()); + std::vector hide_name_regexes = + str_util::Split(FLAGS_hide_name_regexes, ',', str_util::SkipEmpty()); + std::vector select = + str_util::Split(FLAGS_select, ',', str_util::SkipEmpty()); - tensorflow::string output_type; - std::map output_options; - tensorflow::Status s = tensorflow::tfprof::ParseOutput( - FLAGS_output, &output_type, &output_options); + string output_type; + std::map output_options; + Status s = ParseOutput(FLAGS_output, &output_type, &output_options); CHECK(s.ok()) << s.ToString(); - tensorflow::string cmd = ""; + string cmd = ""; if (argc == 1 && FLAGS_graph_path.empty()) { printf("1) go/tfprof: Tutorial.\n"); printf("2) tfprof help: Detail help information.\n"); @@ -168,44 +161,40 @@ int main(int argc, char** argv) { "Profiling everything!\n"); return 0; } else if (argc > 1) { - if (tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[5]) { - tensorflow::tfprof::PrintHelp(); + if (string(argv[1]) == kCmds[5]) { + PrintHelp(); return 0; } - if (tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[0] || - tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[1] || - tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[2] || - tensorflow::string(argv[1]) == tensorflow::tfprof::kCmds[3]) { + if (string(argv[1]) == kCmds[0] || string(argv[1]) == kCmds[1] || + string(argv[1]) == kCmds[2] || string(argv[1]) == kCmds[3]) { cmd = argv[1]; } } printf("Reading Files...\n"); - std::unique_ptr graph(new tensorflow::GraphDef()); - TF_CHECK_OK(tensorflow::tfprof::ReadGraphDef(tensorflow::Env::Default(), - FLAGS_graph_path, graph.get())); + std::unique_ptr graph(new GraphDef()); + TF_CHECK_OK( + ReadProtoFile(Env::Default(), FLAGS_graph_path, graph.get(), false)); - std::unique_ptr op_log( - new tensorflow::tfprof::OpLog()); + std::unique_ptr op_log(new OpLog()); if (!FLAGS_op_log_path.empty()) { - tensorflow::string op_log_str; - s = tensorflow::ReadFileToString(tensorflow::Env::Default(), - FLAGS_op_log_path, &op_log_str); + string op_log_str; + s = ReadFileToString(Env::Default(), FLAGS_op_log_path, &op_log_str); if (!s.ok()) { fprintf(stderr, "Failed to read op_log_path: %s\n", s.ToString().c_str()); return 1; } - if (!tensorflow::ParseProtoUnlimited(op_log.get(), op_log_str)) { + if (!ParseProtoUnlimited(op_log.get(), op_log_str)) { fprintf(stderr, "Failed to parse op_log_path\n"); return 1; } } - std::unique_ptr ckpt_reader; + std::unique_ptr ckpt_reader; TF_Status* status = TF_NewStatus(); if (!FLAGS_checkpoint_path.empty()) { - ckpt_reader.reset(new tensorflow::checkpoint::CheckpointReader( - FLAGS_checkpoint_path, status)); + ckpt_reader.reset( + new checkpoint::CheckpointReader(FLAGS_checkpoint_path, status)); if (TF_GetCode(status) != TF_OK) { fprintf(stderr, "%s\n", TF_Message(status)); TF_DeleteStatus(status); @@ -214,16 +203,14 @@ int main(int argc, char** argv) { TF_DeleteStatus(status); } - tensorflow::tfprof::TFStats tf_stat( - std::move(graph), nullptr, std::move(op_log), std::move(ckpt_reader)); + TFStats tf_stat(std::move(graph), nullptr, std::move(op_log), + std::move(ckpt_reader)); - std::vector run_meta_files = - Split(FLAGS_run_meta_path, ',', tensorflow::str_util::SkipEmpty()); + std::vector run_meta_files = + str_util::Split(FLAGS_run_meta_path, ',', str_util::SkipEmpty()); for (int i = 0; i < run_meta_files.size(); ++i) { - std::unique_ptr run_meta( - new tensorflow::RunMetadata()); - s = ReadBinaryProto(tensorflow::Env::Default(), run_meta_files[i], - run_meta.get()); + std::unique_ptr run_meta(new RunMetadata()); + s = ReadProtoFile(Env::Default(), run_meta_files[i], run_meta.get(), true); if (!s.ok()) { fprintf(stderr, "Failed to read run_meta_path %s. Status: %s\n", run_meta_files[i].c_str(), s.ToString().c_str()); @@ -232,19 +219,17 @@ int main(int argc, char** argv) { tf_stat.ParseRunMeta(i, std::move(run_meta)); } - tensorflow::tfprof::Options opts( - FLAGS_max_depth, FLAGS_min_bytes, FLAGS_min_micros, FLAGS_min_params, - FLAGS_min_float_ops, FLAGS_min_occurrence, FLAGS_step, FLAGS_order_by, - account_type_regexes, start_name_regexes, trim_name_regexes, - show_name_regexes, hide_name_regexes, FLAGS_account_displayed_op_only, - select, output_type, output_options); + Options opts(FLAGS_max_depth, FLAGS_min_bytes, FLAGS_min_micros, + FLAGS_min_params, FLAGS_min_float_ops, FLAGS_min_occurrence, + FLAGS_step, FLAGS_order_by, account_type_regexes, + start_name_regexes, trim_name_regexes, show_name_regexes, + hide_name_regexes, FLAGS_account_displayed_op_only, select, + output_type, output_options); - if (cmd == tensorflow::tfprof::kCmds[2] || - cmd == tensorflow::tfprof::kCmds[3]) { + if (cmd == kCmds[2] || cmd == kCmds[3]) { tf_stat.ShowMultiGraphNode(cmd, opts); return 0; - } else if (cmd == tensorflow::tfprof::kCmds[0] || - cmd == tensorflow::tfprof::kCmds[1]) { + } else if (cmd == kCmds[0] || cmd == kCmds[1]) { tf_stat.ShowGraphNode(cmd, opts); return 0; } @@ -253,7 +238,7 @@ int main(int argc, char** argv) { linenoiseHistoryLoad(".tfprof_history.txt"); for (char* line = nullptr; (line = linenoise("tfprof> ")) != nullptr;) { - tensorflow::string line_s = line; + string line_s = line; free(line); if (line_s.empty()) { @@ -263,24 +248,25 @@ int main(int argc, char** argv) { linenoiseHistoryAdd(line_s.c_str()); linenoiseHistorySave(".tfprof_history.txt"); - tensorflow::tfprof::Options new_opts = opts; - tensorflow::Status s = - tensorflow::tfprof::ParseCmdLine(line_s, &cmd, &new_opts); + Options new_opts = opts; + Status s = ParseCmdLine(line_s, &cmd, &new_opts); if (!s.ok()) { fprintf(stderr, "E: %s\n", s.ToString().c_str()); continue; } - if (cmd == tensorflow::tfprof::kCmds[4]) { + if (cmd == kCmds[4]) { opts = new_opts; - } else if (cmd == tensorflow::tfprof::kCmds[5]) { - tensorflow::tfprof::PrintHelp(); - } else if (cmd == tensorflow::tfprof::kCmds[2] || - cmd == tensorflow::tfprof::kCmds[3]) { + } else if (cmd == kCmds[5]) { + PrintHelp(); + } else if (cmd == kCmds[2] || cmd == kCmds[3]) { tf_stat.ShowMultiGraphNode(cmd, new_opts); - } else if (cmd == tensorflow::tfprof::kCmds[0] || - cmd == tensorflow::tfprof::kCmds[1]) { + } else if (cmd == kCmds[0] || cmd == kCmds[1]) { tf_stat.ShowGraphNode(cmd, new_opts); } } return 0; } +} // namespace tfprof +} // namespace tensorflow + +int main(int argc, char** argv) { return tensorflow::tfprof::Run(argc, argv); } diff --git a/tensorflow/tools/tfprof/tfprof_output.proto b/tensorflow/tools/tfprof/tfprof_output.proto index d00e93c939f..1ea956152c0 100644 --- a/tensorflow/tools/tfprof/tfprof_output.proto +++ b/tensorflow/tools/tfprof/tfprof_output.proto @@ -22,6 +22,9 @@ message TFGraphNodeProto { optional TFProfTensorProto tensor_value = 15; // op execution time. optional int64 exec_micros = 2; + optional int64 accelerator_exec_micros = 17; + optional int64 cpu_exec_micros = 18; + // Total requested bytes by the op. optional int64 requested_bytes = 3; // Number of parameters if available. @@ -36,6 +39,9 @@ message TFGraphNodeProto { // the node itself. The actual children depend on the data structure used // (scope, graph). optional int64 total_exec_micros = 6; + optional int64 total_accelerator_exec_micros = 19; + optional int64 total_cpu_exec_micros = 20; + optional int64 total_requested_bytes = 7; optional int64 total_parameters = 8; optional int64 total_float_ops = 14; @@ -64,6 +70,9 @@ message TFMultiGraphNodeProto { // code execution time. optional int64 exec_micros = 2; + optional int64 accelerator_exec_micros = 12; + optional int64 cpu_exec_micros = 13; + // Total requested bytes by the code. optional int64 requested_bytes = 3; // Number of parameters if available. @@ -74,6 +83,9 @@ message TFMultiGraphNodeProto { // The following are the aggregated stats from descendants. // The actual descendants depend on the data structure used. optional int64 total_exec_micros = 6; + optional int64 total_accelerator_exec_micros = 14; + optional int64 total_cpu_exec_micros = 15; + optional int64 total_requested_bytes = 7; optional int64 total_parameters = 8; optional int64 total_float_ops = 9; diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index ec5922ada8f..68bfefcf3c9 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -4,14 +4,8 @@ load("//third_party/gpus:cuda_configure.bzl", "cuda_configure") load("//third_party/sycl:sycl_configure.bzl", "sycl_configure") load("@io_bazel_rules_closure//closure/private:java_import_external.bzl", "java_import_external") load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external") -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library_external") load("//third_party/py:python_configure.bzl", "python_configure") -load("//third_party:polymer.bzl", "tensorboard_polymer_workspace") -load("//third_party:python.bzl", "tensorboard_python_workspace") -load("//third_party:js.bzl", "tensorboard_js_workspace") -load("//third_party:typings.bzl", "tensorboard_typings_workspace") - def _is_windows(repository_ctx): """Returns true if the host operating system is windows.""" @@ -150,12 +144,6 @@ def tf_workspace(path_prefix="", tf_repo_name=""): print("path_prefix was specified to tf_workspace but is no longer used " + "and will be removed in the future.") - # TODO(dandelion): Take these out when TB exits TF - tensorboard_polymer_workspace() - tensorboard_python_workspace() - tensorboard_typings_workspace() - tensorboard_js_workspace() - native.new_http_archive( name = "eigen_archive", urls = [ @@ -291,13 +279,46 @@ def tf_workspace(path_prefix="", tf_repo_name=""): name = "six_archive", urls = [ "http://mirror.bazel.build/pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", - "http://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", + "https://pypi.python.org/packages/source/s/six/six-1.10.0.tar.gz", ], sha256 = "105f8d68616f8248e24bf0e9372ef04d3cc10104f1980f54d57b2ce73a5ad56a", strip_prefix = "six-1.10.0", build_file = str(Label("//third_party:six.BUILD")), ) + native.new_http_archive( + name = "org_python_pypi_backports_weakref", + urls = [ + "http://mirror.bazel.build/pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz", + "https://pypi.python.org/packages/bc/cc/3cdb0a02e7e96f6c70bd971bc8a90b8463fda83e264fa9c5c1c98ceabd81/backports.weakref-1.0rc1.tar.gz", + ], + sha256 = "8813bf712a66b3d8b85dc289e1104ed220f1878cf981e2fe756dfaabe9a82892", + strip_prefix = "backports.weakref-1.0rc1/src", + build_file = str(Label("//third_party:backports_weakref.BUILD")), + ) + + native.new_http_archive( + name = "com_github_andreif_codegen", + urls = [ + "http://mirror.bazel.build/github.com/andreif/codegen/archive/1.0.tar.gz", + "https://github.com/andreif/codegen/archive/1.0.tar.gz", + ], + sha256 = "2dadd04a2802de27e0fe5a19b76538f6da9d39ff244036afa00c1bba754de5ee", + strip_prefix = "codegen-1.0", + build_file = str(Label("//third_party:codegen.BUILD")), + ) + + filegroup_external( + name = "org_python_license", + licenses = ["notice"], # Python 2.0 + sha256_urls = { + "b5556e921715ddb9242c076cae3963f483aa47266c5e37ea4c187f77cc79501c": [ + "http://mirror.bazel.build/docs.python.org/2.7/_sources/license.txt", + "https://docs.python.org/2.7/_sources/license.txt", + ], + }, + ) + native.bind( name = "six", actual = "@six_archive//:six", @@ -622,4 +643,3 @@ def tf_workspace(path_prefix="", tf_repo_name=""): strip_prefix = "pprof-c0fb62ec88c411cc91194465e54db2632845b650", build_file = str(Label("//third_party:pprof.BUILD")), ) - diff --git a/third_party/backports_weakref.BUILD b/third_party/backports_weakref.BUILD new file mode 100644 index 00000000000..0adfc5f0541 --- /dev/null +++ b/third_party/backports_weakref.BUILD @@ -0,0 +1,22 @@ +# Description: +# Backport of new features in Python's weakref module. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Python 2.0 + +py_library( + name = "org_python_pypi_backports_weakref", + srcs = [ + "backports/__init__.py", + "backports/weakref.py", + ], + srcs_version = "PY2AND3", +) + +genrule( + name = "license", + srcs = ["@org_python_license"], + outs = ["LICENSE"], + cmd = "cp $< $@", +) diff --git a/third_party/bleach.BUILD b/third_party/bleach.BUILD deleted file mode 100644 index 1bf75b84a76..00000000000 --- a/third_party/bleach.BUILD +++ /dev/null @@ -1,20 +0,0 @@ -# Description: -# Build file for Bleach. -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -py_library( - name = "org_mozilla_bleach", - srcs = [ - "bleach/__init__.py", - "bleach/callbacks.py", - "bleach/encoding.py", - "bleach/sanitizer.py", - "bleach/version.py", - ], - srcs_version = "PY2AND3", - deps = ["@org_html5lib"], -) diff --git a/third_party/clutz.BUILD b/third_party/clutz.BUILD deleted file mode 100644 index 593b70366a3..00000000000 --- a/third_party/clutz.BUILD +++ /dev/null @@ -1,44 +0,0 @@ -# Description: -# Build tool for making TypeScript .d.ts files from Closure JavaScript. - -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # MIT - -exports_files([ - "LICENSE", - "src/resources/closure.lib.d.ts", -]) - -JVM_FLAGS = [ - "-Xss20m", # JSCompiler needs big stacks for recursive parsing - "-XX:+UseParallelGC", # Best GC when app isn't latency sensitive -] - -java_binary( - name = "clutz", - srcs = glob(["src/main/java/com/google/javascript/clutz/**/*.java"]), - jvm_flags = JVM_FLAGS, - main_class = "com.google.javascript.clutz.DeclarationGenerator", - deps = [ - "@args4j", - "@com_google_code_findbugs_jsr305", - "@com_google_code_gson", - "@com_google_guava", - "@com_google_javascript_closure_compiler", - ], -) - -java_binary( - name = "gents", - srcs = glob(["src/main/java/com/google/javascript/gents/**/*.java"]), - jvm_flags = JVM_FLAGS, - main_class = "com.google.javascript.gents.TypeScriptGenerator", - deps = [ - "@args4j", - "@com_google_code_findbugs_jsr305", - "@com_google_code_gson", - "@com_google_guava", - "@com_google_javascript_closure_compiler", - ], -) diff --git a/third_party/clutz.bzl b/third_party/clutz.bzl deleted file mode 100644 index f273c78c794..00000000000 --- a/third_party/clutz.bzl +++ /dev/null @@ -1,77 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Build definitions for TypeScript from Closure JavaScript libraries.""" - -load("@io_bazel_rules_closure//closure/private:defs.bzl", - "JS_FILE_TYPE", - "collect_js", - "unfurl") - -CLUTZ_ATTRIBUTES = { - "_clutz": attr.label( - default=Label("@io_angular_clutz//:clutz"), - executable=True, - cfg="host"), - "_clutz_externs": attr.label( - default=Label("@com_google_javascript_closure_compiler_externs"), - allow_files=True), -} - -def extract_dts_from_closure_libraries(ctx): - """Extracts type definitions from closure dependencies. - - This just generates one big .d.ts file for all transitive Closure sources, - and does not pass it down. That means each rule has to duplicate the effort, - but on the other hand allows transitive dependencies on shared rules without - causing duplicate definition errors. - - Args: - ctx: A Skylark context. - Returns: - The generated Clutz typings file, or None if there were no JS deps. - """ - deps = unfurl(ctx.attr.deps, provider="closure_js_library") - js = collect_js(ctx, deps) - if not js.srcs: - return None - js_typings = ctx.new_file(ctx.bin_dir, "%s-js-typings.d.ts" % ctx.label.name) - srcs = depset(JS_FILE_TYPE.filter(ctx.files._clutz_externs)) + js.srcs - args = ["-o", js_typings.path] - for src in srcs: - args.append(src.path) - if getattr(ctx.attr, "clutz_entry_points", None): - args.append("--closure_entry_points") - args.extend(ctx.attr.clutz_entry_points) - ctx.action( - inputs=list(srcs), - outputs=[js_typings], - executable=ctx.executable._clutz, - arguments=args, - mnemonic="Clutz", - progress_message="Running Clutz on %d JS files %s" % ( - len(srcs), ctx.label)) - return js_typings - -################################################################################ -# The following definitions are for API compatibility with internal clutz.bzl - -CLUTZ_OUTPUTS = {} - -def _clutz_aspect_impl(target, ctx): - return struct() - -clutz_aspect = aspect( - implementation=_clutz_aspect_impl, - attr_aspects=["exports"]) diff --git a/third_party/codegen.BUILD b/third_party/codegen.BUILD new file mode 100644 index 00000000000..df436c81635 --- /dev/null +++ b/third_party/codegen.BUILD @@ -0,0 +1,16 @@ +# -*- mode: python; -*- +# +# Description: +# Extension to ast that allow ast -> python code generation. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # New BSD + +exports_files(["LICENSE"]) + +py_library( + name = "com_github_andreif_codegen", + srcs = glob(["codegen.py"]), + srcs_version = "PY2AND3", +) diff --git a/third_party/gpus/crosstool/remote.BUILD.tpl b/third_party/gpus/crosstool/remote.BUILD.tpl new file mode 100644 index 00000000000..b2316331db2 --- /dev/null +++ b/third_party/gpus/crosstool/remote.BUILD.tpl @@ -0,0 +1,10 @@ +# Description: +# Template for crosstool Build file to use a pre-generated config. +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +alias( + name = "toolchain", + actual = "%{remote_cuda_repo}:toolchain", +) diff --git a/third_party/gpus/cuda/remote.BUILD.tpl b/third_party/gpus/cuda/remote.BUILD.tpl new file mode 100644 index 00000000000..d88d512b90c --- /dev/null +++ b/third_party/gpus/cuda/remote.BUILD.tpl @@ -0,0 +1,105 @@ +# Description: +# Template for cuda Build file to use a pre-generated config. +licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like + +package(default_visibility = ["//visibility:public"]) + +config_setting( + name = "using_nvcc", + values = { + "define": "using_cuda_nvcc=true", + }, +) + +config_setting( + name = "using_clang", + values = { + "define": "using_cuda_clang=true", + }, +) + +# Equivalent to using_clang && -c opt. +config_setting( + name = "using_clang_opt", + values = { + "define": "using_cuda_clang=true", + "compilation_mode": "opt", + }, +) + +config_setting( + name = "darwin", + values = {"cpu": "darwin"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "freebsd", + values = {"cpu": "freebsd"}, + visibility = ["//visibility:public"], +) + +alias( + name = "cuda_headers", + actual = "%{remote_cuda_repo}cuda:cuda_headers", +) + +alias( + name = "cudart_static", + actual = "%{remote_cuda_repo}cuda:cudart_static", +) + +alias( + name = "cuda_driver", + actual = "%{remote_cuda_repo}cuda:cuda_driver", +) + +alias( + name = "cudart", + actual = "%{remote_cuda_repo}cuda:cudart", +) + +alias( + name = "cublas", + actual = "%{remote_cuda_repo}cuda:cublas", +) + +alias( + name = "cusolver", + actual = "%{remote_cuda_repo}cuda:cusolver", +) + +alias( + name = "cudnn", + actual = "%{remote_cuda_repo}cuda:cudnn", +) + +alias( + name = "cufft", + actual = "%{remote_cuda_repo}cuda:cufft", +) + +alias( + name = "curand", + actual = "%{remote_cuda_repo}cuda:curand", +) + +alias( + name = "cuda", + actual = "%{remote_cuda_repo}cuda:cuda", +) + +alias( + name = "cupti_headers", + actual = "%{remote_cuda_repo}cuda:cupti_headers", +) + +alias( + name = "cupti_dsos", + actual = "%{remote_cuda_repo}cuda:cupti_dsos", +) + +alias( + name = "libdevice_root", + actual = "%{remote_cuda_repo}cuda:libdevice_root", +) diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl index 61932a8e6d1..83a377dde5c 100644 --- a/third_party/gpus/cuda_configure.bzl +++ b/third_party/gpus/cuda_configure.bzl @@ -26,6 +26,7 @@ _TF_CUDA_VERSION = "TF_CUDA_VERSION" _TF_CUDNN_VERSION = "TF_CUDNN_VERSION" _CUDNN_INSTALL_PATH = "CUDNN_INSTALL_PATH" _TF_CUDA_COMPUTE_CAPABILITIES = "TF_CUDA_COMPUTE_CAPABILITIES" +_TF_CUDA_CONFIG_REPO = "TF_CUDA_CONFIG_REPO" _DEFAULT_CUDA_VERSION = "" _DEFAULT_CUDNN_VERSION = "" @@ -883,15 +884,16 @@ def _use_cuda_clang(repository_ctx): return enable_cuda == "1" return False -def _compute_cuda_extra_copts(repository_ctx, cuda_config): +def _compute_cuda_extra_copts(repository_ctx, compute_capabilities): if _use_cuda_clang(repository_ctx): - capability_flags = ["--cuda-gpu-arch=sm_" + cap.replace(".", "") for cap in cuda_config.compute_capabilities] + capability_flags = ["--cuda-gpu-arch=sm_" + + cap.replace(".", "") for cap in compute_capabilities] else: # Capabilities are handled in the "crosstool_wrapper_driver_is_not_gcc" for nvcc capability_flags = [] return str(capability_flags) -def _create_cuda_repository(repository_ctx): +def _create_local_cuda_repository(repository_ctx): """Creates the repository containing files set up to build with CUDA.""" cuda_config = _get_cuda_config(repository_ctx) @@ -939,7 +941,8 @@ def _create_cuda_repository(repository_ctx): _tpl(repository_ctx, "cuda:build_defs.bzl", { "%{cuda_is_configured}": "True", - "%{cuda_extra_copts}": _compute_cuda_extra_copts(repository_ctx, cuda_config), + "%{cuda_extra_copts}": _compute_cuda_extra_copts( + repository_ctx, cuda_config.compute_capabilities), }) _tpl(repository_ctx, "cuda:BUILD", @@ -999,18 +1002,43 @@ def _create_cuda_repository(repository_ctx): "%{cuda_toolkit_path}": cuda_config.cuda_toolkit_path, }) +def _create_remote_cuda_repository(repository_ctx, remote_config_repo): + """Creates pointers to a remotely configured repo set up to build with CUDA.""" + _tpl(repository_ctx, "cuda:build_defs.bzl", + { + "%{cuda_is_configured}": "True", + "%{cuda_extra_copts}": _compute_cuda_extra_copts( + repository_ctx, _compute_capabilities(repository_ctx)), + + }) + _tpl(repository_ctx, "cuda:remote.BUILD", + { + "%{remote_cuda_repo}": remote_config_repo, + }, "cuda/BUILD") + _tpl(repository_ctx, "crosstool:remote.BUILD", { + "%{remote_cuda_repo}": remote_config_repo, + }, "crosstool/BUILD") def _cuda_autoconf_impl(repository_ctx): """Implementation of the cuda_autoconf repository rule.""" if not _enable_cuda(repository_ctx): _create_dummy_repository(repository_ctx) else: - _create_cuda_repository(repository_ctx) - + if _TF_CUDA_CONFIG_REPO in repository_ctx.os.environ: + _create_remote_cuda_repository(repository_ctx, + repository_ctx.os.environ[_TF_CUDA_CONFIG_REPO]) + elif repository_ctx.attr.remote_config_repo != "": + _create_remote_cuda_repository(repository_ctx, + repository_ctx.attr.remote_config_repo) + else: + _create_local_cuda_repository(repository_ctx) cuda_configure = repository_rule( implementation = _cuda_autoconf_impl, + attrs = { + "remote_config_repo": attr.string(mandatory = False, default =""), + }, environ = [ _GCC_HOST_COMPILER_PATH, "TF_NEED_CUDA", @@ -1019,6 +1047,7 @@ cuda_configure = repository_rule( _TF_CUDA_VERSION, _TF_CUDNN_VERSION, _TF_CUDA_COMPUTE_CAPABILITIES, + _TF_CUDA_CONFIG_REPO, ], ) @@ -1027,9 +1056,13 @@ cuda_configure = repository_rule( Add the following to your WORKSPACE FILE: ```python -cuda_configure(name = "local_config_cuda") +cuda_configure( + name = "local_config_cuda" + remote_config_repo = "@remote_cuda_config_tf//" +) ``` Args: name: A unique name for this workspace rule. + remote_config_repo: Location of a pre-generated config (optional). """ diff --git a/third_party/html5lib.BUILD b/third_party/html5lib.BUILD deleted file mode 100644 index 63aac14f155..00000000000 --- a/third_party/html5lib.BUILD +++ /dev/null @@ -1,17 +0,0 @@ -# Description: -# Import of html5lib library. - -package(default_visibility = ["//visibility:public"]) - -licenses(["notice"]) # BSD-like notice-style license, see LICENSE file - -exports_files(["LICENSE"]) - -py_library( - name = "org_html5lib", - srcs = glob(["html5lib/**/*.py"]), - srcs_version = "PY2AND3", - deps = [ - "@six_archive//:six", - ], -) diff --git a/third_party/js.bzl b/third_party/js.bzl deleted file mode 100644 index 2d2339c95e5..00000000000 --- a/third_party/js.bzl +++ /dev/null @@ -1,420 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TensorBoard external JS dependencies (both infrastructure and frontend libs) -load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external") -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library_external") - - - ############################################################################## - # TensorBoard Build Tools -def tensorboard_js_workspace(): - filegroup_external( - name = "org_nodejs", - # MIT with portions licensed: - # - MIT - # - Old MIT - # - 2-Clause-BSD - # - 3-Clause-BSD - # - ISC - # - Unicode - # - zlib - # - Artistic 2.0 - licenses = ["notice"], - sha256_urls_extract_macos = { - "47109a00cac344d80296c195451bb5eee7c21727fcef1594384ddfe1f852957a": [ - "http://mirror.bazel.build/nodejs.org/dist/v4.3.2/node-v4.3.2-darwin-x64.tar.xz", - "http://nodejs.org/dist/v4.3.2/node-v4.3.2-darwin-x64.tar.xz", - ], - }, - sha256_urls_windows = { - "3d4cfca9dcec556a077a2324bf5bd165ea3e6e64a2bfd7fc6e7a1f0dc4eb552b": [ - "http://mirror.bazel.build/raw.githubusercontent.com/nodejs/node/v4.3.2/LICENSE", - "https://raw.githubusercontent.com/nodejs/node/v4.3.2/LICENSE", - ], - "606c44c42d17866c017c50c0afadad411d9492ac4281d2431b937f881911614e": [ - "http://mirror.bazel.build/nodejs.org/dist/v4.3.2/win-x64/node.exe", - "http://nodejs.org/dist/v4.3.2/win-x64/node.exe", - ], - "451a40570099a95488d6438f175813629e0430f87f23c8659bc18dc42494820a": [ - "http://mirror.bazel.build/nodejs.org/dist/v4.3.2/win-x64/node.lib", - "http://nodejs.org/dist/v4.3.2/win-x64/node.lib", - ], - }, - sha256_urls_extract = { - "4350d0431b49697517c6cca5d66adf5f74eb9101c52f52ae959fa94225822d44": [ - "http://mirror.bazel.build/nodejs.org/dist/v4.3.2/node-v4.3.2-linux-x64.tar.xz", - "http://nodejs.org/dist/v4.3.2/node-v4.3.2-linux-x64.tar.xz", - ], - }, - strip_prefix = { - "node-v4.3.2-darwin-x64.tar.xz": "node-v4.3.2-darwin-x64", - "node-v4.3.2-linux-x64.tar.xz": "node-v4.3.2-linux-x64", - }, - executable = [ - "node", - "node.exe", - ], - ) - - filegroup_external( - name = "com_microsoft_typescript", - licenses = ["notice"], # Apache 2.0 - sha256_urls = { - "a7d00bfd54525bc694b6e32f64c7ebcf5e6b7ae3657be5cc12767bce74654a47": [ - "http://mirror.bazel.build/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/LICENSE.txt", - "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/LICENSE.txt", - ], - "8465342c318f9c4cf0a29b109fa63ee3742dd4dc7080d05d9fd8f604814d04cf": [ - "http://mirror.bazel.build/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/tsc.js", - "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/tsc.js", - ], - "a67e36da3029d232e4e938e61a0a3302f516d71e7100d54dbf5362ad8618e994": [ - "http://mirror.bazel.build/raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/lib.es6.d.ts", - "https://raw.githubusercontent.com/Microsoft/TypeScript/v2.3.1/lib/lib.es6.d.ts", - ], - }, - extra_build_file_content = "\n".join([ - "sh_binary(", - " name = \"tsc\",", - " srcs = [\"tsc.sh\"],", - " data = [", - " \"tsc.js\",", - " \"@org_nodejs\",", - " ],", - ")", - "", - "genrule(", - " name = \"tsc_sh\",", - " outs = [\"tsc.sh\"],", - " cmd = \"cat >$@ <<'EOF'\\n\" +", - " \"#!/bin/bash\\n\" +", - " \"NODE=external/org_nodejs/bin/node\\n\" +", - " \"if [[ -e external/org_nodejs/node.exe ]]; then\\n\" +", - " \" NODE=external/org_nodejs/node.exe\\n\" +", - " \"fi\\n\" +", - " \"exec $${NODE} external/com_microsoft_typescript/tsc.js \\\"$$@\\\"\\n\" +", - " \"EOF\",", - " executable = True,", - ")", - ]), - ) - - - native.new_http_archive( - name = "io_angular_clutz", - build_file = "//third_party:clutz.BUILD", - sha256 = "2981de41d1ff4774b544423da9a2cd8beb3be649e95aef2ef2fd83957300b3fe", - strip_prefix = "clutz-b0db5ade9bb535d387f05292316c422790c9848e", - urls = [ - "http://mirror.bazel.build/github.com/angular/clutz/archive/b0db5ade9bb535d387f05292316c422790c9848e.tar.gz", # 2017-05-22 - "https://github.com/angular/clutz/archive/b0db5ade9bb535d387f05292316c422790c9848e.tar.gz", - ], - ) - - filegroup_external( - name = "com_google_javascript_closure_compiler_externs", - licenses = ["notice"], # Apache 2.0 - sha256_urls_extract = { - "0f515a6ebfa138490b3c5ea9f3591ea1a7e4a930d3074f18b3eca86084ad9b66": [ - "http://mirror.bazel.build/github.com/google/closure-compiler/archive/b37e6000001b0a6bf4c0be49024ebda14a8711d9.tar.gz", # 2017-06-02 - "https://github.com/google/closure-compiler/archive/b37e6000001b0a6bf4c0be49024ebda14a8711d9.tar.gz", - ], - }, - strip_prefix = {"b37e6000001b0a6bf4c0be49024ebda14a8711d9.tar.gz": "closure-compiler-b37e6000001b0a6bf4c0be49024ebda14a8711d9/externs"}, - ) - - filegroup_external( - name = "com_google_javascript_closure_compiler_externs_polymer", - licenses = ["notice"], # Apache 2.0 - sha256_urls = { - "23baad9a200a717a821c6df504c84d3a893d7ea9102b14876eb80097e3b94292": [ - "http://mirror.bazel.build/raw.githubusercontent.com/google/closure-compiler/0e8dc5597a295ee259e3fecd98d6535dc621232f/contrib/externs/polymer-1.0.js", # 2017-05-27 - "https://raw.githubusercontent.com/google/closure-compiler/0e8dc5597a295ee259e3fecd98d6535dc621232f/contrib/externs/polymer-1.0.js", - ], - }, - ) - - filegroup_external( - name = "org_threejs", - # no @license header - licenses = ["notice"], # MIT - sha256_urls = { - "7aff264bd84c90bed3c72a4dc31db8c19151853c6df6980f52b01d3e9872c82d": [ - "http://mirror.bazel.build/raw.githubusercontent.com/mrdoob/three.js/ad419d40bdaab80abbb34b8f359b4ee840033a02/build/three.js", - "https://raw.githubusercontent.com/mrdoob/three.js/ad419d40bdaab80abbb34b8f359b4ee840033a02/build/three.js", - ], - "0e98ded15bb7fe398a655667e76b39909d36c0973a8950d01c62f65f93161c27": [ - "http://mirror.bazel.build/raw.githubusercontent.com/mrdoob/three.js/ad419d40bdaab80abbb34b8f359b4ee840033a02/examples/js/controls/OrbitControls.js", - "https://raw.githubusercontent.com/mrdoob/three.js/ad419d40bdaab80abbb34b8f359b4ee840033a02/examples/js/controls/OrbitControls.js", - ], - }, - ) - - ############################################################################## - # TensorBoard JavaScript Production Dependencies - web_library_external( - name = "com_lodash", - licenses = ["notice"], # MIT - sha256 = "0e88207e5f90af4ce8790d6e1e7d09d2702d81bce0bafdc253d18c0a5bf7661e", - urls = [ - "http://mirror.bazel.build/github.com/lodash/lodash/archive/3.10.1.tar.gz", - "https://github.com/lodash/lodash/archive/3.10.1.tar.gz", - ], - strip_prefix = "lodash-3.10.1", - path = "/lodash", - srcs = ["lodash.js"], - ) - - filegroup_external( - name = "com_numericjs", - # no @license header - licenses = ["notice"], # MIT - sha256_urls = { - "0e94aada97f12dee6118064add9170484c55022f5d53206ee4407143cd36ddcd": [ - "http://mirror.bazel.build/raw.githubusercontent.com/sloisel/numeric/v1.2.6/license.txt", - "https://raw.githubusercontent.com/sloisel/numeric/v1.2.6/license.txt", - ], - "dfaca3b8485bee735788cc6eebca82ea25719adc1fb8911c7799c6bd5a95df3b": [ - "http://mirror.bazel.build/raw.githubusercontent.com/sloisel/numeric/v1.2.6/src/numeric.js", - "https://raw.githubusercontent.com/sloisel/numeric/v1.2.6/src/numeric.js", - ], - }, - ) - - filegroup_external( - name = "com_palantir_plottable", - # no @license header - licenses = ["notice"], # MIT - sha256_urls_extract = { - # Plottable doesn't have a release tarball on GitHub. Using the - # sources directly from git also requires running Node tooling - # beforehand to generate files. NPM is the only place to get it. - "e3159beb279391c47433789f22b32bac88488cfcad6c0b6ec8605ce6b0081b0d": [ - "http://mirror.bazel.build/registry.npmjs.org/plottable/-/plottable-3.1.0.tgz", - "https://registry.npmjs.org/plottable/-/plottable-3.1.0.tgz", - ], - }, - ) - - filegroup_external( - name = "io_github_cpettitt_dagre", - # no @license header - licenses = ["notice"], # MIT - sha256_urls = { - "6a349742a6cb219d5a2fc8d0844f6d89a6efc62e20c664450d884fc7ff2d6015": [ - "http://mirror.bazel.build/raw.githubusercontent.com/cpettitt/dagre/v0.7.4/LICENSE", - "https://raw.githubusercontent.com/cpettitt/dagre/v0.7.4/LICENSE", - ], - "7323829ddd77924a69e2b1235ded3eac30acd990da0f037e0fbd3c8e9035b50d": [ - "http://mirror.bazel.build/raw.githubusercontent.com/cpettitt/dagre/v0.7.4/dist/dagre.core.js", - "https://raw.githubusercontent.com/cpettitt/dagre/v0.7.4/dist/dagre.core.js", - ], - }, - ) - - filegroup_external( - name = "io_github_cpettitt_graphlib", - licenses = ["notice"], # MIT - sha256_urls = { - "6a349742a6cb219d5a2fc8d0844f6d89a6efc62e20c664450d884fc7ff2d6015": [ - "http://mirror.bazel.build/raw.githubusercontent.com/cpettitt/graphlib/v1.0.7/LICENSE", - "https://raw.githubusercontent.com/cpettitt/graphlib/v1.0.7/LICENSE", - ], - "772045d412b1513b549be991c2e1846c38019429d43974efcae943fbe83489bf": [ - "http://mirror.bazel.build/raw.githubusercontent.com/cpettitt/graphlib/v1.0.7/dist/graphlib.core.js", - "https://raw.githubusercontent.com/cpettitt/graphlib/v1.0.7/dist/graphlib.core.js", - ], - }, - ) - - filegroup_external( - name = "io_github_waylonflinn_weblas", - # no @license header - licenses = ["notice"], # MIT - sha256_urls = { - "633f2861a9a862b9cd7967e841e14dd3527912f209d6563595774fa31e3d84cb": [ - "http://mirror.bazel.build/raw.githubusercontent.com/waylonflinn/weblas/v0.9.0/LICENSES", - "https://raw.githubusercontent.com/waylonflinn/weblas/v0.9.0/LICENSE", - ], - "f138fce57f673ca8a633f4aee5ae5b6fcb6ad0de59069a42a74e996fd04d8fcc": [ - "http://mirror.bazel.build/raw.githubusercontent.com/waylonflinn/weblas/v0.9.0/dist/weblas.js", - "https://raw.githubusercontent.com/waylonflinn/weblas/v0.9.0/dist/weblas.js", - ], - }, - ) - - filegroup_external( - name = "org_d3js", - # no @license header - licenses = ["notice"], # BSD-3-Clause - sha256_urls_extract = { - "b5fac5b296bc196e6aa7b59f9e33986fc44d23d59a0e211705187be9e35b943d": [ - "http://mirror.bazel.build/github.com/d3/d3/releases/download/v4.8.0/d3.zip", - "https://github.com/d3/d3/releases/download/v4.8.0/d3.zip", - ], - }, - # TODO(jart): Use srcs=["d3.js"] instead of this once supported. - generated_rule_name = "all_files", - extra_build_file_content = "\n".join([ - "filegroup(", - " name = \"org_d3js\",", - " srcs = [\"d3.js\"],", - ")", - ]), - ) - - filegroup_external( - name = "org_chromium_catapult_vulcanized_trace_viewer", - licenses = ["notice"], # BSD-3-Clause - sha256_urls = { - "f0df289ba9d03d857ad1c2f5918861376b1510b71588ffc60eff5c7a7bfedb09": [ - "http://mirror.bazel.build/raw.githubusercontent.com/catapult-project/catapult/2f7ee994984f3ebd3dd3dc3e05777bf180ec2ee8/LICENSE", - "https://raw.githubusercontent.com/catapult-project/catapult/2f7ee994984f3ebd3dd3dc3e05777bf180ec2ee8/LICENSE", - ], - "9e99e79439ea5a1471bd4dd325bd6733e133bcb3da4df4b878ed6d2aec7c8d86": [ - "http://mirror.bazel.build/raw.githubusercontent.com/catapult-project/catapult/2f7ee994984f3ebd3dd3dc3e05777bf180ec2ee8/trace_viewer_full.html", - "https://raw.githubusercontent.com/catapult-project/catapult/2f7ee994984f3ebd3dd3dc3e05777bf180ec2ee8/trace_viewer_full.html" - ], - }, - ) - - ############################################################################## - # TensorBoard Testing Dependencies - web_library_external( - name = "org_npmjs_registry_accessibility_developer_tools", - licenses = ["notice"], # Apache License 2.0 - sha256 = "1d6a72f401c9d53f68238c617dd43a05cd85ca5aa2e676a5b3c352711448e093", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/accessibility-developer-tools/-/accessibility-developer-tools-2.10.0.tgz", - "https://registry.npmjs.org/accessibility-developer-tools/-/accessibility-developer-tools-2.10.0.tgz", - ], - strip_prefix = "package", - path = "/accessibility-developer-tools", - suppress = ["strictDependencies"], - ) - - web_library_external( - name = "org_npmjs_registry_async", - licenses = ["notice"], # MIT - sha256 = "08655255ae810bf4d1cb1642df57658fcce823776d3ba8f4b46f4bbff6c87ece", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/async/-/async-1.5.0.tgz", - "https://registry.npmjs.org/async/-/async-1.5.0.tgz", - ], - strip_prefix = "package", - path = "/async", - ) - - web_library_external( - name = "org_npmjs_registry_chai", - licenses = ["notice"], # MIT - sha256 = "aca8137bed5bb295bd7173325b7ad604cd2aeb341d739232b4f9f0b26745be90", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/chai/-/chai-3.5.0.tgz", - "https://registry.npmjs.org/chai/-/chai-3.5.0.tgz", - ], - strip_prefix = "package", - path = "/chai", - ) - - web_library_external( - name = "org_npmjs_registry_mocha", - licenses = ["notice"], # MIT - sha256 = "13ef37a071196a2fba680799b906555d3f0ab61e80a7e8f73f93e77914590dd4", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/mocha/-/mocha-2.5.3.tgz", - "https://registry.npmjs.org/mocha/-/mocha-2.5.3.tgz", - ], - suppress = ["strictDependencies"], - strip_prefix = "package", - path = "/mocha", - ) - - web_library_external( - name = "org_npmjs_registry_sinon", - licenses = ["notice"], # BSD-3-Clause - sha256 = "49edb057695fc9019aae992bf7e677a07de7c6ce2bf9f9facde4a245045d1532", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/sinon/-/sinon-1.17.4.tgz", - "https://registry.npmjs.org/sinon/-/sinon-1.17.4.tgz", - ], - strip_prefix = "package/lib", - path = "/sinonjs", - ) - - web_library_external( - name = "org_npmjs_registry_sinon_chai", - licenses = ["notice"], # BSD-3-Clause - sha256 = "b85fc56f713832960b56fe9269ee4bb2cd41edd2ceb130b0936e5bdbed5dea63", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/sinon-chai/-/sinon-chai-2.8.0.tgz", - "https://registry.npmjs.org/sinon-chai/-/sinon-chai-2.8.0.tgz", - ], - strip_prefix = "package", - path = "/sinon-chai", - ) - - web_library_external( - name = "org_npmjs_registry_stacky", - licenses = ["notice"], # BSD-3-Clause - sha256 = "c659e60f7957d9d80c23a7aacc4d71b19c6421a08f91174c0062de369595acae", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/stacky/-/stacky-1.3.1.tgz", - "https://registry.npmjs.org/stacky/-/stacky-1.3.1.tgz", - ], - strip_prefix = "package", - path = "/stacky", - ) - - web_library_external( - name = "org_npmjs_registry_web_component_tester", - licenses = ["notice"], # BSD-3-Clause - sha256 = "9d4ebd4945df8a936916d4d32b7f280f2a3afa35f79e7ca8ad3ed0a42770c537", - urls = [ - "http://mirror.bazel.build/registry.npmjs.org/web-component-tester/-/web-component-tester-4.3.6.tgz", - "https://registry.npmjs.org/web-component-tester/-/web-component-tester-4.3.6.tgz", - ], - strip_prefix = "package", - path = "/web-component-tester", - suppress = [ - "absolutePaths", - "strictDependencies", - ], - deps = [ - "@com_lodash", - "@org_npmjs_registry_accessibility_developer_tools", - "@org_npmjs_registry_async", - "@org_npmjs_registry_chai", - "@org_npmjs_registry_mocha", - "@org_npmjs_registry_sinon", - "@org_npmjs_registry_sinon_chai", - "@org_npmjs_registry_stacky", - "@org_polymer_test_fixture", - ], - ) - - web_library_external( - name = "org_polymer_test_fixture", - licenses = ["notice"], # BSD-3-Clause - sha256 = "59d6cfb1187733b71275becfea181fe0aa1f734df5ff77f5850c806bbbf9a0d9", - strip_prefix = "test-fixture-2.0.1", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/test-fixture/archive/v2.0.1.tar.gz", - "https://github.com/PolymerElements/test-fixture/archive/v2.0.1.tar.gz", - ], - path = "/test-fixture", - exclude = ["test/**"], - ) - diff --git a/third_party/markdown.BUILD b/third_party/markdown.BUILD deleted file mode 100644 index fa3e85d5304..00000000000 --- a/third_party/markdown.BUILD +++ /dev/null @@ -1,15 +0,0 @@ -# Description: -# Markdown processor - -package(default_visibility = ["//visibility:public"]) - -# This software says they use a BSD license. -licenses(["notice"]) - -exports_files(["LICENSE.md"]) - -py_library( - name = "org_pythonhosted_markdown", - srcs = glob(["markdown/**/*.py"]), - srcs_version = "PY2AND3", -) diff --git a/third_party/polymer.bzl b/third_party/polymer.bzl deleted file mode 100644 index bd6e05803cf..00000000000 --- a/third_party/polymer.bzl +++ /dev/null @@ -1,1335 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TensorBoard Polymer Dependencies - -load("@io_bazel_rules_closure//closure:defs.bzl", "web_library_external") - -def tensorboard_polymer_workspace(): - web_library_external( - name = "org_polymer_font_roboto", - licenses = ["notice"], # BSD-3-Clause - sha256 = "fae51429b56a4a4c15f1f0c23b733c7095940cc9c04c275fa7adb3bf055b23b3", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/font-roboto/archive/v1.0.1.tar.gz", - "https://github.com/PolymerElements/font-roboto/archive/v1.0.1.tar.gz", - ], - strip_prefix = "font-roboto-1.0.1", - path = "/font-roboto", - srcs = ["roboto.html"], - ) - - web_library_external( - name = "org_polymer_hydrolysis", - licenses = ["notice"], # BSD-3-Clause - sha256 = "703b50f6b00f9e0546b5a3451da57bb20f77a166e27e4967923b9e835bab9b80", - urls = [ - "http://mirror.bazel.build/github.com/Polymer/polymer-analyzer/archive/v1.19.3.tar.gz", - "https://github.com/Polymer/polymer-analyzer/archive/v1.19.3.tar.gz", - ], - strip_prefix = "polymer-analyzer-1.19.3", - path = "/hydrolysis", - srcs = [ - "hydrolysis-analyzer.html", - "hydrolysis.html", - "hydrolysis.js", - ], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_a11y_announcer", - licenses = ["notice"], # BSD-3-Clause - sha256 = "6bce143db7a374a68535ec8b861a5f30e81f2f1e4ee36a55bda2a891f6fd2818", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-a11y-announcer/archive/v1.0.5.tar.gz", - "https://github.com/PolymerElements/iron-a11y-announcer/archive/v1.0.5.tar.gz", - ], - strip_prefix = "iron-a11y-announcer-1.0.5", - path = "/iron-a11y-announcer", - srcs = ["iron-a11y-announcer.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_a11y_keys_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "6823efc47a83208fd51d39c5a1d3eb0c0bebc705df1ce01310509da22a13ebd2", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-a11y-keys-behavior/archive/v1.1.8.tar.gz", - "https://github.com/PolymerElements/iron-a11y-keys-behavior/archive/v1.1.8.tar.gz", - ], - strip_prefix = "iron-a11y-keys-behavior-1.1.8", - path = "/iron-a11y-keys-behavior", - srcs = ["iron-a11y-keys-behavior.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_ajax", - licenses = ["notice"], # BSD-3-Clause - sha256 = "9162d8af4611e911ac3ebbfc08bb7038ac04f6e79a9287b1476fe36ad6770bc5", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-ajax/archive/v1.2.0.tar.gz", - "https://github.com/PolymerElements/iron-ajax/archive/v1.2.0.tar.gz", - ], - strip_prefix = "iron-ajax-1.2.0", - path = "/iron-ajax", - srcs = [ - "iron-ajax.html", - "iron-request.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_promise_polyfill", - ], - ) - - web_library_external( - name = "org_polymer_iron_autogrow_textarea", - licenses = ["notice"], # BSD-3-Clause - sha256 = "50bbb901d2c8f87462e3552e3d671a552faa12c37c485e548d7a234ebffbc427", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-autogrow-textarea/archive/v1.0.12.tar.gz", - "https://github.com/PolymerElements/iron-autogrow-textarea/archive/v1.0.12.tar.gz", - ], - strip_prefix = "iron-autogrow-textarea-1.0.12", - path = "/iron-autogrow-textarea", - srcs = ["iron-autogrow-textarea.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_form_element_behavior", - "@org_polymer_iron_validatable_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_behaviors", - licenses = ["notice"], # BSD-3-Clause - sha256 = "a1e8d4b7a13f3d36beba9c2a6b186ed33a53e6af2e79f98c1fcc7e85e7b53f89", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-behaviors/archive/v1.0.17.tar.gz", - "https://github.com/PolymerElements/iron-behaviors/archive/v1.0.17.tar.gz", - ], - strip_prefix = "iron-behaviors-1.0.17", - path = "/iron-behaviors", - srcs = [ - "iron-button-state.html", - "iron-control-state.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_checked_element_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "539a0e1c4df0bc702d3bd342388e4e56c77ec4c2066cce69e41426a69f92e8bd", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-checked-element-behavior/archive/v1.0.4.tar.gz", - "https://github.com/PolymerElements/iron-checked-element-behavior/archive/v1.0.4.tar.gz", - ], - strip_prefix = "iron-checked-element-behavior-1.0.4", - path = "/iron-checked-element-behavior", - srcs = ["iron-checked-element-behavior.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_form_element_behavior", - "@org_polymer_iron_validatable_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_component_page", - licenses = ["notice"], # BSD-3-Clause - sha256 = "3636e8b9a1f229fc33b5aad3933bd02a9825f66e679a0be31855d7c8245c4b4b", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-component-page/archive/v1.1.4.tar.gz", - "https://github.com/PolymerElements/iron-component-page/archive/v1.1.4.tar.gz", - ], - strip_prefix = "iron-component-page-1.1.4", - path = "/iron-component-page", - srcs = ["iron-component-page.html"], - deps = [ - "@org_polymer", - "@org_polymer_hydrolysis", - "@org_polymer_iron_ajax", - "@org_polymer_iron_doc_viewer", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_icons", - "@org_polymer_iron_selector", - "@org_polymer_paper_header_panel", - "@org_polymer_paper_styles", - "@org_polymer_paper_toolbar", - ], - ) - - web_library_external( - name = "org_polymer_iron_collapse", - licenses = ["notice"], # BSD-3-Clause - sha256 = "275808994a609a2f9923e2dd2db1957945ab141ba840eadc33f19e1f406d600e", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-collapse/archive/v1.0.8.tar.gz", - "https://github.com/PolymerElements/iron-collapse/archive/v1.0.8.tar.gz", - ], - strip_prefix = "iron-collapse-1.0.8", - path = "/iron-collapse", - srcs = ["iron-collapse.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_resizable_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_demo_helpers", - licenses = ["notice"], # BSD-3-Clause - sha256 = "aa7458492a6ac3d1f6344640a4c2ab07bce64e7ad0422b83b5d665707598cce6", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-demo-helpers/archive/v1.1.0.tar.gz", - "https://github.com/PolymerElements/iron-demo-helpers/archive/v1.1.0.tar.gz", - ], - strip_prefix = "iron-demo-helpers-1.1.0", - path = "/iron-demo-helpers", - srcs = [ - "demo-pages-shared-styles.html", - "demo-snippet.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_icons", - "@org_polymer_marked_element", - "@org_polymer_paper_icon_button", - "@org_polymer_paper_styles", - "@org_polymer_prism_element", - ], - ) - - web_library_external( - name = "org_polymer_iron_doc_viewer", - licenses = ["notice"], # BSD-3-Clause - sha256 = "f0e9dfbbcd94d7e88ce82cb61e615406ace63c185fee9396f7f182206ca5cc9a", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-doc-viewer/archive/v1.0.12.tar.gz", - "https://github.com/PolymerElements/iron-doc-viewer/archive/v1.0.12.tar.gz", - ], - strip_prefix = "iron-doc-viewer-1.0.12", - path = "/iron-doc-viewer", - srcs = [ - "iron-doc-property-styles.html", - "iron-doc-property.html", - "iron-doc-viewer-styles.html", - "iron-doc-viewer.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_marked_element", - "@org_polymer_paper_button", - "@org_polymer_paper_styles", - "@org_polymer_prism_element", - ], - ) - - web_library_external( - name = "org_polymer_iron_dropdown", - licenses = ["notice"], # BSD-3-Clause - sha256 = "f7e4a31d096d10d8af1920397695cb17f3eb1cbe5e5ff91a861dabfcc085f376", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-dropdown/archive/v1.4.0.tar.gz", - "https://github.com/PolymerElements/iron-dropdown/archive/v1.4.0.tar.gz", - ], - strip_prefix = "iron-dropdown-1.4.0", - path = "/iron-dropdown", - srcs = [ - "iron-dropdown.html", - "iron-dropdown-scroll-manager.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_overlay_behavior", - "@org_polymer_iron_resizable_behavior", - "@org_polymer_neon_animation", - ], - ) - - web_library_external( - name = "org_polymer_iron_fit_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "10132a2ea309a37c4c07b8fead71f64abc588ee6107931e34680f5f36dd8291e", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-fit-behavior/archive/v1.2.5.tar.gz", - "https://github.com/PolymerElements/iron-fit-behavior/archive/v1.2.5.tar.gz", - ], - strip_prefix = "iron-fit-behavior-1.2.5", - path = "/iron-fit-behavior", - srcs = ["iron-fit-behavior.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_flex_layout", - licenses = ["notice"], # BSD-3-Clause - sha256 = "79287f6ca1c2d4e003f68b88fe19d03a1b6a0011e2b4cae579fe4d1474163a2e", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-flex-layout/archive/v1.3.0.tar.gz", - "https://github.com/PolymerElements/iron-flex-layout/archive/v1.3.0.tar.gz", - ], - strip_prefix = "iron-flex-layout-1.3.0", - path = "/iron-flex-layout", - srcs = [ - "classes/iron-flex-layout.html", - "classes/iron-shadow-flex-layout.html", - "iron-flex-layout.html", - "iron-flex-layout-classes.html", - ], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_form_element_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "1dd9371c638e5bc2ecba8a64074aa680dfb8712198e9612f9ed24d387efc8f26", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-form-element-behavior/archive/v1.0.6.tar.gz", - "https://github.com/PolymerElements/iron-form-element-behavior/archive/v1.0.6.tar.gz", - ], - strip_prefix = "iron-form-element-behavior-1.0.6", - path = "/iron-form-element-behavior", - srcs = ["iron-form-element-behavior.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_icon", - licenses = ["notice"], # BSD-3-Clause - sha256 = "9ed58a69159a02c07a6050d242e6d4e585a29f3245b8c8c390cfd52ddb786dc4", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-icon/archive/v1.0.11.tar.gz", - "https://github.com/PolymerElements/iron-icon/archive/v1.0.11.tar.gz", - ], - strip_prefix = "iron-icon-1.0.11", - path = "/iron-icon", - srcs = ["iron-icon.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_meta", - ], - ) - - web_library_external( - name = "org_polymer_iron_icons", - licenses = ["notice"], # BSD-3-Clause - sha256 = "3b18542c147c7923dc3a36b1a51984a73255d610f297d43c9aaccc52859bd0d0", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-icons/archive/v1.1.3.tar.gz", - "https://github.com/PolymerElements/iron-icons/archive/v1.1.3.tar.gz", - ], - strip_prefix = "iron-icons-1.1.3", - path = "/iron-icons", - srcs = [ - "av-icons.html", - "communication-icons.html", - "device-icons.html", - "editor-icons.html", - "hardware-icons.html", - "image-icons.html", - "iron-icons.html", - "maps-icons.html", - "notification-icons.html", - "places-icons.html", - "social-icons.html", - ], - deps = [ - "@org_polymer_iron_icon", - "@org_polymer_iron_iconset_svg", - ], - ) - - web_library_external( - name = "org_polymer_iron_iconset_svg", - licenses = ["notice"], # BSD-3-Clause - sha256 = "7e3925b7e63a7d22524c4b43ce16ab80d06a576649644783643c11a003284368", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-iconset-svg/archive/v1.1.0.tar.gz", - "https://github.com/PolymerElements/iron-iconset-svg/archive/v1.1.0.tar.gz", - ], - strip_prefix = "iron-iconset-svg-1.1.0", - path = "/iron-iconset-svg", - srcs = ["iron-iconset-svg.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_meta", - ], - ) - - web_library_external( - name = "org_polymer_iron_input", - licenses = ["notice"], # BSD-3-Clause - sha256 = "c505101ead08ab25526b1f49baecc8c28b4221b92a65e7334c783bdc81553c36", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-input/archive/1.0.10.tar.gz", - "https://github.com/PolymerElements/iron-input/archive/1.0.10.tar.gz", - ], - strip_prefix = "iron-input-1.0.10", - path = "/iron-input", - srcs = ["iron-input.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_announcer", - "@org_polymer_iron_validatable_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_list", - licenses = ["notice"], # BSD-3-Clause - sha256 = "72a6530b9f0ad5557f5d287845792a0ada74d8b159198e27f940e226313dc116", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-list/archive/v1.3.9.tar.gz", - "https://github.com/PolymerElements/iron-list/archive/v1.3.9.tar.gz", - ], - strip_prefix = "iron-list-1.3.9", - path = "/iron-list", - srcs = ["iron-list.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_resizable_behavior", - "@org_polymer_iron_scroll_target_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_menu_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "ad27889343bc9a709258b073f69abc028bb1ffd3fdb975cd2d3939f7f5d7bb6c", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-menu-behavior/archive/v1.1.10.tar.gz", - "https://github.com/PolymerElements/iron-menu-behavior/archive/v1.1.10.tar.gz", - ], - strip_prefix = "iron-menu-behavior-1.1.10", - path = "/iron-menu-behavior", - srcs = [ - "iron-menu-behavior.html", - "iron-menubar-behavior.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_selector", - ], - ) - - web_library_external( - name = "org_polymer_iron_meta", - licenses = ["notice"], # BSD-3-Clause - sha256 = "fb05e6031bae6b4effe5f15d44b3f548d5807f9e3b3aa2442ba17cf4b8b84361", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-meta/archive/v1.1.1.tar.gz", - "https://github.com/PolymerElements/iron-meta/archive/v1.1.1.tar.gz", - ], - strip_prefix = "iron-meta-1.1.1", - path = "/iron-meta", - srcs = ["iron-meta.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_overlay_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "3df5b54ff2e0510c87a2aff8c9d730d3fe83d3d11277cc1a49fa29b549acb46c", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-overlay-behavior/archive/v1.10.1.tar.gz", - "https://github.com/PolymerElements/iron-overlay-behavior/archive/v1.10.1.tar.gz", - ], - strip_prefix = "iron-overlay-behavior-1.10.1", - path = "/iron-overlay-behavior", - srcs = [ - "iron-focusables-helper.html", - "iron-overlay-backdrop.html", - "iron-overlay-behavior.html", - "iron-overlay-manager.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_fit_behavior", - "@org_polymer_iron_resizable_behavior", - ], - ) - - web_library_external( - name = "org_polymer_iron_range_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "b2f2b6d52284542330bd30b586e217926eb0adec5e13934a3cef557717c22dc2", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-range-behavior/archive/v1.0.4.tar.gz", - "https://github.com/PolymerElements/iron-range-behavior/archive/v1.0.4.tar.gz", - ], - strip_prefix = "iron-range-behavior-1.0.4", - path = "/iron-range-behavior", - srcs = ["iron-range-behavior.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_resizable_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "a87a78ee9223c2f6afae7fc94a3ff91cbce6f7e2a7ed3f2979af7945c9281616", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-resizable-behavior/archive/v1.0.3.tar.gz", - "https://github.com/PolymerElements/iron-resizable-behavior/archive/v1.0.3.tar.gz", - ], - strip_prefix = "iron-resizable-behavior-1.0.3", - path = "/iron-resizable-behavior", - srcs = ["iron-resizable-behavior.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_scroll_target_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "d0de0c804b1ec91d814754144afd9da1cdb082690de88bd5e47fd5f41990746f", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-scroll-target-behavior/archive/v1.0.3.tar.gz", - "https://github.com/PolymerElements/iron-scroll-target-behavior/archive/v1.0.3.tar.gz", - ], - strip_prefix = "iron-scroll-target-behavior-1.0.3", - path = "/iron-scroll-target-behavior", - srcs = ["iron-scroll-target-behavior.html"], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_selector", - licenses = ["notice"], # BSD-3-Clause - sha256 = "ba28a47443bad3b744611c9d7a79fb21dbdf2e35edc5ef8f812e2dcd72b16747", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-selector/archive/v1.5.2.tar.gz", - "https://github.com/PolymerElements/iron-selector/archive/v1.5.2.tar.gz", - ], - strip_prefix = "iron-selector-1.5.2", - path = "/iron-selector", - srcs = [ - "iron-multi-selectable.html", - "iron-selectable.html", - "iron-selection.html", - "iron-selector.html", - ], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_iron_validatable_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "aef4901e68043824f36104799269573dd345ffaac494186e466fdc79c06fdb63", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/iron-validatable-behavior/archive/v1.1.1.tar.gz", - "https://github.com/PolymerElements/iron-validatable-behavior/archive/v1.1.1.tar.gz", - ], - strip_prefix = "iron-validatable-behavior-1.1.1", - path = "/iron-validatable-behavior", - srcs = ["iron-validatable-behavior.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_meta", - ], - ) - - web_library_external( - name = "org_polymer_marked", - licenses = ["notice"], # MIT - sha256 = "93d30bd593736ca440938d77808b7ef5972da0f3fcfe4ae63ae7b4ce117da2cb", - urls = [ - "http://mirror.bazel.build/github.com/chjj/marked/archive/v0.3.2.zip", - "https://github.com/chjj/marked/archive/v0.3.2.zip", - ], - strip_prefix = "marked-0.3.2", - path = "/marked", - srcs = ["lib/marked.js"], - ) - - web_library_external( - name = "org_polymer_marked_element", - licenses = ["notice"], # BSD-3-Clause - sha256 = "7547616df95f8b903757e6afbabfcdba5322c2bcec3f17c726b8bba5adf4bc5f", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/marked-element/archive/v1.1.3.tar.gz", - "https://github.com/PolymerElements/marked-element/archive/v1.1.3.tar.gz", - ], - strip_prefix = "marked-element-1.1.3", - path = "/marked-element", - srcs = [ - "marked-element.html", - "marked-import.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_marked", - ], - ) - - web_library_external( - name = "org_polymer_neon_animation", - licenses = ["notice"], # BSD-3-Clause - sha256 = "8800c314a76b2da190a2b203259c1091f6d38e0057ed37c2a3d0b734980fa9a5", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/neon-animation/archive/v1.2.2.tar.gz", - "https://github.com/PolymerElements/neon-animation/archive/v1.2.2.tar.gz", - ], - strip_prefix = "neon-animation-1.2.2", - path = "/neon-animation", - srcs = [ - "animations/cascaded-animation.html", - "animations/fade-in-animation.html", - "animations/fade-out-animation.html", - "animations/hero-animation.html", - "animations/opaque-animation.html", - "animations/reverse-ripple-animation.html", - "animations/ripple-animation.html", - "animations/scale-down-animation.html", - "animations/scale-up-animation.html", - "animations/slide-down-animation.html", - "animations/slide-from-bottom-animation.html", - "animations/slide-from-left-animation.html", - "animations/slide-from-right-animation.html", - "animations/slide-from-top-animation.html", - "animations/slide-left-animation.html", - "animations/slide-right-animation.html", - "animations/slide-up-animation.html", - "animations/transform-animation.html", - "neon-animatable.html", - "neon-animatable-behavior.html", - "neon-animated-pages.html", - "neon-animation.html", - "neon-animation-behavior.html", - "neon-animation-runner-behavior.html", - "neon-animations.html", - "neon-shared-element-animatable-behavior.html", - "neon-shared-element-animation-behavior.html", - "web-animations.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_meta", - "@org_polymer_iron_resizable_behavior", - "@org_polymer_iron_selector", - "@org_polymer_web_animations_js", - ], - ) - - web_library_external( - name = "org_polymer_paper_behaviors", - licenses = ["notice"], # BSD-3-Clause - sha256 = "7cfcb9082ef9909da262df6b5c120bc62dbeaff278cb563e8fc60465ddd387e5", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-behaviors/archive/v1.0.12.tar.gz", - "https://github.com/PolymerElements/paper-behaviors/archive/v1.0.12.tar.gz", - ], - strip_prefix = "paper-behaviors-1.0.12", - path = "/paper-behaviors", - srcs = [ - "paper-button-behavior.html", - "paper-checked-element-behavior.html", - "paper-inky-focus-behavior.html", - "paper-ripple-behavior.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_checked_element_behavior", - "@org_polymer_paper_ripple", - ], - ) - - web_library_external( - name = "org_polymer_paper_button", - licenses = ["notice"], # BSD-3-Clause - sha256 = "896c0a7e34bfcce63fc23c63e105ed9c4d62fa3a6385b7161e1e5cd4058820a6", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-button/archive/v1.0.11.tar.gz", - "https://github.com/PolymerElements/paper-button/archive/v1.0.11.tar.gz", - ], - strip_prefix = "paper-button-1.0.11", - path = "/paper-button", - srcs = ["paper-button.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_material", - "@org_polymer_paper_ripple", - ], - ) - - web_library_external( - name = "org_polymer_paper_checkbox", - licenses = ["notice"], # BSD-3-Clause - sha256 = "6828a6954a048b1230fbd2606faffbae950ba1d042175b96ec50ae355786a166", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-checkbox/archive/v1.4.0.tar.gz", - "https://github.com/PolymerElements/paper-checkbox/archive/v1.4.0.tar.gz", - ], - strip_prefix = "paper-checkbox-1.4.0", - path = "/paper-checkbox", - srcs = ["paper-checkbox.html"], - deps = [ - "@org_polymer", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_dialog", - licenses = ["notice"], # BSD-3-Clause - sha256 = "c6a9709e7f528d03dcd574503c18b72d4751ca30017346d16e6a791d37ed9259", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-dialog/archive/v1.0.4.tar.gz", - "https://github.com/PolymerElements/paper-dialog/archive/v1.0.4.tar.gz", - ], - strip_prefix = "paper-dialog-1.0.4", - path = "/paper-dialog", - srcs = ["paper-dialog.html"], - deps = [ - "@org_polymer", - "@org_polymer_neon_animation", - "@org_polymer_paper_dialog_behavior", - ], - ) - - web_library_external( - name = "org_polymer_paper_dialog_behavior", - licenses = ["notice"], # BSD-3-Clause - sha256 = "a7e0e27ce63554bc14f384cf94bcfa24da8dc5f5120dfd565f45e166261aee40", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-dialog-behavior/archive/v1.2.5.tar.gz", - "https://github.com/PolymerElements/paper-dialog-behavior/archive/v1.2.5.tar.gz", - ], - strip_prefix = "paper-dialog-behavior-1.2.5", - path = "/paper-dialog-behavior", - srcs = [ - "paper-dialog-behavior.html", - "paper-dialog-common.css", - "paper-dialog-shared-styles.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_overlay_behavior", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_dialog_scrollable", - licenses = ["notice"], # BSD-3-Clause - sha256 = "a2e69283e7674f782c44d811387a0f8da2d01fac0172743d1add65e253e6b5ff", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-dialog-scrollable/archive/1.1.5.tar.gz", - "https://github.com/PolymerElements/paper-dialog-scrollable/archive/1.1.5.tar.gz", - ], - strip_prefix = "paper-dialog-scrollable-1.1.5", - path = "/paper-dialog-scrollable", - srcs = ["paper-dialog-scrollable.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_paper_dialog_behavior", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_dropdown_menu", - licenses = ["notice"], # BSD-3-Clause - sha256 = "9d88f654ec03ee9be211df9e69bede9e8a22b51bf1dbcc63b79762e4256d81ad", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-dropdown-menu/archive/v1.4.0.tar.gz", - "https://github.com/PolymerElements/paper-dropdown-menu/archive/v1.4.0.tar.gz", - ], - strip_prefix = "paper-dropdown-menu-1.4.0", - path = "/paper-dropdown-menu", - srcs = [ - "paper-dropdown-menu.html", - "paper-dropdown-menu-icons.html", - "paper-dropdown-menu-light.html", - "paper-dropdown-menu-shared-styles.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_form_element_behavior", - "@org_polymer_iron_icon", - "@org_polymer_iron_iconset_svg", - "@org_polymer_iron_validatable_behavior", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_input", - "@org_polymer_paper_menu_button", - "@org_polymer_paper_ripple", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_header_panel", - licenses = ["notice"], # BSD-3-Clause - sha256 = "0db4bd8a4bf6f20dcd0dffb4f907b31c93a8647c9c021344239cf30b40b87075", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-header-panel/archive/v1.1.4.tar.gz", - "https://github.com/PolymerElements/paper-header-panel/archive/v1.1.4.tar.gz", - ], - strip_prefix = "paper-header-panel-1.1.4", - path = "/paper-header-panel", - srcs = ["paper-header-panel.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - ], - ) - - web_library_external( - name = "org_polymer_paper_icon_button", - licenses = ["notice"], # BSD-3-Clause - sha256 = "9cba5bcfd6aeb4c41581c1392c678cf2278d360e9d122f4d9db54a9ebb404496", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-icon-button/archive/v1.1.3.tar.gz", - "https://github.com/PolymerElements/paper-icon-button/archive/v1.1.3.tar.gz", - ], - strip_prefix = "paper-icon-button-1.1.3", - path = "/paper-icon-button", - srcs = [ - "paper-icon-button.html", - "paper-icon-button-light.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_icon", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_input", - licenses = ["notice"], # BSD-3-Clause - sha256 = "17c3dea9bb1c2026cc61324696c6c774214a0dc37686b91ca214a6af550994db", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-input/archive/v1.1.18.tar.gz", - "https://github.com/PolymerElements/paper-input/archive/v1.1.18.tar.gz", - ], - strip_prefix = "paper-input-1.1.18", - path = "/paper-input", - srcs = [ - "paper-input.html", - "paper-input-addon-behavior.html", - "paper-input-behavior.html", - "paper-input-char-counter.html", - "paper-input-container.html", - "paper-input-error.html", - "paper-textarea.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_autogrow_textarea", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_form_element_behavior", - "@org_polymer_iron_input", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_item", - licenses = ["notice"], # BSD-3-Clause - sha256 = "12ee0dcb61b0d5721c5988571f6974d7b2211e97724f4195893fbcc9058cdac8", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-item/archive/v1.1.4.tar.gz", - "https://github.com/PolymerElements/paper-item/archive/v1.1.4.tar.gz", - ], - strip_prefix = "paper-item-1.1.4", - path = "/paper-item", - srcs = [ - "paper-icon-item.html", - "paper-item.html", - "paper-item-behavior.html", - "paper-item-body.html", - "paper-item-shared-styles.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_flex_layout", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_listbox", - licenses = ["notice"], # BSD-3-Clause - sha256 = "3cb35f4fe9a3f15185a9e91711dba8f27e9291c8cd371ebf1be21b8f1d5f65fb", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-listbox/archive/v1.1.2.tar.gz", - "https://github.com/PolymerElements/paper-listbox/archive/v1.1.2.tar.gz", - ], - strip_prefix = "paper-listbox-1.1.2", - path = "/paper-listbox", - srcs = ["paper-listbox.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_menu_behavior", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_material", - licenses = ["notice"], # BSD-3-Clause - sha256 = "09f6c8bd6ddbea2be541dc86306efe41cdfb31bec0b69d35a5dc29772bbc8506", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-material/archive/v1.0.6.tar.gz", - "https://github.com/PolymerElements/paper-material/archive/v1.0.6.tar.gz", - ], - strip_prefix = "paper-material-1.0.6", - path = "/paper-material", - srcs = [ - "paper-material.html", - "paper-material-shared-styles.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_menu", - licenses = ["notice"], # BSD-3-Clause - sha256 = "a3cee220926e315f7412236b3628288774694447c0da4428345f36d0f127ba3b", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-menu/archive/v1.2.2.tar.gz", - "https://github.com/PolymerElements/paper-menu/archive/v1.2.2.tar.gz", - ], - strip_prefix = "paper-menu-1.2.2", - path = "/paper-menu", - srcs = [ - "paper-menu.html", - "paper-menu-shared-styles.html", - "paper-submenu.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_collapse", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_menu_behavior", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_menu_button", - licenses = ["notice"], # BSD-3-Clause - sha256 = "be3290c288a2bd4f9887213db22c75add99cc29ff4d088100c0bc4eb0e57997b", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-menu-button/archive/v1.5.1.tar.gz", - "https://github.com/PolymerElements/paper-menu-button/archive/v1.5.1.tar.gz", - ], - strip_prefix = "paper-menu-button-1.5.1", - path = "/paper-menu-button", - srcs = [ - "paper-menu-button.html", - "paper-menu-button-animations.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_dropdown", - "@org_polymer_neon_animation", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_progress", - licenses = ["notice"], # BSD-3-Clause - sha256 = "2b6776b2f023c1f344feea17ba29b58d879e46f8ed43b7256495054b5183fff6", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-progress/archive/v1.0.9.tar.gz", - "https://github.com/PolymerElements/paper-progress/archive/v1.0.9.tar.gz", - ], - strip_prefix = "paper-progress-1.0.9", - path = "/paper-progress", - srcs = ["paper-progress.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_range_behavior", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_radio_button", - licenses = ["notice"], # BSD-3-Clause - sha256 = "6e911d0c308aa388136b3af79d1bdcbe5a1f4159cbc79d71efb4ff3b6c0b4e91", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-radio-button/archive/v1.1.2.tar.gz", - "https://github.com/PolymerElements/paper-radio-button/archive/v1.1.2.tar.gz", - ], - strip_prefix = "paper-radio-button-1.1.2", - path = "/paper-radio-button", - srcs = ["paper-radio-button.html"], - deps = [ - "@org_polymer", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_radio_group", - licenses = ["notice"], # BSD-3-Clause - sha256 = "7885ad1f81e9dcc03dcea4139b54a201ff55c18543770cd44f94530046c9e163", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-radio-group/archive/v1.0.9.tar.gz", - "https://github.com/PolymerElements/paper-radio-group/archive/v1.0.9.tar.gz", - ], - strip_prefix = "paper-radio-group-1.0.9", - path = "/paper-radio-group", - srcs = ["paper-radio-group.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_selector", - "@org_polymer_paper_radio_button", - ], - ) - - web_library_external( - name = "org_polymer_paper_ripple", - licenses = ["notice"], # BSD-3-Clause - sha256 = "ba76bfb1c737260a8a103d3ca97faa1f7c3288c7db9b2519f401b7a782147c09", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-ripple/archive/v1.0.5.tar.gz", - "https://github.com/PolymerElements/paper-ripple/archive/v1.0.5.tar.gz", - ], - strip_prefix = "paper-ripple-1.0.5", - path = "/paper-ripple", - srcs = ["paper-ripple.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - ], - ) - - web_library_external( - name = "org_polymer_paper_slider", - licenses = ["notice"], # BSD-3-Clause - sha256 = "08e7c541dbf5d2e959208810bfc03188e82ced87e4d30d325172967f67962c3c", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-slider/archive/v1.0.10.tar.gz", - "https://github.com/PolymerElements/paper-slider/archive/v1.0.10.tar.gz", - ], - strip_prefix = "paper-slider-1.0.10", - path = "/paper-slider", - srcs = ["paper-slider.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_keys_behavior", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_form_element_behavior", - "@org_polymer_iron_range_behavior", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_input", - "@org_polymer_paper_progress", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_spinner", - licenses = ["notice"], # BSD-3-Clause - sha256 = "6a752907fab7899cbeed15b478e7b9299047c15fbf9d1561d6eb4d204bdbd178", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-spinner/archive/v1.1.1.tar.gz", - "https://github.com/PolymerElements/paper-spinner/archive/v1.1.1.tar.gz", - ], - strip_prefix = "paper-spinner-1.1.1", - path = "/paper-spinner", - srcs = [ - "paper-spinner.html", "paper-spinner-behavior.html", - "paper-spinner-lite.html", "paper-spinner-styles.html" - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_styles", - licenses = ["notice"], # BSD-3-Clause - sha256 = "6d26b0a4c286402098853dc7388f6b22f30dfb7a74e47b34992ac03380144bb2", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-styles/archive/v1.1.4.tar.gz", - "https://github.com/PolymerElements/paper-styles/archive/v1.1.4.tar.gz", - ], - strip_prefix = "paper-styles-1.1.4", - path = "/paper-styles", - srcs = [ - "classes/global.html", - "classes/shadow.html", - "classes/shadow-layout.html", - "classes/typography.html", - "color.html", - "default-theme.html", - "demo.css", - "demo-pages.html", - "paper-styles.html", - "paper-styles-classes.html", - "shadow.html", - "typography.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_font_roboto", - "@org_polymer_iron_flex_layout", - ], - ) - - web_library_external( - name = "org_polymer_paper_tabs", - licenses = ["notice"], # BSD-3-Clause - sha256 = "c23b6a5221db35e5b1ed3eb8e8696b952572563e285adaec96aba1e3134db825", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-tabs/archive/v1.7.0.tar.gz", - "https://github.com/PolymerElements/paper-tabs/archive/v1.7.0.tar.gz", - ], - strip_prefix = "paper-tabs-1.7.0", - path = "/paper-tabs", - srcs = [ - "paper-tab.html", - "paper-tabs.html", - "paper-tabs-icons.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_iron_behaviors", - "@org_polymer_iron_flex_layout", - "@org_polymer_iron_icon", - "@org_polymer_iron_iconset_svg", - "@org_polymer_iron_menu_behavior", - "@org_polymer_iron_resizable_behavior", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_icon_button", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_toast", - licenses = ["notice"], # BSD-3-Clause - sha256 = "55f623712ed1f2bae6d6fadc522a2458e083ccd44cc0a907672547e7b10758a9", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-toast/archive/v1.3.0.tar.gz", - "https://github.com/PolymerElements/paper-toast/archive/v1.3.0.tar.gz", - ], - strip_prefix = "paper-toast-1.3.0", - path = "/paper-toast", - srcs = ["paper-toast.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_a11y_announcer", - "@org_polymer_iron_overlay_behavior", - ], - ) - - web_library_external( - name = "org_polymer_paper_toggle_button", - licenses = ["notice"], # BSD-3-Clause - sha256 = "4aa7cf0396fa2994a8bc2ac6e8428f48b07b945bb7c41bd52041ef5827b45de3", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-toggle-button/archive/v1.2.0.tar.gz", - "https://github.com/PolymerElements/paper-toggle-button/archive/v1.2.0.tar.gz", - ], - strip_prefix = "paper-toggle-button-1.2.0", - path = "/paper-toggle-button", - srcs = ["paper-toggle-button.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_paper_behaviors", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_toolbar", - licenses = ["notice"], # BSD-3-Clause - sha256 = "dbddffc0654d9fb5fb48843087eebe16bf7a134902495a664c96c11bf8a2c63d", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-toolbar/archive/v1.1.4.tar.gz", - "https://github.com/PolymerElements/paper-toolbar/archive/v1.1.4.tar.gz", - ], - strip_prefix = "paper-toolbar-1.1.4", - path = "/paper-toolbar", - srcs = ["paper-toolbar.html"], - deps = [ - "@org_polymer", - "@org_polymer_iron_flex_layout", - "@org_polymer_paper_styles", - ], - ) - - web_library_external( - name = "org_polymer_paper_tooltip", - licenses = ["notice"], # BSD-3-Clause - sha256 = "4c6667acf01f73da14c3cbc0aa574bf14280304567987ee0314534328377d2ad", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/paper-tooltip/archive/v1.1.2.tar.gz", - "https://github.com/PolymerElements/paper-tooltip/archive/v1.1.2.tar.gz", - ], - strip_prefix = "paper-tooltip-1.1.2", - path = "/paper-tooltip", - srcs = ["paper-tooltip.html"], - deps = [ - "@org_polymer", - "@org_polymer_neon_animation", - ], - ) - - web_library_external( - name = "org_polymer", - licenses = ["notice"], # BSD-3-Clause - sha256 = "07a9e62ffb52193da3af09adda2fbac5cc690439978520e2d03e783863f65f91", - strip_prefix = "polymer-1.7.0", - urls = [ - "http://mirror.bazel.build/github.com/polymer/polymer/archive/v1.7.0.tar.gz", - "https://github.com/polymer/polymer/archive/v1.7.0.tar.gz", - ], - path = "/polymer", - srcs = [ - "polymer.html", - "polymer-micro.html", - "polymer-mini.html", - ], - ) - - web_library_external( - name = "org_polymer_prism", - licenses = ["notice"], # MIT - sha256 = "e06eb54f2a80e6b3cd0bd4d59f900423bcaee53fc03998a056df63740c684683", - urls = [ - "http://mirror.bazel.build/github.com/PrismJS/prism/archive/abee2b7587f1925e57777044270e2a1860810994.tar.gz", - "https://github.com/PrismJS/prism/archive/abee2b7587f1925e57777044270e2a1860810994.tar.gz", - ], - strip_prefix = "prism-abee2b7587f1925e57777044270e2a1860810994", - path = "/prism", - srcs = [ - "prism.js", - "themes/prism.css", - ], - ) - - web_library_external( - name = "org_polymer_prism_element", - licenses = ["notice"], # BSD-3-Clause - sha256 = "ad70bf9cd5bbdf525d465e1b0658867ab4022193eb9c74087a839044b46312b4", - urls = [ - "http://mirror.bazel.build/github.com/PolymerElements/prism-element/archive/1.0.4.tar.gz", - "https://github.com/PolymerElements/prism-element/archive/1.0.4.tar.gz", - ], - strip_prefix = "prism-element-1.0.4", - path = "/prism-element", - srcs = [ - "prism-highlighter.html", - "prism-import.html", - ], - deps = [ - "@org_polymer", - "@org_polymer_prism", - ], - ) - - web_library_external( - name = "org_polymer_promise_polyfill", - licenses = ["notice"], # BSD-3-Clause - sha256 = "4495450e5d884c3e16b537b43afead7f84d17c7dc061bcfcbf440eac083e4ef5", - strip_prefix = "promise-polyfill-1.0.0", - urls = [ - "http://mirror.bazel.build/github.com/PolymerLabs/promise-polyfill/archive/v1.0.0.tar.gz", - "https://github.com/PolymerLabs/promise-polyfill/archive/v1.0.0.tar.gz", - ], - path = "/promise-polyfill", - srcs = [ - "Promise.js", - "Promise-Statics.js", - "promise-polyfill.html", - "promise-polyfill-lite.html" - ], - deps = ["@org_polymer"], - ) - - web_library_external( - name = "org_polymer_web_animations_js", - licenses = ["notice"], # BSD-3-Clause - sha256 = "f8bd760cbdeba131f6790bd5abe170bcbf7b1755ff58ed16d0b82fa8a7f34a7f", - urls = [ - "http://mirror.bazel.build/github.com/web-animations/web-animations-js/archive/2.2.1.tar.gz", - "https://github.com/web-animations/web-animations-js/archive/2.2.1.tar.gz", - ], - strip_prefix = "web-animations-js-2.2.1", - path = "/web-animations-js", - srcs = ["web-animations-next-lite.min.js"], - ) - - web_library_external( - name = "org_polymer_webcomponentsjs", - licenses = ["notice"], # BSD-3-Clause - sha256 = "138c43306ee0a6d699ddca9b3c6b0f4982974ea8b7bdad291ea7276c72301df9", - urls = [ - "http://mirror.bazel.build/github.com/webcomponents/webcomponentsjs/archive/v0.7.22.tar.gz", - "https://github.com/webcomponents/webcomponentsjs/archive/v0.7.22.tar.gz", - ], - strip_prefix = "webcomponentsjs-0.7.22", - path = "/webcomponentsjs", - srcs = [ - "CustomElements.js", - "CustomElements.min.js", - "HTMLImports.js", - "HTMLImports.min.js", - "MutationObserver.js", - "MutationObserver.min.js", - "ShadowDOM.js", - "ShadowDOM.min.js", - "webcomponents.js", - "webcomponents.min.js", - "webcomponents-lite.js", - "webcomponents-lite.min.js", - ], - ) diff --git a/third_party/py/python_configure.bzl b/third_party/py/python_configure.bzl index b4a98af7b6e..19a6f1e749a 100644 --- a/third_party/py/python_configure.bzl +++ b/third_party/py/python_configure.bzl @@ -13,6 +13,7 @@ _NUMPY_INCLUDE_PATH = "NUMPY_INCLUDE_PATH" _PYTHON_BIN_PATH = "PYTHON_BIN_PATH" _PYTHON_INCLUDE_PATH = "PYTHON_INCLUDE_PATH" _PYTHON_LIB_PATH = "PYTHON_LIB_PATH" +_TF_PYTHON_CONFIG_REPO = "TF_PYTHON_CONFIG_REPO" def _tpl(repository_ctx, tpl, substitutions={}, out=None): @@ -278,18 +279,20 @@ def _create_local_python_repository(repository_ctx): }) -def _create_remote_python_repository(repository_ctx): +def _create_remote_python_repository(repository_ctx, remote_config_repo): """Creates pointers to a remotely configured repo set up to build with Python. """ _tpl(repository_ctx, "remote.BUILD", { - "%{REMOTE_PYTHON_REPO}": repository_ctx.attr.remote_config_repo, + "%{REMOTE_PYTHON_REPO}": remote_config_repo, }, "BUILD") def _python_autoconf_impl(repository_ctx): """Implementation of the python_autoconf repository rule.""" - if repository_ctx.attr.remote_config_repo != "": - _create_remote_python_repository(repository_ctx) + remote_config_repo = _get_env_var(repository_ctx, _TF_PYTHON_CONFIG_REPO, + repository_ctx.attr.remote_config_repo, False) + if remote_config_repo != "": + _create_remote_python_repository(repository_ctx, remote_config_repo) else: _create_local_python_repository(repository_ctx) @@ -307,6 +310,7 @@ python_configure = repository_rule( _PYTHON_INCLUDE_PATH, _PYTHON_LIB_PATH, _NUMPY_INCLUDE_PATH, + _TF_PYTHON_CONFIG_REPO, ], ) """Detects and configures the local Python. diff --git a/third_party/python.bzl b/third_party/python.bzl deleted file mode 100644 index 25c2ae3e780..00000000000 --- a/third_party/python.bzl +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TensorBoard external dependencies that are used on the python side. -# Protobuf and six were deliberately left in the top-level workspace, as they -# are used in TensorFlow as well. - -def tensorboard_python_workspace(): - native.new_http_archive( - name = "org_pythonhosted_markdown", - urls = [ - "http://mirror.bazel.build/pypi.python.org/packages/1d/25/3f6d2cb31ec42ca5bd3bfbea99b63892b735d76e26f20dd2dcc34ffe4f0d/Markdown-2.6.8.tar.gz", - "https://pypi.python.org/packages/1d/25/3f6d2cb31ec42ca5bd3bfbea99b63892b735d76e26f20dd2dcc34ffe4f0d/Markdown-2.6.8.tar.gz", - ], - strip_prefix = "Markdown-2.6.8", - sha256 = "0ac8a81e658167da95d063a9279c9c1b2699f37c7c4153256a458b3a43860e33", - build_file = str(Label("//third_party:markdown.BUILD")), - ) - - native.new_http_archive( - name = "org_html5lib", - urls = [ - "http://mirror.bazel.build/github.com/html5lib/html5lib-python/archive/0.9999999.tar.gz", - "https://github.com/html5lib/html5lib-python/archive/0.9999999.tar.gz", # identical to 1.0b8 - ], - sha256 = "184257f98539159a433e2a2197309657ae1283b4c44dbd9c87b2f02ff36adce8", - strip_prefix = "html5lib-python-0.9999999", - build_file = str(Label("//third_party:html5lib.BUILD")), - ) - - native.new_http_archive( - name = "org_mozilla_bleach", - urls = [ - "http://mirror.bazel.build/github.com/mozilla/bleach/archive/v1.5.tar.gz", - "https://github.com/mozilla/bleach/archive/v1.5.tar.gz", - ], - strip_prefix = "bleach-1.5", - sha256 = "0d68713d02ba4148c417ab1637dd819333d96929a34401d0233947bec0881ad8", - build_file = str(Label("//third_party:bleach.BUILD")), - ) - - native.new_http_archive( - name = "org_pocoo_werkzeug", - urls = [ - "http://mirror.bazel.build/pypi.python.org/packages/b7/7f/44d3cfe5a12ba002b253f6985a4477edfa66da53787a2a838a40f6415263/Werkzeug-0.11.10.tar.gz", - "https://pypi.python.org/packages/b7/7f/44d3cfe5a12ba002b253f6985a4477edfa66da53787a2a838a40f6415263/Werkzeug-0.11.10.tar.gz", - ], - strip_prefix = "Werkzeug-0.11.10", - sha256 = "cc64dafbacc716cdd42503cf6c44cb5a35576443d82f29f6829e5c49264aeeee", - build_file = str(Label("//third_party:werkzeug.BUILD")), - ) \ No newline at end of file diff --git a/third_party/toolchains/cpus/BUILD b/third_party/toolchains/cpus/BUILD new file mode 100644 index 00000000000..c33dcd4e119 --- /dev/null +++ b/third_party/toolchains/cpus/BUILD @@ -0,0 +1,80 @@ +# A build file to configure cc toolchain for CPU build used with Bazel remote +# execution service +# DO NOT EDIT: automatically generated BUILD file + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "malloc", +) + +cc_library( + name = "stl", +) + +filegroup( + name = "empty", + srcs = [], +) + +filegroup( + name = "cc_wrapper", + srcs = ["cc_wrapper.sh"], +) + +# This is the entry point for --crosstool_top. Toolchains are found +# by lopping off the name of --crosstool_top and searching for +# the "${CPU}" entry in the toolchains attribute. +cc_toolchain_suite( + name = "toolchain", + toolchains = { + "k8|clang": ":cc-compiler-k8", + "armeabi-v7a|compiler": ":cc-compiler-armeabi-v7a", + "ios_x86_64|compiler": ":cc-compiler-ios_x86_64", + }, +) + +cc_toolchain( + name = "cc-compiler-k8", + all_files = ":empty", + compiler_files = ":empty", + cpu = "k8", + dwp_files = ":empty", + dynamic_runtime_libs = [":empty"], + linker_files = ":empty", + objcopy_files = ":empty", + static_runtime_libs = [":empty"], + strip_files = ":empty", + supports_param_files = 1, +) + +# Android tooling requires a default toolchain for the armeabi-v7a cpu. +cc_toolchain( + name = "cc-compiler-armeabi-v7a", + all_files = ":empty", + compiler_files = ":empty", + cpu = "local", + dwp_files = ":empty", + dynamic_runtime_libs = [":empty"], + linker_files = ":empty", + objcopy_files = ":empty", + static_runtime_libs = [":empty"], + strip_files = ":empty", + supports_param_files = 1, +) + +# ios crosstool configuration requires a default toolchain for the +# ios_x86_64 cpu. +cc_toolchain( + name = "cc-compiler-ios_x86_64", + all_files = ":empty", + compiler_files = ":empty", + cpu = "local", + dwp_files = ":empty", + dynamic_runtime_libs = [":empty"], + linker_files = ":empty", + objcopy_files = ":empty", + static_runtime_libs = [":empty"], + strip_files = ":empty", + supports_param_files = 1, +) diff --git a/third_party/toolchains/cpus/CROSSTOOL b/third_party/toolchains/cpus/CROSSTOOL new file mode 100644 index 00000000000..0e2d0119afa --- /dev/null +++ b/third_party/toolchains/cpus/CROSSTOOL @@ -0,0 +1,922 @@ +# A crosstool configuration for CPU build used with Bazel remote +# execution service +# DO NOT EDIT: automatically generated file + +major_version: "local" +minor_version: "" +default_target_cpu: "same_as_host" + +default_toolchain { + cpu: "k8" + toolchain_identifier: "linux_gnu_x86" +} + +default_toolchain { + cpu: "armeabi-v7a" + toolchain_identifier: "stub_armeabi-v7a" +} + +default_toolchain { + cpu: "x64_windows_msvc" + toolchain_identifier: "msvc_x64" +} + +default_toolchain { + cpu: "x64_windows_msys" + toolchain_identifier: "msys_x64" +} + +default_toolchain { + cpu: "s390x" + toolchain_identifier: "linux_gnu_x86" +} + +default_toolchain { + cpu: "ios_x86_64" + toolchain_identifier: "ios_x86_64" +} + +# Android tooling requires a default toolchain for the armeabi-v7a cpu. +toolchain { + abi_version: "armeabi-v7a" + abi_libc_version: "armeabi-v7a" + builtin_sysroot: "" + compiler: "compiler" + host_system_name: "armeabi-v7a" + needsPic: true + supports_gold_linker: false + supports_incremental_linker: false + supports_fission: false + supports_interface_shared_objects: false + supports_normalizing_ar: false + supports_start_end_lib: false + target_libc: "armeabi-v7a" + target_cpu: "armeabi-v7a" + target_system_name: "armeabi-v7a" + toolchain_identifier: "stub_armeabi-v7a" + + tool_path { name: "ar" path: "/bin/false" } + tool_path { name: "compat-ld" path: "/bin/false" } + tool_path { name: "cpp" path: "/bin/false" } + tool_path { name: "dwp" path: "/bin/false" } + tool_path { name: "gcc" path: "/bin/false" } + tool_path { name: "gcov" path: "/bin/false" } + tool_path { name: "ld" path: "/bin/false" } + + tool_path { name: "nm" path: "/bin/false" } + tool_path { name: "objcopy" path: "/bin/false" } + tool_path { name: "objdump" path: "/bin/false" } + tool_path { name: "strip" path: "/bin/false" } + linking_mode_flags { mode: DYNAMIC } +} + +toolchain { + toolchain_identifier: "ios_x86_64" + host_system_name: "x86_64-apple-macosx" + target_system_name: "x86_64-apple-ios" + target_cpu: "ios_x86_64" + target_libc: "ios" + compiler: "compiler" + abi_version: "local" + abi_libc_version: "local" + supports_gold_linker: false + supports_incremental_linker: false + supports_fission: false + supports_interface_shared_objects: false + supports_normalizing_ar: false + supports_start_end_lib: false + + tool_path { name: "ar" path: "/bin/false" } + tool_path { name: "compat-ld" path: "/bin/false" } + tool_path { name: "cpp" path: "/bin/false" } + tool_path { name: "dwp" path: "/bin/false" } + tool_path { name: "gcc" path: "/bin/false" } + tool_path { name: "gcov" path: "/bin/false" } + tool_path { name: "ld" path: "/bin/false" } + + tool_path { name: "nm" path: "/bin/false" } + tool_path { name: "objcopy" path: "/bin/false" } + tool_path { name: "objdump" path: "/bin/false" } + tool_path { name: "strip" path: "/bin/false" } + linking_mode_flags { mode: DYNAMIC } +} + +toolchain { + toolchain_identifier: "linux_gnu_x86" + abi_version: "clang" + abi_libc_version: "glibc_2.19" + builtin_sysroot: "" + compiler: "clang" + host_system_name: "i686-unknown-linux-gnu" + needsPic: true + supports_gold_linker: true + supports_incremental_linker: false + supports_fission: false + supports_interface_shared_objects: false + supports_normalizing_ar: false + supports_start_end_lib: true + target_libc: "glibc_2.19" + target_cpu: "k8" + target_system_name: "x86_64-unknown-linux-gnu" + cxx_flag: "-std=c++0x" + linker_flag: "-lstdc++" + linker_flag: "-lm" + linker_flag: "-fuse-ld=gold" + linker_flag: "-B/usr/local/bin" + linker_flag: "-B/usr/bin" + cxx_builtin_include_directory: "/usr/include/c++/4.9" + cxx_builtin_include_directory: "/usr/include/x86_64-linux-gnu/c++/4.9" + cxx_builtin_include_directory: "/usr/include/c++/4.9/backward" + cxx_builtin_include_directory: "/usr/local/include" + cxx_builtin_include_directory: "/usr/local/lib/clang/5.0.0/include" + cxx_builtin_include_directory: "/usr/include/x86_64-linux-gnu" + cxx_builtin_include_directory: "/usr/include" + objcopy_embed_flag: "-I" + objcopy_embed_flag: "binary" + unfiltered_cxx_flag: "-Wno-builtin-macro-redefined" + unfiltered_cxx_flag: "-D__DATE__=\"redacted\"" + unfiltered_cxx_flag: "-D__TIMESTAMP__=\"redacted\"" + unfiltered_cxx_flag: "-D__TIME__=\"redacted\"" + compiler_flag: "-U_FORTIFY_SOURCE" + compiler_flag: "-fstack-protector" + compiler_flag: "-Wall" + compiler_flag: "-B/usr/local/bin" + compiler_flag: "-B/usr/bin" + compiler_flag: "-fcolor-diagnostics" + compiler_flag: "-fno-omit-frame-pointer" + tool_path {name: "ld" path: "/usr/bin/ld" } + tool_path {name: "cpp" path: "/usr/bin/cpp" } + tool_path {name: "dwp" path: "/usr/bin/dwp" } + tool_path {name: "gcov" path: "/usr/bin/gcov" } + tool_path {name: "nm" path: "/usr/bin/nm" } + tool_path {name: "objcopy" path: "/usr/bin/objcopy" } + tool_path {name: "objdump" path: "/usr/bin/objdump" } + tool_path {name: "strip" path: "/usr/bin/strip" } + tool_path {name: "gcc" path: "/usr/local/bin/clang" } + tool_path {name: "ar" path: "/usr/bin/ar" } + + compilation_mode_flags { + mode: DBG + compiler_flag: "-g" + } + compilation_mode_flags { + mode: OPT + compiler_flag: "-g0" + compiler_flag: "-O2" + compiler_flag: "-D_FORTIFY_SOURCE=1" + compiler_flag: "-DNDEBUG" + compiler_flag: "-ffunction-sections" + compiler_flag: "-fdata-sections" + linker_flag: "-Wl,--gc-sections" + } + linking_mode_flags { mode: DYNAMIC } + + + feature { + name: 'coverage' + provides: 'profile' + flag_set { + action: 'preprocess-assemble' + action: 'c-compile' + action: 'c++-compile' + action: 'c++-header-parsing' + action: 'c++-header-preprocessing' + action: 'c++-module-compile' + flag_group { + flag: '-fprofile-arcs' + flag: '-ftest-coverage' + } + } + flag_set { + action: 'c++-link-interface-dynamic-library' + action: 'c++-link-dynamic-library' + action: 'c++-link-executable' + flag_group { + flag: '-lgcov' + } + } + } + +} + +toolchain { + toolchain_identifier: "msvc_x64" + host_system_name: "local" + target_system_name: "local" + + abi_version: "local" + abi_libc_version: "local" + target_cpu: "x64_windows" + compiler: "cl" + target_libc: "msvcrt140" + default_python_version: "python2.7" + + + + tool_path { + name: "ar" + path: "wrapper/bin/msvc_link.bat" + } + tool_path { + name: "cpp" + path: "wrapper/bin/msvc_cl.bat" + } + tool_path { + name: "gcc" + path: "wrapper/bin/msvc_cl.bat" + } + tool_path { + name: "gcov" + path: "wrapper/bin/msvc_nop.bat" + } + tool_path { + name: "ld" + path: "wrapper/bin/msvc_link.bat" + } + tool_path { + name: "nm" + path: "wrapper/bin/msvc_nop.bat" + } + tool_path { + name: "objcopy" + path: "wrapper/bin/msvc_nop.bat" + } + tool_path { + name: "objdump" + path: "wrapper/bin/msvc_nop.bat" + } + tool_path { + name: "strip" + path: "wrapper/bin/msvc_nop.bat" + } + supports_gold_linker: false + supports_start_end_lib: false + supports_interface_shared_objects: false + supports_incremental_linker: false + supports_normalizing_ar: true + needsPic: false + + # TODO(pcloudy): Review those flags below, they should be defined by cl.exe + compiler_flag: "/DOS_WINDOWS=OS_WINDOWS" + compiler_flag: "/DCOMPILER_MSVC" + + # Don't pollute with GDI macros in windows.h. + compiler_flag: "/DNOGDI" + # Don't define min/max macros in windows.h. + compiler_flag: "/DNOMINMAX" + compiler_flag: "/DPRAGMA_SUPPORTED" + # Platform defines. + compiler_flag: "/D_WIN32_WINNT=0x0600" + # Turn off warning messages. + compiler_flag: "/D_CRT_SECURE_NO_DEPRECATE" + compiler_flag: "/D_CRT_SECURE_NO_WARNINGS" + compiler_flag: "/D_SILENCE_STDEXT_HASH_DEPRECATION_WARNINGS" + # Use math constants (M_PI, etc.) from the math library + compiler_flag: "/D_USE_MATH_DEFINES" + + # Useful options to have on for compilation. + # Increase the capacity of object files to 2^32 sections. + compiler_flag: "/bigobj" + # Allocate 500MB for precomputed headers. + compiler_flag: "/Zm500" + # Use unsigned char by default. + compiler_flag: "/J" + # Use function level linking. + compiler_flag: "/Gy" + # Use string pooling. + compiler_flag: "/GF" + # Warning level 3 (could possibly go to 4 in the future). + compiler_flag: "/W3" + # Catch both asynchronous (structured) and synchronous (C++) exceptions. + compiler_flag: "/EHsc" + + # Globally disabled warnings. + # Don't warn about elements of array being be default initialized. + compiler_flag: "/wd4351" + # Don't warn about no matching delete found. + compiler_flag: "/wd4291" + # Don't warn about diamond inheritance patterns. + compiler_flag: "/wd4250" + # Don't warn about insecure functions (e.g. non _s functions). + compiler_flag: "/wd4996" + + linker_flag: "/MACHINE:X64" + + linker_flag: "/SUBSYSTEM:CONSOLE" + + # Suppress startup banner. + feature { + name: "nologo" + flag_set { + action: "c-compile" + action: "c++-compile" + action: "c++-module-compile" + action: "c++-module-codegen" + action: "c++-header-parsing" + action: "c++-header-preprocessing" + action: "assemble" + action: "preprocess-assemble" + action: "c++-link-executable" + action: "c++-link-dynamic-library" + action: "c++-link-static-library" + action: "c++-link-alwayslink-static-library" + action: "c++-link-pic-static-library" + action: "c++-link-alwayslink-pic-static-library" + flag_group { + flag: "/nologo" + } + } + } + + feature { + name: "msvc_env" + env_set { + action: "c-compile" + action: "c++-compile" + action: "c++-module-compile" + action: "c++-module-codegen" + action: "c++-header-parsing" + action: "c++-header-preprocessing" + action: "assemble" + action: "preprocess-assemble" + action: "c++-link-executable" + action: "c++-link-dynamic-library" + action: "c++-link-static-library" + action: "c++-link-alwayslink-static-library" + action: "c++-link-pic-static-library" + action: "c++-link-alwayslink-pic-static-library" + env_entry { + key: "PATH" + value: "" + } + env_entry { + key: "INCLUDE" + value: "" + } + env_entry { + key: "LIB" + value: "" + } + env_entry { + key: "TMP" + value: "" + } + } + } + + feature { + name: 'include_paths' + flag_set { + action: 'preprocess-assemble' + action: 'c-compile' + action: 'c++-compile' + action: 'c++-header-parsing' + action: 'c++-header-preprocessing' + action: 'c++-module-compile' + flag_group { + flag: '/I%{quote_include_paths}' + } + flag_group { + flag: '/I%{include_paths}' + } + flag_group { + flag: '/I%{system_include_paths}' + } + } + } + + # Stop adding any flag for dotD file, Bazel knows how to parse the output of /showIncludes option + # TODO(bazel-team): Remove this empty feature. https://github.com/bazelbuild/bazel/issues/2868 + feature { + name: 'dependency_file' + } + + # Tell Bazel to parse the output of /showIncludes + feature { + name: 'parse_showincludes' + flag_set { + action: 'assemble' + action: 'preprocess-assemble' + action: 'c-compile' + action: 'c++-compile' + action: 'c++-module-compile' + action: 'c++-header-preprocessing' + action: 'c++-header-parsing' + flag_group { + flag: "/showIncludes" + } + } + } + + # Stop passing -frandom-seed option + feature { + name: 'random_seed' + } + + # This feature is just for enabling flag_set in action_config for -c and -o options during the transitional period + feature { + name: 'compile_action_flags_in_flag_set' + } + + action_config { + config_name: 'c-compile' + action_name: 'c-compile' + tool { + tool_path: 'wrapper/bin/msvc_cl.bat' + } + flag_set { + flag_group { + flag: '/c' + flag: '%{source_file}' + } + } + flag_set { + expand_if_all_available: 'output_object_file' + flag_group { + flag: '/Fo%{output_object_file}' + } + } + flag_set { + expand_if_all_available: 'output_assembly_file' + flag_group { + flag: '/Fa%{output_assembly_file}' + } + } + flag_set { + expand_if_all_available: 'output_preprocess_file' + flag_group { + flag: '/P' + flag: '/Fi%{output_preprocess_file}' + } + } + implies: 'nologo' + implies: 'msvc_env' + implies: 'parse_showincludes' + } + + action_config { + config_name: 'c++-compile' + action_name: 'c++-compile' + tool { + tool_path: 'wrapper/bin/msvc_cl.bat' + } + flag_set { + flag_group { + flag: '/c' + flag: '%{source_file}' + } + } + flag_set { + expand_if_all_available: 'output_object_file' + flag_group { + flag: '/Fo%{output_object_file}' + } + } + flag_set { + expand_if_all_available: 'output_assembly_file' + flag_group { + flag: '/Fa%{output_assembly_file}' + } + } + flag_set { + expand_if_all_available: 'output_preprocess_file' + flag_group { + flag: '/P' + flag: '/Fi%{output_preprocess_file}' + } + } + implies: 'nologo' + implies: 'msvc_env' + implies: 'parse_showincludes' + } + + action_config { + config_name: 'c++-link-executable' + action_name: 'c++-link-executable' + tool { + tool_path: 'wrapper/bin/msvc_link.bat' + } + implies: 'nologo' + implies: 'strip_debug_symbols' + implies: 'linkstamps' + implies: 'output_execpath_flags' + implies: 'input_param_flags' + implies: 'legacy_link_flags' + implies: 'linker_param_file' + implies: 'msvc_env' + } + + action_config { + config_name: 'c++-link-dynamic-library' + action_name: 'c++-link-dynamic-library' + tool { + tool_path: 'wrapper/bin/msvc_link.bat' + } + implies: 'nologo' + implies: 'strip_debug_symbols' + implies: 'shared_flag' + implies: 'linkstamps' + implies: 'output_execpath_flags' + implies: 'input_param_flags' + implies: 'has_configured_linker_path' + implies: 'legacy_link_flags' + implies: 'linker_param_file' + implies: 'msvc_env' + } + + action_config { + config_name: 'c++-link-static-library' + action_name: 'c++-link-static-library' + tool { + tool_path: 'wrapper/bin/msvc_link.bat' + } + implies: 'nologo' + implies: 'input_param_flags' + implies: 'linker_param_file' + implies: 'msvc_env' + } + + action_config { + config_name: 'c++-link-alwayslink-static-library' + action_name: 'c++-link-alwayslink-static-library' + tool { + tool_path: 'wrapper/bin/msvc_link.bat' + } + implies: 'nologo' + implies: 'input_param_flags' + implies: 'linker_param_file' + implies: 'msvc_env' + } + + # TODO(pcloudy): The following action_config is listed in MANDATORY_LINK_TARGET_TYPES. + # But do we really need them on Windows? + action_config { + config_name: 'c++-link-pic-static-library' + action_name: 'c++-link-pic-static-library' + tool { + tool_path: 'wrapper/bin/msvc_link.bat' + } + implies: 'nologo' + implies: 'input_param_flags' + implies: 'linker_param_file' + implies: 'msvc_env' + } + + action_config { + config_name: 'c++-link-alwayslink-pic-static-library' + action_name: 'c++-link-alwayslink-pic-static-library' + tool { + tool_path: 'wrapper/bin/msvc_link.bat' + } + implies: 'nologo' + implies: 'input_param_flags' + implies: 'linker_param_file' + implies: 'msvc_env' + } + + action_config { + config_name: 'c++-link-interface-dynamic-library' + action_name: 'c++-link-interface-dynamic-library' + tool { + tool_path: 'wrapper/bin/msvc_link.bat' + } + implies: 'nologo' + implies: 'strip_debug_symbols' + implies: 'linker_param_file' + implies: 'msvc_env' + } + + feature { + name: 'generate_pdb_file' + requires: { + feature: 'dbg' + } + requires: { + feature: 'fastbuild' + } + } + + feature { + name: 'has_configured_linker_path' + } + + feature { + name: 'strip_debug_symbols' + flag_set { + action: 'c++-link-executable' + action: 'c++-link-dynamic-library' + action: 'c++-link-interface-dynamic-library' + flag_group { + expand_if_all_available: 'strip_debug_symbols' + flag: '-Wl,-S' + } + } + } + + feature { + name: 'shared_flag' + flag_set { + action: 'c++-link-dynamic-library' + flag_group { + flag: '/DLL' + } + } + } + + feature { + name: 'linkstamps' + flag_set { + action: 'c++-link-executable' + action: 'c++-link-dynamic-library' + expand_if_all_available: 'linkstamp_paths' + flag_group { + flag: '%{linkstamp_paths}' + } + } + } + + feature { + name: 'output_execpath_flags' + flag_set { + expand_if_all_available: 'output_execpath' + action: 'c++-link-executable' + action: 'c++-link-dynamic-library' + flag_group { + flag: '/OUT:%{output_execpath}' + } + } + } + + feature { + name: 'input_param_flags' + flag_set { + expand_if_all_available: 'library_search_directories' + action: 'c++-link-executable' + action: 'c++-link-dynamic-library' + action: 'c++-link-static-library' + action: 'c++-link-alwayslink-static-library' + action: 'c++-link-pic-static-library' + action: 'c++-link-alwayslink-pic-static-library' + flag_group { + iterate_over: 'library_search_directories' + flag: "-L%{library_search_directories}" + } + } + flag_set { + expand_if_all_available: 'libopts' + action: 'c++-link-executable' + action: 'c++-link-dynamic-library' + action: 'c++-link-static-library' + action: 'c++-link-alwayslink-static-library' + action: 'c++-link-pic-static-library' + action: 'c++-link-alwayslink-pic-static-library' + flag_group { + flag: '%{libopts}' + } + } + flag_set { + expand_if_all_available: 'libraries_to_link' + action: 'c++-link-executable' + action: 'c++-link-dynamic-library' + action: 'c++-link-static-library' + action: 'c++-link-alwayslink-static-library' + action: 'c++-link-pic-static-library' + action: 'c++-link-alwayslink-pic-static-library' + flag_group { + iterate_over: 'libraries_to_link' + flag_group { + expand_if_equal: { + variable: 'libraries_to_link.type' + value: 'object_file_group' + } + iterate_over: 'libraries_to_link.object_files' + flag_group { + flag: '%{libraries_to_link.object_files}' + } + } + flag_group { + expand_if_equal: { + variable: 'libraries_to_link.type' + value: 'object_file' + } + flag_group { + flag: '%{libraries_to_link.name}' + } + } + flag_group { + expand_if_equal: { + variable: 'libraries_to_link.type' + value: 'interface_library' + } + flag_group { + expand_if_false: 'libraries_to_link.is_whole_archive' + flag: '%{libraries_to_link.name}' + } + flag_group { + expand_if_true: 'libraries_to_link.is_whole_archive' + flag: '/WHOLEARCHIVE:%{libraries_to_link.name}' + } + } + flag_group { + expand_if_equal: { + variable: 'libraries_to_link.type' + value: 'static_library' + } + flag_group { + expand_if_false: 'libraries_to_link.is_whole_archive' + flag: '%{libraries_to_link.name}' + } + flag_group { + expand_if_true: 'libraries_to_link.is_whole_archive' + flag: '/WHOLEARCHIVE:%{libraries_to_link.name}' + } + } + flag_group { + expand_if_equal: { + variable: 'libraries_to_link.type' + value: 'dynamic_library' + } + flag_group { + expand_if_false: 'libraries_to_link.is_whole_archive' + flag: '%{libraries_to_link.name}' + } + flag_group { + expand_if_true: 'libraries_to_link.is_whole_archive' + flag: '/WHOLEARCHIVE:%{libraries_to_link.name}' + } + } + flag_group { + expand_if_equal: { + variable: 'libraries_to_link.type' + value: 'versioned_dynamic_library' + } + flag_group { + expand_if_false: 'libraries_to_link.is_whole_archive' + flag: '%{libraries_to_link.name}' + } + flag_group { + expand_if_true: 'libraries_to_link.is_whole_archive' + flag: '/WHOLEARCHIVE:%{libraries_to_link.name}' + } + } + } + } + } + + feature { + name: 'legacy_link_flags' + flag_set { + expand_if_all_available: 'legacy_link_flags' + action: 'c++-link-executable' + action: 'c++-link-dynamic-library' + flag_group { + flag: '%{legacy_link_flags}' + } + } + } + + feature { + name: 'linker_param_file' + flag_set { + expand_if_all_available: 'linker_param_file' + action: 'c++-link-executable' + action: 'c++-link-dynamic-library' + flag_group { + flag: '-Wl,@%{linker_param_file}' + } + } + flag_set { + expand_if_all_available: 'linker_param_file' + action: 'c++-link-static-library' + action: 'c++-link-alwayslink-static-library' + action: 'c++-link-pic-static-library' + action: 'c++-link-alwayslink-pic-static-library' + flag_group { + flag: '@%{linker_param_file}' + } + } + } + + feature { + name: 'link_crt_library' + flag_set { + action: 'c-compile' + action: 'c++-compile' + flag_group { + # The flag is filled by cc_configure. + # The default option is /MT, set USE_DYNAMIC_CRT=1 to change it to /MD + flag: "" + } + } + flag_set { + action: 'c++-link-executable' + action: 'c++-link-dynamic-library' + flag_group { + # The flag is filled by cc_configure. + # The default value is libcmt.lib, set USE_DYNAMIC_CRT=1 to change it to msvcrt.lib + flag: "/DEFAULTLIB:" + } + } + } + + feature { + name: 'link_crt_debug_library' + flag_set { + action: 'c-compile' + action: 'c++-compile' + flag_group { + # The flag is filled by cc_configure. + # The default option is /MTd, set USE_DYNAMIC_CRT=1 to change it to /MDd + flag: "" + } + } + flag_set { + action: 'c++-link-executable' + action: 'c++-link-dynamic-library' + flag_group { + # The flag is filled by cc_configure. + # The default value is libcmtd.lib, set USE_DYNAMIC_CRT=1 to change it to msvcrtd.lib + flag: "/DEFAULTLIB:" + } + } + } + + feature { + name: 'dbg' + flag_set { + action: 'c-compile' + action: 'c++-compile' + flag_group { + flag: "/Od" + flag: "/Z7" + # This will signal the wrapper that we are doing a debug build, which sets + # some internal state of the toolchain wrapper. It is intentionally a "-" + # flag to make this very obvious. + flag: "-g" + } + } + flag_set { + action: 'c++-link-executable' + action: 'c++-link-dynamic-library' + flag_group { + flag: "/DEBUG:FULL" + flag: "/INCREMENTAL:NO" + } + } + implies: 'link_crt_debug_library' + implies: 'generate_pdb_file' + } + + feature { + name: 'fastbuild' + flag_set { + action: 'c-compile' + action: 'c++-compile' + flag_group { + flag: "/Od" + flag: "/Z7" + } + } + flag_set { + action: 'c++-link-executable' + action: 'c++-link-dynamic-library' + flag_group { + flag: "/DEBUG:FASTLINK" + flag: "/INCREMENTAL:NO" + } + } + implies: 'link_crt_library' + implies: 'generate_pdb_file' + } + + feature { + name: 'opt' + flag_set { + action: 'c-compile' + action: 'c++-compile' + flag_group { + flag: "/O2" + } + } + implies: 'link_crt_library' + } + + compilation_mode_flags { + mode: DBG + compiler_flag: "-Xcompilation-mode=dbg" + linker_flag: "-Xcompilation-mode=dbg" + } + + compilation_mode_flags { + mode: FASTBUILD + compiler_flag: "-Xcompilation-mode=fastbuild" + linker_flag: "-Xcompilation-mode=fastbuild" + } + + compilation_mode_flags { + mode: OPT + compiler_flag: "-Xcompilation-mode=opt" + linker_flag: "-Xcompilation-mode=opt" + } + +} diff --git a/third_party/toolchains/cpus/py/BUILD b/third_party/toolchains/cpus/py/BUILD new file mode 100644 index 00000000000..d54eebb8e76 --- /dev/null +++ b/third_party/toolchains/cpus/py/BUILD @@ -0,0 +1,185 @@ +# A build file to configure python remote repository used with Bazel remote +# execution service +# DO NOT EDIT: automatically generated BUILD file + +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "python_headers", + hdrs = [":python_include"], + includes = ["python_include"], +) + +cc_library( + name = "numpy_headers", + hdrs = [":numpy_include"], + includes = ["numpy_include"], +) + +config_setting( + name = "windows", + values = {"cpu": "x64_windows"}, + visibility = ["//visibility:public"], +) + +genrule( + name = "python_include", + outs = [ + "python_include/code.h", + "python_include/dtoa.h", + "python_include/tupleobject.h", + "python_include/object.h", + "python_include/ast.h", + "python_include/pymacconfig.h", + "python_include/errcode.h", + "python_include/frameobject.h", + "python_include/pgenheaders.h", + "python_include/cellobject.h", + "python_include/intobject.h", + "python_include/pythread.h", + "python_include/cStringIO.h", + "python_include/boolobject.h", + "python_include/modsupport.h", + "python_include/import.h", + "python_include/pymath.h", + "python_include/node.h", + "python_include/funcobject.h", + "python_include/eval.h", + "python_include/longintrepr.h", + "python_include/floatobject.h", + "python_include/rangeobject.h", + "python_include/pyfpe.h", + "python_include/pystrcmp.h", + "python_include/dictobject.h", + "python_include/pyarena.h", + "python_include/objimpl.h", + "python_include/bitset.h", + "python_include/memoryobject.h", + "python_include/bytearrayobject.h", + "python_include/pydebug.h", + "python_include/pyerrors.h", + "python_include/weakrefobject.h", + "python_include/grammar.h", + "python_include/symtable.h", + "python_include/longobject.h", + "python_include/structmember.h", + "python_include/enumobject.h", + "python_include/classobject.h", + "python_include/unicodeobject.h", + "python_include/sliceobject.h", + "python_include/pystrtod.h", + "python_include/genobject.h", + "python_include/pymactoolbox.h", + "python_include/compile.h", + "python_include/pyexpat.h", + "python_include/asdl.h", + "python_include/codecs.h", + "python_include/pyctype.h", + "python_include/sysmodule.h", + "python_include/methodobject.h", + "python_include/graminit.h", + "python_include/cobject.h", + "python_include/intrcheck.h", + "python_include/pyport.h", + "python_include/warnings.h", + "python_include/osdefs.h", + "python_include/fileobject.h", + "python_include/stringobject.h", + "python_include/timefuncs.h", + "python_include/traceback.h", + "python_include/ceval.h", + "python_include/bytes_methods.h", + "python_include/pyconfig.h", + "python_include/Python.h", + "python_include/moduleobject.h", + "python_include/pystate.h", + "python_include/descrobject.h", + "python_include/ucnhash.h", + "python_include/pygetopt.h", + "python_include/pymem.h", + "python_include/complexobject.h", + "python_include/structseq.h", + "python_include/datetime.h", + "python_include/pythonrun.h", + "python_include/numpy/oldnumeric.h", + "python_include/numpy/npy_1_7_deprecated_api.h", + "python_include/numpy/ufunc_api.txt", + "python_include/numpy/multiarray_api.txt", + "python_include/numpy/halffloat.h", + "python_include/numpy/npy_common.h", + "python_include/numpy/utils.h", + "python_include/numpy/npy_interrupt.h", + "python_include/numpy/npy_endian.h", + "python_include/numpy/__ufunc_api.h", + "python_include/numpy/_neighborhood_iterator_imp.h", + "python_include/numpy/ufuncobject.h", + "python_include/numpy/ndarraytypes.h", + "python_include/numpy/npy_math.h", + "python_include/numpy/noprefix.h", + "python_include/numpy/npy_3kcompat.h", + "python_include/numpy/arrayscalars.h", + "python_include/numpy/npy_os.h", + "python_include/numpy/ndarrayobject.h", + "python_include/numpy/npy_no_deprecated_api.h", + "python_include/numpy/arrayobject.h", + "python_include/numpy/_numpyconfig.h", + "python_include/numpy/__multiarray_api.h", + "python_include/numpy/npy_cpu.h", + "python_include/numpy/old_defines.h", + "python_include/numpy/numpyconfig.h", + "python_include/pycapsule.h", + "python_include/setobject.h", + "python_include/listobject.h", + "python_include/bytesobject.h", + "python_include/pgen.h", + "python_include/patchlevel.h", + "python_include/opcode.h", + "python_include/parsetok.h", + "python_include/marshal.h", + "python_include/token.h", + "python_include/iterobject.h", + "python_include/abstract.h", + "python_include/py_curses.h", + "python_include/metagrammar.h", + "python_include/bufferobject.h", + "python_include/Python-ast.h", + ], + cmd = """ +cp "/usr/include/python2.7/code.h" "$(@D)/python_include/code.h" && cp "/usr/include/python2.7/dtoa.h" "$(@D)/python_include/dtoa.h" && cp "/usr/include/python2.7/tupleobject.h" "$(@D)/python_include/tupleobject.h" && cp "/usr/include/python2.7/object.h" "$(@D)/python_include/object.h" && cp "/usr/include/python2.7/ast.h" "$(@D)/python_include/ast.h" && cp "/usr/include/python2.7/pymacconfig.h" "$(@D)/python_include/pymacconfig.h" && cp "/usr/include/python2.7/errcode.h" "$(@D)/python_include/errcode.h" && cp "/usr/include/python2.7/frameobject.h" "$(@D)/python_include/frameobject.h" && cp "/usr/include/python2.7/pgenheaders.h" "$(@D)/python_include/pgenheaders.h" && cp "/usr/include/python2.7/cellobject.h" "$(@D)/python_include/cellobject.h" && cp "/usr/include/python2.7/intobject.h" "$(@D)/python_include/intobject.h" && cp "/usr/include/python2.7/pythread.h" "$(@D)/python_include/pythread.h" && cp "/usr/include/python2.7/cStringIO.h" "$(@D)/python_include/cStringIO.h" && cp "/usr/include/python2.7/boolobject.h" "$(@D)/python_include/boolobject.h" && cp "/usr/include/python2.7/modsupport.h" "$(@D)/python_include/modsupport.h" && cp "/usr/include/python2.7/import.h" "$(@D)/python_include/import.h" && cp "/usr/include/python2.7/pymath.h" "$(@D)/python_include/pymath.h" && cp "/usr/include/python2.7/node.h" "$(@D)/python_include/node.h" && cp "/usr/include/python2.7/funcobject.h" "$(@D)/python_include/funcobject.h" && cp "/usr/include/python2.7/eval.h" "$(@D)/python_include/eval.h" && cp "/usr/include/python2.7/longintrepr.h" "$(@D)/python_include/longintrepr.h" && cp "/usr/include/python2.7/floatobject.h" "$(@D)/python_include/floatobject.h" && cp "/usr/include/python2.7/rangeobject.h" "$(@D)/python_include/rangeobject.h" && cp "/usr/include/python2.7/pyfpe.h" "$(@D)/python_include/pyfpe.h" && cp "/usr/include/python2.7/pystrcmp.h" "$(@D)/python_include/pystrcmp.h" && cp "/usr/include/python2.7/dictobject.h" "$(@D)/python_include/dictobject.h" && cp "/usr/include/python2.7/pyarena.h" "$(@D)/python_include/pyarena.h" && cp "/usr/include/python2.7/objimpl.h" "$(@D)/python_include/objimpl.h" && cp "/usr/include/python2.7/bitset.h" "$(@D)/python_include/bitset.h" && cp "/usr/include/python2.7/memoryobject.h" "$(@D)/python_include/memoryobject.h" && cp "/usr/include/python2.7/bytearrayobject.h" "$(@D)/python_include/bytearrayobject.h" && cp "/usr/include/python2.7/pydebug.h" "$(@D)/python_include/pydebug.h" && cp "/usr/include/python2.7/pyerrors.h" "$(@D)/python_include/pyerrors.h" && cp "/usr/include/python2.7/weakrefobject.h" "$(@D)/python_include/weakrefobject.h" && cp "/usr/include/python2.7/grammar.h" "$(@D)/python_include/grammar.h" && cp "/usr/include/python2.7/symtable.h" "$(@D)/python_include/symtable.h" && cp "/usr/include/python2.7/longobject.h" "$(@D)/python_include/longobject.h" && cp "/usr/include/python2.7/structmember.h" "$(@D)/python_include/structmember.h" && cp "/usr/include/python2.7/enumobject.h" "$(@D)/python_include/enumobject.h" && cp "/usr/include/python2.7/classobject.h" "$(@D)/python_include/classobject.h" && cp "/usr/include/python2.7/unicodeobject.h" "$(@D)/python_include/unicodeobject.h" && cp "/usr/include/python2.7/sliceobject.h" "$(@D)/python_include/sliceobject.h" && cp "/usr/include/python2.7/pystrtod.h" "$(@D)/python_include/pystrtod.h" && cp "/usr/include/python2.7/genobject.h" "$(@D)/python_include/genobject.h" && cp "/usr/include/python2.7/pymactoolbox.h" "$(@D)/python_include/pymactoolbox.h" && cp "/usr/include/python2.7/compile.h" "$(@D)/python_include/compile.h" && cp "/usr/include/python2.7/pyexpat.h" "$(@D)/python_include/pyexpat.h" && cp "/usr/include/python2.7/asdl.h" "$(@D)/python_include/asdl.h" && cp "/usr/include/python2.7/codecs.h" "$(@D)/python_include/codecs.h" && cp "/usr/include/python2.7/pyctype.h" "$(@D)/python_include/pyctype.h" && cp "/usr/include/python2.7/sysmodule.h" "$(@D)/python_include/sysmodule.h" && cp "/usr/include/python2.7/methodobject.h" "$(@D)/python_include/methodobject.h" && cp "/usr/include/python2.7/graminit.h" "$(@D)/python_include/graminit.h" && cp "/usr/include/python2.7/cobject.h" "$(@D)/python_include/cobject.h" && cp "/usr/include/python2.7/intrcheck.h" "$(@D)/python_include/intrcheck.h" && cp "/usr/include/python2.7/pyport.h" "$(@D)/python_include/pyport.h" && cp "/usr/include/python2.7/warnings.h" "$(@D)/python_include/warnings.h" && cp "/usr/include/python2.7/osdefs.h" "$(@D)/python_include/osdefs.h" && cp "/usr/include/python2.7/fileobject.h" "$(@D)/python_include/fileobject.h" && cp "/usr/include/python2.7/stringobject.h" "$(@D)/python_include/stringobject.h" && cp "/usr/include/python2.7/timefuncs.h" "$(@D)/python_include/timefuncs.h" && cp "/usr/include/python2.7/traceback.h" "$(@D)/python_include/traceback.h" && cp "/usr/include/python2.7/ceval.h" "$(@D)/python_include/ceval.h" && cp "/usr/include/python2.7/bytes_methods.h" "$(@D)/python_include/bytes_methods.h" && cp "/usr/include/python2.7/pyconfig.h" "$(@D)/python_include/pyconfig.h" && cp "/usr/include/python2.7/Python.h" "$(@D)/python_include/Python.h" && cp "/usr/include/python2.7/moduleobject.h" "$(@D)/python_include/moduleobject.h" && cp "/usr/include/python2.7/pystate.h" "$(@D)/python_include/pystate.h" && cp "/usr/include/python2.7/descrobject.h" "$(@D)/python_include/descrobject.h" && cp "/usr/include/python2.7/ucnhash.h" "$(@D)/python_include/ucnhash.h" && cp "/usr/include/python2.7/pygetopt.h" "$(@D)/python_include/pygetopt.h" && cp "/usr/include/python2.7/pymem.h" "$(@D)/python_include/pymem.h" && cp "/usr/include/python2.7/complexobject.h" "$(@D)/python_include/complexobject.h" && cp "/usr/include/python2.7/structseq.h" "$(@D)/python_include/structseq.h" && cp "/usr/include/python2.7/datetime.h" "$(@D)/python_include/datetime.h" && cp "/usr/include/python2.7/pythonrun.h" "$(@D)/python_include/pythonrun.h" && cp "/usr/include/python2.7/numpy/oldnumeric.h" "$(@D)/python_include/numpy/oldnumeric.h" && cp "/usr/include/python2.7/numpy/npy_1_7_deprecated_api.h" "$(@D)/python_include/numpy/npy_1_7_deprecated_api.h" && cp "/usr/include/python2.7/numpy/ufunc_api.txt" "$(@D)/python_include/numpy/ufunc_api.txt" && cp "/usr/include/python2.7/numpy/multiarray_api.txt" "$(@D)/python_include/numpy/multiarray_api.txt" && cp "/usr/include/python2.7/numpy/halffloat.h" "$(@D)/python_include/numpy/halffloat.h" && cp "/usr/include/python2.7/numpy/npy_common.h" "$(@D)/python_include/numpy/npy_common.h" && cp "/usr/include/python2.7/numpy/utils.h" "$(@D)/python_include/numpy/utils.h" && cp "/usr/include/python2.7/numpy/npy_interrupt.h" "$(@D)/python_include/numpy/npy_interrupt.h" && cp "/usr/include/python2.7/numpy/npy_endian.h" "$(@D)/python_include/numpy/npy_endian.h" && cp "/usr/include/python2.7/numpy/__ufunc_api.h" "$(@D)/python_include/numpy/__ufunc_api.h" && cp "/usr/include/python2.7/numpy/_neighborhood_iterator_imp.h" "$(@D)/python_include/numpy/_neighborhood_iterator_imp.h" && cp "/usr/include/python2.7/numpy/ufuncobject.h" "$(@D)/python_include/numpy/ufuncobject.h" && cp "/usr/include/python2.7/numpy/ndarraytypes.h" "$(@D)/python_include/numpy/ndarraytypes.h" && cp "/usr/include/python2.7/numpy/npy_math.h" "$(@D)/python_include/numpy/npy_math.h" && cp "/usr/include/python2.7/numpy/noprefix.h" "$(@D)/python_include/numpy/noprefix.h" && cp "/usr/include/python2.7/numpy/npy_3kcompat.h" "$(@D)/python_include/numpy/npy_3kcompat.h" && cp "/usr/include/python2.7/numpy/arrayscalars.h" "$(@D)/python_include/numpy/arrayscalars.h" && cp "/usr/include/python2.7/numpy/npy_os.h" "$(@D)/python_include/numpy/npy_os.h" && cp "/usr/include/python2.7/numpy/ndarrayobject.h" "$(@D)/python_include/numpy/ndarrayobject.h" && cp "/usr/include/python2.7/numpy/npy_no_deprecated_api.h" "$(@D)/python_include/numpy/npy_no_deprecated_api.h" && cp "/usr/include/python2.7/numpy/arrayobject.h" "$(@D)/python_include/numpy/arrayobject.h" && cp "/usr/include/python2.7/numpy/_numpyconfig.h" "$(@D)/python_include/numpy/_numpyconfig.h" && cp "/usr/include/python2.7/numpy/__multiarray_api.h" "$(@D)/python_include/numpy/__multiarray_api.h" && cp "/usr/include/python2.7/numpy/npy_cpu.h" "$(@D)/python_include/numpy/npy_cpu.h" && cp "/usr/include/python2.7/numpy/old_defines.h" "$(@D)/python_include/numpy/old_defines.h" && cp "/usr/include/python2.7/numpy/numpyconfig.h" "$(@D)/python_include/numpy/numpyconfig.h" && cp "/usr/include/python2.7/pycapsule.h" "$(@D)/python_include/pycapsule.h" && cp "/usr/include/python2.7/setobject.h" "$(@D)/python_include/setobject.h" && cp "/usr/include/python2.7/listobject.h" "$(@D)/python_include/listobject.h" && cp "/usr/include/python2.7/bytesobject.h" "$(@D)/python_include/bytesobject.h" && cp "/usr/include/python2.7/pgen.h" "$(@D)/python_include/pgen.h" && cp "/usr/include/python2.7/patchlevel.h" "$(@D)/python_include/patchlevel.h" && cp "/usr/include/python2.7/opcode.h" "$(@D)/python_include/opcode.h" && cp "/usr/include/python2.7/parsetok.h" "$(@D)/python_include/parsetok.h" && cp "/usr/include/python2.7/marshal.h" "$(@D)/python_include/marshal.h" && cp "/usr/include/python2.7/token.h" "$(@D)/python_include/token.h" && cp "/usr/include/python2.7/iterobject.h" "$(@D)/python_include/iterobject.h" && cp "/usr/include/python2.7/abstract.h" "$(@D)/python_include/abstract.h" && cp "/usr/include/python2.7/py_curses.h" "$(@D)/python_include/py_curses.h" && cp "/usr/include/python2.7/metagrammar.h" "$(@D)/python_include/metagrammar.h" && cp "/usr/include/python2.7/bufferobject.h" "$(@D)/python_include/bufferobject.h" && cp "/usr/include/python2.7/Python-ast.h" "$(@D)/python_include/Python-ast.h" """, +) + +genrule( + name = "numpy_include", + outs = [ + "numpy_include/numpy/oldnumeric.h", + "numpy_include/numpy/npy_1_7_deprecated_api.h", + "numpy_include/numpy/ufunc_api.txt", + "numpy_include/numpy/multiarray_api.txt", + "numpy_include/numpy/halffloat.h", + "numpy_include/numpy/npy_common.h", + "numpy_include/numpy/utils.h", + "numpy_include/numpy/npy_interrupt.h", + "numpy_include/numpy/npy_endian.h", + "numpy_include/numpy/__ufunc_api.h", + "numpy_include/numpy/_neighborhood_iterator_imp.h", + "numpy_include/numpy/ufuncobject.h", + "numpy_include/numpy/ndarraytypes.h", + "numpy_include/numpy/npy_math.h", + "numpy_include/numpy/noprefix.h", + "numpy_include/numpy/npy_3kcompat.h", + "numpy_include/numpy/arrayscalars.h", + "numpy_include/numpy/npy_os.h", + "numpy_include/numpy/ndarrayobject.h", + "numpy_include/numpy/npy_no_deprecated_api.h", + "numpy_include/numpy/arrayobject.h", + "numpy_include/numpy/_numpyconfig.h", + "numpy_include/numpy/__multiarray_api.h", + "numpy_include/numpy/npy_cpu.h", + "numpy_include/numpy/old_defines.h", + "numpy_include/numpy/numpyconfig.h", + ], + cmd = """ +cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/oldnumeric.h" "$(@D)/numpy_include/numpy/oldnumeric.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_1_7_deprecated_api.h" "$(@D)/numpy_include/numpy/npy_1_7_deprecated_api.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/ufunc_api.txt" "$(@D)/numpy_include/numpy/ufunc_api.txt" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/multiarray_api.txt" "$(@D)/numpy_include/numpy/multiarray_api.txt" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/halffloat.h" "$(@D)/numpy_include/numpy/halffloat.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_common.h" "$(@D)/numpy_include/numpy/npy_common.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/utils.h" "$(@D)/numpy_include/numpy/utils.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_interrupt.h" "$(@D)/numpy_include/numpy/npy_interrupt.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_endian.h" "$(@D)/numpy_include/numpy/npy_endian.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/__ufunc_api.h" "$(@D)/numpy_include/numpy/__ufunc_api.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/_neighborhood_iterator_imp.h" "$(@D)/numpy_include/numpy/_neighborhood_iterator_imp.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/ufuncobject.h" "$(@D)/numpy_include/numpy/ufuncobject.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/ndarraytypes.h" "$(@D)/numpy_include/numpy/ndarraytypes.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_math.h" "$(@D)/numpy_include/numpy/npy_math.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/noprefix.h" "$(@D)/numpy_include/numpy/noprefix.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_3kcompat.h" "$(@D)/numpy_include/numpy/npy_3kcompat.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/arrayscalars.h" "$(@D)/numpy_include/numpy/arrayscalars.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_os.h" "$(@D)/numpy_include/numpy/npy_os.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/ndarrayobject.h" "$(@D)/numpy_include/numpy/ndarrayobject.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_no_deprecated_api.h" "$(@D)/numpy_include/numpy/npy_no_deprecated_api.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/arrayobject.h" "$(@D)/numpy_include/numpy/arrayobject.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/_numpyconfig.h" "$(@D)/numpy_include/numpy/_numpyconfig.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/__multiarray_api.h" "$(@D)/numpy_include/numpy/__multiarray_api.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/npy_cpu.h" "$(@D)/numpy_include/numpy/npy_cpu.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/old_defines.h" "$(@D)/numpy_include/numpy/old_defines.h" && cp "/usr/lib/python2.7/dist-packages/numpy/core/include/numpy/numpyconfig.h" "$(@D)/numpy_include/numpy/numpyconfig.h" """, +) diff --git a/third_party/toolchains/gpus/crosstool/BUILD b/third_party/toolchains/gpus/crosstool/BUILD new file mode 100644 index 00000000000..a8c6b0f0291 --- /dev/null +++ b/third_party/toolchains/gpus/crosstool/BUILD @@ -0,0 +1,52 @@ +# A build file to configure cc toolchain for GPU build used with Bazel remote +# execution service +# DO NOT EDIT: automatically generated BUILD file + +licenses(["restricted"]) + +package(default_visibility = ["//visibility:public"]) + +cc_toolchain_suite( + name = "toolchain", + toolchains = { + "local|compiler": ":cc-compiler-local", + "darwin|compiler": ":cc-compiler-darwin", + }, +) + +cc_toolchain( + name = "cc-compiler-local", + all_files = ":empty", + compiler_files = ":empty", + cpu = "local", + dwp_files = ":empty", + dynamic_runtime_libs = [":empty"], + linker_files = ":empty", + objcopy_files = ":empty", + static_runtime_libs = [":empty"], + strip_files = ":empty", + # To support linker flags that need to go to the start of command line + # we need the toolchain to support parameter files. Parameter files are + # last on the command line and contain all shared libraries to link, so all + # regular options will be left of them. + supports_param_files = 1, +) + +cc_toolchain( + name = "cc-compiler-darwin", + all_files = ":empty", + compiler_files = ":empty", + cpu = "darwin", + dwp_files = ":empty", + dynamic_runtime_libs = [":empty"], + linker_files = ":empty", + objcopy_files = ":empty", + static_runtime_libs = [":empty"], + strip_files = ":empty", + supports_param_files = 0, +) + +filegroup( + name = "empty", + srcs = [], +) diff --git a/third_party/toolchains/gpus/crosstool/CROSSTOOL b/third_party/toolchains/gpus/crosstool/CROSSTOOL new file mode 100644 index 00000000000..224b8912f6d --- /dev/null +++ b/third_party/toolchains/gpus/crosstool/CROSSTOOL @@ -0,0 +1,302 @@ +# A crosstool configuration for GPU build used with Bazel remote +# execution service +# DO NOT EDIT: automatically generated file + +major_version: "local" +minor_version: "" +default_target_cpu: "same_as_host" + +default_toolchain { + cpu: "k8" + toolchain_identifier: "local_linux" +} +default_toolchain { + cpu: "piii" + toolchain_identifier: "local_linux" +} +default_toolchain { + cpu: "arm" + toolchain_identifier: "local_linux" +} +default_toolchain { + cpu: "darwin" + toolchain_identifier: "local_darwin" +} +default_toolchain { + cpu: "ppc" + toolchain_identifier: "local_linux" +} + +toolchain { + abi_version: "local" + abi_libc_version: "local" + compiler: "compiler" + host_system_name: "local" + needsPic: true + target_libc: "local" + target_cpu: "local" + target_system_name: "local" + toolchain_identifier: "local_linux" + + feature { + name: "c++11" + flag_set { + action: "c++-compile" + flag_group { + flag: "-std=c++11" + } + } + } + + feature { + name: "stdlib" + flag_set { + action: "c++-link-executable" + action: "c++-link-dynamic-library" + flag_group { + flag: "-lstdc++" + } + } + } + + feature { + name: "determinism" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + # Make C++ compilation deterministic. Use linkstamping instead of these + # compiler symbols. + flag: "-Wno-builtin-macro-redefined" + flag: "-D__DATE__=\"redacted\"" + flag: "-D__TIMESTAMP__=\"redacted\"" + flag: "-D__TIME__=\"redacted\"" + } + } + } + + feature { + name: "alwayslink" + flag_set { + action: "c++-link-dynamic-library" + action: "c++-link-executable" + flag_group { + flag: "-Wl,-no-as-needed" + } + } + } + + # This feature will be enabled for builds that support pic by bazel. + feature { + name: "pic" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + expand_if_all_available: "pic" + flag: "-fPIC" + } + flag_group { + expand_if_none_available: "pic" + flag: "-fPIE" + } + } + } + + # Security hardening on by default. + feature { + name: "hardening" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + # Conservative choice; -D_FORTIFY_SOURCE=2 may be unsafe in some cases. + # We need to undef it before redefining it as some distributions now + # have it enabled by default. + flag: "-U_FORTIFY_SOURCE" + flag: "-D_FORTIFY_SOURCE=1" + flag: "-fstack-protector" + } + } + flag_set { + action: "c++-link-dynamic-library" + flag_group { + flag: "-Wl,-z,relro,-z,now" + } + } + flag_set { + action: "c++-link-executable" + flag_group { + flag: "-pie" + flag: "-Wl,-z,relro,-z,now" + } + } + } + + feature { + name: "warnings" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + # All warnings are enabled. Maybe enable -Werror as well? + flag: "-Wall" + # TODO(ngiraldo): Some parts of the codebase set -Werror and hit this + # warning, so switch it off for now. + flag: "-Wno-invalid-partial-specialization" + } + } + } + + # Keep stack frames for debugging, even in opt mode. + feature { + name: "frame-pointer" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + flag: "-fno-omit-frame-pointer" + } + } + } + + feature { + name: "build-id" + flag_set { + action: "c++-link-executable" + action: "c++-link-dynamic-library" + flag_group { + # Stamp the binary with a unique identifier. + flag: "-Wl,--build-id=md5" + flag: "-Wl,--hash-style=gnu" + } + } + } + + feature { + name: "no-canonical-prefixes" + flag_set { + action: "c-compile" + action: "c++-compile" + action: "c++-link-executable" + action: "c++-link-dynamic-library" + flag_group { + flag:"-no-canonical-prefixes" + } + } + } + + feature { + name: "disable-assertions" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + flag: "-DNDEBUG" + } + } + } + + feature { + name: "linker-bin-path" + + flag_set { + action: "c++-link-executable" + action: "c++-link-dynamic-library" + flag_group { + flag: "-B/usr/bin/" + } + } + } + + feature { + name: "common" + implies: "stdlib" + implies: "c++11" + implies: "determinism" + implies: "alwayslink" + implies: "hardening" + implies: "warnings" + implies: "frame-pointer" + implies: "build-id" + implies: "no-canonical-prefixes" + implies: "linker-bin-path" + } + + feature { + name: "opt" + implies: "common" + implies: "disable-assertions" + + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + # No debug symbols. + # Maybe we should enable https://gcc.gnu.org/wiki/DebugFission for opt + # or even generally? However, that can't happen here, as it requires + # special handling in Bazel. + flag: "-g0" + + # Conservative choice for -O + # -O3 can increase binary size and even slow down the resulting binaries. + # Profile first and / or use FDO if you need better performance than this. + flag: "-O2" + + # Removal of unused code and data at link time (can this increase binary size in some cases?). + flag: "-ffunction-sections" + flag: "-fdata-sections" + } + } + flag_set { + action: "c++-link-dynamic-library" + action: "c++-link-executable" + flag_group { + flag: "-Wl,--gc-sections" + } + } + } + + feature { + name: "fastbuild" + implies: "common" + } + + feature { + name: "dbg" + implies: "common" + flag_set { + action: "c-compile" + action: "c++-compile" + flag_group { + flag: "-g" + } + } + } + + # Set clang as a C/C++ compiler. + tool_path { name: "gcc" path: "/usr/local/bin/clang" } + + # Use the default system toolchain for everything else. + tool_path { name: "ar" path: "/usr/bin/ar" } + tool_path { name: "compat-ld" path: "/usr/bin/ld" } + tool_path { name: "cpp" path: "/usr/bin/cpp" } + tool_path { name: "dwp" path: "/usr/bin/dwp" } + tool_path { name: "gcov" path: "/usr/bin/gcov" } + tool_path { name: "ld" path: "/usr/bin/ld" } + tool_path { name: "nm" path: "/usr/bin/nm" } + tool_path { name: "objcopy" path: "/usr/bin/objcopy" } + tool_path { name: "objdump" path: "/usr/bin/objdump" } + tool_path { name: "strip" path: "/usr/bin/strip" } + + # Enabled dynamic linking. + linking_mode_flags { mode: DYNAMIC } + + cxx_builtin_include_directory: "/usr/include/c++/5.4.0" + cxx_builtin_include_directory: "/usr/include/x86_64-linux-gnu/c++/5.4.0" + cxx_builtin_include_directory: "/usr/include/c++/5.4.0/backward" + cxx_builtin_include_directory: "/usr/local/include" + cxx_builtin_include_directory: "/usr/local/lib/clang/5.0.0/include" + cxx_builtin_include_directory: "/usr/include/x86_64-linux-gnu" + cxx_builtin_include_directory: "/usr/include" +} diff --git a/third_party/toolchains/gpus/cuda/BUILD b/third_party/toolchains/gpus/cuda/BUILD new file mode 100644 index 00000000000..9f8f8754ff7 --- /dev/null +++ b/third_party/toolchains/gpus/cuda/BUILD @@ -0,0 +1,1362 @@ +# A build file to configure cuda remote repository used with Bazel remote +# execution service +# DO NOT EDIT: automatically generated BUILD file + +licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like + +package(default_visibility = ["//visibility:public"]) + +config_setting( + name = "using_nvcc", + values = { + "define": "using_cuda_nvcc=true", + }, +) + +config_setting( + name = "using_clang", + values = { + "define": "using_cuda_clang=true", + }, +) + +# Equivalent to using_clang && -c opt. +config_setting( + name = "using_clang_opt", + values = { + "define": "using_cuda_clang=true", + "compilation_mode": "opt", + }, +) + +config_setting( + name = "darwin", + values = {"cpu": "darwin"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "freebsd", + values = {"cpu": "freebsd"}, + visibility = ["//visibility:public"], +) + +cc_library( + name = "cuda_headers", + hdrs = [ + "cuda/cuda_config.h", + ":cuda-include", + ":cudnn-include", + ], + includes = [ + ".", + "cuda/include", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "cudart_static", + srcs = ["cuda/lib/libcudart_static.a"], + includes = [ + ".", + "cuda/include", + ], + linkopts = select({ + ":freebsd": [], + "//conditions:default": ["-ldl"], + }) + [ + "-lpthread", + "-lrt", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "cuda_driver", + srcs = ["cuda/lib/libcuda.so"], + includes = [ + ".", + "cuda/include", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "cudart", + srcs = ["cuda/lib/libcudart.so.8.0"], + data = ["cuda/lib/libcudart.so.8.0"], + includes = [ + ".", + "cuda/include", + ], + linkstatic = 1, + visibility = ["//visibility:public"], +) + +cc_library( + name = "cublas", + srcs = ["cuda/lib/libcublas.so.8.0"], + data = ["cuda/lib/libcublas.so.8.0"], + includes = [ + ".", + "cuda/include", + ], + linkstatic = 1, + visibility = ["//visibility:public"], +) + +cc_library( + name = "cusolver", + srcs = ["cuda/lib/libcusolver.so.8.0"], + data = ["cuda/lib/libcusolver.so.8.0"], + includes = [ + ".", + "cuda/include", + ], + linkopts = ["-lgomp"], + linkstatic = 1, + visibility = ["//visibility:public"], +) + +cc_library( + name = "cudnn", + srcs = ["cuda/lib/libcudnn.so.5"], + data = ["cuda/lib/libcudnn.so.5"], + includes = [ + ".", + "cuda/include", + ], + linkstatic = 1, + visibility = ["//visibility:public"], +) + +cc_library( + name = "cufft", + srcs = ["cuda/lib/libcufft.so.8.0"], + data = ["cuda/lib/libcufft.so.8.0"], + includes = [ + ".", + "cuda/include", + ], + linkstatic = 1, + visibility = ["//visibility:public"], +) + +cc_library( + name = "curand", + srcs = ["cuda/lib/libcurand.so.8.0"], + data = ["cuda/lib/libcurand.so.8.0"], + includes = [ + ".", + "cuda/include", + ], + linkstatic = 1, + visibility = ["//visibility:public"], +) + +cc_library( + name = "cuda", + visibility = ["//visibility:public"], + deps = [ + ":cublas", + ":cuda_headers", + ":cudart", + ":cudnn", + ":cufft", + ":curand", + ], +) + +cc_library( + name = "cupti_headers", + hdrs = [ + "cuda/cuda_config.h", + ":cuda-extras", + ], + includes = [ + ".", + "cuda/extras/CUPTI/include/", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "cupti_dsos", + data = ["cuda/lib/libcupti.so.8.0"], + includes = [ + ".", + "cuda/extras/CUPTI/include/", + ], + visibility = ["//visibility:public"], +) + +cc_library( + name = "libdevice_root", + data = [":cuda-nvvm"], + includes = [ + ".", + "cuda/extras/CUPTI/include/", + ], + visibility = ["//visibility:public"], +) + +genrule( + name = "cuda-include", + outs = [ + "cuda/include/math_functions.hpp", + "cuda/include/cufft.h", + "cuda/include/nvgraph.h", + "cuda/include/curand_normal.h", + "cuda/include/curand_uniform.h", + "cuda/include/nppi_data_exchange_and_initialization.h", + "cuda/include/cuda_gl_interop.h", + "cuda/include/nppi_compression_functions.h", + "cuda/include/npp.h", + "cuda/include/cuda.h", + "cuda/include/nppi_statistics_functions.h", + "cuda/include/vector_functions.hpp", + "cuda/include/sm_32_intrinsics.hpp", + "cuda/include/sm_32_intrinsics.h", + "cuda/include/curand_discrete.h", + "cuda/include/cuda_runtime.h", + "cuda/include/cufftXt.h", + "cuda/include/sm_61_intrinsics.h", + "cuda/include/texture_fetch_functions.h", + "cuda/include/curand_mrg32k3a.h", + "cuda/include/host_defines.h", + "cuda/include/common_functions.h", + "cuda/include/nppi_support_functions.h", + "cuda/include/nppi_linear_transforms.h", + "cuda/include/device_double_functions.hpp", + "cuda/include/math_constants.h", + "cuda/include/nvToolsExtSync.h", + "cuda/include/npps_initialization.h", + "cuda/include/cusolverSp_LOWLEVEL_PREVIEW.h", + "cuda/include/texture_indirect_functions.hpp", + "cuda/include/cudaProfiler.h", + "cuda/include/npps_filtering_functions.h", + "cuda/include/cusparse_v2.h", + "cuda/include/nppi.h", + "cuda/include/surface_indirect_functions.h", + "cuda/include/sm_30_intrinsics.h", + "cuda/include/device_double_functions.h", + "cuda/include/sm_35_intrinsics.h", + "cuda/include/cusolverSp.h", + "cuda/include/library_types.h", + "cuda/include/surface_indirect_functions.hpp", + "cuda/include/cudalibxt.h", + "cuda/include/channel_descriptor.h", + "cuda/include/device_functions_decls.h", + "cuda/include/curand_kernel.h", + "cuda/include/curand_mtgp32_host.h", + "cuda/include/nvToolsExtCuda.h", + "cuda/include/nvToolsExt.h", + "cuda/include/cuComplex.h", + "cuda/include/sm_32_atomic_functions.h", + "cuda/include/texture_indirect_functions.h", + "cuda/include/sm_32_atomic_functions.hpp", + "cuda/include/sm_20_intrinsics.hpp", + "cuda/include/device_launch_parameters.h", + "cuda/include/curand_mtgp32.h", + "cuda/include/texture_fetch_functions.hpp", + "cuda/include/cuda_occupancy.h", + "cuda/include/CL/opencl.h", + "cuda/include/CL/cl_platform.h", + "cuda/include/CL/cl_egl.h", + "cuda/include/CL/cl_gl.h", + "cuda/include/CL/cl.h", + "cuda/include/CL/cl_gl_ext.h", + "cuda/include/CL/cl_ext.h", + "cuda/include/CL/cl.hpp", + "cuda/include/host_config.h", + "cuda/include/cuda_surface_types.h", + "cuda/include/math_functions.h", + "cuda/include/nvToolsExtMeta.h", + "cuda/include/sm_20_atomic_functions.hpp", + "cuda/include/device_functions.h", + "cuda/include/device_types.h", + "cuda/include/npps_conversion_functions.h", + "cuda/include/curand_precalc.h", + "cuda/include/cusolverRf.h", + "cuda/include/sm_60_atomic_functions.hpp", + "cuda/include/cuviddec.h", + "cuda/include/curand_discrete2.h", + "cuda/include/device_functions.hpp", + "cuda/include/thrust/transform_scan.h", + "cuda/include/thrust/system_error.h", + "cuda/include/thrust/device_malloc.h", + "cuda/include/thrust/partition.h", + "cuda/include/thrust/unique.h", + "cuda/include/thrust/device_delete.h", + "cuda/include/thrust/execution_policy.h", + "cuda/include/thrust/adjacent_difference.h", + "cuda/include/thrust/sequence.h", + "cuda/include/thrust/merge.h", + "cuda/include/thrust/device_new.h", + "cuda/include/thrust/transform_reduce.h", + "cuda/include/thrust/device_vector.h", + "cuda/include/thrust/gather.h", + "cuda/include/thrust/sort.h", + "cuda/include/thrust/scan.h", + "cuda/include/thrust/detail/temporary_array.h", + "cuda/include/thrust/detail/util/align.h", + "cuda/include/thrust/detail/util/blocking.h", + "cuda/include/thrust/detail/transform.inl", + "cuda/include/thrust/detail/device_vector.inl", + "cuda/include/thrust/detail/binary_search.inl", + "cuda/include/thrust/detail/overlapped_copy.h", + "cuda/include/thrust/detail/vector_base.inl", + "cuda/include/thrust/detail/device_reference.inl", + "cuda/include/thrust/detail/functional/actor.h", + "cuda/include/thrust/detail/functional/value.h", + "cuda/include/thrust/detail/functional/operators.h", + "cuda/include/thrust/detail/functional/operators/logical_operators.h", + "cuda/include/thrust/detail/functional/operators/relational_operators.h", + "cuda/include/thrust/detail/functional/operators/assignment_operator.h", + "cuda/include/thrust/detail/functional/operators/bitwise_operators.h", + "cuda/include/thrust/detail/functional/operators/operator_adaptors.h", + "cuda/include/thrust/detail/functional/operators/arithmetic_operators.h", + "cuda/include/thrust/detail/functional/operators/compound_assignment_operators.h", + "cuda/include/thrust/detail/functional/argument.h", + "cuda/include/thrust/detail/functional/placeholder.h", + "cuda/include/thrust/detail/functional/actor.inl", + "cuda/include/thrust/detail/functional/composite.h", + "cuda/include/thrust/detail/static_map.h", + "cuda/include/thrust/detail/type_traits/has_nested_type.h", + "cuda/include/thrust/detail/type_traits/is_call_possible.h", + "cuda/include/thrust/detail/type_traits/function_traits.h", + "cuda/include/thrust/detail/type_traits/pointer_traits.h", + "cuda/include/thrust/detail/type_traits/has_member_function.h", + "cuda/include/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h", + "cuda/include/thrust/detail/type_traits/minimum_type.h", + "cuda/include/thrust/detail/type_traits/has_trivial_assign.h", + "cuda/include/thrust/detail/type_traits/is_metafunction_defined.h", + "cuda/include/thrust/detail/type_traits/iterator/is_discard_iterator.h", + "cuda/include/thrust/detail/type_traits/iterator/is_output_iterator.h", + "cuda/include/thrust/detail/type_traits/result_of_adaptable_function.h", + "cuda/include/thrust/detail/reference.h", + "cuda/include/thrust/detail/inner_product.inl", + "cuda/include/thrust/detail/use_default.h", + "cuda/include/thrust/detail/sequence.inl", + "cuda/include/thrust/detail/sort.inl", + "cuda/include/thrust/detail/equal.inl", + "cuda/include/thrust/detail/execution_policy.h", + "cuda/include/thrust/detail/integer_traits.h", + "cuda/include/thrust/detail/type_traits.h", + "cuda/include/thrust/detail/reverse.inl", + "cuda/include/thrust/detail/tabulate.inl", + "cuda/include/thrust/detail/unique.inl", + "cuda/include/thrust/detail/scatter.inl", + "cuda/include/thrust/detail/set_operations.inl", + "cuda/include/thrust/detail/device_malloc.inl", + "cuda/include/thrust/detail/copy_if.inl", + "cuda/include/thrust/detail/fill.inl", + "cuda/include/thrust/detail/temporary_array.inl", + "cuda/include/thrust/detail/transform_scan.inl", + "cuda/include/thrust/detail/minmax.h", + "cuda/include/thrust/detail/swap.inl", + "cuda/include/thrust/detail/pointer.inl", + "cuda/include/thrust/detail/transform_reduce.inl", + "cuda/include/thrust/detail/config.h", + "cuda/include/thrust/detail/distance.inl", + "cuda/include/thrust/detail/pair.inl", + "cuda/include/thrust/detail/allocator/temporary_allocator.h", + "cuda/include/thrust/detail/allocator/tagged_allocator.h", + "cuda/include/thrust/detail/allocator/destroy_range.inl", + "cuda/include/thrust/detail/allocator/destroy_range.h", + "cuda/include/thrust/detail/allocator/no_throw_allocator.h", + "cuda/include/thrust/detail/allocator/default_construct_range.inl", + "cuda/include/thrust/detail/allocator/fill_construct_range.inl", + "cuda/include/thrust/detail/allocator/tagged_allocator.inl", + "cuda/include/thrust/detail/allocator/malloc_allocator.h", + "cuda/include/thrust/detail/allocator/allocator_traits.h", + "cuda/include/thrust/detail/allocator/copy_construct_range.h", + "cuda/include/thrust/detail/allocator/allocator_traits.inl", + "cuda/include/thrust/detail/allocator/default_construct_range.h", + "cuda/include/thrust/detail/allocator/copy_construct_range.inl", + "cuda/include/thrust/detail/allocator/malloc_allocator.inl", + "cuda/include/thrust/detail/allocator/temporary_allocator.inl", + "cuda/include/thrust/detail/allocator/fill_construct_range.h", + "cuda/include/thrust/detail/temporary_buffer.h", + "cuda/include/thrust/detail/reduce.inl", + "cuda/include/thrust/detail/device_new.inl", + "cuda/include/thrust/detail/pointer.h", + "cuda/include/thrust/detail/for_each.inl", + "cuda/include/thrust/detail/generate.inl", + "cuda/include/thrust/detail/dispatch/is_trivial_copy.h", + "cuda/include/thrust/detail/adjacent_difference.inl", + "cuda/include/thrust/detail/tuple_meta_transform.h", + "cuda/include/thrust/detail/functional.inl", + "cuda/include/thrust/detail/remove.inl", + "cuda/include/thrust/detail/tuple_transform.h", + "cuda/include/thrust/detail/merge.inl", + "cuda/include/thrust/detail/extrema.inl", + "cuda/include/thrust/detail/trivial_sequence.h", + "cuda/include/thrust/detail/vector_base.h", + "cuda/include/thrust/detail/count.inl", + "cuda/include/thrust/detail/uninitialized_copy.inl", + "cuda/include/thrust/detail/function.h", + "cuda/include/thrust/detail/swap_ranges.inl", + "cuda/include/thrust/detail/device_delete.inl", + "cuda/include/thrust/detail/static_assert.h", + "cuda/include/thrust/detail/logical.inl", + "cuda/include/thrust/detail/seq.h", + "cuda/include/thrust/detail/mpl/math.h", + "cuda/include/thrust/detail/mismatch.inl", + "cuda/include/thrust/detail/internal_functional.h", + "cuda/include/thrust/detail/get_iterator_value.h", + "cuda/include/thrust/detail/copy.inl", + "cuda/include/thrust/detail/copy.h", + "cuda/include/thrust/detail/complex/catrigf.h", + "cuda/include/thrust/detail/complex/cpowf.h", + "cuda/include/thrust/detail/complex/csqrtf.h", + "cuda/include/thrust/detail/complex/ccoshf.h", + "cuda/include/thrust/detail/complex/csinhf.h", + "cuda/include/thrust/detail/complex/clogf.h", + "cuda/include/thrust/detail/complex/ccosh.h", + "cuda/include/thrust/detail/complex/arithmetic.h", + "cuda/include/thrust/detail/complex/csqrt.h", + "cuda/include/thrust/detail/complex/cpow.h", + "cuda/include/thrust/detail/complex/complex.inl", + "cuda/include/thrust/detail/complex/math_private.h", + "cuda/include/thrust/detail/complex/c99math.h", + "cuda/include/thrust/detail/complex/cproj.h", + "cuda/include/thrust/detail/complex/catrig.h", + "cuda/include/thrust/detail/complex/ctanhf.h", + "cuda/include/thrust/detail/complex/cexpf.h", + "cuda/include/thrust/detail/complex/csinh.h", + "cuda/include/thrust/detail/complex/stream.h", + "cuda/include/thrust/detail/complex/ctanh.h", + "cuda/include/thrust/detail/complex/cexp.h", + "cuda/include/thrust/detail/complex/clog.h", + "cuda/include/thrust/detail/range/head_flags.h", + "cuda/include/thrust/detail/range/tail_flags.h", + "cuda/include/thrust/detail/execute_with_allocator.h", + "cuda/include/thrust/detail/integer_math.h", + "cuda/include/thrust/detail/swap.h", + "cuda/include/thrust/detail/uninitialized_fill.inl", + "cuda/include/thrust/detail/scan.inl", + "cuda/include/thrust/detail/gather.inl", + "cuda/include/thrust/detail/reference_forward_declaration.h", + "cuda/include/thrust/detail/numeric_traits.h", + "cuda/include/thrust/detail/reference.inl", + "cuda/include/thrust/detail/cstdint.h", + "cuda/include/thrust/detail/device_free.inl", + "cuda/include/thrust/detail/copy_if.h", + "cuda/include/thrust/detail/partition.inl", + "cuda/include/thrust/detail/find.inl", + "cuda/include/thrust/detail/config/forceinline.h", + "cuda/include/thrust/detail/config/debug.h", + "cuda/include/thrust/detail/config/config.h", + "cuda/include/thrust/detail/config/host_device.h", + "cuda/include/thrust/detail/config/host_system.h", + "cuda/include/thrust/detail/config/compiler.h", + "cuda/include/thrust/detail/config/device_system.h", + "cuda/include/thrust/detail/config/compiler_fence.h", + "cuda/include/thrust/detail/config/exec_check_disable.h", + "cuda/include/thrust/detail/config/simple_defines.h", + "cuda/include/thrust/detail/config/global_workarounds.h", + "cuda/include/thrust/detail/replace.inl", + "cuda/include/thrust/detail/device_ptr.inl", + "cuda/include/thrust/detail/tuple.inl", + "cuda/include/thrust/detail/malloc_and_free.h", + "cuda/include/thrust/detail/host_vector.inl", + "cuda/include/thrust/detail/raw_pointer_cast.h", + "cuda/include/thrust/detail/advance.inl", + "cuda/include/thrust/detail/contiguous_storage.h", + "cuda/include/thrust/detail/raw_reference_cast.h", + "cuda/include/thrust/detail/contiguous_storage.inl", + "cuda/include/thrust/reverse.h", + "cuda/include/thrust/device_malloc_allocator.h", + "cuda/include/thrust/scatter.h", + "cuda/include/thrust/pair.h", + "cuda/include/thrust/advance.h", + "cuda/include/thrust/find.h", + "cuda/include/thrust/device_ptr.h", + "cuda/include/thrust/generate.h", + "cuda/include/thrust/uninitialized_fill.h", + "cuda/include/thrust/system/system_error.h", + "cuda/include/thrust/system/detail/bad_alloc.h", + "cuda/include/thrust/system/detail/adl/transform_scan.h", + "cuda/include/thrust/system/detail/adl/unique_by_key.h", + "cuda/include/thrust/system/detail/adl/partition.h", + "cuda/include/thrust/system/detail/adl/unique.h", + "cuda/include/thrust/system/detail/adl/adjacent_difference.h", + "cuda/include/thrust/system/detail/adl/sequence.h", + "cuda/include/thrust/system/detail/adl/merge.h", + "cuda/include/thrust/system/detail/adl/transform_reduce.h", + "cuda/include/thrust/system/detail/adl/gather.h", + "cuda/include/thrust/system/detail/adl/sort.h", + "cuda/include/thrust/system/detail/adl/scan.h", + "cuda/include/thrust/system/detail/adl/temporary_buffer.h", + "cuda/include/thrust/system/detail/adl/scan_by_key.h", + "cuda/include/thrust/system/detail/adl/reverse.h", + "cuda/include/thrust/system/detail/adl/assign_value.h", + "cuda/include/thrust/system/detail/adl/scatter.h", + "cuda/include/thrust/system/detail/adl/find.h", + "cuda/include/thrust/system/detail/adl/generate.h", + "cuda/include/thrust/system/detail/adl/uninitialized_fill.h", + "cuda/include/thrust/system/detail/adl/remove.h", + "cuda/include/thrust/system/detail/adl/tabulate.h", + "cuda/include/thrust/system/detail/adl/for_each.h", + "cuda/include/thrust/system/detail/adl/reduce_by_key.h", + "cuda/include/thrust/system/detail/adl/reduce.h", + "cuda/include/thrust/system/detail/adl/equal.h", + "cuda/include/thrust/system/detail/adl/copy.h", + "cuda/include/thrust/system/detail/adl/swap_ranges.h", + "cuda/include/thrust/system/detail/adl/uninitialized_copy.h", + "cuda/include/thrust/system/detail/adl/binary_search.h", + "cuda/include/thrust/system/detail/adl/set_operations.h", + "cuda/include/thrust/system/detail/adl/mismatch.h", + "cuda/include/thrust/system/detail/adl/extrema.h", + "cuda/include/thrust/system/detail/adl/count.h", + "cuda/include/thrust/system/detail/adl/replace.h", + "cuda/include/thrust/system/detail/adl/get_value.h", + "cuda/include/thrust/system/detail/adl/inner_product.h", + "cuda/include/thrust/system/detail/adl/copy_if.h", + "cuda/include/thrust/system/detail/adl/logical.h", + "cuda/include/thrust/system/detail/adl/iter_swap.h", + "cuda/include/thrust/system/detail/adl/malloc_and_free.h", + "cuda/include/thrust/system/detail/adl/fill.h", + "cuda/include/thrust/system/detail/adl/transform.h", + "cuda/include/thrust/system/detail/errno.h", + "cuda/include/thrust/system/detail/error_category.inl", + "cuda/include/thrust/system/detail/sequential/transform_scan.h", + "cuda/include/thrust/system/detail/sequential/unique_by_key.h", + "cuda/include/thrust/system/detail/sequential/stable_primitive_sort.h", + "cuda/include/thrust/system/detail/sequential/stable_primitive_sort.inl", + "cuda/include/thrust/system/detail/sequential/stable_merge_sort.h", + "cuda/include/thrust/system/detail/sequential/sort.inl", + "cuda/include/thrust/system/detail/sequential/partition.h", + "cuda/include/thrust/system/detail/sequential/unique.h", + "cuda/include/thrust/system/detail/sequential/execution_policy.h", + "cuda/include/thrust/system/detail/sequential/adjacent_difference.h", + "cuda/include/thrust/system/detail/sequential/sequence.h", + "cuda/include/thrust/system/detail/sequential/merge.h", + "cuda/include/thrust/system/detail/sequential/transform_reduce.h", + "cuda/include/thrust/system/detail/sequential/gather.h", + "cuda/include/thrust/system/detail/sequential/sort.h", + "cuda/include/thrust/system/detail/sequential/copy_backward.h", + "cuda/include/thrust/system/detail/sequential/stable_radix_sort.inl", + "cuda/include/thrust/system/detail/sequential/scan.h", + "cuda/include/thrust/system/detail/sequential/temporary_buffer.h", + "cuda/include/thrust/system/detail/sequential/scan_by_key.h", + "cuda/include/thrust/system/detail/sequential/reverse.h", + "cuda/include/thrust/system/detail/sequential/assign_value.h", + "cuda/include/thrust/system/detail/sequential/scatter.h", + "cuda/include/thrust/system/detail/sequential/find.h", + "cuda/include/thrust/system/detail/sequential/stable_merge_sort.inl", + "cuda/include/thrust/system/detail/sequential/merge.inl", + "cuda/include/thrust/system/detail/sequential/generate.h", + "cuda/include/thrust/system/detail/sequential/uninitialized_fill.h", + "cuda/include/thrust/system/detail/sequential/general_copy.h", + "cuda/include/thrust/system/detail/sequential/insertion_sort.h", + "cuda/include/thrust/system/detail/sequential/remove.h", + "cuda/include/thrust/system/detail/sequential/tabulate.h", + "cuda/include/thrust/system/detail/sequential/for_each.h", + "cuda/include/thrust/system/detail/sequential/reduce_by_key.h", + "cuda/include/thrust/system/detail/sequential/reduce.h", + "cuda/include/thrust/system/detail/sequential/equal.h", + "cuda/include/thrust/system/detail/sequential/stable_radix_sort.h", + "cuda/include/thrust/system/detail/sequential/copy.inl", + "cuda/include/thrust/system/detail/sequential/copy.h", + "cuda/include/thrust/system/detail/sequential/swap_ranges.h", + "cuda/include/thrust/system/detail/sequential/uninitialized_copy.h", + "cuda/include/thrust/system/detail/sequential/binary_search.h", + "cuda/include/thrust/system/detail/sequential/set_operations.h", + "cuda/include/thrust/system/detail/sequential/mismatch.h", + "cuda/include/thrust/system/detail/sequential/extrema.h", + "cuda/include/thrust/system/detail/sequential/count.h", + "cuda/include/thrust/system/detail/sequential/trivial_copy.h", + "cuda/include/thrust/system/detail/sequential/replace.h", + "cuda/include/thrust/system/detail/sequential/get_value.h", + "cuda/include/thrust/system/detail/sequential/inner_product.h", + "cuda/include/thrust/system/detail/sequential/copy_if.h", + "cuda/include/thrust/system/detail/sequential/logical.h", + "cuda/include/thrust/system/detail/sequential/iter_swap.h", + "cuda/include/thrust/system/detail/sequential/malloc_and_free.h", + "cuda/include/thrust/system/detail/sequential/fill.h", + "cuda/include/thrust/system/detail/sequential/transform.h", + "cuda/include/thrust/system/detail/error_condition.inl", + "cuda/include/thrust/system/detail/internal/decompose.h", + "cuda/include/thrust/system/detail/error_code.inl", + "cuda/include/thrust/system/detail/generic/transform_scan.h", + "cuda/include/thrust/system/detail/generic/memory.inl", + "cuda/include/thrust/system/detail/generic/transform.inl", + "cuda/include/thrust/system/detail/generic/binary_search.inl", + "cuda/include/thrust/system/detail/generic/scan_by_key.inl", + "cuda/include/thrust/system/detail/generic/unique_by_key.h", + "cuda/include/thrust/system/detail/generic/inner_product.inl", + "cuda/include/thrust/system/detail/generic/select_system.h", + "cuda/include/thrust/system/detail/generic/sequence.inl", + "cuda/include/thrust/system/detail/generic/sort.inl", + "cuda/include/thrust/system/detail/generic/equal.inl", + "cuda/include/thrust/system/detail/generic/partition.h", + "cuda/include/thrust/system/detail/generic/unique.h", + "cuda/include/thrust/system/detail/generic/adjacent_difference.h", + "cuda/include/thrust/system/detail/generic/tag.h", + "cuda/include/thrust/system/detail/generic/unique_by_key.inl", + "cuda/include/thrust/system/detail/generic/sequence.h", + "cuda/include/thrust/system/detail/generic/type_traits.h", + "cuda/include/thrust/system/detail/generic/merge.h", + "cuda/include/thrust/system/detail/generic/reverse.inl", + "cuda/include/thrust/system/detail/generic/tabulate.inl", + "cuda/include/thrust/system/detail/generic/unique.inl", + "cuda/include/thrust/system/detail/generic/scatter.inl", + "cuda/include/thrust/system/detail/generic/set_operations.inl", + "cuda/include/thrust/system/detail/generic/copy_if.inl", + "cuda/include/thrust/system/detail/generic/transform_reduce.h", + "cuda/include/thrust/system/detail/generic/transform_scan.inl", + "cuda/include/thrust/system/detail/generic/gather.h", + "cuda/include/thrust/system/detail/generic/reduce_by_key.inl", + "cuda/include/thrust/system/detail/generic/transform_reduce.inl", + "cuda/include/thrust/system/detail/generic/sort.h", + "cuda/include/thrust/system/detail/generic/distance.inl", + "cuda/include/thrust/system/detail/generic/scan.h", + "cuda/include/thrust/system/detail/generic/temporary_buffer.h", + "cuda/include/thrust/system/detail/generic/reduce.inl", + "cuda/include/thrust/system/detail/generic/scan_by_key.h", + "cuda/include/thrust/system/detail/generic/reverse.h", + "cuda/include/thrust/system/detail/generic/temporary_buffer.inl", + "cuda/include/thrust/system/detail/generic/scatter.h", + "cuda/include/thrust/system/detail/generic/generate.inl", + "cuda/include/thrust/system/detail/generic/adjacent_difference.inl", + "cuda/include/thrust/system/detail/generic/remove.inl", + "cuda/include/thrust/system/detail/generic/advance.h", + "cuda/include/thrust/system/detail/generic/find.h", + "cuda/include/thrust/system/detail/generic/merge.inl", + "cuda/include/thrust/system/detail/generic/scalar/binary_search.inl", + "cuda/include/thrust/system/detail/generic/scalar/binary_search.h", + "cuda/include/thrust/system/detail/generic/extrema.inl", + "cuda/include/thrust/system/detail/generic/generate.h", + "cuda/include/thrust/system/detail/generic/uninitialized_fill.h", + "cuda/include/thrust/system/detail/generic/count.inl", + "cuda/include/thrust/system/detail/generic/remove.h", + "cuda/include/thrust/system/detail/generic/uninitialized_copy.inl", + "cuda/include/thrust/system/detail/generic/tabulate.h", + "cuda/include/thrust/system/detail/generic/for_each.h", + "cuda/include/thrust/system/detail/generic/distance.h", + "cuda/include/thrust/system/detail/generic/swap_ranges.inl", + "cuda/include/thrust/system/detail/generic/reduce_by_key.h", + "cuda/include/thrust/system/detail/generic/reduce.h", + "cuda/include/thrust/system/detail/generic/equal.h", + "cuda/include/thrust/system/detail/generic/mismatch.inl", + "cuda/include/thrust/system/detail/generic/copy.inl", + "cuda/include/thrust/system/detail/generic/copy.h", + "cuda/include/thrust/system/detail/generic/swap_ranges.h", + "cuda/include/thrust/system/detail/generic/uninitialized_copy.h", + "cuda/include/thrust/system/detail/generic/binary_search.h", + "cuda/include/thrust/system/detail/generic/set_operations.h", + "cuda/include/thrust/system/detail/generic/uninitialized_fill.inl", + "cuda/include/thrust/system/detail/generic/mismatch.h", + "cuda/include/thrust/system/detail/generic/scan.inl", + "cuda/include/thrust/system/detail/generic/gather.inl", + "cuda/include/thrust/system/detail/generic/extrema.h", + "cuda/include/thrust/system/detail/generic/count.h", + "cuda/include/thrust/system/detail/generic/replace.h", + "cuda/include/thrust/system/detail/generic/inner_product.h", + "cuda/include/thrust/system/detail/generic/copy_if.h", + "cuda/include/thrust/system/detail/generic/logical.h", + "cuda/include/thrust/system/detail/generic/partition.inl", + "cuda/include/thrust/system/detail/generic/memory.h", + "cuda/include/thrust/system/detail/generic/find.inl", + "cuda/include/thrust/system/detail/generic/replace.inl", + "cuda/include/thrust/system/detail/generic/advance.inl", + "cuda/include/thrust/system/detail/generic/fill.h", + "cuda/include/thrust/system/detail/generic/transform.h", + "cuda/include/thrust/system/detail/system_error.inl", + "cuda/include/thrust/system/omp/execution_policy.h", + "cuda/include/thrust/system/omp/vector.h", + "cuda/include/thrust/system/omp/detail/transform_scan.h", + "cuda/include/thrust/system/omp/detail/memory.inl", + "cuda/include/thrust/system/omp/detail/reduce_intervals.inl", + "cuda/include/thrust/system/omp/detail/unique_by_key.h", + "cuda/include/thrust/system/omp/detail/sort.inl", + "cuda/include/thrust/system/omp/detail/partition.h", + "cuda/include/thrust/system/omp/detail/unique.h", + "cuda/include/thrust/system/omp/detail/execution_policy.h", + "cuda/include/thrust/system/omp/detail/adjacent_difference.h", + "cuda/include/thrust/system/omp/detail/unique_by_key.inl", + "cuda/include/thrust/system/omp/detail/sequence.h", + "cuda/include/thrust/system/omp/detail/merge.h", + "cuda/include/thrust/system/omp/detail/unique.inl", + "cuda/include/thrust/system/omp/detail/copy_if.inl", + "cuda/include/thrust/system/omp/detail/transform_reduce.h", + "cuda/include/thrust/system/omp/detail/gather.h", + "cuda/include/thrust/system/omp/detail/reduce_by_key.inl", + "cuda/include/thrust/system/omp/detail/sort.h", + "cuda/include/thrust/system/omp/detail/scan.h", + "cuda/include/thrust/system/omp/detail/temporary_buffer.h", + "cuda/include/thrust/system/omp/detail/default_decomposition.h", + "cuda/include/thrust/system/omp/detail/reduce.inl", + "cuda/include/thrust/system/omp/detail/scan_by_key.h", + "cuda/include/thrust/system/omp/detail/reverse.h", + "cuda/include/thrust/system/omp/detail/assign_value.h", + "cuda/include/thrust/system/omp/detail/scatter.h", + "cuda/include/thrust/system/omp/detail/for_each.inl", + "cuda/include/thrust/system/omp/detail/default_decomposition.inl", + "cuda/include/thrust/system/omp/detail/remove.inl", + "cuda/include/thrust/system/omp/detail/vector.inl", + "cuda/include/thrust/system/omp/detail/find.h", + "cuda/include/thrust/system/omp/detail/generate.h", + "cuda/include/thrust/system/omp/detail/uninitialized_fill.h", + "cuda/include/thrust/system/omp/detail/remove.h", + "cuda/include/thrust/system/omp/detail/tabulate.h", + "cuda/include/thrust/system/omp/detail/for_each.h", + "cuda/include/thrust/system/omp/detail/reduce_by_key.h", + "cuda/include/thrust/system/omp/detail/reduce.h", + "cuda/include/thrust/system/omp/detail/equal.h", + "cuda/include/thrust/system/omp/detail/copy.inl", + "cuda/include/thrust/system/omp/detail/copy.h", + "cuda/include/thrust/system/omp/detail/swap_ranges.h", + "cuda/include/thrust/system/omp/detail/uninitialized_copy.h", + "cuda/include/thrust/system/omp/detail/binary_search.h", + "cuda/include/thrust/system/omp/detail/set_operations.h", + "cuda/include/thrust/system/omp/detail/mismatch.h", + "cuda/include/thrust/system/omp/detail/extrema.h", + "cuda/include/thrust/system/omp/detail/count.h", + "cuda/include/thrust/system/omp/detail/replace.h", + "cuda/include/thrust/system/omp/detail/get_value.h", + "cuda/include/thrust/system/omp/detail/inner_product.h", + "cuda/include/thrust/system/omp/detail/copy_if.h", + "cuda/include/thrust/system/omp/detail/logical.h", + "cuda/include/thrust/system/omp/detail/partition.inl", + "cuda/include/thrust/system/omp/detail/iter_swap.h", + "cuda/include/thrust/system/omp/detail/par.h", + "cuda/include/thrust/system/omp/detail/reduce_intervals.h", + "cuda/include/thrust/system/omp/detail/malloc_and_free.h", + "cuda/include/thrust/system/omp/detail/fill.h", + "cuda/include/thrust/system/omp/detail/transform.h", + "cuda/include/thrust/system/omp/memory.h", + "cuda/include/thrust/system/tbb/execution_policy.h", + "cuda/include/thrust/system/tbb/vector.h", + "cuda/include/thrust/system/tbb/detail/transform_scan.h", + "cuda/include/thrust/system/tbb/detail/memory.inl", + "cuda/include/thrust/system/tbb/detail/unique_by_key.h", + "cuda/include/thrust/system/tbb/detail/sort.inl", + "cuda/include/thrust/system/tbb/detail/partition.h", + "cuda/include/thrust/system/tbb/detail/unique.h", + "cuda/include/thrust/system/tbb/detail/execution_policy.h", + "cuda/include/thrust/system/tbb/detail/adjacent_difference.h", + "cuda/include/thrust/system/tbb/detail/unique_by_key.inl", + "cuda/include/thrust/system/tbb/detail/sequence.h", + "cuda/include/thrust/system/tbb/detail/merge.h", + "cuda/include/thrust/system/tbb/detail/unique.inl", + "cuda/include/thrust/system/tbb/detail/copy_if.inl", + "cuda/include/thrust/system/tbb/detail/transform_reduce.h", + "cuda/include/thrust/system/tbb/detail/gather.h", + "cuda/include/thrust/system/tbb/detail/reduce_by_key.inl", + "cuda/include/thrust/system/tbb/detail/sort.h", + "cuda/include/thrust/system/tbb/detail/scan.h", + "cuda/include/thrust/system/tbb/detail/temporary_buffer.h", + "cuda/include/thrust/system/tbb/detail/reduce.inl", + "cuda/include/thrust/system/tbb/detail/scan_by_key.h", + "cuda/include/thrust/system/tbb/detail/reverse.h", + "cuda/include/thrust/system/tbb/detail/assign_value.h", + "cuda/include/thrust/system/tbb/detail/scatter.h", + "cuda/include/thrust/system/tbb/detail/for_each.inl", + "cuda/include/thrust/system/tbb/detail/remove.inl", + "cuda/include/thrust/system/tbb/detail/vector.inl", + "cuda/include/thrust/system/tbb/detail/find.h", + "cuda/include/thrust/system/tbb/detail/merge.inl", + "cuda/include/thrust/system/tbb/detail/generate.h", + "cuda/include/thrust/system/tbb/detail/uninitialized_fill.h", + "cuda/include/thrust/system/tbb/detail/remove.h", + "cuda/include/thrust/system/tbb/detail/tabulate.h", + "cuda/include/thrust/system/tbb/detail/for_each.h", + "cuda/include/thrust/system/tbb/detail/reduce_by_key.h", + "cuda/include/thrust/system/tbb/detail/reduce.h", + "cuda/include/thrust/system/tbb/detail/equal.h", + "cuda/include/thrust/system/tbb/detail/copy.inl", + "cuda/include/thrust/system/tbb/detail/copy.h", + "cuda/include/thrust/system/tbb/detail/swap_ranges.h", + "cuda/include/thrust/system/tbb/detail/uninitialized_copy.h", + "cuda/include/thrust/system/tbb/detail/binary_search.h", + "cuda/include/thrust/system/tbb/detail/set_operations.h", + "cuda/include/thrust/system/tbb/detail/mismatch.h", + "cuda/include/thrust/system/tbb/detail/scan.inl", + "cuda/include/thrust/system/tbb/detail/extrema.h", + "cuda/include/thrust/system/tbb/detail/count.h", + "cuda/include/thrust/system/tbb/detail/replace.h", + "cuda/include/thrust/system/tbb/detail/get_value.h", + "cuda/include/thrust/system/tbb/detail/inner_product.h", + "cuda/include/thrust/system/tbb/detail/copy_if.h", + "cuda/include/thrust/system/tbb/detail/logical.h", + "cuda/include/thrust/system/tbb/detail/partition.inl", + "cuda/include/thrust/system/tbb/detail/iter_swap.h", + "cuda/include/thrust/system/tbb/detail/par.h", + "cuda/include/thrust/system/tbb/detail/reduce_intervals.h", + "cuda/include/thrust/system/tbb/detail/malloc_and_free.h", + "cuda/include/thrust/system/tbb/detail/fill.h", + "cuda/include/thrust/system/tbb/detail/transform.h", + "cuda/include/thrust/system/tbb/memory.h", + "cuda/include/thrust/system/error_code.h", + "cuda/include/thrust/system/cpp/execution_policy.h", + "cuda/include/thrust/system/cpp/vector.h", + "cuda/include/thrust/system/cpp/detail/transform_scan.h", + "cuda/include/thrust/system/cpp/detail/memory.inl", + "cuda/include/thrust/system/cpp/detail/unique_by_key.h", + "cuda/include/thrust/system/cpp/detail/partition.h", + "cuda/include/thrust/system/cpp/detail/unique.h", + "cuda/include/thrust/system/cpp/detail/execution_policy.h", + "cuda/include/thrust/system/cpp/detail/adjacent_difference.h", + "cuda/include/thrust/system/cpp/detail/sequence.h", + "cuda/include/thrust/system/cpp/detail/merge.h", + "cuda/include/thrust/system/cpp/detail/transform_reduce.h", + "cuda/include/thrust/system/cpp/detail/gather.h", + "cuda/include/thrust/system/cpp/detail/sort.h", + "cuda/include/thrust/system/cpp/detail/scan.h", + "cuda/include/thrust/system/cpp/detail/temporary_buffer.h", + "cuda/include/thrust/system/cpp/detail/scan_by_key.h", + "cuda/include/thrust/system/cpp/detail/reverse.h", + "cuda/include/thrust/system/cpp/detail/assign_value.h", + "cuda/include/thrust/system/cpp/detail/scatter.h", + "cuda/include/thrust/system/cpp/detail/vector.inl", + "cuda/include/thrust/system/cpp/detail/find.h", + "cuda/include/thrust/system/cpp/detail/generate.h", + "cuda/include/thrust/system/cpp/detail/uninitialized_fill.h", + "cuda/include/thrust/system/cpp/detail/remove.h", + "cuda/include/thrust/system/cpp/detail/tabulate.h", + "cuda/include/thrust/system/cpp/detail/for_each.h", + "cuda/include/thrust/system/cpp/detail/reduce_by_key.h", + "cuda/include/thrust/system/cpp/detail/reduce.h", + "cuda/include/thrust/system/cpp/detail/equal.h", + "cuda/include/thrust/system/cpp/detail/copy.h", + "cuda/include/thrust/system/cpp/detail/swap_ranges.h", + "cuda/include/thrust/system/cpp/detail/uninitialized_copy.h", + "cuda/include/thrust/system/cpp/detail/binary_search.h", + "cuda/include/thrust/system/cpp/detail/set_operations.h", + "cuda/include/thrust/system/cpp/detail/mismatch.h", + "cuda/include/thrust/system/cpp/detail/extrema.h", + "cuda/include/thrust/system/cpp/detail/count.h", + "cuda/include/thrust/system/cpp/detail/replace.h", + "cuda/include/thrust/system/cpp/detail/get_value.h", + "cuda/include/thrust/system/cpp/detail/inner_product.h", + "cuda/include/thrust/system/cpp/detail/copy_if.h", + "cuda/include/thrust/system/cpp/detail/logical.h", + "cuda/include/thrust/system/cpp/detail/iter_swap.h", + "cuda/include/thrust/system/cpp/detail/par.h", + "cuda/include/thrust/system/cpp/detail/malloc_and_free.h", + "cuda/include/thrust/system/cpp/detail/fill.h", + "cuda/include/thrust/system/cpp/detail/transform.h", + "cuda/include/thrust/system/cpp/memory.h", + "cuda/include/thrust/system/cuda/execution_policy.h", + "cuda/include/thrust/system/cuda/vector.h", + "cuda/include/thrust/system/cuda/error.h", + "cuda/include/thrust/system/cuda/detail/copy_device_to_device.h", + "cuda/include/thrust/system/cuda/detail/transform_scan.h", + "cuda/include/thrust/system/cuda/detail/memory.inl", + "cuda/include/thrust/system/cuda/detail/cub/util_allocator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/grid/grid_mapping.cuh", + "cuda/include/thrust/system/cuda/detail/cub/grid/grid_barrier.cuh", + "cuda/include/thrust/system/cuda/detail/cub/grid/grid_even_share.cuh", + "cuda/include/thrust/system/cuda/detail/cub/grid/grid_queue.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_device.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_run_length_encode.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_partition.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_radix_sort.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_rle_dispatch.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_histogram_dispatch.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_reduce_by_key_dispatch.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_scan_dispatch.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_select_dispatch.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_reduce_dispatch.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_radix_sort_dispatch.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_scan.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_select.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_reduce.cuh", + "cuda/include/thrust/system/cuda/detail/cub/device/device_histogram.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_reduce.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_histo.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_scan.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_radix_sort_downsweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_radix_sort_upsweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_satomic.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_sort.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_gatomic.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_select.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/block_scan_prefix_operators.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_reduce_by_key.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_macro.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_namespace.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_radix_sort_upsweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_histogram_sweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_rle_sweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_select_sweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_scan_sweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_reduce_sweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_satomic_sweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_sort_sweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_gatomic_sweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_radix_sort_downsweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_reduce_by_key_sweep.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_scan_prefix_operators.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_type.cuh", + "cuda/include/thrust/system/cuda/detail/cub/host/spinlock.cuh", + "cuda/include/thrust/system/cuda/detail/cub/warp/warp_reduce.cuh", + "cuda/include/thrust/system/cuda/detail/cub/warp/warp_scan.cuh", + "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_shfl.cuh", + "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_smem.cuh", + "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_shfl.cuh", + "cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_smem.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_ptx.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_debug.cuh", + "cuda/include/thrust/system/cuda/detail/cub/cub.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/transform_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/tex_obj_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/tex_ref_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_output_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/counting_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/arg_index_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/iterator/constant_input_iterator.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_scan.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_load.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_discontinuity.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_radix_rank.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_shift.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_store.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_reduce.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_exchange.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_radix_sort.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_histogram.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/block_raking_layout.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_warp_reductions.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking_commutative_only.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_atomic.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_raking.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_sort.cuh", + "cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking.cuh", + "cuda/include/thrust/system/cuda/detail/cub/thread/thread_load.cuh", + "cuda/include/thrust/system/cuda/detail/cub/thread/thread_store.cuh", + "cuda/include/thrust/system/cuda/detail/cub/thread/thread_scan.cuh", + "cuda/include/thrust/system/cuda/detail/cub/thread/thread_operators.cuh", + "cuda/include/thrust/system/cuda/detail/cub/thread/thread_reduce.cuh", + "cuda/include/thrust/system/cuda/detail/cub/util_arch.cuh", + "cuda/include/thrust/system/cuda/detail/reduce_intervals.inl", + "cuda/include/thrust/system/cuda/detail/copy_cross_system.inl", + "cuda/include/thrust/system/cuda/detail/unique_by_key.h", + "cuda/include/thrust/system/cuda/detail/bulk.h", + "cuda/include/thrust/system/cuda/detail/sort.inl", + "cuda/include/thrust/system/cuda/detail/partition.h", + "cuda/include/thrust/system/cuda/detail/unique.h", + "cuda/include/thrust/system/cuda/detail/execution_policy.h", + "cuda/include/thrust/system/cuda/detail/cuda_launch_config.h", + "cuda/include/thrust/system/cuda/detail/cub.h", + "cuda/include/thrust/system/cuda/detail/adjacent_difference.h", + "cuda/include/thrust/system/cuda/detail/sequence.h", + "cuda/include/thrust/system/cuda/detail/merge.h", + "cuda/include/thrust/system/cuda/detail/set_symmetric_difference.inl", + "cuda/include/thrust/system/cuda/detail/copy_if.inl", + "cuda/include/thrust/system/cuda/detail/transform_reduce.h", + "cuda/include/thrust/system/cuda/detail/error.inl", + "cuda/include/thrust/system/cuda/detail/gather.h", + "cuda/include/thrust/system/cuda/detail/reduce_by_key.inl", + "cuda/include/thrust/system/cuda/detail/sort.h", + "cuda/include/thrust/system/cuda/detail/synchronize.h", + "cuda/include/thrust/system/cuda/detail/scan.h", + "cuda/include/thrust/system/cuda/detail/temporary_indirect_permutation.h", + "cuda/include/thrust/system/cuda/detail/extern_shared_ptr.h", + "cuda/include/thrust/system/cuda/detail/detail/set_operation.inl", + "cuda/include/thrust/system/cuda/detail/detail/balanced_path.h", + "cuda/include/thrust/system/cuda/detail/detail/virtualized_smem_closure.h", + "cuda/include/thrust/system/cuda/detail/detail/stable_primitive_sort.h", + "cuda/include/thrust/system/cuda/detail/detail/set_operation.h", + "cuda/include/thrust/system/cuda/detail/detail/stable_primitive_sort.inl", + "cuda/include/thrust/system/cuda/detail/detail/stable_merge_sort.h", + "cuda/include/thrust/system/cuda/detail/detail/launch_closure.inl", + "cuda/include/thrust/system/cuda/detail/detail/merge.h", + "cuda/include/thrust/system/cuda/detail/detail/alignment.h", + "cuda/include/thrust/system/cuda/detail/detail/stable_radix_sort.inl", + "cuda/include/thrust/system/cuda/detail/detail/stable_sort_each.h", + "cuda/include/thrust/system/cuda/detail/detail/launch_calculator.inl", + "cuda/include/thrust/system/cuda/detail/detail/stable_merge_sort.inl", + "cuda/include/thrust/system/cuda/detail/detail/launch_closure.h", + "cuda/include/thrust/system/cuda/detail/detail/stable_radix_sort.h", + "cuda/include/thrust/system/cuda/detail/detail/uninitialized.h", + "cuda/include/thrust/system/cuda/detail/detail/cached_temporary_allocator.h", + "cuda/include/thrust/system/cuda/detail/detail/launch_calculator.h", + "cuda/include/thrust/system/cuda/detail/detail/stable_sort_each.inl", + "cuda/include/thrust/system/cuda/detail/temporary_buffer.h", + "cuda/include/thrust/system/cuda/detail/default_decomposition.h", + "cuda/include/thrust/system/cuda/detail/reduce.inl", + "cuda/include/thrust/system/cuda/detail/scan_by_key.h", + "cuda/include/thrust/system/cuda/detail/reverse.h", + "cuda/include/thrust/system/cuda/detail/assign_value.h", + "cuda/include/thrust/system/cuda/detail/scatter.h", + "cuda/include/thrust/system/cuda/detail/reduce_intervals.hpp", + "cuda/include/thrust/system/cuda/detail/for_each.inl", + "cuda/include/thrust/system/cuda/detail/default_decomposition.inl", + "cuda/include/thrust/system/cuda/detail/guarded_cuda_runtime_api.h", + "cuda/include/thrust/system/cuda/detail/adjacent_difference.inl", + "cuda/include/thrust/system/cuda/detail/vector.inl", + "cuda/include/thrust/system/cuda/detail/throw_on_error.h", + "cuda/include/thrust/system/cuda/detail/find.h", + "cuda/include/thrust/system/cuda/detail/terminate.h", + "cuda/include/thrust/system/cuda/detail/merge.inl", + "cuda/include/thrust/system/cuda/detail/trivial_copy.inl", + "cuda/include/thrust/system/cuda/detail/generate.h", + "cuda/include/thrust/system/cuda/detail/execute_on_stream.h", + "cuda/include/thrust/system/cuda/detail/uninitialized_fill.h", + "cuda/include/thrust/system/cuda/detail/remove.h", + "cuda/include/thrust/system/cuda/detail/tabulate.h", + "cuda/include/thrust/system/cuda/detail/for_each.h", + "cuda/include/thrust/system/cuda/detail/reduce_by_key.h", + "cuda/include/thrust/system/cuda/detail/decomposition.h", + "cuda/include/thrust/system/cuda/detail/reduce.h", + "cuda/include/thrust/system/cuda/detail/equal.h", + "cuda/include/thrust/system/cuda/detail/runtime_introspection.h", + "cuda/include/thrust/system/cuda/detail/copy.inl", + "cuda/include/thrust/system/cuda/detail/copy.h", + "cuda/include/thrust/system/cuda/detail/swap_ranges.h", + "cuda/include/thrust/system/cuda/detail/uninitialized_copy.h", + "cuda/include/thrust/system/cuda/detail/binary_search.h", + "cuda/include/thrust/system/cuda/detail/runtime_introspection.inl", + "cuda/include/thrust/system/cuda/detail/set_operations.h", + "cuda/include/thrust/system/cuda/detail/mismatch.h", + "cuda/include/thrust/system/cuda/detail/scan.inl", + "cuda/include/thrust/system/cuda/detail/synchronize.inl", + "cuda/include/thrust/system/cuda/detail/extrema.h", + "cuda/include/thrust/system/cuda/detail/set_union.inl", + "cuda/include/thrust/system/cuda/detail/set_intersection.inl", + "cuda/include/thrust/system/cuda/detail/count.h", + "cuda/include/thrust/system/cuda/detail/trivial_copy.h", + "cuda/include/thrust/system/cuda/detail/copy_device_to_device.inl", + "cuda/include/thrust/system/cuda/detail/replace.h", + "cuda/include/thrust/system/cuda/detail/bulk/malloc.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/config.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/closure.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/tail_flags.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/terminate.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/alignment.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/guarded_cuda_runtime_api.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/choose_sizes.inl", + "cuda/include/thrust/system/cuda/detail/bulk/detail/tuple_meta_transform.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_task.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/head_flags.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/synchronize.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/throw_on_error.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/parameter_ptr.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/cuda_launcher.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/triple_chevron_launcher.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/runtime_introspection.inl", + "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/cuda_launch_config.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/runtime_introspection.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/async.inl", + "cuda/include/thrust/system/cuda/detail/bulk/detail/tuple_transform.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/pointer_traits.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/apply_from_tuple.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/detail/is_contiguous_iterator.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/iterator.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/choose_sizes.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/copy.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/merge.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/accumulate.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/scan.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/detail/stable_merge_sort.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/gather.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/sort.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/reduce.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/scatter.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/adjacent_difference.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/reduce_by_key.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/algorithm/for_each.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/bulk.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/execution_policy.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/iterator/strided_iterator.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/uninitialized.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/async.hpp", + "cuda/include/thrust/system/cuda/detail/bulk/future.hpp", + "cuda/include/thrust/system/cuda/detail/guarded_driver_types.h", + "cuda/include/thrust/system/cuda/detail/get_value.h", + "cuda/include/thrust/system/cuda/detail/inner_product.h", + "cuda/include/thrust/system/cuda/detail/copy_if.h", + "cuda/include/thrust/system/cuda/detail/logical.h", + "cuda/include/thrust/system/cuda/detail/iter_swap.h", + "cuda/include/thrust/system/cuda/detail/block/merge.h", + "cuda/include/thrust/system/cuda/detail/block/inclusive_scan.h", + "cuda/include/thrust/system/cuda/detail/block/merge.inl", + "cuda/include/thrust/system/cuda/detail/block/merging_sort.h", + "cuda/include/thrust/system/cuda/detail/block/exclusive_scan.h", + "cuda/include/thrust/system/cuda/detail/block/reduce.h", + "cuda/include/thrust/system/cuda/detail/block/copy.h", + "cuda/include/thrust/system/cuda/detail/block/odd_even_sort.h", + "cuda/include/thrust/system/cuda/detail/par.h", + "cuda/include/thrust/system/cuda/detail/copy_cross_system.h", + "cuda/include/thrust/system/cuda/detail/reduce_intervals.h", + "cuda/include/thrust/system/cuda/detail/malloc_and_free.h", + "cuda/include/thrust/system/cuda/detail/fill.h", + "cuda/include/thrust/system/cuda/detail/set_difference.inl", + "cuda/include/thrust/system/cuda/detail/transform.h", + "cuda/include/thrust/system/cuda/experimental/pinned_allocator.h", + "cuda/include/thrust/system/cuda/memory.h", + "cuda/include/thrust/remove.h", + "cuda/include/thrust/tabulate.h", + "cuda/include/thrust/for_each.h", + "cuda/include/thrust/distance.h", + "cuda/include/thrust/reduce.h", + "cuda/include/thrust/equal.h", + "cuda/include/thrust/complex.h", + "cuda/include/thrust/device_allocator.h", + "cuda/include/thrust/copy.h", + "cuda/include/thrust/uninitialized_copy.h", + "cuda/include/thrust/device_reference.h", + "cuda/include/thrust/binary_search.h", + "cuda/include/thrust/set_operations.h", + "cuda/include/thrust/swap.h", + "cuda/include/thrust/mismatch.h", + "cuda/include/thrust/extrema.h", + "cuda/include/thrust/count.h", + "cuda/include/thrust/device_free.h", + "cuda/include/thrust/random/discard_block_engine.h", + "cuda/include/thrust/random/normal_distribution.h", + "cuda/include/thrust/random/detail/linear_feedback_shift_engine_wordmask.h", + "cuda/include/thrust/random/detail/subtract_with_carry_engine.inl", + "cuda/include/thrust/random/detail/xor_combine_engine_max.h", + "cuda/include/thrust/random/detail/linear_congruential_engine_discard.h", + "cuda/include/thrust/random/detail/uniform_int_distribution.inl", + "cuda/include/thrust/random/detail/discard_block_engine.inl", + "cuda/include/thrust/random/detail/uniform_real_distribution.inl", + "cuda/include/thrust/random/detail/random_core_access.h", + "cuda/include/thrust/random/detail/mod.h", + "cuda/include/thrust/random/detail/linear_feedback_shift_engine.inl", + "cuda/include/thrust/random/detail/linear_congruential_engine.inl", + "cuda/include/thrust/random/detail/xor_combine_engine.inl", + "cuda/include/thrust/random/detail/normal_distribution.inl", + "cuda/include/thrust/random/detail/normal_distribution_base.h", + "cuda/include/thrust/random/uniform_int_distribution.h", + "cuda/include/thrust/random/linear_feedback_shift_engine.h", + "cuda/include/thrust/random/xor_combine_engine.h", + "cuda/include/thrust/random/subtract_with_carry_engine.h", + "cuda/include/thrust/random/linear_congruential_engine.h", + "cuda/include/thrust/random/uniform_real_distribution.h", + "cuda/include/thrust/functional.h", + "cuda/include/thrust/replace.h", + "cuda/include/thrust/device_new_allocator.h", + "cuda/include/thrust/host_vector.h", + "cuda/include/thrust/version.h", + "cuda/include/thrust/inner_product.h", + "cuda/include/thrust/iterator/iterator_traits.h", + "cuda/include/thrust/iterator/discard_iterator.h", + "cuda/include/thrust/iterator/retag.h", + "cuda/include/thrust/iterator/permutation_iterator.h", + "cuda/include/thrust/iterator/transform_iterator.h", + "cuda/include/thrust/iterator/detail/reverse_iterator.inl", + "cuda/include/thrust/iterator/detail/zip_iterator.inl", + "cuda/include/thrust/iterator/detail/counting_iterator.inl", + "cuda/include/thrust/iterator/detail/distance_from_result.h", + "cuda/include/thrust/iterator/detail/host_system_tag.h", + "cuda/include/thrust/iterator/detail/iterator_traversal_tags.h", + "cuda/include/thrust/iterator/detail/retag.h", + "cuda/include/thrust/iterator/detail/tagged_iterator.h", + "cuda/include/thrust/iterator/detail/iterator_traits.inl", + "cuda/include/thrust/iterator/detail/minimum_category.h", + "cuda/include/thrust/iterator/detail/discard_iterator_base.h", + "cuda/include/thrust/iterator/detail/iterator_category_to_traversal.h", + "cuda/include/thrust/iterator/detail/zip_iterator_base.h", + "cuda/include/thrust/iterator/detail/normal_iterator.h", + "cuda/include/thrust/iterator/detail/join_iterator.h", + "cuda/include/thrust/iterator/detail/device_system_tag.h", + "cuda/include/thrust/iterator/detail/universal_categories.h", + "cuda/include/thrust/iterator/detail/reverse_iterator_base.h", + "cuda/include/thrust/iterator/detail/minimum_system.h", + "cuda/include/thrust/iterator/detail/tuple_of_iterator_references.h", + "cuda/include/thrust/iterator/detail/is_iterator_category.h", + "cuda/include/thrust/iterator/detail/permutation_iterator_base.h", + "cuda/include/thrust/iterator/detail/any_assign.h", + "cuda/include/thrust/iterator/detail/any_system_tag.h", + "cuda/include/thrust/iterator/detail/is_trivial_iterator.h", + "cuda/include/thrust/iterator/detail/iterator_category_to_system.h", + "cuda/include/thrust/iterator/detail/iterator_adaptor_base.h", + "cuda/include/thrust/iterator/detail/constant_iterator_base.h", + "cuda/include/thrust/iterator/detail/transform_iterator.inl", + "cuda/include/thrust/iterator/detail/iterator_facade_category.h", + "cuda/include/thrust/iterator/detail/iterator_category_with_system_and_traversal.h", + "cuda/include/thrust/iterator/constant_iterator.h", + "cuda/include/thrust/iterator/counting_iterator.h", + "cuda/include/thrust/iterator/iterator_adaptor.h", + "cuda/include/thrust/iterator/iterator_facade.h", + "cuda/include/thrust/iterator/iterator_categories.h", + "cuda/include/thrust/iterator/reverse_iterator.h", + "cuda/include/thrust/iterator/zip_iterator.h", + "cuda/include/thrust/logical.h", + "cuda/include/thrust/tuple.h", + "cuda/include/thrust/memory.h", + "cuda/include/thrust/random.h", + "cuda/include/thrust/fill.h", + "cuda/include/thrust/transform.h", + "cuda/include/texture_types.h", + "cuda/include/nppversion.h", + "cuda/include/cuda_texture_types.h", + "cuda/include/fatbinary.h", + "cuda/include/cublasXt.h", + "cuda/include/cuda_fp16.h", + "cuda/include/vector_functions.h", + "cuda/include/cusparse.h", + "cuda/include/nppi_filtering_functions.h", + "cuda/include/nppi_morphological_operations.h", + "cuda/include/sobol_direction_vectors.h", + "cuda/include/nvblas.h", + "cuda/include/curand_mtgp32dc_p_11213.h", + "cuda/include/nvcuvid.h", + "cuda/include/cuda_runtime_api.h", + "cuda/include/curand_mtgp32_kernel.h", + "cuda/include/cublas_v2.h", + "cuda/include/builtin_types.h", + "cuda/include/nppi_geometry_transforms.h", + "cuda/include/npps_support_functions.h", + "cuda/include/cufftw.h", + "cuda/include/cuda_device_runtime_api.h", + "cuda/include/sm_30_intrinsics.hpp", + "cuda/include/vector_types.h", + "cuda/include/sm_35_atomic_functions.h", + "cuda/include/sm_20_intrinsics.h", + "cuda/include/driver_types.h", + "cuda/include/nvToolsExtCudaRt.h", + "cuda/include/curand_globals.h", + "cuda/include/device_atomic_functions.h", + "cuda/include/surface_types.h", + "cuda/include/nvrtc.h", + "cuda/include/nppdefs.h", + "cuda/include/sm_60_atomic_functions.h", + "cuda/include/driver_functions.h", + "cuda/include/cusolver_common.h", + "cuda/include/cublas.h", + "cuda/include/curand_lognormal.h", + "cuda/include/device_atomic_functions.hpp", + "cuda/include/crt/device_runtime.h", + "cuda/include/crt/storage_class.h", + "cuda/include/crt/func_macro.h", + "cuda/include/crt/host_runtime.h", + "cuda/include/nppi_arithmetic_and_logical_operations.h", + "cuda/include/npps_arithmetic_and_logical_operations.h", + "cuda/include/nppi_computer_vision.h", + "cuda/include/surface_functions.hpp", + "cuda/include/surface_functions.h", + "cuda/include/curand_normal_static.h", + "cuda/include/curand.h", + "cuda/include/math_functions_dbl_ptx3.h", + "cuda/include/curand_philox4x32_x.h", + "cuda/include/nppi_threshold_and_compare_operations.h", + "cuda/include/nvml.h", + "cuda/include/npps.h", + "cuda/include/cuda_vdpau_interop.h", + "cuda/include/sm_61_intrinsics.hpp", + "cuda/include/cublas_api.h", + "cuda/include/nppi_color_conversion.h", + "cuda/include/math_functions_dbl_ptx3.hpp", + "cuda/include/nppcore.h", + "cuda/include/cudaGL.h", + "cuda/include/fatBinaryCtl.h", + "cuda/include/npps_statistics_functions.h", + "cuda/include/cudaVDPAU.h", + "cuda/include/curand_poisson.h", + "cuda/include/cusolverDn.h", + "cuda/include/cuda_profiler_api.h", + "cuda/include/sm_20_atomic_functions.h", + "cuda/include/nvfunctional", + ], + cmd = """ +cp "/usr/local/cuda-8.0/include/math_functions.hpp" "$(@D)/cuda/include/math_functions.hpp" && cp "/usr/local/cuda-8.0/include/cufft.h" "$(@D)/cuda/include/cufft.h" && cp "/usr/local/cuda-8.0/include/nvgraph.h" "$(@D)/cuda/include/nvgraph.h" && cp "/usr/local/cuda-8.0/include/curand_normal.h" "$(@D)/cuda/include/curand_normal.h" && cp "/usr/local/cuda-8.0/include/curand_uniform.h" "$(@D)/cuda/include/curand_uniform.h" && cp "/usr/local/cuda-8.0/include/nppi_data_exchange_and_initialization.h" "$(@D)/cuda/include/nppi_data_exchange_and_initialization.h" && cp "/usr/local/cuda-8.0/include/cuda_gl_interop.h" "$(@D)/cuda/include/cuda_gl_interop.h" && cp "/usr/local/cuda-8.0/include/nppi_compression_functions.h" "$(@D)/cuda/include/nppi_compression_functions.h" && cp "/usr/local/cuda-8.0/include/npp.h" "$(@D)/cuda/include/npp.h" && cp "/usr/local/cuda-8.0/include/cuda.h" "$(@D)/cuda/include/cuda.h" && cp "/usr/local/cuda-8.0/include/nppi_statistics_functions.h" "$(@D)/cuda/include/nppi_statistics_functions.h" && cp "/usr/local/cuda-8.0/include/vector_functions.hpp" "$(@D)/cuda/include/vector_functions.hpp" && cp "/usr/local/cuda-8.0/include/sm_32_intrinsics.hpp" "$(@D)/cuda/include/sm_32_intrinsics.hpp" && cp "/usr/local/cuda-8.0/include/sm_32_intrinsics.h" "$(@D)/cuda/include/sm_32_intrinsics.h" && cp "/usr/local/cuda-8.0/include/curand_discrete.h" "$(@D)/cuda/include/curand_discrete.h" && cp "/usr/local/cuda-8.0/include/cuda_runtime.h" "$(@D)/cuda/include/cuda_runtime.h" && cp "/usr/local/cuda-8.0/include/cufftXt.h" "$(@D)/cuda/include/cufftXt.h" && cp "/usr/local/cuda-8.0/include/sm_61_intrinsics.h" "$(@D)/cuda/include/sm_61_intrinsics.h" && cp "/usr/local/cuda-8.0/include/texture_fetch_functions.h" "$(@D)/cuda/include/texture_fetch_functions.h" && cp "/usr/local/cuda-8.0/include/curand_mrg32k3a.h" "$(@D)/cuda/include/curand_mrg32k3a.h" && cp "/usr/local/cuda-8.0/include/host_defines.h" "$(@D)/cuda/include/host_defines.h" && cp "/usr/local/cuda-8.0/include/common_functions.h" "$(@D)/cuda/include/common_functions.h" && cp "/usr/local/cuda-8.0/include/nppi_support_functions.h" "$(@D)/cuda/include/nppi_support_functions.h" && cp "/usr/local/cuda-8.0/include/nppi_linear_transforms.h" "$(@D)/cuda/include/nppi_linear_transforms.h" && cp "/usr/local/cuda-8.0/include/device_double_functions.hpp" "$(@D)/cuda/include/device_double_functions.hpp" && cp "/usr/local/cuda-8.0/include/math_constants.h" "$(@D)/cuda/include/math_constants.h" && cp "/usr/local/cuda-8.0/include/nvToolsExtSync.h" "$(@D)/cuda/include/nvToolsExtSync.h" && cp "/usr/local/cuda-8.0/include/npps_initialization.h" "$(@D)/cuda/include/npps_initialization.h" && cp "/usr/local/cuda-8.0/include/cusolverSp_LOWLEVEL_PREVIEW.h" "$(@D)/cuda/include/cusolverSp_LOWLEVEL_PREVIEW.h" && cp "/usr/local/cuda-8.0/include/texture_indirect_functions.hpp" "$(@D)/cuda/include/texture_indirect_functions.hpp" && cp "/usr/local/cuda-8.0/include/cudaProfiler.h" "$(@D)/cuda/include/cudaProfiler.h" && cp "/usr/local/cuda-8.0/include/npps_filtering_functions.h" "$(@D)/cuda/include/npps_filtering_functions.h" && cp "/usr/local/cuda-8.0/include/cusparse_v2.h" "$(@D)/cuda/include/cusparse_v2.h" && cp "/usr/local/cuda-8.0/include/nppi.h" "$(@D)/cuda/include/nppi.h" && cp "/usr/local/cuda-8.0/include/surface_indirect_functions.h" "$(@D)/cuda/include/surface_indirect_functions.h" && cp "/usr/local/cuda-8.0/include/sm_30_intrinsics.h" "$(@D)/cuda/include/sm_30_intrinsics.h" && cp "/usr/local/cuda-8.0/include/device_double_functions.h" "$(@D)/cuda/include/device_double_functions.h" && cp "/usr/local/cuda-8.0/include/sm_35_intrinsics.h" "$(@D)/cuda/include/sm_35_intrinsics.h" && cp "/usr/local/cuda-8.0/include/cusolverSp.h" "$(@D)/cuda/include/cusolverSp.h" && cp "/usr/local/cuda-8.0/include/library_types.h" "$(@D)/cuda/include/library_types.h" && cp "/usr/local/cuda-8.0/include/surface_indirect_functions.hpp" "$(@D)/cuda/include/surface_indirect_functions.hpp" && cp "/usr/local/cuda-8.0/include/cudalibxt.h" "$(@D)/cuda/include/cudalibxt.h" && cp "/usr/local/cuda-8.0/include/channel_descriptor.h" "$(@D)/cuda/include/channel_descriptor.h" && cp "/usr/local/cuda-8.0/include/device_functions_decls.h" "$(@D)/cuda/include/device_functions_decls.h" && cp "/usr/local/cuda-8.0/include/curand_kernel.h" "$(@D)/cuda/include/curand_kernel.h" && cp "/usr/local/cuda-8.0/include/curand_mtgp32_host.h" "$(@D)/cuda/include/curand_mtgp32_host.h" && cp "/usr/local/cuda-8.0/include/nvToolsExtCuda.h" "$(@D)/cuda/include/nvToolsExtCuda.h" && cp "/usr/local/cuda-8.0/include/nvToolsExt.h" "$(@D)/cuda/include/nvToolsExt.h" && cp "/usr/local/cuda-8.0/include/cuComplex.h" "$(@D)/cuda/include/cuComplex.h" && cp "/usr/local/cuda-8.0/include/sm_32_atomic_functions.h" "$(@D)/cuda/include/sm_32_atomic_functions.h" && cp "/usr/local/cuda-8.0/include/texture_indirect_functions.h" "$(@D)/cuda/include/texture_indirect_functions.h" && cp "/usr/local/cuda-8.0/include/sm_32_atomic_functions.hpp" "$(@D)/cuda/include/sm_32_atomic_functions.hpp" && cp "/usr/local/cuda-8.0/include/sm_20_intrinsics.hpp" "$(@D)/cuda/include/sm_20_intrinsics.hpp" && cp "/usr/local/cuda-8.0/include/device_launch_parameters.h" "$(@D)/cuda/include/device_launch_parameters.h" && cp "/usr/local/cuda-8.0/include/curand_mtgp32.h" "$(@D)/cuda/include/curand_mtgp32.h" && cp "/usr/local/cuda-8.0/include/texture_fetch_functions.hpp" "$(@D)/cuda/include/texture_fetch_functions.hpp" && cp "/usr/local/cuda-8.0/include/cuda_occupancy.h" "$(@D)/cuda/include/cuda_occupancy.h" && cp "/usr/local/cuda-8.0/include/CL/opencl.h" "$(@D)/cuda/include/CL/opencl.h" && cp "/usr/local/cuda-8.0/include/CL/cl_platform.h" "$(@D)/cuda/include/CL/cl_platform.h" && cp "/usr/local/cuda-8.0/include/CL/cl_egl.h" "$(@D)/cuda/include/CL/cl_egl.h" && cp "/usr/local/cuda-8.0/include/CL/cl_gl.h" "$(@D)/cuda/include/CL/cl_gl.h" && cp "/usr/local/cuda-8.0/include/CL/cl.h" "$(@D)/cuda/include/CL/cl.h" && cp "/usr/local/cuda-8.0/include/CL/cl_gl_ext.h" "$(@D)/cuda/include/CL/cl_gl_ext.h" && cp "/usr/local/cuda-8.0/include/CL/cl_ext.h" "$(@D)/cuda/include/CL/cl_ext.h" && cp "/usr/local/cuda-8.0/include/CL/cl.hpp" "$(@D)/cuda/include/CL/cl.hpp" && cp "/usr/local/cuda-8.0/include/host_config.h" "$(@D)/cuda/include/host_config.h" && cp "/usr/local/cuda-8.0/include/cuda_surface_types.h" "$(@D)/cuda/include/cuda_surface_types.h" && cp "/usr/local/cuda-8.0/include/math_functions.h" "$(@D)/cuda/include/math_functions.h" && cp "/usr/local/cuda-8.0/include/nvToolsExtMeta.h" "$(@D)/cuda/include/nvToolsExtMeta.h" && cp "/usr/local/cuda-8.0/include/sm_20_atomic_functions.hpp" "$(@D)/cuda/include/sm_20_atomic_functions.hpp" && cp "/usr/local/cuda-8.0/include/device_functions.h" "$(@D)/cuda/include/device_functions.h" && cp "/usr/local/cuda-8.0/include/device_types.h" "$(@D)/cuda/include/device_types.h" && cp "/usr/local/cuda-8.0/include/npps_conversion_functions.h" "$(@D)/cuda/include/npps_conversion_functions.h" && cp "/usr/local/cuda-8.0/include/curand_precalc.h" "$(@D)/cuda/include/curand_precalc.h" && cp "/usr/local/cuda-8.0/include/cusolverRf.h" "$(@D)/cuda/include/cusolverRf.h" && cp "/usr/local/cuda-8.0/include/sm_60_atomic_functions.hpp" "$(@D)/cuda/include/sm_60_atomic_functions.hpp" && cp "/usr/local/cuda-8.0/include/cuviddec.h" "$(@D)/cuda/include/cuviddec.h" && cp "/usr/local/cuda-8.0/include/curand_discrete2.h" "$(@D)/cuda/include/curand_discrete2.h" && cp "/usr/local/cuda-8.0/include/device_functions.hpp" "$(@D)/cuda/include/device_functions.hpp" && cp "/usr/local/cuda-8.0/include/thrust/transform_scan.h" "$(@D)/cuda/include/thrust/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system_error.h" "$(@D)/cuda/include/thrust/system_error.h" && cp "/usr/local/cuda-8.0/include/thrust/device_malloc.h" "$(@D)/cuda/include/thrust/device_malloc.h" && cp "/usr/local/cuda-8.0/include/thrust/partition.h" "$(@D)/cuda/include/thrust/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/unique.h" "$(@D)/cuda/include/thrust/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/device_delete.h" "$(@D)/cuda/include/thrust/device_delete.h" && cp "/usr/local/cuda-8.0/include/thrust/execution_policy.h" "$(@D)/cuda/include/thrust/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/adjacent_difference.h" "$(@D)/cuda/include/thrust/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/sequence.h" "$(@D)/cuda/include/thrust/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/merge.h" "$(@D)/cuda/include/thrust/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/device_new.h" "$(@D)/cuda/include/thrust/device_new.h" && cp "/usr/local/cuda-8.0/include/thrust/transform_reduce.h" "$(@D)/cuda/include/thrust/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/device_vector.h" "$(@D)/cuda/include/thrust/device_vector.h" && cp "/usr/local/cuda-8.0/include/thrust/gather.h" "$(@D)/cuda/include/thrust/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/sort.h" "$(@D)/cuda/include/thrust/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/scan.h" "$(@D)/cuda/include/thrust/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/temporary_array.h" "$(@D)/cuda/include/thrust/detail/temporary_array.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/util/align.h" "$(@D)/cuda/include/thrust/detail/util/align.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/util/blocking.h" "$(@D)/cuda/include/thrust/detail/util/blocking.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/transform.inl" "$(@D)/cuda/include/thrust/detail/transform.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_vector.inl" "$(@D)/cuda/include/thrust/detail/device_vector.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/binary_search.inl" "$(@D)/cuda/include/thrust/detail/binary_search.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/overlapped_copy.h" "$(@D)/cuda/include/thrust/detail/overlapped_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/vector_base.inl" "$(@D)/cuda/include/thrust/detail/vector_base.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_reference.inl" "$(@D)/cuda/include/thrust/detail/device_reference.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/actor.h" "$(@D)/cuda/include/thrust/detail/functional/actor.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/value.h" "$(@D)/cuda/include/thrust/detail/functional/value.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/logical_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/logical_operators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/relational_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/relational_operators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/assignment_operator.h" "$(@D)/cuda/include/thrust/detail/functional/operators/assignment_operator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/bitwise_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/bitwise_operators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/operator_adaptors.h" "$(@D)/cuda/include/thrust/detail/functional/operators/operator_adaptors.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/arithmetic_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/arithmetic_operators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/operators/compound_assignment_operators.h" "$(@D)/cuda/include/thrust/detail/functional/operators/compound_assignment_operators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/argument.h" "$(@D)/cuda/include/thrust/detail/functional/argument.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/placeholder.h" "$(@D)/cuda/include/thrust/detail/functional/placeholder.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/actor.inl" "$(@D)/cuda/include/thrust/detail/functional/actor.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional/composite.h" "$(@D)/cuda/include/thrust/detail/functional/composite.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/static_map.h" "$(@D)/cuda/include/thrust/detail/static_map.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/has_nested_type.h" "$(@D)/cuda/include/thrust/detail/type_traits/has_nested_type.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/is_call_possible.h" "$(@D)/cuda/include/thrust/detail/type_traits/is_call_possible.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/function_traits.h" "$(@D)/cuda/include/thrust/detail/type_traits/function_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/pointer_traits.h" "$(@D)/cuda/include/thrust/detail/type_traits/pointer_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/has_member_function.h" "$(@D)/cuda/include/thrust/detail/type_traits/has_member_function.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h" "$(@D)/cuda/include/thrust/detail/type_traits/algorithm/intermediate_type_from_function_and_iterators.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/minimum_type.h" "$(@D)/cuda/include/thrust/detail/type_traits/minimum_type.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/has_trivial_assign.h" "$(@D)/cuda/include/thrust/detail/type_traits/has_trivial_assign.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/is_metafunction_defined.h" "$(@D)/cuda/include/thrust/detail/type_traits/is_metafunction_defined.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/iterator/is_discard_iterator.h" "$(@D)/cuda/include/thrust/detail/type_traits/iterator/is_discard_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/iterator/is_output_iterator.h" "$(@D)/cuda/include/thrust/detail/type_traits/iterator/is_output_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits/result_of_adaptable_function.h" "$(@D)/cuda/include/thrust/detail/type_traits/result_of_adaptable_function.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/reference.h" "$(@D)/cuda/include/thrust/detail/reference.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/inner_product.inl" "$(@D)/cuda/include/thrust/detail/inner_product.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/use_default.h" "$(@D)/cuda/include/thrust/detail/use_default.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/sequence.inl" "$(@D)/cuda/include/thrust/detail/sequence.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/sort.inl" "$(@D)/cuda/include/thrust/detail/sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/equal.inl" "$(@D)/cuda/include/thrust/detail/equal.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/execution_policy.h" "$(@D)/cuda/include/thrust/detail/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/integer_traits.h" "$(@D)/cuda/include/thrust/detail/integer_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/type_traits.h" "$(@D)/cuda/include/thrust/detail/type_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/reverse.inl" "$(@D)/cuda/include/thrust/detail/reverse.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/tabulate.inl" "$(@D)/cuda/include/thrust/detail/tabulate.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/unique.inl" "$(@D)/cuda/include/thrust/detail/unique.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/scatter.inl" "$(@D)/cuda/include/thrust/detail/scatter.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/set_operations.inl" "$(@D)/cuda/include/thrust/detail/set_operations.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_malloc.inl" "$(@D)/cuda/include/thrust/detail/device_malloc.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/copy_if.inl" "$(@D)/cuda/include/thrust/detail/copy_if.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/fill.inl" "$(@D)/cuda/include/thrust/detail/fill.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/temporary_array.inl" "$(@D)/cuda/include/thrust/detail/temporary_array.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/transform_scan.inl" "$(@D)/cuda/include/thrust/detail/transform_scan.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/minmax.h" "$(@D)/cuda/include/thrust/detail/minmax.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/swap.inl" "$(@D)/cuda/include/thrust/detail/swap.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/pointer.inl" "$(@D)/cuda/include/thrust/detail/pointer.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/transform_reduce.inl" "$(@D)/cuda/include/thrust/detail/transform_reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/config.h" "$(@D)/cuda/include/thrust/detail/config.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/distance.inl" "$(@D)/cuda/include/thrust/detail/distance.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/pair.inl" "$(@D)/cuda/include/thrust/detail/pair.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/temporary_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/temporary_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/tagged_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/tagged_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/destroy_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/destroy_range.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/destroy_range.h" "$(@D)/cuda/include/thrust/detail/allocator/destroy_range.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/no_throw_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/no_throw_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/default_construct_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/default_construct_range.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/fill_construct_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/fill_construct_range.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/tagged_allocator.inl" "$(@D)/cuda/include/thrust/detail/allocator/tagged_allocator.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/malloc_allocator.h" "$(@D)/cuda/include/thrust/detail/allocator/malloc_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/allocator_traits.h" "$(@D)/cuda/include/thrust/detail/allocator/allocator_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/copy_construct_range.h" "$(@D)/cuda/include/thrust/detail/allocator/copy_construct_range.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/allocator_traits.inl" "$(@D)/cuda/include/thrust/detail/allocator/allocator_traits.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/default_construct_range.h" "$(@D)/cuda/include/thrust/detail/allocator/default_construct_range.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/copy_construct_range.inl" "$(@D)/cuda/include/thrust/detail/allocator/copy_construct_range.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/malloc_allocator.inl" "$(@D)/cuda/include/thrust/detail/allocator/malloc_allocator.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/temporary_allocator.inl" "$(@D)/cuda/include/thrust/detail/allocator/temporary_allocator.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/allocator/fill_construct_range.h" "$(@D)/cuda/include/thrust/detail/allocator/fill_construct_range.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/detail/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/reduce.inl" "$(@D)/cuda/include/thrust/detail/reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_new.inl" "$(@D)/cuda/include/thrust/detail/device_new.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/pointer.h" "$(@D)/cuda/include/thrust/detail/pointer.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/for_each.inl" "$(@D)/cuda/include/thrust/detail/for_each.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/generate.inl" "$(@D)/cuda/include/thrust/detail/generate.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/dispatch/is_trivial_copy.h" "$(@D)/cuda/include/thrust/detail/dispatch/is_trivial_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/adjacent_difference.inl" "$(@D)/cuda/include/thrust/detail/adjacent_difference.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/tuple_meta_transform.h" "$(@D)/cuda/include/thrust/detail/tuple_meta_transform.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/functional.inl" "$(@D)/cuda/include/thrust/detail/functional.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/remove.inl" "$(@D)/cuda/include/thrust/detail/remove.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/tuple_transform.h" "$(@D)/cuda/include/thrust/detail/tuple_transform.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/merge.inl" "$(@D)/cuda/include/thrust/detail/merge.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/extrema.inl" "$(@D)/cuda/include/thrust/detail/extrema.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/trivial_sequence.h" "$(@D)/cuda/include/thrust/detail/trivial_sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/vector_base.h" "$(@D)/cuda/include/thrust/detail/vector_base.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/count.inl" "$(@D)/cuda/include/thrust/detail/count.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/uninitialized_copy.inl" "$(@D)/cuda/include/thrust/detail/uninitialized_copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/function.h" "$(@D)/cuda/include/thrust/detail/function.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/swap_ranges.inl" "$(@D)/cuda/include/thrust/detail/swap_ranges.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_delete.inl" "$(@D)/cuda/include/thrust/detail/device_delete.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/static_assert.h" "$(@D)/cuda/include/thrust/detail/static_assert.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/logical.inl" "$(@D)/cuda/include/thrust/detail/logical.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/seq.h" "$(@D)/cuda/include/thrust/detail/seq.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/mpl/math.h" "$(@D)/cuda/include/thrust/detail/mpl/math.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/mismatch.inl" "$(@D)/cuda/include/thrust/detail/mismatch.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/internal_functional.h" "$(@D)/cuda/include/thrust/detail/internal_functional.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/get_iterator_value.h" "$(@D)/cuda/include/thrust/detail/get_iterator_value.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/copy.inl" "$(@D)/cuda/include/thrust/detail/copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/copy.h" "$(@D)/cuda/include/thrust/detail/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/catrigf.h" "$(@D)/cuda/include/thrust/detail/complex/catrigf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/cpowf.h" "$(@D)/cuda/include/thrust/detail/complex/cpowf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/csqrtf.h" "$(@D)/cuda/include/thrust/detail/complex/csqrtf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/ccoshf.h" "$(@D)/cuda/include/thrust/detail/complex/ccoshf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/csinhf.h" "$(@D)/cuda/include/thrust/detail/complex/csinhf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/clogf.h" "$(@D)/cuda/include/thrust/detail/complex/clogf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/ccosh.h" "$(@D)/cuda/include/thrust/detail/complex/ccosh.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/arithmetic.h" "$(@D)/cuda/include/thrust/detail/complex/arithmetic.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/csqrt.h" "$(@D)/cuda/include/thrust/detail/complex/csqrt.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/cpow.h" "$(@D)/cuda/include/thrust/detail/complex/cpow.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/complex.inl" "$(@D)/cuda/include/thrust/detail/complex/complex.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/math_private.h" "$(@D)/cuda/include/thrust/detail/complex/math_private.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/c99math.h" "$(@D)/cuda/include/thrust/detail/complex/c99math.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/cproj.h" "$(@D)/cuda/include/thrust/detail/complex/cproj.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/catrig.h" "$(@D)/cuda/include/thrust/detail/complex/catrig.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/ctanhf.h" "$(@D)/cuda/include/thrust/detail/complex/ctanhf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/cexpf.h" "$(@D)/cuda/include/thrust/detail/complex/cexpf.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/csinh.h" "$(@D)/cuda/include/thrust/detail/complex/csinh.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/stream.h" "$(@D)/cuda/include/thrust/detail/complex/stream.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/ctanh.h" "$(@D)/cuda/include/thrust/detail/complex/ctanh.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/cexp.h" "$(@D)/cuda/include/thrust/detail/complex/cexp.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/complex/clog.h" "$(@D)/cuda/include/thrust/detail/complex/clog.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/range/head_flags.h" "$(@D)/cuda/include/thrust/detail/range/head_flags.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/range/tail_flags.h" "$(@D)/cuda/include/thrust/detail/range/tail_flags.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/execute_with_allocator.h" "$(@D)/cuda/include/thrust/detail/execute_with_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/integer_math.h" "$(@D)/cuda/include/thrust/detail/integer_math.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/swap.h" "$(@D)/cuda/include/thrust/detail/swap.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/uninitialized_fill.inl" "$(@D)/cuda/include/thrust/detail/uninitialized_fill.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/scan.inl" "$(@D)/cuda/include/thrust/detail/scan.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/gather.inl" "$(@D)/cuda/include/thrust/detail/gather.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/reference_forward_declaration.h" "$(@D)/cuda/include/thrust/detail/reference_forward_declaration.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/numeric_traits.h" "$(@D)/cuda/include/thrust/detail/numeric_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/reference.inl" "$(@D)/cuda/include/thrust/detail/reference.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/cstdint.h" "$(@D)/cuda/include/thrust/detail/cstdint.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_free.inl" "$(@D)/cuda/include/thrust/detail/device_free.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/copy_if.h" "$(@D)/cuda/include/thrust/detail/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/partition.inl" "$(@D)/cuda/include/thrust/detail/partition.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/find.inl" "$(@D)/cuda/include/thrust/detail/find.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/forceinline.h" "$(@D)/cuda/include/thrust/detail/config/forceinline.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/debug.h" "$(@D)/cuda/include/thrust/detail/config/debug.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/config.h" "$(@D)/cuda/include/thrust/detail/config/config.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/host_device.h" "$(@D)/cuda/include/thrust/detail/config/host_device.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/host_system.h" "$(@D)/cuda/include/thrust/detail/config/host_system.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/compiler.h" "$(@D)/cuda/include/thrust/detail/config/compiler.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/device_system.h" "$(@D)/cuda/include/thrust/detail/config/device_system.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/compiler_fence.h" "$(@D)/cuda/include/thrust/detail/config/compiler_fence.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/exec_check_disable.h" "$(@D)/cuda/include/thrust/detail/config/exec_check_disable.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/simple_defines.h" "$(@D)/cuda/include/thrust/detail/config/simple_defines.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/config/global_workarounds.h" "$(@D)/cuda/include/thrust/detail/config/global_workarounds.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/replace.inl" "$(@D)/cuda/include/thrust/detail/replace.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/device_ptr.inl" "$(@D)/cuda/include/thrust/detail/device_ptr.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/tuple.inl" "$(@D)/cuda/include/thrust/detail/tuple.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/detail/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/host_vector.inl" "$(@D)/cuda/include/thrust/detail/host_vector.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/raw_pointer_cast.h" "$(@D)/cuda/include/thrust/detail/raw_pointer_cast.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/advance.inl" "$(@D)/cuda/include/thrust/detail/advance.inl" && cp "/usr/local/cuda-8.0/include/thrust/detail/contiguous_storage.h" "$(@D)/cuda/include/thrust/detail/contiguous_storage.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/raw_reference_cast.h" "$(@D)/cuda/include/thrust/detail/raw_reference_cast.h" && cp "/usr/local/cuda-8.0/include/thrust/detail/contiguous_storage.inl" "$(@D)/cuda/include/thrust/detail/contiguous_storage.inl" && cp "/usr/local/cuda-8.0/include/thrust/reverse.h" "$(@D)/cuda/include/thrust/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/device_malloc_allocator.h" "$(@D)/cuda/include/thrust/device_malloc_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/scatter.h" "$(@D)/cuda/include/thrust/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/pair.h" "$(@D)/cuda/include/thrust/pair.h" && cp "/usr/local/cuda-8.0/include/thrust/advance.h" "$(@D)/cuda/include/thrust/advance.h" && cp "/usr/local/cuda-8.0/include/thrust/find.h" "$(@D)/cuda/include/thrust/find.h" && cp "/usr/local/cuda-8.0/include/thrust/device_ptr.h" "$(@D)/cuda/include/thrust/device_ptr.h" && cp "/usr/local/cuda-8.0/include/thrust/generate.h" "$(@D)/cuda/include/thrust/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/uninitialized_fill.h" "$(@D)/cuda/include/thrust/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/system_error.h" "$(@D)/cuda/include/thrust/system/system_error.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/bad_alloc.h" "$(@D)/cuda/include/thrust/system/detail/bad_alloc.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/transform_scan.h" "$(@D)/cuda/include/thrust/system/detail/adl/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/unique_by_key.h" "$(@D)/cuda/include/thrust/system/detail/adl/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/partition.h" "$(@D)/cuda/include/thrust/system/detail/adl/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/unique.h" "$(@D)/cuda/include/thrust/system/detail/adl/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/detail/adl/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/sequence.h" "$(@D)/cuda/include/thrust/system/detail/adl/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/merge.h" "$(@D)/cuda/include/thrust/system/detail/adl/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/transform_reduce.h" "$(@D)/cuda/include/thrust/system/detail/adl/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/gather.h" "$(@D)/cuda/include/thrust/system/detail/adl/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/sort.h" "$(@D)/cuda/include/thrust/system/detail/adl/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/scan.h" "$(@D)/cuda/include/thrust/system/detail/adl/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/detail/adl/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/scan_by_key.h" "$(@D)/cuda/include/thrust/system/detail/adl/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/reverse.h" "$(@D)/cuda/include/thrust/system/detail/adl/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/assign_value.h" "$(@D)/cuda/include/thrust/system/detail/adl/assign_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/scatter.h" "$(@D)/cuda/include/thrust/system/detail/adl/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/find.h" "$(@D)/cuda/include/thrust/system/detail/adl/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/generate.h" "$(@D)/cuda/include/thrust/system/detail/adl/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/detail/adl/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/remove.h" "$(@D)/cuda/include/thrust/system/detail/adl/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/tabulate.h" "$(@D)/cuda/include/thrust/system/detail/adl/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/for_each.h" "$(@D)/cuda/include/thrust/system/detail/adl/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/detail/adl/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/reduce.h" "$(@D)/cuda/include/thrust/system/detail/adl/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/equal.h" "$(@D)/cuda/include/thrust/system/detail/adl/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/copy.h" "$(@D)/cuda/include/thrust/system/detail/adl/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/swap_ranges.h" "$(@D)/cuda/include/thrust/system/detail/adl/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/detail/adl/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/adl/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/set_operations.h" "$(@D)/cuda/include/thrust/system/detail/adl/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/mismatch.h" "$(@D)/cuda/include/thrust/system/detail/adl/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/extrema.h" "$(@D)/cuda/include/thrust/system/detail/adl/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/count.h" "$(@D)/cuda/include/thrust/system/detail/adl/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/replace.h" "$(@D)/cuda/include/thrust/system/detail/adl/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/get_value.h" "$(@D)/cuda/include/thrust/system/detail/adl/get_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/inner_product.h" "$(@D)/cuda/include/thrust/system/detail/adl/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/copy_if.h" "$(@D)/cuda/include/thrust/system/detail/adl/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/logical.h" "$(@D)/cuda/include/thrust/system/detail/adl/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/iter_swap.h" "$(@D)/cuda/include/thrust/system/detail/adl/iter_swap.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/detail/adl/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/fill.h" "$(@D)/cuda/include/thrust/system/detail/adl/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/adl/transform.h" "$(@D)/cuda/include/thrust/system/detail/adl/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/errno.h" "$(@D)/cuda/include/thrust/system/detail/errno.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/error_category.inl" "$(@D)/cuda/include/thrust/system/detail/error_category.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/transform_scan.h" "$(@D)/cuda/include/thrust/system/detail/sequential/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/unique_by_key.h" "$(@D)/cuda/include/thrust/system/detail/sequential/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/stable_primitive_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_primitive_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/stable_primitive_sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_primitive_sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/stable_merge_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_merge_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/partition.h" "$(@D)/cuda/include/thrust/system/detail/sequential/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/unique.h" "$(@D)/cuda/include/thrust/system/detail/sequential/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/execution_policy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/detail/sequential/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/sequence.h" "$(@D)/cuda/include/thrust/system/detail/sequential/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/merge.h" "$(@D)/cuda/include/thrust/system/detail/sequential/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/transform_reduce.h" "$(@D)/cuda/include/thrust/system/detail/sequential/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/gather.h" "$(@D)/cuda/include/thrust/system/detail/sequential/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/copy_backward.h" "$(@D)/cuda/include/thrust/system/detail/sequential/copy_backward.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/stable_radix_sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_radix_sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/scan.h" "$(@D)/cuda/include/thrust/system/detail/sequential/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/detail/sequential/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/scan_by_key.h" "$(@D)/cuda/include/thrust/system/detail/sequential/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/reverse.h" "$(@D)/cuda/include/thrust/system/detail/sequential/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/assign_value.h" "$(@D)/cuda/include/thrust/system/detail/sequential/assign_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/scatter.h" "$(@D)/cuda/include/thrust/system/detail/sequential/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/find.h" "$(@D)/cuda/include/thrust/system/detail/sequential/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/stable_merge_sort.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_merge_sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/merge.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/merge.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/generate.h" "$(@D)/cuda/include/thrust/system/detail/sequential/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/detail/sequential/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/general_copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/general_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/insertion_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/insertion_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/remove.h" "$(@D)/cuda/include/thrust/system/detail/sequential/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/tabulate.h" "$(@D)/cuda/include/thrust/system/detail/sequential/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/for_each.h" "$(@D)/cuda/include/thrust/system/detail/sequential/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/detail/sequential/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/reduce.h" "$(@D)/cuda/include/thrust/system/detail/sequential/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/equal.h" "$(@D)/cuda/include/thrust/system/detail/sequential/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/stable_radix_sort.h" "$(@D)/cuda/include/thrust/system/detail/sequential/stable_radix_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/copy.inl" "$(@D)/cuda/include/thrust/system/detail/sequential/copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/swap_ranges.h" "$(@D)/cuda/include/thrust/system/detail/sequential/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/sequential/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/set_operations.h" "$(@D)/cuda/include/thrust/system/detail/sequential/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/mismatch.h" "$(@D)/cuda/include/thrust/system/detail/sequential/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/extrema.h" "$(@D)/cuda/include/thrust/system/detail/sequential/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/count.h" "$(@D)/cuda/include/thrust/system/detail/sequential/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/trivial_copy.h" "$(@D)/cuda/include/thrust/system/detail/sequential/trivial_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/replace.h" "$(@D)/cuda/include/thrust/system/detail/sequential/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/get_value.h" "$(@D)/cuda/include/thrust/system/detail/sequential/get_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/inner_product.h" "$(@D)/cuda/include/thrust/system/detail/sequential/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/copy_if.h" "$(@D)/cuda/include/thrust/system/detail/sequential/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/logical.h" "$(@D)/cuda/include/thrust/system/detail/sequential/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/iter_swap.h" "$(@D)/cuda/include/thrust/system/detail/sequential/iter_swap.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/detail/sequential/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/fill.h" "$(@D)/cuda/include/thrust/system/detail/sequential/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/sequential/transform.h" "$(@D)/cuda/include/thrust/system/detail/sequential/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/error_condition.inl" "$(@D)/cuda/include/thrust/system/detail/error_condition.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/internal/decompose.h" "$(@D)/cuda/include/thrust/system/detail/internal/decompose.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/error_code.inl" "$(@D)/cuda/include/thrust/system/detail/error_code.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/transform_scan.h" "$(@D)/cuda/include/thrust/system/detail/generic/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/memory.inl" "$(@D)/cuda/include/thrust/system/detail/generic/memory.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/transform.inl" "$(@D)/cuda/include/thrust/system/detail/generic/transform.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/binary_search.inl" "$(@D)/cuda/include/thrust/system/detail/generic/binary_search.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scan_by_key.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scan_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/unique_by_key.h" "$(@D)/cuda/include/thrust/system/detail/generic/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/inner_product.inl" "$(@D)/cuda/include/thrust/system/detail/generic/inner_product.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/select_system.h" "$(@D)/cuda/include/thrust/system/detail/generic/select_system.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/sequence.inl" "$(@D)/cuda/include/thrust/system/detail/generic/sequence.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/sort.inl" "$(@D)/cuda/include/thrust/system/detail/generic/sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/equal.inl" "$(@D)/cuda/include/thrust/system/detail/generic/equal.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/partition.h" "$(@D)/cuda/include/thrust/system/detail/generic/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/unique.h" "$(@D)/cuda/include/thrust/system/detail/generic/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/detail/generic/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/tag.h" "$(@D)/cuda/include/thrust/system/detail/generic/tag.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/unique_by_key.inl" "$(@D)/cuda/include/thrust/system/detail/generic/unique_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/sequence.h" "$(@D)/cuda/include/thrust/system/detail/generic/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/type_traits.h" "$(@D)/cuda/include/thrust/system/detail/generic/type_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/merge.h" "$(@D)/cuda/include/thrust/system/detail/generic/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/reverse.inl" "$(@D)/cuda/include/thrust/system/detail/generic/reverse.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/tabulate.inl" "$(@D)/cuda/include/thrust/system/detail/generic/tabulate.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/unique.inl" "$(@D)/cuda/include/thrust/system/detail/generic/unique.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scatter.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scatter.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/set_operations.inl" "$(@D)/cuda/include/thrust/system/detail/generic/set_operations.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/copy_if.inl" "$(@D)/cuda/include/thrust/system/detail/generic/copy_if.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/transform_reduce.h" "$(@D)/cuda/include/thrust/system/detail/generic/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/transform_scan.inl" "$(@D)/cuda/include/thrust/system/detail/generic/transform_scan.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/gather.h" "$(@D)/cuda/include/thrust/system/detail/generic/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/detail/generic/reduce_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/transform_reduce.inl" "$(@D)/cuda/include/thrust/system/detail/generic/transform_reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/sort.h" "$(@D)/cuda/include/thrust/system/detail/generic/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/distance.inl" "$(@D)/cuda/include/thrust/system/detail/generic/distance.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scan.h" "$(@D)/cuda/include/thrust/system/detail/generic/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/detail/generic/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/reduce.inl" "$(@D)/cuda/include/thrust/system/detail/generic/reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scan_by_key.h" "$(@D)/cuda/include/thrust/system/detail/generic/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/reverse.h" "$(@D)/cuda/include/thrust/system/detail/generic/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/temporary_buffer.inl" "$(@D)/cuda/include/thrust/system/detail/generic/temporary_buffer.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scatter.h" "$(@D)/cuda/include/thrust/system/detail/generic/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/generate.inl" "$(@D)/cuda/include/thrust/system/detail/generic/generate.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/adjacent_difference.inl" "$(@D)/cuda/include/thrust/system/detail/generic/adjacent_difference.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/remove.inl" "$(@D)/cuda/include/thrust/system/detail/generic/remove.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/advance.h" "$(@D)/cuda/include/thrust/system/detail/generic/advance.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/find.h" "$(@D)/cuda/include/thrust/system/detail/generic/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/merge.inl" "$(@D)/cuda/include/thrust/system/detail/generic/merge.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scalar/binary_search.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scalar/binary_search.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scalar/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/generic/scalar/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/extrema.inl" "$(@D)/cuda/include/thrust/system/detail/generic/extrema.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/generate.h" "$(@D)/cuda/include/thrust/system/detail/generic/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/count.inl" "$(@D)/cuda/include/thrust/system/detail/generic/count.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/remove.h" "$(@D)/cuda/include/thrust/system/detail/generic/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/uninitialized_copy.inl" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/tabulate.h" "$(@D)/cuda/include/thrust/system/detail/generic/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/for_each.h" "$(@D)/cuda/include/thrust/system/detail/generic/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/distance.h" "$(@D)/cuda/include/thrust/system/detail/generic/distance.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/swap_ranges.inl" "$(@D)/cuda/include/thrust/system/detail/generic/swap_ranges.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/detail/generic/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/reduce.h" "$(@D)/cuda/include/thrust/system/detail/generic/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/equal.h" "$(@D)/cuda/include/thrust/system/detail/generic/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/mismatch.inl" "$(@D)/cuda/include/thrust/system/detail/generic/mismatch.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/copy.inl" "$(@D)/cuda/include/thrust/system/detail/generic/copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/copy.h" "$(@D)/cuda/include/thrust/system/detail/generic/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/swap_ranges.h" "$(@D)/cuda/include/thrust/system/detail/generic/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/binary_search.h" "$(@D)/cuda/include/thrust/system/detail/generic/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/set_operations.h" "$(@D)/cuda/include/thrust/system/detail/generic/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/uninitialized_fill.inl" "$(@D)/cuda/include/thrust/system/detail/generic/uninitialized_fill.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/mismatch.h" "$(@D)/cuda/include/thrust/system/detail/generic/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/scan.inl" "$(@D)/cuda/include/thrust/system/detail/generic/scan.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/gather.inl" "$(@D)/cuda/include/thrust/system/detail/generic/gather.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/extrema.h" "$(@D)/cuda/include/thrust/system/detail/generic/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/count.h" "$(@D)/cuda/include/thrust/system/detail/generic/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/replace.h" "$(@D)/cuda/include/thrust/system/detail/generic/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/inner_product.h" "$(@D)/cuda/include/thrust/system/detail/generic/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/copy_if.h" "$(@D)/cuda/include/thrust/system/detail/generic/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/logical.h" "$(@D)/cuda/include/thrust/system/detail/generic/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/partition.inl" "$(@D)/cuda/include/thrust/system/detail/generic/partition.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/memory.h" "$(@D)/cuda/include/thrust/system/detail/generic/memory.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/find.inl" "$(@D)/cuda/include/thrust/system/detail/generic/find.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/replace.inl" "$(@D)/cuda/include/thrust/system/detail/generic/replace.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/advance.inl" "$(@D)/cuda/include/thrust/system/detail/generic/advance.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/fill.h" "$(@D)/cuda/include/thrust/system/detail/generic/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/generic/transform.h" "$(@D)/cuda/include/thrust/system/detail/generic/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/detail/system_error.inl" "$(@D)/cuda/include/thrust/system/detail/system_error.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/execution_policy.h" "$(@D)/cuda/include/thrust/system/omp/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/vector.h" "$(@D)/cuda/include/thrust/system/omp/vector.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/omp/detail/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/memory.inl" "$(@D)/cuda/include/thrust/system/omp/detail/memory.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reduce_intervals.inl" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_intervals.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/omp/detail/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/sort.inl" "$(@D)/cuda/include/thrust/system/omp/detail/sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/partition.h" "$(@D)/cuda/include/thrust/system/omp/detail/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/unique.h" "$(@D)/cuda/include/thrust/system/omp/detail/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/omp/detail/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/omp/detail/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/unique_by_key.inl" "$(@D)/cuda/include/thrust/system/omp/detail/unique_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/sequence.h" "$(@D)/cuda/include/thrust/system/omp/detail/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/merge.h" "$(@D)/cuda/include/thrust/system/omp/detail/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/unique.inl" "$(@D)/cuda/include/thrust/system/omp/detail/unique.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/copy_if.inl" "$(@D)/cuda/include/thrust/system/omp/detail/copy_if.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/omp/detail/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/gather.h" "$(@D)/cuda/include/thrust/system/omp/detail/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/sort.h" "$(@D)/cuda/include/thrust/system/omp/detail/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/scan.h" "$(@D)/cuda/include/thrust/system/omp/detail/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/omp/detail/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/default_decomposition.h" "$(@D)/cuda/include/thrust/system/omp/detail/default_decomposition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reduce.inl" "$(@D)/cuda/include/thrust/system/omp/detail/reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/omp/detail/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reverse.h" "$(@D)/cuda/include/thrust/system/omp/detail/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/omp/detail/assign_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/scatter.h" "$(@D)/cuda/include/thrust/system/omp/detail/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/for_each.inl" "$(@D)/cuda/include/thrust/system/omp/detail/for_each.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/default_decomposition.inl" "$(@D)/cuda/include/thrust/system/omp/detail/default_decomposition.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/remove.inl" "$(@D)/cuda/include/thrust/system/omp/detail/remove.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/vector.inl" "$(@D)/cuda/include/thrust/system/omp/detail/vector.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/find.h" "$(@D)/cuda/include/thrust/system/omp/detail/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/generate.h" "$(@D)/cuda/include/thrust/system/omp/detail/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/omp/detail/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/remove.h" "$(@D)/cuda/include/thrust/system/omp/detail/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/omp/detail/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/for_each.h" "$(@D)/cuda/include/thrust/system/omp/detail/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reduce.h" "$(@D)/cuda/include/thrust/system/omp/detail/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/equal.h" "$(@D)/cuda/include/thrust/system/omp/detail/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/copy.inl" "$(@D)/cuda/include/thrust/system/omp/detail/copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/copy.h" "$(@D)/cuda/include/thrust/system/omp/detail/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/omp/detail/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/omp/detail/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/omp/detail/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/omp/detail/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/omp/detail/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/extrema.h" "$(@D)/cuda/include/thrust/system/omp/detail/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/count.h" "$(@D)/cuda/include/thrust/system/omp/detail/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/replace.h" "$(@D)/cuda/include/thrust/system/omp/detail/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/get_value.h" "$(@D)/cuda/include/thrust/system/omp/detail/get_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/omp/detail/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/omp/detail/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/logical.h" "$(@D)/cuda/include/thrust/system/omp/detail/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/partition.inl" "$(@D)/cuda/include/thrust/system/omp/detail/partition.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/omp/detail/iter_swap.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/par.h" "$(@D)/cuda/include/thrust/system/omp/detail/par.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/reduce_intervals.h" "$(@D)/cuda/include/thrust/system/omp/detail/reduce_intervals.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/omp/detail/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/fill.h" "$(@D)/cuda/include/thrust/system/omp/detail/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/detail/transform.h" "$(@D)/cuda/include/thrust/system/omp/detail/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/omp/memory.h" "$(@D)/cuda/include/thrust/system/omp/memory.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/execution_policy.h" "$(@D)/cuda/include/thrust/system/tbb/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/vector.h" "$(@D)/cuda/include/thrust/system/tbb/vector.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/tbb/detail/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/memory.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/memory.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/tbb/detail/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/sort.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/partition.h" "$(@D)/cuda/include/thrust/system/tbb/detail/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/unique.h" "$(@D)/cuda/include/thrust/system/tbb/detail/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/tbb/detail/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/tbb/detail/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/unique_by_key.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/unique_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/sequence.h" "$(@D)/cuda/include/thrust/system/tbb/detail/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/merge.h" "$(@D)/cuda/include/thrust/system/tbb/detail/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/unique.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/unique.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/copy_if.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/copy_if.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/tbb/detail/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/gather.h" "$(@D)/cuda/include/thrust/system/tbb/detail/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/sort.h" "$(@D)/cuda/include/thrust/system/tbb/detail/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/scan.h" "$(@D)/cuda/include/thrust/system/tbb/detail/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/tbb/detail/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/reduce.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/tbb/detail/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/reverse.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/tbb/detail/assign_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/scatter.h" "$(@D)/cuda/include/thrust/system/tbb/detail/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/for_each.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/for_each.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/remove.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/remove.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/vector.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/vector.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/find.h" "$(@D)/cuda/include/thrust/system/tbb/detail/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/merge.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/merge.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/generate.h" "$(@D)/cuda/include/thrust/system/tbb/detail/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/tbb/detail/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/remove.h" "$(@D)/cuda/include/thrust/system/tbb/detail/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/tbb/detail/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/for_each.h" "$(@D)/cuda/include/thrust/system/tbb/detail/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/reduce.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/equal.h" "$(@D)/cuda/include/thrust/system/tbb/detail/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/copy.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/copy.h" "$(@D)/cuda/include/thrust/system/tbb/detail/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/tbb/detail/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/tbb/detail/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/tbb/detail/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/tbb/detail/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/tbb/detail/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/scan.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/scan.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/extrema.h" "$(@D)/cuda/include/thrust/system/tbb/detail/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/count.h" "$(@D)/cuda/include/thrust/system/tbb/detail/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/replace.h" "$(@D)/cuda/include/thrust/system/tbb/detail/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/get_value.h" "$(@D)/cuda/include/thrust/system/tbb/detail/get_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/tbb/detail/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/tbb/detail/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/logical.h" "$(@D)/cuda/include/thrust/system/tbb/detail/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/partition.inl" "$(@D)/cuda/include/thrust/system/tbb/detail/partition.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/tbb/detail/iter_swap.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/par.h" "$(@D)/cuda/include/thrust/system/tbb/detail/par.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/reduce_intervals.h" "$(@D)/cuda/include/thrust/system/tbb/detail/reduce_intervals.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/tbb/detail/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/fill.h" "$(@D)/cuda/include/thrust/system/tbb/detail/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/detail/transform.h" "$(@D)/cuda/include/thrust/system/tbb/detail/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/tbb/memory.h" "$(@D)/cuda/include/thrust/system/tbb/memory.h" && cp "/usr/local/cuda-8.0/include/thrust/system/error_code.h" "$(@D)/cuda/include/thrust/system/error_code.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/execution_policy.h" "$(@D)/cuda/include/thrust/system/cpp/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/vector.h" "$(@D)/cuda/include/thrust/system/cpp/vector.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/cpp/detail/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/memory.inl" "$(@D)/cuda/include/thrust/system/cpp/detail/memory.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/cpp/detail/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/partition.h" "$(@D)/cuda/include/thrust/system/cpp/detail/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/unique.h" "$(@D)/cuda/include/thrust/system/cpp/detail/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/cpp/detail/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/cpp/detail/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/sequence.h" "$(@D)/cuda/include/thrust/system/cpp/detail/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/merge.h" "$(@D)/cuda/include/thrust/system/cpp/detail/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/cpp/detail/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/gather.h" "$(@D)/cuda/include/thrust/system/cpp/detail/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/sort.h" "$(@D)/cuda/include/thrust/system/cpp/detail/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/scan.h" "$(@D)/cuda/include/thrust/system/cpp/detail/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/cpp/detail/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/cpp/detail/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/reverse.h" "$(@D)/cuda/include/thrust/system/cpp/detail/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/cpp/detail/assign_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/scatter.h" "$(@D)/cuda/include/thrust/system/cpp/detail/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/vector.inl" "$(@D)/cuda/include/thrust/system/cpp/detail/vector.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/find.h" "$(@D)/cuda/include/thrust/system/cpp/detail/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/generate.h" "$(@D)/cuda/include/thrust/system/cpp/detail/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/cpp/detail/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/remove.h" "$(@D)/cuda/include/thrust/system/cpp/detail/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/cpp/detail/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/for_each.h" "$(@D)/cuda/include/thrust/system/cpp/detail/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/cpp/detail/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/reduce.h" "$(@D)/cuda/include/thrust/system/cpp/detail/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/equal.h" "$(@D)/cuda/include/thrust/system/cpp/detail/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/copy.h" "$(@D)/cuda/include/thrust/system/cpp/detail/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/cpp/detail/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/cpp/detail/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/cpp/detail/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/cpp/detail/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/cpp/detail/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/extrema.h" "$(@D)/cuda/include/thrust/system/cpp/detail/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/count.h" "$(@D)/cuda/include/thrust/system/cpp/detail/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/replace.h" "$(@D)/cuda/include/thrust/system/cpp/detail/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/get_value.h" "$(@D)/cuda/include/thrust/system/cpp/detail/get_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/cpp/detail/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/cpp/detail/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/logical.h" "$(@D)/cuda/include/thrust/system/cpp/detail/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/cpp/detail/iter_swap.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/par.h" "$(@D)/cuda/include/thrust/system/cpp/detail/par.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/cpp/detail/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/fill.h" "$(@D)/cuda/include/thrust/system/cpp/detail/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/detail/transform.h" "$(@D)/cuda/include/thrust/system/cpp/detail/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cpp/memory.h" "$(@D)/cuda/include/thrust/system/cpp/memory.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/execution_policy.h" "$(@D)/cuda/include/thrust/system/cuda/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/vector.h" "$(@D)/cuda/include/thrust/system/cuda/vector.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/error.h" "$(@D)/cuda/include/thrust/system/cuda/error.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy_device_to_device.h" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_device_to_device.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/transform_scan.h" "$(@D)/cuda/include/thrust/system/cuda/detail/transform_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/memory.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/memory.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_allocator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_allocator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/grid/grid_mapping.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_mapping.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/grid/grid_barrier.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_barrier.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/grid/grid_even_share.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_even_share.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/grid/grid_queue.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/grid/grid_queue.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_device.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_device.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_run_length_encode.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_run_length_encode.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_partition.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_partition.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_radix_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_radix_sort.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_rle_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_rle_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_histogram_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_histogram_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_reduce_by_key_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_reduce_by_key_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_scan_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_scan_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_select_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_select_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_reduce_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_reduce_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/dispatch/device_radix_sort_dispatch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/dispatch/device_radix_sort_dispatch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_scan.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_select.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_select.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_reduce.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/device/device_histogram.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/device/device_histogram.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_reduce.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_histo.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_histo.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_scan.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_radix_sort_downsweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_radix_sort_downsweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_radix_sort_upsweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_radix_sort_upsweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_satomic.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_satomic.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_sort.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_gatomic.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/specializations/block_range_histo_gatomic.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_select.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_select.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_scan_prefix_operators.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_scan_prefix_operators.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_range/block_range_reduce_by_key.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_range/block_range_reduce_by_key.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_macro.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_macro.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_namespace.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_namespace.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_radix_sort_upsweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_radix_sort_upsweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_histogram_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_histogram_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_rle_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_rle_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_select_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_select_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_scan_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_scan_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_reduce_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_reduce_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_satomic_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_satomic_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_sort_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_sort_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_gatomic_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/specializations/block_histogram_gatomic_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_radix_sort_downsweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_radix_sort_downsweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_reduce_by_key_sweep.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_reduce_by_key_sweep.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block_sweep/block_scan_prefix_operators.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block_sweep/block_scan_prefix_operators.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_type.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_type.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/host/spinlock.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/host/spinlock.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/warp/warp_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/warp_reduce.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/warp/warp_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/warp_scan.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_shfl.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_shfl.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_smem.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_reduce_smem.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_shfl.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_shfl.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_smem.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/warp/specializations/warp_scan_smem.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_ptx.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_ptx.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_debug.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_debug.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/cub.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/cub.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/transform_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/transform_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/tex_obj_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/tex_obj_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/tex_ref_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/tex_ref_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/cache_modified_output_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_output_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/counting_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/counting_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/cache_modified_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/cache_modified_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/arg_index_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/arg_index_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/iterator/constant_input_iterator.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/iterator/constant_input_iterator.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_scan.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_load.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_load.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_discontinuity.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_discontinuity.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_radix_rank.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_radix_rank.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_shift.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_shift.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_store.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_store.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_reduce.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_exchange.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_exchange.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_radix_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_radix_sort.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_histogram.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_histogram.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/block_raking_layout.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/block_raking_layout.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_warp_reductions.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_warp_reductions.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_warp_scans.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking_commutative_only.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking_commutative_only.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_atomic.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_atomic.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_raking.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_scan_raking.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_sort.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_histogram_sort.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/block/specializations/block_reduce_raking.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/thread/thread_load.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_load.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/thread/thread_store.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_store.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/thread/thread_scan.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_scan.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/thread/thread_operators.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_operators.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/thread/thread_reduce.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/thread/thread_reduce.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub/util_arch.cuh" "$(@D)/cuda/include/thrust/system/cuda/detail/cub/util_arch.cuh" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce_intervals.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce_intervals.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy_cross_system.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_cross_system.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/unique_by_key.h" "$(@D)/cuda/include/thrust/system/cuda/detail/unique_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk.h" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/sort.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/partition.h" "$(@D)/cuda/include/thrust/system/cuda/detail/partition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/unique.h" "$(@D)/cuda/include/thrust/system/cuda/detail/unique.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/execution_policy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/execution_policy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cuda_launch_config.h" "$(@D)/cuda/include/thrust/system/cuda/detail/cuda_launch_config.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/cub.h" "$(@D)/cuda/include/thrust/system/cuda/detail/cub.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/adjacent_difference.h" "$(@D)/cuda/include/thrust/system/cuda/detail/adjacent_difference.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/sequence.h" "$(@D)/cuda/include/thrust/system/cuda/detail/sequence.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/merge.h" "$(@D)/cuda/include/thrust/system/cuda/detail/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/set_symmetric_difference.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/set_symmetric_difference.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy_if.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_if.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/transform_reduce.h" "$(@D)/cuda/include/thrust/system/cuda/detail/transform_reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/error.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/error.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/gather.h" "$(@D)/cuda/include/thrust/system/cuda/detail/gather.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce_by_key.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce_by_key.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/synchronize.h" "$(@D)/cuda/include/thrust/system/cuda/detail/synchronize.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/scan.h" "$(@D)/cuda/include/thrust/system/cuda/detail/scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/temporary_indirect_permutation.h" "$(@D)/cuda/include/thrust/system/cuda/detail/temporary_indirect_permutation.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/extern_shared_ptr.h" "$(@D)/cuda/include/thrust/system/cuda/detail/extern_shared_ptr.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/set_operation.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/set_operation.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/balanced_path.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/balanced_path.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/virtualized_smem_closure.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/virtualized_smem_closure.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_primitive_sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_primitive_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/set_operation.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/set_operation.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_primitive_sort.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_primitive_sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_merge_sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_merge_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/launch_closure.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/launch_closure.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/merge.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/alignment.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/alignment.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_radix_sort.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_radix_sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_sort_each.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_sort_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/launch_calculator.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/launch_calculator.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_merge_sort.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_merge_sort.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/launch_closure.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/launch_closure.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_radix_sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_radix_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/uninitialized.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/uninitialized.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/cached_temporary_allocator.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/cached_temporary_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/launch_calculator.h" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/launch_calculator.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/detail/stable_sort_each.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/detail/stable_sort_each.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/temporary_buffer.h" "$(@D)/cuda/include/thrust/system/cuda/detail/temporary_buffer.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/default_decomposition.h" "$(@D)/cuda/include/thrust/system/cuda/detail/default_decomposition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/scan_by_key.h" "$(@D)/cuda/include/thrust/system/cuda/detail/scan_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reverse.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reverse.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/assign_value.h" "$(@D)/cuda/include/thrust/system/cuda/detail/assign_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/scatter.h" "$(@D)/cuda/include/thrust/system/cuda/detail/scatter.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce_intervals.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce_intervals.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/for_each.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/for_each.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/default_decomposition.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/default_decomposition.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/guarded_cuda_runtime_api.h" "$(@D)/cuda/include/thrust/system/cuda/detail/guarded_cuda_runtime_api.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/adjacent_difference.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/adjacent_difference.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/vector.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/vector.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/throw_on_error.h" "$(@D)/cuda/include/thrust/system/cuda/detail/throw_on_error.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/find.h" "$(@D)/cuda/include/thrust/system/cuda/detail/find.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/terminate.h" "$(@D)/cuda/include/thrust/system/cuda/detail/terminate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/merge.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/merge.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/trivial_copy.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/trivial_copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/generate.h" "$(@D)/cuda/include/thrust/system/cuda/detail/generate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/execute_on_stream.h" "$(@D)/cuda/include/thrust/system/cuda/detail/execute_on_stream.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/uninitialized_fill.h" "$(@D)/cuda/include/thrust/system/cuda/detail/uninitialized_fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/remove.h" "$(@D)/cuda/include/thrust/system/cuda/detail/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/tabulate.h" "$(@D)/cuda/include/thrust/system/cuda/detail/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/for_each.h" "$(@D)/cuda/include/thrust/system/cuda/detail/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce_by_key.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce_by_key.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/decomposition.h" "$(@D)/cuda/include/thrust/system/cuda/detail/decomposition.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/equal.h" "$(@D)/cuda/include/thrust/system/cuda/detail/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/runtime_introspection.h" "$(@D)/cuda/include/thrust/system/cuda/detail/runtime_introspection.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/copy.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/swap_ranges.h" "$(@D)/cuda/include/thrust/system/cuda/detail/swap_ranges.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/uninitialized_copy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/binary_search.h" "$(@D)/cuda/include/thrust/system/cuda/detail/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/runtime_introspection.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/runtime_introspection.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/set_operations.h" "$(@D)/cuda/include/thrust/system/cuda/detail/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/mismatch.h" "$(@D)/cuda/include/thrust/system/cuda/detail/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/scan.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/scan.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/synchronize.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/synchronize.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/extrema.h" "$(@D)/cuda/include/thrust/system/cuda/detail/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/set_union.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/set_union.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/set_intersection.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/set_intersection.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/count.h" "$(@D)/cuda/include/thrust/system/cuda/detail/count.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/trivial_copy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/trivial_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy_device_to_device.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_device_to_device.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/replace.h" "$(@D)/cuda/include/thrust/system/cuda/detail/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/malloc.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/malloc.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/config.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/config.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/closure.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/closure.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/tail_flags.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/tail_flags.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/terminate.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/terminate.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/alignment.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/alignment.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/guarded_cuda_runtime_api.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/guarded_cuda_runtime_api.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/choose_sizes.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/choose_sizes.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/tuple_meta_transform.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/tuple_meta_transform.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_task.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_task.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/head_flags.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/head_flags.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/synchronize.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/synchronize.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/throw_on_error.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/throw_on_error.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/parameter_ptr.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/parameter_ptr.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/cuda_launcher.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/cuda_launcher.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/triple_chevron_launcher.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/triple_chevron_launcher.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/runtime_introspection.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/runtime_introspection.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/cuda_launch_config.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/cuda_launch_config.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/runtime_introspection.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/cuda_launcher/runtime_introspection.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/async.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/async.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/tuple_transform.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/tuple_transform.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/pointer_traits.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/pointer_traits.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/apply_from_tuple.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/apply_from_tuple.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/detail/is_contiguous_iterator.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/detail/is_contiguous_iterator.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/iterator.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/iterator.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/choose_sizes.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/choose_sizes.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/copy.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/copy.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/merge.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/merge.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/accumulate.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/accumulate.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/scan.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/scan.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/detail/stable_merge_sort.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/detail/stable_merge_sort.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/gather.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/gather.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/sort.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/sort.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/reduce.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/reduce.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/scatter.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/scatter.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/adjacent_difference.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/adjacent_difference.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/reduce_by_key.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/reduce_by_key.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/algorithm/for_each.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/algorithm/for_each.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/bulk.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/bulk.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/execution_policy.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/execution_policy.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/iterator/strided_iterator.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/iterator/strided_iterator.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/uninitialized.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/uninitialized.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/async.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/async.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/bulk/future.hpp" "$(@D)/cuda/include/thrust/system/cuda/detail/bulk/future.hpp" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/guarded_driver_types.h" "$(@D)/cuda/include/thrust/system/cuda/detail/guarded_driver_types.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/get_value.h" "$(@D)/cuda/include/thrust/system/cuda/detail/get_value.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/inner_product.h" "$(@D)/cuda/include/thrust/system/cuda/detail/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy_if.h" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_if.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/logical.h" "$(@D)/cuda/include/thrust/system/cuda/detail/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/iter_swap.h" "$(@D)/cuda/include/thrust/system/cuda/detail/iter_swap.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/merge.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/merge.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/inclusive_scan.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/inclusive_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/merge.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/block/merge.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/merging_sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/merging_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/exclusive_scan.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/exclusive_scan.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/reduce.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/copy.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/block/odd_even_sort.h" "$(@D)/cuda/include/thrust/system/cuda/detail/block/odd_even_sort.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/par.h" "$(@D)/cuda/include/thrust/system/cuda/detail/par.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/copy_cross_system.h" "$(@D)/cuda/include/thrust/system/cuda/detail/copy_cross_system.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/reduce_intervals.h" "$(@D)/cuda/include/thrust/system/cuda/detail/reduce_intervals.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/malloc_and_free.h" "$(@D)/cuda/include/thrust/system/cuda/detail/malloc_and_free.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/fill.h" "$(@D)/cuda/include/thrust/system/cuda/detail/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/set_difference.inl" "$(@D)/cuda/include/thrust/system/cuda/detail/set_difference.inl" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/detail/transform.h" "$(@D)/cuda/include/thrust/system/cuda/detail/transform.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/experimental/pinned_allocator.h" "$(@D)/cuda/include/thrust/system/cuda/experimental/pinned_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/system/cuda/memory.h" "$(@D)/cuda/include/thrust/system/cuda/memory.h" && cp "/usr/local/cuda-8.0/include/thrust/remove.h" "$(@D)/cuda/include/thrust/remove.h" && cp "/usr/local/cuda-8.0/include/thrust/tabulate.h" "$(@D)/cuda/include/thrust/tabulate.h" && cp "/usr/local/cuda-8.0/include/thrust/for_each.h" "$(@D)/cuda/include/thrust/for_each.h" && cp "/usr/local/cuda-8.0/include/thrust/distance.h" "$(@D)/cuda/include/thrust/distance.h" && cp "/usr/local/cuda-8.0/include/thrust/reduce.h" "$(@D)/cuda/include/thrust/reduce.h" && cp "/usr/local/cuda-8.0/include/thrust/equal.h" "$(@D)/cuda/include/thrust/equal.h" && cp "/usr/local/cuda-8.0/include/thrust/complex.h" "$(@D)/cuda/include/thrust/complex.h" && cp "/usr/local/cuda-8.0/include/thrust/device_allocator.h" "$(@D)/cuda/include/thrust/device_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/copy.h" "$(@D)/cuda/include/thrust/copy.h" && cp "/usr/local/cuda-8.0/include/thrust/uninitialized_copy.h" "$(@D)/cuda/include/thrust/uninitialized_copy.h" && cp "/usr/local/cuda-8.0/include/thrust/device_reference.h" "$(@D)/cuda/include/thrust/device_reference.h" && cp "/usr/local/cuda-8.0/include/thrust/binary_search.h" "$(@D)/cuda/include/thrust/binary_search.h" && cp "/usr/local/cuda-8.0/include/thrust/set_operations.h" "$(@D)/cuda/include/thrust/set_operations.h" && cp "/usr/local/cuda-8.0/include/thrust/swap.h" "$(@D)/cuda/include/thrust/swap.h" && cp "/usr/local/cuda-8.0/include/thrust/mismatch.h" "$(@D)/cuda/include/thrust/mismatch.h" && cp "/usr/local/cuda-8.0/include/thrust/extrema.h" "$(@D)/cuda/include/thrust/extrema.h" && cp "/usr/local/cuda-8.0/include/thrust/count.h" "$(@D)/cuda/include/thrust/count.h" && cp "/usr/local/cuda-8.0/include/thrust/device_free.h" "$(@D)/cuda/include/thrust/device_free.h" && cp "/usr/local/cuda-8.0/include/thrust/random/discard_block_engine.h" "$(@D)/cuda/include/thrust/random/discard_block_engine.h" && cp "/usr/local/cuda-8.0/include/thrust/random/normal_distribution.h" "$(@D)/cuda/include/thrust/random/normal_distribution.h" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/linear_feedback_shift_engine_wordmask.h" "$(@D)/cuda/include/thrust/random/detail/linear_feedback_shift_engine_wordmask.h" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/subtract_with_carry_engine.inl" "$(@D)/cuda/include/thrust/random/detail/subtract_with_carry_engine.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/xor_combine_engine_max.h" "$(@D)/cuda/include/thrust/random/detail/xor_combine_engine_max.h" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/linear_congruential_engine_discard.h" "$(@D)/cuda/include/thrust/random/detail/linear_congruential_engine_discard.h" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/uniform_int_distribution.inl" "$(@D)/cuda/include/thrust/random/detail/uniform_int_distribution.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/discard_block_engine.inl" "$(@D)/cuda/include/thrust/random/detail/discard_block_engine.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/uniform_real_distribution.inl" "$(@D)/cuda/include/thrust/random/detail/uniform_real_distribution.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/random_core_access.h" "$(@D)/cuda/include/thrust/random/detail/random_core_access.h" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/mod.h" "$(@D)/cuda/include/thrust/random/detail/mod.h" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/linear_feedback_shift_engine.inl" "$(@D)/cuda/include/thrust/random/detail/linear_feedback_shift_engine.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/linear_congruential_engine.inl" "$(@D)/cuda/include/thrust/random/detail/linear_congruential_engine.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/xor_combine_engine.inl" "$(@D)/cuda/include/thrust/random/detail/xor_combine_engine.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/normal_distribution.inl" "$(@D)/cuda/include/thrust/random/detail/normal_distribution.inl" && cp "/usr/local/cuda-8.0/include/thrust/random/detail/normal_distribution_base.h" "$(@D)/cuda/include/thrust/random/detail/normal_distribution_base.h" && cp "/usr/local/cuda-8.0/include/thrust/random/uniform_int_distribution.h" "$(@D)/cuda/include/thrust/random/uniform_int_distribution.h" && cp "/usr/local/cuda-8.0/include/thrust/random/linear_feedback_shift_engine.h" "$(@D)/cuda/include/thrust/random/linear_feedback_shift_engine.h" && cp "/usr/local/cuda-8.0/include/thrust/random/xor_combine_engine.h" "$(@D)/cuda/include/thrust/random/xor_combine_engine.h" && cp "/usr/local/cuda-8.0/include/thrust/random/subtract_with_carry_engine.h" "$(@D)/cuda/include/thrust/random/subtract_with_carry_engine.h" && cp "/usr/local/cuda-8.0/include/thrust/random/linear_congruential_engine.h" "$(@D)/cuda/include/thrust/random/linear_congruential_engine.h" && cp "/usr/local/cuda-8.0/include/thrust/random/uniform_real_distribution.h" "$(@D)/cuda/include/thrust/random/uniform_real_distribution.h" && cp "/usr/local/cuda-8.0/include/thrust/functional.h" "$(@D)/cuda/include/thrust/functional.h" && cp "/usr/local/cuda-8.0/include/thrust/replace.h" "$(@D)/cuda/include/thrust/replace.h" && cp "/usr/local/cuda-8.0/include/thrust/device_new_allocator.h" "$(@D)/cuda/include/thrust/device_new_allocator.h" && cp "/usr/local/cuda-8.0/include/thrust/host_vector.h" "$(@D)/cuda/include/thrust/host_vector.h" && cp "/usr/local/cuda-8.0/include/thrust/version.h" "$(@D)/cuda/include/thrust/version.h" && cp "/usr/local/cuda-8.0/include/thrust/inner_product.h" "$(@D)/cuda/include/thrust/inner_product.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/iterator_traits.h" "$(@D)/cuda/include/thrust/iterator/iterator_traits.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/discard_iterator.h" "$(@D)/cuda/include/thrust/iterator/discard_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/retag.h" "$(@D)/cuda/include/thrust/iterator/retag.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/permutation_iterator.h" "$(@D)/cuda/include/thrust/iterator/permutation_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/transform_iterator.h" "$(@D)/cuda/include/thrust/iterator/transform_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/reverse_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/reverse_iterator.inl" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/zip_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/zip_iterator.inl" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/counting_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/counting_iterator.inl" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/distance_from_result.h" "$(@D)/cuda/include/thrust/iterator/detail/distance_from_result.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/host_system_tag.h" "$(@D)/cuda/include/thrust/iterator/detail/host_system_tag.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_traversal_tags.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_traversal_tags.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/retag.h" "$(@D)/cuda/include/thrust/iterator/detail/retag.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/tagged_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/tagged_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_traits.inl" "$(@D)/cuda/include/thrust/iterator/detail/iterator_traits.inl" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/minimum_category.h" "$(@D)/cuda/include/thrust/iterator/detail/minimum_category.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/discard_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/discard_iterator_base.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_category_to_traversal.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_category_to_traversal.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/zip_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/zip_iterator_base.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/normal_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/normal_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/join_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/join_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/device_system_tag.h" "$(@D)/cuda/include/thrust/iterator/detail/device_system_tag.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/universal_categories.h" "$(@D)/cuda/include/thrust/iterator/detail/universal_categories.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/reverse_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/reverse_iterator_base.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/minimum_system.h" "$(@D)/cuda/include/thrust/iterator/detail/minimum_system.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/tuple_of_iterator_references.h" "$(@D)/cuda/include/thrust/iterator/detail/tuple_of_iterator_references.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/is_iterator_category.h" "$(@D)/cuda/include/thrust/iterator/detail/is_iterator_category.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/permutation_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/permutation_iterator_base.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/any_assign.h" "$(@D)/cuda/include/thrust/iterator/detail/any_assign.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/any_system_tag.h" "$(@D)/cuda/include/thrust/iterator/detail/any_system_tag.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/is_trivial_iterator.h" "$(@D)/cuda/include/thrust/iterator/detail/is_trivial_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_category_to_system.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_category_to_system.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_adaptor_base.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_adaptor_base.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/constant_iterator_base.h" "$(@D)/cuda/include/thrust/iterator/detail/constant_iterator_base.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/transform_iterator.inl" "$(@D)/cuda/include/thrust/iterator/detail/transform_iterator.inl" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_facade_category.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_facade_category.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/detail/iterator_category_with_system_and_traversal.h" "$(@D)/cuda/include/thrust/iterator/detail/iterator_category_with_system_and_traversal.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/constant_iterator.h" "$(@D)/cuda/include/thrust/iterator/constant_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/counting_iterator.h" "$(@D)/cuda/include/thrust/iterator/counting_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/iterator_adaptor.h" "$(@D)/cuda/include/thrust/iterator/iterator_adaptor.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/iterator_facade.h" "$(@D)/cuda/include/thrust/iterator/iterator_facade.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/iterator_categories.h" "$(@D)/cuda/include/thrust/iterator/iterator_categories.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/reverse_iterator.h" "$(@D)/cuda/include/thrust/iterator/reverse_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/iterator/zip_iterator.h" "$(@D)/cuda/include/thrust/iterator/zip_iterator.h" && cp "/usr/local/cuda-8.0/include/thrust/logical.h" "$(@D)/cuda/include/thrust/logical.h" && cp "/usr/local/cuda-8.0/include/thrust/tuple.h" "$(@D)/cuda/include/thrust/tuple.h" && cp "/usr/local/cuda-8.0/include/thrust/memory.h" "$(@D)/cuda/include/thrust/memory.h" && cp "/usr/local/cuda-8.0/include/thrust/random.h" "$(@D)/cuda/include/thrust/random.h" && cp "/usr/local/cuda-8.0/include/thrust/fill.h" "$(@D)/cuda/include/thrust/fill.h" && cp "/usr/local/cuda-8.0/include/thrust/transform.h" "$(@D)/cuda/include/thrust/transform.h" && cp "/usr/local/cuda-8.0/include/texture_types.h" "$(@D)/cuda/include/texture_types.h" && cp "/usr/local/cuda-8.0/include/nppversion.h" "$(@D)/cuda/include/nppversion.h" && cp "/usr/local/cuda-8.0/include/cuda_texture_types.h" "$(@D)/cuda/include/cuda_texture_types.h" && cp "/usr/local/cuda-8.0/include/fatbinary.h" "$(@D)/cuda/include/fatbinary.h" && cp "/usr/local/cuda-8.0/include/cublasXt.h" "$(@D)/cuda/include/cublasXt.h" && cp "/usr/local/cuda-8.0/include/cuda_fp16.h" "$(@D)/cuda/include/cuda_fp16.h" && cp "/usr/local/cuda-8.0/include/vector_functions.h" "$(@D)/cuda/include/vector_functions.h" && cp "/usr/local/cuda-8.0/include/cusparse.h" "$(@D)/cuda/include/cusparse.h" && cp "/usr/local/cuda-8.0/include/nppi_filtering_functions.h" "$(@D)/cuda/include/nppi_filtering_functions.h" && cp "/usr/local/cuda-8.0/include/nppi_morphological_operations.h" "$(@D)/cuda/include/nppi_morphological_operations.h" && cp "/usr/local/cuda-8.0/include/sobol_direction_vectors.h" "$(@D)/cuda/include/sobol_direction_vectors.h" && cp "/usr/local/cuda-8.0/include/nvblas.h" "$(@D)/cuda/include/nvblas.h" && cp "/usr/local/cuda-8.0/include/curand_mtgp32dc_p_11213.h" "$(@D)/cuda/include/curand_mtgp32dc_p_11213.h" && cp "/usr/local/cuda-8.0/include/nvcuvid.h" "$(@D)/cuda/include/nvcuvid.h" && cp "/usr/local/cuda-8.0/include/cuda_runtime_api.h" "$(@D)/cuda/include/cuda_runtime_api.h" && cp "/usr/local/cuda-8.0/include/curand_mtgp32_kernel.h" "$(@D)/cuda/include/curand_mtgp32_kernel.h" && cp "/usr/local/cuda-8.0/include/cublas_v2.h" "$(@D)/cuda/include/cublas_v2.h" && cp "/usr/local/cuda-8.0/include/builtin_types.h" "$(@D)/cuda/include/builtin_types.h" && cp "/usr/local/cuda-8.0/include/nppi_geometry_transforms.h" "$(@D)/cuda/include/nppi_geometry_transforms.h" && cp "/usr/local/cuda-8.0/include/npps_support_functions.h" "$(@D)/cuda/include/npps_support_functions.h" && cp "/usr/local/cuda-8.0/include/cufftw.h" "$(@D)/cuda/include/cufftw.h" && cp "/usr/local/cuda-8.0/include/cuda_device_runtime_api.h" "$(@D)/cuda/include/cuda_device_runtime_api.h" && cp "/usr/local/cuda-8.0/include/sm_30_intrinsics.hpp" "$(@D)/cuda/include/sm_30_intrinsics.hpp" && cp "/usr/local/cuda-8.0/include/vector_types.h" "$(@D)/cuda/include/vector_types.h" && cp "/usr/local/cuda-8.0/include/sm_35_atomic_functions.h" "$(@D)/cuda/include/sm_35_atomic_functions.h" && cp "/usr/local/cuda-8.0/include/sm_20_intrinsics.h" "$(@D)/cuda/include/sm_20_intrinsics.h" && cp "/usr/local/cuda-8.0/include/driver_types.h" "$(@D)/cuda/include/driver_types.h" && cp "/usr/local/cuda-8.0/include/nvToolsExtCudaRt.h" "$(@D)/cuda/include/nvToolsExtCudaRt.h" && cp "/usr/local/cuda-8.0/include/curand_globals.h" "$(@D)/cuda/include/curand_globals.h" && cp "/usr/local/cuda-8.0/include/device_atomic_functions.h" "$(@D)/cuda/include/device_atomic_functions.h" && cp "/usr/local/cuda-8.0/include/surface_types.h" "$(@D)/cuda/include/surface_types.h" && cp "/usr/local/cuda-8.0/include/nvrtc.h" "$(@D)/cuda/include/nvrtc.h" && cp "/usr/local/cuda-8.0/include/nppdefs.h" "$(@D)/cuda/include/nppdefs.h" && cp "/usr/local/cuda-8.0/include/sm_60_atomic_functions.h" "$(@D)/cuda/include/sm_60_atomic_functions.h" && cp "/usr/local/cuda-8.0/include/driver_functions.h" "$(@D)/cuda/include/driver_functions.h" && cp "/usr/local/cuda-8.0/include/cusolver_common.h" "$(@D)/cuda/include/cusolver_common.h" && cp "/usr/local/cuda-8.0/include/cublas.h" "$(@D)/cuda/include/cublas.h" && cp "/usr/local/cuda-8.0/include/curand_lognormal.h" "$(@D)/cuda/include/curand_lognormal.h" && cp "/usr/local/cuda-8.0/include/device_atomic_functions.hpp" "$(@D)/cuda/include/device_atomic_functions.hpp" && cp "/usr/local/cuda-8.0/include/crt/device_runtime.h" "$(@D)/cuda/include/crt/device_runtime.h" && cp "/usr/local/cuda-8.0/include/crt/storage_class.h" "$(@D)/cuda/include/crt/storage_class.h" && cp "/usr/local/cuda-8.0/include/crt/func_macro.h" "$(@D)/cuda/include/crt/func_macro.h" && cp "/usr/local/cuda-8.0/include/crt/host_runtime.h" "$(@D)/cuda/include/crt/host_runtime.h" && cp "/usr/local/cuda-8.0/include/nppi_arithmetic_and_logical_operations.h" "$(@D)/cuda/include/nppi_arithmetic_and_logical_operations.h" && cp "/usr/local/cuda-8.0/include/npps_arithmetic_and_logical_operations.h" "$(@D)/cuda/include/npps_arithmetic_and_logical_operations.h" && cp "/usr/local/cuda-8.0/include/nppi_computer_vision.h" "$(@D)/cuda/include/nppi_computer_vision.h" && cp "/usr/local/cuda-8.0/include/surface_functions.hpp" "$(@D)/cuda/include/surface_functions.hpp" && cp "/usr/local/cuda-8.0/include/surface_functions.h" "$(@D)/cuda/include/surface_functions.h" && cp "/usr/local/cuda-8.0/include/curand_normal_static.h" "$(@D)/cuda/include/curand_normal_static.h" && cp "/usr/local/cuda-8.0/include/curand.h" "$(@D)/cuda/include/curand.h" && cp "/usr/local/cuda-8.0/include/math_functions_dbl_ptx3.h" "$(@D)/cuda/include/math_functions_dbl_ptx3.h" && cp "/usr/local/cuda-8.0/include/curand_philox4x32_x.h" "$(@D)/cuda/include/curand_philox4x32_x.h" && cp "/usr/local/cuda-8.0/include/nppi_threshold_and_compare_operations.h" "$(@D)/cuda/include/nppi_threshold_and_compare_operations.h" && cp "/usr/local/cuda-8.0/include/nvml.h" "$(@D)/cuda/include/nvml.h" && cp "/usr/local/cuda-8.0/include/npps.h" "$(@D)/cuda/include/npps.h" && cp "/usr/local/cuda-8.0/include/cuda_vdpau_interop.h" "$(@D)/cuda/include/cuda_vdpau_interop.h" && cp "/usr/local/cuda-8.0/include/sm_61_intrinsics.hpp" "$(@D)/cuda/include/sm_61_intrinsics.hpp" && cp "/usr/local/cuda-8.0/include/cublas_api.h" "$(@D)/cuda/include/cublas_api.h" && cp "/usr/local/cuda-8.0/include/nppi_color_conversion.h" "$(@D)/cuda/include/nppi_color_conversion.h" && cp "/usr/local/cuda-8.0/include/math_functions_dbl_ptx3.hpp" "$(@D)/cuda/include/math_functions_dbl_ptx3.hpp" && cp "/usr/local/cuda-8.0/include/nppcore.h" "$(@D)/cuda/include/nppcore.h" && cp "/usr/local/cuda-8.0/include/cudaGL.h" "$(@D)/cuda/include/cudaGL.h" && cp "/usr/local/cuda-8.0/include/fatBinaryCtl.h" "$(@D)/cuda/include/fatBinaryCtl.h" && cp "/usr/local/cuda-8.0/include/npps_statistics_functions.h" "$(@D)/cuda/include/npps_statistics_functions.h" && cp "/usr/local/cuda-8.0/include/cudaVDPAU.h" "$(@D)/cuda/include/cudaVDPAU.h" && cp "/usr/local/cuda-8.0/include/curand_poisson.h" "$(@D)/cuda/include/curand_poisson.h" && cp "/usr/local/cuda-8.0/include/cusolverDn.h" "$(@D)/cuda/include/cusolverDn.h" && cp "/usr/local/cuda-8.0/include/cuda_profiler_api.h" "$(@D)/cuda/include/cuda_profiler_api.h" && cp "/usr/local/cuda-8.0/include/sm_20_atomic_functions.h" "$(@D)/cuda/include/sm_20_atomic_functions.h" && cp "/usr/local/cuda-8.0/include/nvfunctional" "$(@D)/cuda/include/nvfunctional" """, + local = 1, +) + +genrule( + name = "cuda-nvvm", + outs = [ + "cuda/nvvm/bin/cicc", + "cuda/nvvm/libdevice/libdevice.compute_50.10.bc", + "cuda/nvvm/libdevice/libdevice.compute_30.10.bc", + "cuda/nvvm/libdevice/libdevice.compute_20.10.bc", + "cuda/nvvm/libdevice/libdevice.compute_35.10.bc", + "cuda/nvvm/lib64/libnvvm.so.3", + "cuda/nvvm/lib64/libnvvm.so", + "cuda/nvvm/lib64/libnvvm.so.3.1.0", + "cuda/nvvm/include/nvvm.h", + "cuda/nvvm/libnvvm-samples/ptxgen/README.txt", + "cuda/nvvm/libnvvm-samples/ptxgen/ptxgen.c", + "cuda/nvvm/libnvvm-samples/ptxgen/CMakeLists.txt", + "cuda/nvvm/libnvvm-samples/build.bat", + "cuda/nvvm/libnvvm-samples/cuda-c-linking/README.txt", + "cuda/nvvm/libnvvm-samples/cuda-c-linking/math-funcs.cu", + "cuda/nvvm/libnvvm-samples/cuda-c-linking/CMakeLists.txt", + "cuda/nvvm/libnvvm-samples/cuda-c-linking/cuda-c-linking.cpp", + "cuda/nvvm/libnvvm-samples/README.txt", + "cuda/nvvm/libnvvm-samples/simple/simple.c", + "cuda/nvvm/libnvvm-samples/simple/simple-gpu.ll", + "cuda/nvvm/libnvvm-samples/simple/README.txt", + "cuda/nvvm/libnvvm-samples/simple/simple-gpu64.ll", + "cuda/nvvm/libnvvm-samples/simple/CMakeLists.txt", + "cuda/nvvm/libnvvm-samples/common/include/DDSWriter.h", + "cuda/nvvm/libnvvm-samples/common/include/drvapi_error_string.h", + "cuda/nvvm/libnvvm-samples/build.sh", + "cuda/nvvm/libnvvm-samples/CMakeLists.txt", + ], + cmd = """ +cp "/usr/local/cuda-8.0/nvvm/bin/cicc" "$(@D)/cuda/nvvm/bin/cicc" && cp "/usr/local/cuda-8.0/nvvm/libdevice/libdevice.compute_50.10.bc" "$(@D)/cuda/nvvm/libdevice/libdevice.compute_50.10.bc" && cp "/usr/local/cuda-8.0/nvvm/libdevice/libdevice.compute_30.10.bc" "$(@D)/cuda/nvvm/libdevice/libdevice.compute_30.10.bc" && cp "/usr/local/cuda-8.0/nvvm/libdevice/libdevice.compute_20.10.bc" "$(@D)/cuda/nvvm/libdevice/libdevice.compute_20.10.bc" && cp "/usr/local/cuda-8.0/nvvm/libdevice/libdevice.compute_35.10.bc" "$(@D)/cuda/nvvm/libdevice/libdevice.compute_35.10.bc" && cp "/usr/local/cuda-8.0/nvvm/lib64/libnvvm.so.3" "$(@D)/cuda/nvvm/lib64/libnvvm.so.3" && cp "/usr/local/cuda-8.0/nvvm/lib64/libnvvm.so" "$(@D)/cuda/nvvm/lib64/libnvvm.so" && cp "/usr/local/cuda-8.0/nvvm/lib64/libnvvm.so.3.1.0" "$(@D)/cuda/nvvm/lib64/libnvvm.so.3.1.0" && cp "/usr/local/cuda-8.0/nvvm/include/nvvm.h" "$(@D)/cuda/nvvm/include/nvvm.h" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/ptxgen/README.txt" "$(@D)/cuda/nvvm/libnvvm-samples/ptxgen/README.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/ptxgen/ptxgen.c" "$(@D)/cuda/nvvm/libnvvm-samples/ptxgen/ptxgen.c" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/ptxgen/CMakeLists.txt" "$(@D)/cuda/nvvm/libnvvm-samples/ptxgen/CMakeLists.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/build.bat" "$(@D)/cuda/nvvm/libnvvm-samples/build.bat" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/cuda-c-linking/README.txt" "$(@D)/cuda/nvvm/libnvvm-samples/cuda-c-linking/README.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/cuda-c-linking/math-funcs.cu" "$(@D)/cuda/nvvm/libnvvm-samples/cuda-c-linking/math-funcs.cu" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/cuda-c-linking/CMakeLists.txt" "$(@D)/cuda/nvvm/libnvvm-samples/cuda-c-linking/CMakeLists.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/cuda-c-linking/cuda-c-linking.cpp" "$(@D)/cuda/nvvm/libnvvm-samples/cuda-c-linking/cuda-c-linking.cpp" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/README.txt" "$(@D)/cuda/nvvm/libnvvm-samples/README.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/simple/simple.c" "$(@D)/cuda/nvvm/libnvvm-samples/simple/simple.c" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/simple/simple-gpu.ll" "$(@D)/cuda/nvvm/libnvvm-samples/simple/simple-gpu.ll" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/simple/README.txt" "$(@D)/cuda/nvvm/libnvvm-samples/simple/README.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/simple/simple-gpu64.ll" "$(@D)/cuda/nvvm/libnvvm-samples/simple/simple-gpu64.ll" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/simple/CMakeLists.txt" "$(@D)/cuda/nvvm/libnvvm-samples/simple/CMakeLists.txt" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/common/include/DDSWriter.h" "$(@D)/cuda/nvvm/libnvvm-samples/common/include/DDSWriter.h" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/common/include/drvapi_error_string.h" "$(@D)/cuda/nvvm/libnvvm-samples/common/include/drvapi_error_string.h" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/build.sh" "$(@D)/cuda/nvvm/libnvvm-samples/build.sh" && cp "/usr/local/cuda-8.0/nvvm/libnvvm-samples/CMakeLists.txt" "$(@D)/cuda/nvvm/libnvvm-samples/CMakeLists.txt" """, +) + +genrule( + name = "cuda-extras", + outs = [ + "cuda/extras/CUPTI/include/cupti_result.h", + "cuda/extras/CUPTI/include/cupti_events.h", + "cuda/extras/CUPTI/include/openacc/cupti_openacc.h", + "cuda/extras/CUPTI/include/cupti_version.h", + "cuda/extras/CUPTI/include/generated_cuda_gl_interop_meta.h", + "cuda/extras/CUPTI/include/generated_cudaVDPAU_meta.h", + "cuda/extras/CUPTI/include/cupti_activity.h", + "cuda/extras/CUPTI/include/generated_cuda_runtime_api_meta.h", + "cuda/extras/CUPTI/include/generated_cuda_meta.h", + "cuda/extras/CUPTI/include/cupti_nvtx_cbid.h", + "cuda/extras/CUPTI/include/cuda_stdint.h", + "cuda/extras/CUPTI/include/generated_cudaGL_meta.h", + "cuda/extras/CUPTI/include/generated_cuda_vdpau_interop_meta.h", + "cuda/extras/CUPTI/include/cupti_metrics.h", + "cuda/extras/CUPTI/include/cupti_callbacks.h", + "cuda/extras/CUPTI/include/cupti_runtime_cbid.h", + "cuda/extras/CUPTI/include/cupti.h", + "cuda/extras/CUPTI/include/GL/glut.h", + "cuda/extras/CUPTI/include/GL/glu.h", + "cuda/extras/CUPTI/include/GL/glxext.h", + "cuda/extras/CUPTI/include/GL/wglext.h", + "cuda/extras/CUPTI/include/GL/glx.h", + "cuda/extras/CUPTI/include/GL/glext.h", + "cuda/extras/CUPTI/include/GL/wglew.h", + "cuda/extras/CUPTI/include/GL/gl.h", + "cuda/extras/CUPTI/include/GL/glew.h", + "cuda/extras/CUPTI/include/cupti_driver_cbid.h", + "cuda/extras/CUPTI/include/generated_nvtx_meta.h", + ], + cmd = """ +cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_result.h" "$(@D)/cuda/extras/CUPTI/include/cupti_result.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_events.h" "$(@D)/cuda/extras/CUPTI/include/cupti_events.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/openacc/cupti_openacc.h" "$(@D)/cuda/extras/CUPTI/include/openacc/cupti_openacc.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_version.h" "$(@D)/cuda/extras/CUPTI/include/cupti_version.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_cuda_gl_interop_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_gl_interop_meta.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_cudaVDPAU_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cudaVDPAU_meta.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_activity.h" "$(@D)/cuda/extras/CUPTI/include/cupti_activity.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_cuda_runtime_api_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_runtime_api_meta.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_cuda_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_meta.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_nvtx_cbid.h" "$(@D)/cuda/extras/CUPTI/include/cupti_nvtx_cbid.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cuda_stdint.h" "$(@D)/cuda/extras/CUPTI/include/cuda_stdint.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_cudaGL_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cudaGL_meta.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_cuda_vdpau_interop_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_cuda_vdpau_interop_meta.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_metrics.h" "$(@D)/cuda/extras/CUPTI/include/cupti_metrics.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_callbacks.h" "$(@D)/cuda/extras/CUPTI/include/cupti_callbacks.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_runtime_cbid.h" "$(@D)/cuda/extras/CUPTI/include/cupti_runtime_cbid.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti.h" "$(@D)/cuda/extras/CUPTI/include/cupti.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/glut.h" "$(@D)/cuda/extras/CUPTI/include/GL/glut.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/glu.h" "$(@D)/cuda/extras/CUPTI/include/GL/glu.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/glxext.h" "$(@D)/cuda/extras/CUPTI/include/GL/glxext.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/wglext.h" "$(@D)/cuda/extras/CUPTI/include/GL/wglext.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/glx.h" "$(@D)/cuda/extras/CUPTI/include/GL/glx.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/glext.h" "$(@D)/cuda/extras/CUPTI/include/GL/glext.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/wglew.h" "$(@D)/cuda/extras/CUPTI/include/GL/wglew.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/gl.h" "$(@D)/cuda/extras/CUPTI/include/GL/gl.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/GL/glew.h" "$(@D)/cuda/extras/CUPTI/include/GL/glew.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/cupti_driver_cbid.h" "$(@D)/cuda/extras/CUPTI/include/cupti_driver_cbid.h" && cp "/usr/local/cuda-8.0/extras/CUPTI/include/generated_nvtx_meta.h" "$(@D)/cuda/extras/CUPTI/include/generated_nvtx_meta.h" """, +) + +genrule( + name = "cuda-lib", + outs = [ + "cuda/lib/libcuda.so", + "cuda/lib/libcudart.so.8.0", + "cuda/lib/libcudart_static.a", + "cuda/lib/libcublas.so.8.0", + "cuda/lib/libcusolver.so.8.0", + "cuda/lib/libcurand.so.8.0", + "cuda/lib/libcufft.so.8.0", + "cuda/lib/libcudnn.so.5", + "cuda/lib/libcupti.so.8.0", + ], + cmd = """ +cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/stubs/libcuda.so" "$(@D)/cuda/lib/libcuda.so" && cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudart.so.8.0.61" "$(@D)/cuda/lib/libcudart.so.8.0" && cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcudart_static.a" "$(@D)/cuda/lib/libcudart_static.a" && cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcublas.so.8.0.71" "$(@D)/cuda/lib/libcublas.so.8.0" && cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcusolver.so.8.0.61" "$(@D)/cuda/lib/libcusolver.so.8.0" && cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcurand.so.8.0.61" "$(@D)/cuda/lib/libcurand.so.8.0" && cp "/usr/local/cuda-8.0/targets/x86_64-linux/lib/libcufft.so.8.0.61" "$(@D)/cuda/lib/libcufft.so.8.0" && cp "/usr/lib/x86_64-linux-gnu/libcudnn.so.5.1.10" "$(@D)/cuda/lib/libcudnn.so.5" && cp "/usr/local/cuda-8.0/extras/CUPTI/lib64/libcupti.so.8.0.61" "$(@D)/cuda/lib/libcupti.so.8.0" """, +) + +genrule( + name = "cudnn-include", + outs = [ + "cuda/include/cudnn.h", + ], + cmd = """ +cp "/usr/include/cudnn.h" "$(@D)/cudnn.h" """, +) diff --git a/third_party/toolchains/gpus/cuda/build_defs.bzl b/third_party/toolchains/gpus/cuda/build_defs.bzl new file mode 100644 index 00000000000..badaf430193 --- /dev/null +++ b/third_party/toolchains/gpus/cuda/build_defs.bzl @@ -0,0 +1,37 @@ +# Macros for building CUDA code used with Bazel remote +# execution service. +# DO NOT EDIT: automatically generated file + +def if_cuda(if_true, if_false = []): + """Shorthand for select()'ing on whether we're building with CUDA. + + Returns a select statement which evaluates to if_true if we're building + with CUDA enabled. Otherwise, the select statement evaluates to if_false. + + """ + return select({ + "@local_config_cuda//cuda:using_nvcc": if_true, + "@local_config_cuda//cuda:using_clang": if_true, + "//conditions:default": if_false + }) + + +def cuda_default_copts(): + """Default options for all CUDA compilations.""" + return if_cuda(["-x", "cuda", "-DGOOGLE_CUDA=1"] + ["--cuda-gpu-arch=sm_30"]) + + +def cuda_is_configured(): + """Returns true if CUDA was enabled during the configure process.""" + return True + +def if_cuda_is_configured(x): + """Tests if the CUDA was enabled during the configure process. + + Unlike if_cuda(), this does not require that we are building with + --config=cuda. Used to allow non-CUDA code to depend on CUDA libraries. + """ + if cuda_is_configured(): + return x + return [] + diff --git a/tensorflow/tensorboard/components/tf_graph_loader/test/loader.ts b/third_party/toolchains/gpus/cuda/cuda/cuda_config.h similarity index 56% rename from tensorflow/tensorboard/components/tf_graph_loader/test/loader.ts rename to third_party/toolchains/gpus/cuda/cuda/cuda_config.h index fcd9f7b5295..dddee321938 100644 --- a/tensorflow/tensorboard/components/tf_graph_loader/test/loader.ts +++ b/third_party/toolchains/gpus/cuda/cuda/cuda_config.h @@ -1,25 +1,20 @@ /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the 'License'); +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an 'AS IS' BASIS, +distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -suite('graph loader', () => { - let assert = chai.assert; +// DO NOT EDIT: automatically generated file +#ifndef THIRD_PARTY_TENSORFLOW_OPENSOURCE_ONLY_TOOLCHAINS_GPUS_CUDA_CUDA_CONFIG_H_ +#define THIRD_PARTY_TENSORFLOW_OPENSOURCE_ONLY_TOOLCHAINS_GPUS_CUDA_CUDA_CONFIG_H_ - test('loader exists', () => { - assert.isTrue(document.getElementById('loader') != null); - }); - - // TODO(bp): write tests. - -}); +#endif // THIRD_PARTY_TENSORFLOW_OPENSOURCE_ONLY_TOOLCHAINS_GPUS_CUDA_CUDA_CONFIG_H_ diff --git a/third_party/typings.bzl b/third_party/typings.bzl deleted file mode 100644 index d0c9eddbb3f..00000000000 --- a/third_party/typings.bzl +++ /dev/null @@ -1,365 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the 'License'); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an 'AS IS' BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TensorBoard typing dependencies - -load("@io_bazel_rules_closure//closure:defs.bzl", "filegroup_external") - -def tensorboard_typings_workspace(): - filegroup_external( - name = "org_definitelytyped", - licenses = ["notice"], # MIT - sha256_urls = { - "b7da645f6e5555feb7aeede73775da0023ce2257df9c8e76c9159266035a9c0d": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/ebc69904eb78f94030d5d517b42db20867f679c0/chai/chai.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/ebc69904eb78f94030d5d517b42db20867f679c0/chai/chai.d.ts", - ], - "177293828c7a206bf2a7f725753d51396d38668311aa37c96445f91bbf8128a7": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/6e2f2280ef16ef277049d0ce8583af167d586c59/d3/d3.d.ts", # v3 - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/6e2f2280ef16ef277049d0ce8583af167d586c59/d3/d3.d.ts", # v3 - ], - "e4cd3d5de0eb3bc7b1063b50d336764a0ac82a658b39b5cf90511f489ffdee60": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/efd40e67ff323f7147651bdbef03c03ead7b1675/lodash/lodash.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/efd40e67ff323f7147651bdbef03c03ead7b1675/lodash/lodash.d.ts", - ], - "695a03dd2ccb238161d97160b239ab841562710e5c4e42886aefd4ace2ce152e": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/ebc69904eb78f94030d5d517b42db20867f679c0/mocha/mocha.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/ebc69904eb78f94030d5d517b42db20867f679c0/mocha/mocha.d.ts", - ], - "513ccd9ee1c708881120eeacd56788fc3b3da8e5c6172b20324cebbe858803fe": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/708609e0764daeb5eb64104af7aca50c520c4e6e/sinon/sinon.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/708609e0764daeb5eb64104af7aca50c520c4e6e/sinon/sinon.d.ts", - ], - "44eba36339bd1c0792072b7b204ee926fe5ffe1e9e2da916e67ac55548e3668a": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/a872802c0c84ba98ff207d5e673a1fa867c67fd6/polymer/polymer.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/a872802c0c84ba98ff207d5e673a1fa867c67fd6/polymer/polymer.d.ts", - ], - "9453c3e6bae824e90758c3b38975c1ed77e6abd79bf513bcb08368fcdb14898e": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/f5407eba29c04fb8387c86df27512bd055b195d2/threejs/three.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/f5407eba29c04fb8387c86df27512bd055b195d2/threejs/three.d.ts", - ], - "691756a6eb455f340c9e834de0d49fff269e7b8c1799c2454465dcd6a4435b80": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/46719185c564694c5583c4b7ad94dbb786ecad46/webcomponents.js/webcomponents.js.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/46719185c564694c5583c4b7ad94dbb786ecad46/webcomponents.js/webcomponents.js.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_array", - licenses = ["notice"], # MIT - sha256_urls = { - "61e7abb7b1f01fbcb0cab8cf39003392f422566209edd681fbd070eaa84ca000": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-array/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-array/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_axis", - licenses = ["notice"], # MIT - sha256_urls = { - "95f75c8dcc89850b2e72581d96a7b5f46ea4ac852f828893f141f14a597421f9": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-axis/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-axis/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_brush", - licenses = ["notice"], # MIT - sha256_urls = { - "a2738e693ce8a8640c2d29001e77582c9c361fd23bda44db471629866b60ada7": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-brush/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-brush/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_chord", - licenses = ["notice"], # MIT - sha256_urls = { - "c54d24756eb6d744b31e538ad9bab3a75f6d54e2288b29cc72338d4a057d3e83": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-chord/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-chord/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_collection", - licenses = ["notice"], # MIT - sha256_urls = { - "f987667167b1d2970911247e325eb1c37ca0823646f81ccec837ae59039822f7": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-collection/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-collection/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_color", - licenses = ["notice"], # MIT - sha256_urls = { - "9580c81f38ddcce7be0ac9bd3d0d083adebc34e17441709f90b9e4dcd1c19a56": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-color/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-color/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_dispatch", - licenses = ["notice"], # MIT - sha256_urls = { - "169f80b4cceca8e2e9ed384d81a5db0624cc01a26451dfb5a7e0cec6ea9cfb06": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-dispatch/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-dispatch/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_drag", - licenses = ["notice"], # MIT - sha256_urls = { - "08d35d139dde58c2722be98d718d01204fd6167d310f09b379e832f3c741489d": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-drag/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-drag/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_dsv", - licenses = ["notice"], # MIT - sha256_urls = { - "62594d00cf9e4bb895339c8e56f64330e202a5eb2a0fa580a1f6e6336f2c93ce": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-dsv/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-dsv/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_ease", - licenses = ["notice"], # MIT - sha256_urls = { - "d1cf8f99b7bf758c2ba3c0a4ce553e151d4d9b4cf45a6e8bd0edec7ce90f725b": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-ease/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-ease/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_force", - licenses = ["notice"], # MIT - sha256_urls = { - "288421e2008668d2076a4684657dd3d29b992832ef02c552981eb94a91042553": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-force/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-force/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_format", - licenses = ["notice"], # MIT - sha256_urls = { - "b42cb17e580c1fd0b64d478f7bd80ca806efaefda24426a833cf1f30a7275bca": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-format/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-format/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_hierarchy", - licenses = ["notice"], # MIT - sha256_urls = { - "a5683f5835d8716c6b89c075235078438cfab5897023ed720bfa492e244e969e": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-hierarchy/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-hierarchy/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_interpolate", - licenses = ["notice"], # MIT - sha256_urls = { - "590a71b741323ac3139b333ec8b743e24717fdd5b32bcff48ee521162a9dfe1c": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-interpolate/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-interpolate/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_path", - licenses = ["notice"], # MIT - sha256_urls = { - "96f35ba041bcaa265e2b373ee675177410d44d31c980e4f7fbeefd4bcba15b00": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-path/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-path/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_polygon", - licenses = ["notice"], # MIT - sha256_urls = { - "ce453451e8105cac6a4f4a4263ca2142ebb4bf442e342f470a81da691f220fcb": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-polygon/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-polygon/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_quadtree", - licenses = ["notice"], # MIT - sha256_urls = { - "238e278f1be5d6985a19800800cffee80f81199f71d848e3bbc288d1791a6f90": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-quadtree/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-quadtree/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_queue", - licenses = ["notice"], # MIT - sha256_urls = { - "e6ae19aad83495475653578de64fb9d6bf9764eda6c84d70f7935ec84bcc482e": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-queue/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-queue/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_random", - licenses = ["notice"], # MIT - sha256_urls = { - "d31b92ed86c23ec0a4776f99fa81ff033c95b96c8304d8aa9baf3b94af779aa8": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-random/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-random/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_request", - licenses = ["notice"], # MIT - sha256_urls = { - "44bb7b07d977028e6567540a3303b06fc9b33fb0960bc75c520e0733c840d89f": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-request/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-request/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_scale", - licenses = ["notice"], # MIT - sha256_urls = { - "02ce7c644ba34bd1abb84da2e832f248b048b6a23812be4365bd837f186c9f1f": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-scale/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-scale/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_selection", - licenses = ["notice"], # MIT - sha256_urls = { - "699043ddb28dfa5e46d87bc6a24cfc6d604237f298259d3fb3c7066e05e8c86e": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-selection/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-selection/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_shape", - licenses = ["notice"], # MIT - sha256_urls = { - "62668a7aaaf6232762b544f9f89c0f557ca7cfb0cd343a358dda7ecbe26f5739": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-shape/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-shape/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_time", - licenses = ["notice"], # MIT - sha256_urls = { - "0502490ce682fd9265fb1d5d693ce6cd82e3b05e5f5ee3433731266ecb03d5fc": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-time/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-time/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_timer", - licenses = ["notice"], # MIT - sha256_urls = { - "6f191f9aea704aa64b1defa40dfdff1447a6e6bb815feff1660f894500a9c94d": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-timer/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-timer/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_transition", - licenses = ["notice"], # MIT - sha256_urls = { - "a0a7c0c9bfb5c7d6d9d22a8d16b4484b66d13f2ed226954037546cb3da4098ba": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-transition/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-transition/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_voronoi", - licenses = ["notice"], # MIT - sha256_urls = { - "c6bd5f229f915151d0ef678fe50b1aa6a62334ea0a8c6fc0effbac9f7032efc7": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-voronoi/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-voronoi/index.d.ts", - ], - }, - ) - - filegroup_external( - name = "org_definitelytyped_types_d3_zoom", - licenses = ["notice"], # MIT - sha256_urls = { - "a25dc17fbd304cf7a0e5e7bbb8339c930d464eb40c4d6e5f839ce9c0191f4110": [ - "http://mirror.bazel.build/raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-zoom/index.d.ts", - "https://raw.githubusercontent.com/DefinitelyTyped/DefinitelyTyped/1550dfd1b8e38d9bf104b3fd16ea9bf98a2b358e/types/d3-zoom/index.d.ts", - ], - }, - ) diff --git a/third_party/werkzeug.BUILD b/third_party/werkzeug.BUILD deleted file mode 100644 index 72a1402030d..00000000000 --- a/third_party/werkzeug.BUILD +++ /dev/null @@ -1,14 +0,0 @@ -# Description: -# Werkzeug provides utilities for making WSGI applications - -licenses(["notice"]) # BSD 3-Clause - -exports_files(["LICENSE"]) - -# Note: this library includes test code. Consider creating a testonly target. -py_library( - name = "werkzeug", - srcs = glob(["werkzeug/*.py"]), - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], -)