diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 85c4e4fd93c..52945d32391 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -79,6 +79,8 @@ tf_cc_test( "//tensorflow/cc:ops", "//tensorflow/cc:scope", "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", ], diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 473e6339f36..56f7303f70f 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -479,20 +479,16 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, if (kernel == nullptr) { const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); kernel = new tensorflow::KernelAndDevice(ctx->rendezvous); - if (!op->is_function()) { - status->status = - tensorflow::KernelAndDevice::InitOp(device, ndef, kernel); - } else { - // Knowledge of the implementation of InitFn (and in-turn - // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def - // will be accessed, so grab on to the lock. - // See WARNING comment below - would be nice to rework to avoid this - // subtlety. - tensorflow::mutex_lock l(ctx->functions_mu); - status->status = tensorflow::KernelAndDevice::InitFn( - ndef, ctx->func_lib(device), kernel); - } + // Knowledge of the implementation of Init (and in-turn + // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def + // will be accessed, so grab on to the lock. + // See WARNING comment below - would be nice to rework to avoid this + // subtlety. + tensorflow::tf_shared_lock l(ctx->functions_mu); + status->status = + tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel); if (!status->status.ok()) { + delete kernel; return; } tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc index b6d53872c97..3b39903e09a 100644 --- a/tensorflow/c/eager/runtime.cc +++ b/tensorflow/c/eager/runtime.cc @@ -238,9 +238,8 @@ Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef, } // static -Status KernelAndDevice::InitFn(const NodeDef& ndef, - FunctionLibraryRuntime* flib, - KernelAndDevice* out) { +Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, + KernelAndDevice* out) { OpKernel* k = nullptr; Status s = flib->CreateKernel(ndef, &k); out->device_ = flib->device(); diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h index bb098f74013..13b49e5e8cb 100644 --- a/tensorflow/c/eager/runtime.h +++ b/tensorflow/c/eager/runtime.h @@ -150,28 +150,19 @@ class KernelAndDevice { public: // Populates 'out' with a kernel appropriate for 'ndef'. // - // Assumes that 'ndef' refers to a primitive op (as opposed to a function). - static Status InitOp(Device* device, const NodeDef& ndef, - KernelAndDevice* out); - - // Like InitOp but for functions defined in flib (i.e., ndef.op() refers to a - // TensorFlow function in the FunctionLibraryRuntime). - // // The provided FunctionLibraryRuntime MUST outlive all calls to // Run() on the returned KernelAndDevice. // - // TODO(ashankar): There shouldn't be a need for a separate InitOp and InitFn. - // The implementation of InitFn should work for both because - // FunctionLibraryRuntime::CreateKernel will create a primitive op kernel if - // appropriate. However, for now we keep them separate because I haven't - // figured out thread-safety concerns around FunctionLibraryRuntime (in - // particular, how the underlying FunctionLibraryDefinition might be mutated - // by another thread as new functions are registered with it). - // Conservatively, thread-safe usage of the FunctionLibraryRuntime is pushed - // on to the caller (see locking in c_api.cc) for now. But I really should - // dig into this so that both InitOp and InitFn can be collapsed to - // FunctionLibraryRuntime::CreateKernel. - static Status InitFn(const NodeDef& ndef, FunctionLibraryRuntime* flib, + // TODO(ashankar): Figure out thread-safety concerns around + // FunctionLibraryRuntime (in particular, how the underlying + // FunctionLibraryDefinition might be mutated by another thread as new + // functions are registered with it). Conservatively, thread-safe usage of + // the FunctionLibraryRuntime is pushed on to the caller (see locking in + // c_api.cc). + static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, + KernelAndDevice* out); + // TODO(ashankar): Remove this + static Status InitOp(Device* device, const NodeDef& ndef, KernelAndDevice* out); KernelAndDevice(tensorflow::Rendezvous* rendez) @@ -184,10 +175,10 @@ class KernelAndDevice { private: std::unique_ptr kernel_; - tensorflow::Device* device_; - tensorflow::FunctionLibraryRuntime* flib_; - tensorflow::checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; - tensorflow::Rendezvous* rendez_; + Device* device_; + FunctionLibraryRuntime* flib_; + checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; + Rendezvous* rendez_; }; } // namespace tensorflow diff --git a/tensorflow/c/eager/runtime_test.cc b/tensorflow/c/eager/runtime_test.cc index f9bfce38580..3236c6be0ec 100644 --- a/tensorflow/c/eager/runtime_test.cc +++ b/tensorflow/c/eager/runtime_test.cc @@ -23,15 +23,36 @@ limitations under the License. #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/version.h" namespace tensorflow { namespace { -Device* CPUDevice() { - return DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"); -} +class TestEnv { + public: + TestEnv() : flib_def_(OpRegistry::Global(), {}) { + Device* device = + DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0"); + device_mgr_.reset(new DeviceMgr({device})); + flib_runtime_ = NewFunctionLibraryRuntime(device_mgr_.get(), Env::Default(), + device, TF_GRAPH_DEF_VERSION, + &flib_def_, {}, nullptr); + } + + FunctionLibraryRuntime* function_library_runtime() const { + return flib_runtime_.get(); + } + + private: + FunctionLibraryDefinition flib_def_; + std::unique_ptr device_mgr_; + std::unique_ptr flib_runtime_; +}; TEST(AttrTypeMap, Lookup) { const AttrTypeMap* m = nullptr; @@ -69,9 +90,10 @@ TEST(KernelAndDevice, Run) { .Set("transpose_b", false) .NumInputs(inputs.size()) .BuildNodeDef()); - std::unique_ptr device(CPUDevice()); + TestEnv env; KernelAndDevice kernel(nullptr); - Status s = KernelAndDevice::InitOp(device.get(), ndef, &kernel); + Status s = + KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel); ASSERT_TRUE(s.ok()) << s; std::vector outputs; s = kernel.Run(&inputs, &outputs); @@ -132,11 +154,12 @@ void BM_KernelAndDeviceInit(int iters) { .Set("transpose_b", false) .NumInputs(2) .BuildNodeDef()); - std::unique_ptr device(CPUDevice()); + TestEnv env; KernelAndDevice k(nullptr); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { - TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &k)); + TF_CHECK_OK( + KernelAndDevice::Init(ndef, env.function_library_runtime(), &k)); } } BENCHMARK(BM_KernelAndDeviceInit); @@ -154,9 +177,10 @@ void BM_KernelAndDeviceRun(int iters) { .Set("transpose_b", false) .NumInputs(inputs.size()) .BuildNodeDef()); - std::unique_ptr device(CPUDevice()); + TestEnv env; KernelAndDevice kernel(nullptr); - TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &kernel)); + TF_CHECK_OK( + KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel)); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { TF_CHECK_OK(kernel.Run(&inputs, &outputs)); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 98cc3401c14..c5ed976d78d 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -286,6 +286,7 @@ cc_library( srcs = ["call_inliner.cc"], hdrs = ["call_inliner.h"], deps = [ + ":call_graph", ":hlo_pass", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc index 817b59f7627..65472d9ac92 100644 --- a/tensorflow/compiler/xla/service/call_inliner.cc +++ b/tensorflow/compiler/xla/service/call_inliner.cc @@ -17,33 +17,11 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/call_graph.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { - -StatusOr CallInliner::Run(HloModule* module) { - std::deque work_queue; - - // Seed the work queue with call instructions from the main computation. - TF_RETURN_IF_ERROR( - module->entry_computation()->Accept([&](HloInstruction* hlo) { - if (hlo->opcode() == HloOpcode::kCall) { - work_queue.push_back(hlo); - } - return Status::OK(); - })); - - VLOG(1) << "Work queue seeded with " << work_queue.size() << " entries."; - - bool mutated = false; - while (!work_queue.empty()) { - mutated = true; - HloInstruction* call = work_queue.front(); - work_queue.pop_front(); - TF_RETURN_IF_ERROR(ReplaceWithInlinedBody(call, &work_queue)); - } - return mutated; -} +namespace { // Traverses the callee computation, inlining cloned nodes into the caller // computation and connecting them to producers/consumers appropriately. @@ -52,14 +30,18 @@ StatusOr CallInliner::Run(HloModule* module) { // computation have been added to the work_queue. class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { public: - SubcomputationInsertionVisitor(HloInstruction* call, - std::deque* work_queue) - : call_(call), outer_(call->parent()), work_queue_(work_queue) {} + // call is the call operation -- it will be replaced with the body of the + // called computation. + explicit SubcomputationInsertionVisitor(HloInstruction* call) + : call_(call), outer_(call->parent()) { + CHECK_EQ(HloOpcode::kCall, call_->opcode()); + } // Resolves the operands to the HLO instruction in the inlined (caller) graph, // and clones the HLO instruction into that graph with the new operands. // If the instruction is a call, it is added to the work queue. Status DefaultAction(HloInstruction* hlo) override { + TF_RET_CHECK(hlo->opcode() != HloOpcode::kCall); std::vector new_operands; for (HloInstruction* operand : hlo->operands()) { TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, Resolve(operand)); @@ -79,12 +61,6 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { new_control_predecessor->AddControlDependencyTo(new_hlo_pointer)); } - if (new_hlo_pointer->opcode() == HloOpcode::kCall) { - VLOG(1) << "Adding new call HLO to work queue."; - // Call instructions we observe in the subcomputation are added to the - // inliner work queue. - work_queue_->push_back(new_hlo_pointer); - } return Status::OK(); } @@ -141,16 +117,30 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { std::deque* work_queue_; }; -Status CallInliner::ReplaceWithInlinedBody( - HloInstruction* call, std::deque* work_queue) { - TF_RET_CHECK(call->opcode() == HloOpcode::kCall); - TF_RET_CHECK(call->called_computations().size() == 1); - HloComputation* called = call->called_computations()[0]; - VLOG(1) << "Replacing call " << call->ToString() << " with inlined body of " - << called->name(); +} // namespace - SubcomputationInsertionVisitor visitor(call, work_queue); - return called->Accept(&visitor); +StatusOr CallInliner::Run(HloModule* module) { + std::unique_ptr call_graph = CallGraph::Build(module); + // Because call graph nodes are visited in post-order (callees before callers) + // we'll always inline kCalls into their callers in the appropriate order. + bool did_mutate = false; + TF_RETURN_IF_ERROR( + call_graph->VisitNodes([&](const CallGraphNode& node) -> Status { + for (const CallSite& callsite : node.caller_callsites()) { + VLOG(1) << "Visiting callsite: " << callsite.ToString(); + if (callsite.instruction()->opcode() == HloOpcode::kCall) { + did_mutate = true; + const auto& callees = callsite.called_computations(); + TF_RET_CHECK(callees.size() == 1); + HloComputation* callee = callees[0]; + // We visit the callee, cloning its body into its caller. + SubcomputationInsertionVisitor visitor(callsite.instruction()); + TF_RETURN_IF_ERROR(callee->Accept(&visitor)); + } + } + return Status::OK(); + })); + return did_mutate; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h index 8647edffa7f..8660200bc40 100644 --- a/tensorflow/compiler/xla/service/call_inliner.h +++ b/tensorflow/compiler/xla/service/call_inliner.h @@ -31,16 +31,6 @@ class CallInliner : public HloPassInterface { tensorflow::StringPiece name() const override { return "CallInliner"; } StatusOr Run(HloModule* module) override; - - private: - // Replaces the given call operation -- which must be an operation inside the - // entry computation with opcode kCall -- with the called computation's body, - // such that the called computation is inline in the entry computation. - // - // On successful inlining, the inlined computation may have itself contained - // calls; if so, they are added to the work_queue. - Status ReplaceWithInlinedBody(HloInstruction* call, - std::deque* work_queue); }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index 77528d0b75f..f3e7407c544 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -44,6 +44,8 @@ namespace { using CallInlinerTest = HloTestBase; TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { + // "inner" computation just has a control dependency from the "zero" value to + // the "one" value. HloComputation::Builder inner(TestName() + ".inner"); HloInstruction* zero = inner.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(24.0f))); @@ -54,6 +56,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { HloComputation* inner_computation = module->AddEmbeddedComputation(inner.Build()); + // "outer" computation just calls the "inner" computation. HloComputation::Builder outer(TestName() + ".outer"); Shape r0f32 = ShapeUtil::MakeShape(F32, {}); outer.AddInstruction( @@ -73,5 +76,44 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) { EXPECT_EQ(prior->literal().GetFirstElement(), 24); } +// Tests for referential transparency (a function that calls a function that +// returns false should be identical to just returning false). +TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { + const Shape pred = ShapeUtil::MakeShape(PRED, {}); + auto module = CreateNewModule(); + + // Create a lambda that calls a function that returns the false predicate. + // Note we also use this lambda twice by reference, just to make the test a + // little trickier. + HloComputation::Builder just_false(TestName() + ".false"); + just_false.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloComputation* false_computation = + module->AddEmbeddedComputation(just_false.Build()); + + HloComputation::Builder call_false_builder(TestName() + ".call_false"); + call_false_builder.AddInstruction( + HloInstruction::CreateCall(pred, {}, false_computation)); + HloComputation* call_false = + module->AddEmbeddedComputation(call_false_builder.Build()); + + HloComputation::Builder outer(TestName() + ".outer"); + HloInstruction* init_value = outer.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + outer.AddInstruction( + HloInstruction::CreateWhile(pred, call_false, call_false, init_value)); + + auto computation = module->AddEntryComputation(outer.Build()); + + CallInliner call_inliner; + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + ASSERT_TRUE(mutated); + EXPECT_THAT( + computation->root_instruction()->while_condition()->root_instruction(), + op::Constant()); + EXPECT_THAT(computation->root_instruction()->while_body()->root_instruction(), + op::Constant()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 5983341c2b1..9472a5eddce 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -216,8 +216,7 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { Status HandleCall(HloInstruction* call) override { TF_RETURN_IF_ERROR(DefaultAction(call)); CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_); - TF_RETURN_IF_ERROR( - call->to_apply()->root_instruction()->Accept(&candidates_for_call)); + TF_RETURN_IF_ERROR(call->to_apply()->Accept(&candidates_for_call)); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index 0fe7c9fe1b2..eaeb352183b 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -45,8 +45,7 @@ string HloExecutionProfile::ToString( const HloComputation& computation, const DeviceDescription& device_description, HloCostAnalysis* cost_analysis) const { - tensorflow::Status analysis_status = - computation.root_instruction()->Accept(cost_analysis); + tensorflow::Status analysis_status = computation.Accept(cost_analysis); if (!analysis_status.ok()) { return ""; } diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index d63d33ecb00..00fe55419d6 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -1179,8 +1179,7 @@ tensorflow::Status Service::GetComputationStats( HloCostAnalysis analysis( execute_backend_->compiler()->ShapeSizeBytesFunction()); - TF_RETURN_IF_ERROR( - module->entry_computation()->root_instruction()->Accept(&analysis)); + TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis)); ComputationStats stats; stats.set_flop_count(analysis.flop_count()); diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 77d1c019f3a..f3cbc013238 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -151,19 +151,6 @@ XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) { ComputeAndCompareR0(&builder, -3, {}); } -XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, ShapeUtil::MakeShape(S64, {}), "a"); - builder.ConvertElementType(a, F32); - - int64 value = 3LL << 32; - std::unique_ptr a_literal = Literal::CreateR0(value); - std::unique_ptr a_data = - client_->TransferToServer(*a_literal).ConsumeValueOrDie(); - ComputeAndCompareR0(&builder, static_cast(value), - {a_data.get()}); -} - XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) { ComputationBuilder builder(client_, TestName()); builder.Mul(builder.Mul(builder.ConstantR0(2.1f), diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index 8d7f7fd1237..6c952b29e28 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -94,7 +94,7 @@ void RealMain(tensorflow::gtl::ArraySlice args) { OperationDumper dumper(arg); for (auto& computation : module.computations()) { - TF_CHECK_OK(computation->root_instruction()->Accept(&dumper)); + TF_CHECK_OK(computation->Accept(&dumper)); } } } diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD index 11dc8dfd31d..7b20d31e27c 100644 --- a/tensorflow/contrib/boosted_trees/BUILD +++ b/tensorflow/contrib/boosted_trees/BUILD @@ -113,6 +113,7 @@ py_test( srcs_version = "PY2AND3", tags = [ "nomac", # b/63258195 + "notsan", # b/62863147 ], deps = [ ":gbdt_batch", diff --git a/tensorflow/contrib/cmake/patches/gif/CMakeLists.txt b/tensorflow/contrib/cmake/patches/gif/CMakeLists.txt index 0fe919d89e3..fd3fb76bb73 100644 --- a/tensorflow/contrib/cmake/patches/gif/CMakeLists.txt +++ b/tensorflow/contrib/cmake/patches/gif/CMakeLists.txt @@ -19,6 +19,14 @@ set(GIFLIB_INCLUDES "lib/gif_lib.h" ) +if (WIN32) + # Suppress warnings to reduce build log size. + add_definitions(/wd4267 /wd4244 /wd4800 /wd4503 /wd4554 /wd4996 /wd4348 /wd4018) + add_definitions(/wd4099 /wd4146 /wd4267 /wd4305 /wd4307) + add_definitions(/wd4715 /wd4722 /wd4723 /wd4838 /wd4309 /wd4334) + add_definitions(/wd4003 /wd4244 /wd4267 /wd4503 /wd4506 /wd4800 /wd4996) +endif() + include_directories("${CMAKE_CURRENT_SOURCE_DIR}/lib") add_library(giflib ${GIFLIB_SRCS}) diff --git a/tensorflow/contrib/cmake/patches/jpeg/CMakeLists.txt b/tensorflow/contrib/cmake/patches/jpeg/CMakeLists.txt index 782076ef74c..5807813fde5 100644 --- a/tensorflow/contrib/cmake/patches/jpeg/CMakeLists.txt +++ b/tensorflow/contrib/cmake/patches/jpeg/CMakeLists.txt @@ -62,6 +62,14 @@ set(LIBJPEG_INCLUDES "jversion.h" ) +if (WIN32) + # Suppress warnings to reduce build log size. + add_definitions(/wd4267 /wd4244 /wd4800 /wd4503 /wd4554 /wd4996 /wd4348 /wd4018) + add_definitions(/wd4099 /wd4146 /wd4267 /wd4305 /wd4307) + add_definitions(/wd4715 /wd4722 /wd4723 /wd4838 /wd4309 /wd4334) + add_definitions(/wd4003 /wd4244 /wd4267 /wd4503 /wd4506 /wd4800 /wd4996) +endif() + include_directories("${CMAKE_CURRENT_SOURCE_DIR}") add_library(libjpeg ${LIBJPEG_SRCS}) diff --git a/tensorflow/contrib/cmake/patches/lmdb/CMakeLists.txt b/tensorflow/contrib/cmake/patches/lmdb/CMakeLists.txt index 19fa607a101..45713b9b579 100644 --- a/tensorflow/contrib/cmake/patches/lmdb/CMakeLists.txt +++ b/tensorflow/contrib/cmake/patches/lmdb/CMakeLists.txt @@ -12,6 +12,14 @@ set(LIBLMDB_INCLUDES "libraries/liblmdb/midl.h" ) +if (WIN32) + # Suppress warnings to reduce build log size. + add_definitions(/wd4267 /wd4244 /wd4800 /wd4503 /wd4554 /wd4996 /wd4348 /wd4018) + add_definitions(/wd4099 /wd4146 /wd4267 /wd4305 /wd4307) + add_definitions(/wd4715 /wd4722 /wd4723 /wd4838 /wd4309 /wd4334) + add_definitions(/wd4003 /wd4244 /wd4267 /wd4503 /wd4506 /wd4800 /wd4996) +endif() + include_directories("${CMAKE_CURRENT_SOURCE_DIR}") add_library(lmdb ${LIBLMDB_SRCS}) diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 896df8eb344..ce353023174 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -361,6 +361,11 @@ add_python_module("tensorflow/contrib/framework/python") add_python_module("tensorflow/contrib/framework/python/framework") add_python_module("tensorflow/contrib/framework/python/ops") add_python_module("tensorflow/contrib/gan") +add_python_module("tensorflow/contrib/gan/python") +add_python_module("tensorflow/contrib/gan/python/features") +add_python_module("tensorflow/contrib/gan/python/features/python") +add_python_module("tensorflow/contrib/gan/python/losses") +add_python_module("tensorflow/contrib/gan/python/losses/python") add_python_module("tensorflow/contrib/graph_editor") add_python_module("tensorflow/contrib/graph_editor/examples") add_python_module("tensorflow/contrib/graph_editor/tests") @@ -1147,12 +1152,24 @@ add_custom_command(TARGET tf_python_build_pip_package POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/include/unsupported/Eigen) -if(${tensorflow_ENABLE_GPU}) - add_custom_command(TARGET tf_python_build_pip_package POST_BUILD - COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/setup.py bdist_wheel --project_name tensorflow_gpu - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tf_python) +if(${tensorflow_TF_NIGHTLY}) + if(${tensorflow_ENABLE_GPU}) + add_custom_command(TARGET tf_python_build_pip_package POST_BUILD + COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/setup.py bdist_wheel --project_name tf_nightly_gpu + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tf_python) + else() + add_custom_command(TARGET tf_python_build_pip_package POST_BUILD + COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/setup.py bdist_wheel --project_name tf_nightly + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tf_python) + endif(${tensorflow_ENABLE_GPU}) else() - add_custom_command(TARGET tf_python_build_pip_package POST_BUILD - COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/setup.py bdist_wheel - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tf_python) -endif(${tensorflow_ENABLE_GPU}) + if(${tensorflow_ENABLE_GPU}) + add_custom_command(TARGET tf_python_build_pip_package POST_BUILD + COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/setup.py bdist_wheel --project_name tensorflow_gpu + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tf_python) + else() + add_custom_command(TARGET tf_python_build_pip_package POST_BUILD + COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/setup.py bdist_wheel + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tf_python) + endif(${tensorflow_ENABLE_GPU}) +endif(${tensorflow_TF_NIGHTLY}) diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 2f93c345027..6891fd4231f 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -24,6 +24,8 @@ py_test( "//tensorflow/python:functional_ops", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:script_ops", "//tensorflow/python:training", "//third_party/py/numpy", ], diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index b20742f7758..7ee21d4e01d 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -27,10 +27,13 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import function from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import script_ops from tensorflow.python.platform import test from tensorflow.python.training import server_lib @@ -420,7 +423,7 @@ class IteratorTest(test.TestCase): def testRemoteIteratorUsingRemoteCallOpDirectSession(self): worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 + worker_config.device_count["CPU"] = 3 with ops.device("/job:localhost/replica:0/task:0/cpu:1"): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) @@ -448,12 +451,12 @@ class IteratorTest(test.TestCase): target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) self.assertEqual(elem, [1]) - # Fails when target is cpu:0 where the resource is not located. + # Fails when target is cpu:2 where the resource is not located. with self.assertRaises(errors.InvalidArgumentError): sess.run( remote_op, feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" + target_placeholder: "/job:localhost/replica:0/task:0/cpu:2" }) elem = sess.run( remote_op, @@ -474,6 +477,61 @@ class IteratorTest(test.TestCase): target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" }) + def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + with ops.device("/job:localhost/replica:0/task:0/cpu:0"): + dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) + iterator_3 = dataset_3.make_one_shot_iterator() + iterator_3_handle = iterator_3.string_handle() + + def _encode_raw(byte_array): + return "".join([chr(item) for item in byte_array]) + + @function.Defun(dtypes.uint8) + def _remote_fn(h): + handle = script_ops.py_func(_encode_raw, [h], dtypes.string) + remote_iterator = dataset_ops.Iterator.from_string_handle( + handle, dataset_3.output_types, dataset_3.output_shapes) + return remote_iterator.get_next() + + with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): + target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) + iterator_3_handle_uint8 = parsing_ops.decode_raw( + bytes=iterator_3_handle, out_type=dtypes.uint8) + remote_op = functional_ops.remote_call( + args=[iterator_3_handle_uint8], + Tout=[dtypes.int32], + f=_remote_fn, + target=target_placeholder) + + with self.test_session() as sess: + elem = sess.run( + remote_op, + feed_dict={ + target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" + }) + self.assertEqual(elem, [1]) + elem = sess.run( + remote_op, + feed_dict={ + target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" + }) + self.assertEqual(elem, [2]) + elem = sess.run( + remote_op, + feed_dict={ + target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" + }) + self.assertEqual(elem, [3]) + with self.assertRaises(errors.OutOfRangeError): + sess.run( + remote_op, + feed_dict={ + target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" + }) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py index f9198bacfbd..f01f5f11f71 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py @@ -235,7 +235,7 @@ class SloppyInterleaveDatasetTest(test.TestCase): self.read_coordination_events[expected_element].acquire() else: self.write_coordination_events[expected_element].set() - time.sleep(0.01) # Sleep to consistently "avoid" the race condition. + time.sleep(0.1) # Sleep to consistently "avoid" the race condition. actual_element = sess.run(self.next_element) if not done_first_event: done_first_event = True @@ -300,7 +300,7 @@ class SloppyInterleaveDatasetTest(test.TestCase): self.read_coordination_events[expected_element].acquire() else: self.write_coordination_events[expected_element].set() - time.sleep(0.01) # Sleep to consistently "avoid" the race condition. + time.sleep(0.1) # Sleep to consistently "avoid" the race condition. actual_element = sess.run(self.next_element) if not done_first_event: done_first_event = True diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py index 808d25c8c7d..b3de7795776 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -49,25 +49,46 @@ class SqlDatasetTest(test.TestCase): c = conn.cursor() c.execute("DROP TABLE IF EXISTS students") c.execute("DROP TABLE IF EXISTS people") + c.execute("DROP TABLE IF EXISTS townspeople") c.execute( - "CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY," - " first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100)," - " school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), " - "grade_level INTEGER, income INTEGER, favorite_number INTEGER)") + "CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, " + "first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), " + "school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), " + "desk_number INTEGER, income INTEGER, favorite_number INTEGER, " + "favorite_big_number INTEGER, favorite_negative_number INTEGER, " + "favorite_medium_sized_number INTEGER, brownie_points INTEGER, " + "account_balance INTEGER, registration_complete INTEGER)") c.executemany( "INSERT INTO students (first_name, last_name, motto, school_id, " - "favorite_nonsense_word, grade_level, income, favorite_number) " - "VALUES (?, ?, ?, ?, ?, ?, ?, ?)", - [("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647), - ("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 11, -20000, - -2147483648)]) + "favorite_nonsense_word, desk_number, income, favorite_number, " + "favorite_big_number, favorite_negative_number, " + "favorite_medium_sized_number, brownie_points, account_balance, " + "registration_complete) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + [("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647, + 9223372036854775807, -2, 32767, 0, 0, 1), + ("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000, + -2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)]) c.execute( "CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, " "first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))") c.executemany( - "INSERT INTO people (first_name, last_name, state) VALUES (?, ?, ?)", + "INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)", [("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe", "California")]) + c.execute( + "CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY " + "KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories " + "FLOAT, accolades FLOAT, triumphs FLOAT)") + c.executemany( + "INSERT INTO townspeople (first_name, last_name, victories, " + "accolades, triumphs) VALUES (?, ?, ?, ?, ?)", + [("George", "Washington", 20.00, + 1331241.321342132321324589798264627463827647382647382643874, + 9007199254740991.0), + ("John", "Adams", -19.95, + 1331241321342132321324589798264627463827647382647382643874.0, + 9007199254740992.0)]) conn.commit() conn.close() @@ -80,7 +101,6 @@ class SqlDatasetTest(test.TestCase): sess.run( init_op, feed_dict={ - self.driver_name: "sqlite", self.query: "SELECT first_name, last_name, motto FROM students " "ORDER BY first_name DESC" }) @@ -98,7 +118,6 @@ class SqlDatasetTest(test.TestCase): sess.run( init_op, feed_dict={ - self.driver_name: "sqlite", self.query: "SELECT students.first_name, state, motto FROM students " "INNER JOIN people " @@ -118,7 +137,6 @@ class SqlDatasetTest(test.TestCase): sess.run( init_op, feed_dict={ - self.driver_name: "sqlite", self.query: "SELECT first_name, last_name, favorite_nonsense_word " "FROM students ORDER BY first_name DESC" @@ -249,20 +267,124 @@ class SqlDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next) + # Test that `SqlDataset` can read an integer from a SQLite database table and + # place it in an `int8` tensor. + def testReadResultSetInt8(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, desk_number FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 9), sess.run(get_next)) + self.assertEqual((b"Jane", 127), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a negative or 0-valued integer from a + # SQLite database table and place it in an `int8` tensor. + def testReadResultSetInt8NegativeAndZero(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8, + dtypes.int8)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, income, favorite_negative_number " + "FROM students " + "WHERE first_name = 'John' ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 0, -2), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a large (positive or negative) integer from + # a SQLite database table and place it in an `int8` tensor. + def testReadResultSetInt8MaxValues(self): + init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: + "SELECT desk_number, favorite_negative_number FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((9, -2), sess.run(get_next)) + # Max and min values of int8 + self.assertEqual((127, -128), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read an integer from a SQLite database table and + # place it in an `int16` tensor. + def testReadResultSetInt16(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, desk_number FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 9), sess.run(get_next)) + self.assertEqual((b"Jane", 127), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a negative or 0-valued integer from a + # SQLite database table and place it in an `int16` tensor. + def testReadResultSetInt16NegativeAndZero(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16, + dtypes.int16)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, income, favorite_negative_number " + "FROM students " + "WHERE first_name = 'John' ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 0, -2), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a large (positive or negative) integer from + # a SQLite database table and place it in an `int16` tensor. + def testReadResultSetInt16MaxValues(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, favorite_medium_sized_number " + "FROM students ORDER BY first_name DESC" + }) + # Max value of int16 + self.assertEqual((b"John", 32767), sess.run(get_next)) + # Min value of int16 + self.assertEqual((b"Jane", -32768), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read an integer from a SQLite database table and + # place it in an `int32` tensor. def testReadResultSetInt32(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) with self.test_session() as sess: sess.run( init_op, feed_dict={ - self.query: "SELECT first_name, grade_level FROM students " + self.query: "SELECT first_name, desk_number FROM students " "ORDER BY first_name DESC" }) self.assertEqual((b"John", 9), sess.run(get_next)) - self.assertEqual((b"Jane", 11), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + self.assertEqual((b"Jane", 127), sess.run(get_next)) + # Test that `SqlDataset` can read a negative or 0-valued integer from a + # SQLite database table and place it in an `int32` tensor. def testReadResultSetInt32NegativeAndZero(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) with self.test_session() as sess: @@ -277,6 +399,8 @@ class SqlDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + # Test that `SqlDataset` can read a large (positive or negative) integer from + # a SQLite database table and place it in an `int32` tensor. def testReadResultSetInt32MaxValues(self): init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32)) with self.test_session() as sess: @@ -286,7 +410,9 @@ class SqlDatasetTest(test.TestCase): self.query: "SELECT first_name, favorite_number FROM students " "ORDER BY first_name DESC" }) + # Max value of int32 self.assertEqual((b"John", 2147483647), sess.run(get_next)) + # Min value of int32 self.assertEqual((b"Jane", -2147483648), sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -307,6 +433,224 @@ class SqlDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + # Test that `SqlDataset` can read an integer from a SQLite database table + # and place it in an `int64` tensor. + def testReadResultSetInt64(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, desk_number FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 9), sess.run(get_next)) + self.assertEqual((b"Jane", 127), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a negative or 0-valued integer from a + # SQLite database table and place it in an `int64` tensor. + def testReadResultSetInt64NegativeAndZero(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, income FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 0), sess.run(get_next)) + self.assertEqual((b"Jane", -20000), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a large (positive or negative) integer from + # a SQLite database table and place it in an `int64` tensor. + def testReadResultSetInt64MaxValues(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: + "SELECT first_name, favorite_big_number FROM students " + "ORDER BY first_name DESC" + }) + # Max value of int64 + self.assertEqual((b"John", 9223372036854775807), sess.run(get_next)) + # Min value of int64 + self.assertEqual((b"Jane", -9223372036854775808), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read an integer from a SQLite database table and + # place it in a `uint8` tensor. + def testReadResultSetUInt8(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, desk_number FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 9), sess.run(get_next)) + self.assertEqual((b"Jane", 127), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read the minimum and maximum uint8 values from a + # SQLite database table and place them in `uint8` tensors. + def testReadResultSetUInt8MinAndMaxValues(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, brownie_points FROM students " + "ORDER BY first_name DESC" + }) + # Min value of uint8 + self.assertEqual((b"John", 0), sess.run(get_next)) + # Max value of uint8 + self.assertEqual((b"Jane", 255), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read an integer from a SQLite database table + # and place it in a `uint16` tensor. + def testReadResultSetUInt16(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, desk_number FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", 9), sess.run(get_next)) + self.assertEqual((b"Jane", 127), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read the minimum and maximum uint16 values from a + # SQLite database table and place them in `uint16` tensors. + def testReadResultSetUInt16MinAndMaxValues(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, account_balance FROM students " + "ORDER BY first_name DESC" + }) + # Min value of uint16 + self.assertEqual((b"John", 0), sess.run(get_next)) + # Max value of uint16 + self.assertEqual((b"Jane", 65535), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a 0-valued and 1-valued integer from a + # SQLite database table and place them as `True` and `False` respectively + # in `bool` tensors. + def testReadResultSetBool(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: + "SELECT first_name, registration_complete FROM students " + "ORDER BY first_name DESC" + }) + self.assertEqual((b"John", True), sess.run(get_next)) + self.assertEqual((b"Jane", False), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read an integer that is not 0-valued or 1-valued + # from a SQLite database table and place it as `True` in a `bool` tensor. + def testReadResultSetBoolNotZeroOrOne(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: "SELECT first_name, favorite_medium_sized_number " + "FROM students ORDER BY first_name DESC" + }) + self.assertEqual((b"John", True), sess.run(get_next)) + self.assertEqual((b"Jane", True), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a float from a SQLite database table + # and place it in a `float64` tensor. + def testReadResultSetFloat64(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.float64)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: + "SELECT first_name, last_name, victories FROM townspeople " + "ORDER BY first_name" + }) + self.assertEqual((b"George", b"Washington", 20.0), sess.run(get_next)) + self.assertEqual((b"John", b"Adams", -19.95), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a float from a SQLite database table beyond + # the precision of 64-bit IEEE, without throwing an error. Test that + # `SqlDataset` identifies such a value as equal to itself. + def testReadResultSetFloat64OverlyPrecise(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.float64)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: + "SELECT first_name, last_name, accolades FROM townspeople " + "ORDER BY first_name" + }) + self.assertEqual( + (b"George", b"Washington", + 1331241.321342132321324589798264627463827647382647382643874), + sess.run(get_next)) + self.assertEqual( + (b"John", b"Adams", + 1331241321342132321324589798264627463827647382647382643874.0), + sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that `SqlDataset` can read a float from a SQLite database table, + # representing the largest integer representable as a 64-bit IEEE float + # such that the previous integer is also representable as a 64-bit IEEE float. + # Test that `SqlDataset` can distinguish these two numbers. + def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self): + init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string, + dtypes.float64)) + with self.test_session() as sess: + sess.run( + init_op, + feed_dict={ + self.query: + "SELECT first_name, last_name, triumphs FROM townspeople " + "ORDER BY first_name" + }) + self.assertNotEqual((b"George", b"Washington", 9007199254740992.0), + sess.run(get_next)) + self.assertNotEqual((b"John", b"Adams", 9007199254740991.0), + sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index 0ee9acfc97f..cfacfdd7d99 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -2276,6 +2276,23 @@ class SqlDataset(Dataset): def __init__(self, driver_name, data_source_name, query, output_types): """Creates a `SqlDataset`. + `SqlDataset` allows a user to read data from the result set of a SQL query. + For example: + + ```python + dataset = tf.contrib.data.SqlDataset("sqlite", "/foo/bar.sqlite3", + "SELECT name, age FROM people", + (tf.string, tf.int32)) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + # Prints the rows of the result set of the above query. + while True: + try: + print(sess.run(next_element)) + except tf.errors.OutOfRangeError: + break + ``` + Args: driver_name: A 0-D `tf.string` tensor containing the database type. Currently, the only supported value is 'sqlite'. diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 1b831f8afba..26ada939da2 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -21,6 +21,16 @@ py_library( ], ) +cuda_py_test( + name = "tfe_test", + srcs = ["tfe_test.py"], + additional_deps = [ + ":tfe", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform_test", + ], +) + py_library( name = "saver", srcs = ["saver.py"], diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 2c7494a0a86..b5bf839a89b 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -18,9 +18,9 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice. To use, at program startup, call `tfe.enable_eager_execution()`. -@@list_devices @@device - +@@list_devices +@@num_gpus @@defun @@implicit_gradients @@ -58,9 +58,10 @@ from tensorflow.python.util.all_util import remove_undocumented from tensorflow.python.eager import backprop from tensorflow.python.eager.custom_gradient import custom_gradient from tensorflow.python.eager import function -from tensorflow.python.eager.context import context from tensorflow.python.eager.context import device from tensorflow.python.eager.context import enable_eager_execution +from tensorflow.python.eager.context import list_devices +from tensorflow.python.eager.context import num_gpus from tensorflow.python.eager.context import run from tensorflow.python.eager.core import enable_tracing from tensorflow.python.eager.execution_callbacks import add_execution_callback @@ -70,10 +71,6 @@ from tensorflow.python.eager.execution_callbacks import inf_nan_callback from tensorflow.python.eager.execution_callbacks import nan_callback from tensorflow.python.eager.execution_callbacks import seterr - -def list_devices(): - return context().devices() - defun = function.defun implicit_gradients = backprop.implicit_grad implicit_value_and_gradients = backprop.implicit_val_and_grad diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py new file mode 100644 index 00000000000..2a9d7589d3a --- /dev/null +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -0,0 +1,36 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tfe.py.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.eager.python import tfe +from tensorflow.python.platform import test + + +class TFETest(test.TestCase): + + def testListDevices(self): + # Expect at least one device. + self.assertTrue(tfe.list_devices()) + + def testNumGPUs(self): + devices = tfe.list_devices() + self.assertEqual(len(devices) - 1, tfe.num_gpus()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 46cdf086ddc..2d2794e3504 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -26,6 +26,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":extenders", + ":head", ], ) @@ -59,3 +60,14 @@ py_test( "//third_party/py/numpy", ], ) + +py_library( + name = "head", + srcs = [ + "python/estimator/head.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python/estimator:head", + ], +) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index 9180a3acc36..346653c47f4 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -20,10 +20,16 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long,wildcard-import from tensorflow.contrib.estimator.python.estimator.extenders import * +from tensorflow.contrib.estimator.python.estimator.head import * from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import -_allowed_symbols = ['add_metrics'] +_allowed_symbols = [ + 'add_metrics', + 'binary_classification_head', + 'multi_class_head', + 'regression_head', +] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py new file mode 100644 index 00000000000..005de115d43 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -0,0 +1,125 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Abstractions for the head(s) of a model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator.canned import head as head_lib + + +def multi_class_head(n_classes, + weight_column=None, + label_vocabulary=None, + head_name=None): + """Creates a `_Head` for multi class classification. + + Uses `sparse_softmax_cross_entropy` loss. + + This head expects to be fed integer labels specifying the class index. + + Args: + n_classes: Number of classes, must be greater than 2 (for 2 classes, use + `_BinaryLogisticHeadWithSigmoidCrossEntropyLoss`). + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to down weight or boost examples during training. It + will be multiplied by the loss of the example. + label_vocabulary: A list of strings represents possible label values. If it + is not given, that means labels are already encoded as integer within + [0, n_classes). If given, labels must be string type and have any value in + `label_vocabulary`. Also there will be errors if vocabulary is not + provided and labels are string. + head_name: name of the head. If provided, summary and metrics keys will be + suffixed by `"/" + head_name`. + + Returns: + An instance of `_Head` for multi class classification. + + Raises: + ValueError: if `n_classes`, `metric_class_ids` or `label_keys` is invalid. + """ + return head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint:disable=protected-access + n_classes=n_classes, + weight_column=weight_column, + label_vocabulary=label_vocabulary, + head_name=head_name) + + +def binary_classification_head( + weight_column=None, thresholds=None, label_vocabulary=None, head_name=None): + """Creates a `_Head` for single label binary classification. + + This head uses `sigmoid_cross_entropy_with_logits` loss. + + This head expects to be fed float labels of shape `(batch_size, 1)`. + + Args: + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to down weight or boost examples during training. It + will be multiplied by the loss of the example. + thresholds: Iterable of floats in the range `(0, 1)`. For binary + classification metrics such as precision and recall, an eval metric is + generated for each threshold value. This threshold is applied to the + logistic values to determine the binary classification (i.e., above the + threshold is `true`, below is `false`. + label_vocabulary: A list of strings represents possible label values. If it + is not given, that means labels are already encoded within [0, 1]. If + given, labels must be string type and have any value in + `label_vocabulary`. Also there will be errors if vocabulary is not + provided and labels are string. + head_name: name of the head. If provided, summary and metrics keys will be + suffixed by `"/" + head_name`. + + Returns: + An instance of `_Head` for binary classification. + + Raises: + ValueError: if `thresholds` contains a value outside of `(0, 1)`. + """ + return head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint:disable=protected-access + weight_column=weight_column, + thresholds=thresholds, + label_vocabulary=label_vocabulary, + head_name=head_name) + + +def regression_head(weight_column=None, + label_dimension=1, + head_name=None): + """Creates a `_Head` for regression using the mean squared loss. + + Uses `mean_squared_error` loss. + + Args: + weight_column: A string or a `_NumericColumn` created by + `tf.feature_column.numeric_column` defining feature column representing + weights. It is used to down weight or boost examples during training. It + will be multiplied by the loss of the example. + label_dimension: Number of regression labels per example. This is the size + of the last dimension of the labels `Tensor` (typically, this has shape + `[batch_size, label_dimension]`). + head_name: name of the head. If provided, summary and metrics keys will be + suffixed by `"/" + head_name`. + + Returns: + An instance of `_Head` for linear regression. + """ + return head_lib._regression_head_with_mean_squared_error_loss( # pylint:disable=protected-access + weight_column=weight_column, + label_dimension=label_dimension, + head_name=head_name) diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index 4737d1ae593..638a4be4464 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -162,6 +162,7 @@ tf_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:variables", ], + tags = ["notsan"], # b/62863147 ) py_library( diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index b2de2823563..cb2cd7c7ef0 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -1,9 +1,12 @@ +# Files for using TFGAN framework. package(default_visibility = ["//tensorflow:__subpackages__"]) licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) +load("//tensorflow:tensorflow.bzl", "py_test") + py_library( name = "gan", srcs = [ @@ -11,6 +14,192 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":features", + ":losses", + ], +) + +py_library( + name = "losses", + srcs = ["python/losses/__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":losses_impl", + ":tuple_losses", + "//tensorflow/python:util", + ], +) + +py_library( + name = "features", + srcs = ["python/features/__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":clip_weights", + ":conditioning_utils", + ":virtual_batchnorm", + "//tensorflow/python:util", + ], +) + +py_library( + name = "losses_impl", + srcs = ["python/losses/python/losses_impl.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:clip_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:summary", + "//tensorflow/python:tensor_util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/ops/distributions", + "//tensorflow/python/ops/losses", + "//third_party/py/numpy", + ], +) + +py_test( + name = "losses_impl_test", + srcs = ["python/losses/python/losses_impl_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":losses_impl", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:clip_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/ops/distributions", + "//tensorflow/python/ops/losses", + ], +) + +py_library( + name = "tuple_losses", + srcs = [ + "python/losses/python/losses_wargs.py", + "python/losses/python/tuple_losses.py", + "python/losses/python/tuple_losses_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":losses_impl", + "//tensorflow/python:util", + ], +) + +py_test( + name = "tuple_losses_test", + srcs = ["python/losses/python/tuple_losses_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":tuple_losses", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + +py_library( + name = "conditioning_utils", + srcs = [ + "python/features/python/conditioning_utils.py", + "python/features/python/conditioning_utils_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:embedding_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:tensor_util", + "//tensorflow/python:variable_scope", + ], +) + +py_test( + name = "conditioning_utils_test", + srcs = ["python/features/python/conditioning_utils_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":conditioning_utils", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + ], +) + +py_library( + name = "virtual_batchnorm", + srcs = [ + "python/features/python/virtual_batchnorm.py", + "python/features/python/virtual_batchnorm_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_util", + "//tensorflow/python:variable_scope", + ], +) + +py_test( + name = "virtual_batchnorm_test", + srcs = ["python/features/python/virtual_batchnorm_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":virtual_batchnorm", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:layers", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + "//tensorflow/python:random_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_library( + name = "clip_weights", + srcs = [ + "python/features/python/clip_weights.py", + "python/features/python/clip_weights_impl.py", + ], + srcs_version = "PY2AND3", + deps = ["//tensorflow/contrib/opt:opt_py"], +) + +py_test( + name = "clip_weights_test", + srcs = ["python/features/python/clip_weights_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":clip_weights", + "//tensorflow/python:client_testlib", + "//tensorflow/python:training", + "//tensorflow/python:variables", ], ) diff --git a/tensorflow/contrib/gan/__init__.py b/tensorflow/contrib/gan/__init__.py index a46b0e8d5de..b2f4bf01190 100644 --- a/tensorflow/contrib/gan/__init__.py +++ b/tensorflow/contrib/gan/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2017 Google Inc. All Rights Reserved. +# Copyright 2016 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,8 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API.""" +"""TFGAN grouped API. Please see README.md for details and usage.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function + +# Collapse TFGAN into a tiered namespace. +from tensorflow.contrib.gan.python import features +from tensorflow.contrib.gan.python import losses + +del absolute_import +del division +del print_function diff --git a/tensorflow/contrib/gan/python/features/__init__.py b/tensorflow/contrib/gan/python/features/__init__.py new file mode 100644 index 00000000000..6d0972f8db4 --- /dev/null +++ b/tensorflow/contrib/gan/python/features/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2017 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TFGAN grouped API. Please see README.md for details and usage.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Collapse features into a single namespace. +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.gan.python.features.python import clip_weights +from tensorflow.contrib.gan.python.features.python import conditioning_utils +from tensorflow.contrib.gan.python.features.python import virtual_batchnorm + +from tensorflow.contrib.gan.python.features.python.clip_weights import * +from tensorflow.contrib.gan.python.features.python.conditioning_utils import * +from tensorflow.contrib.gan.python.features.python.virtual_batchnorm import * +# pylint: enable=unused-import,wildcard-import + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = clip_weights.__all__ +_allowed_symbols += conditioning_utils.__all__ +_allowed_symbols += virtual_batchnorm.__all__ +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/features/python/clip_weights.py b/tensorflow/contrib/gan/python/features/python/clip_weights.py new file mode 100644 index 00000000000..fa76fd7928f --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/clip_weights.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities to clip weights.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.features.python import clip_weights_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.features.python.clip_weights_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = clip_weights_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/features/python/clip_weights_impl.py b/tensorflow/contrib/gan/python/features/python/clip_weights_impl.py new file mode 100644 index 00000000000..96fbb8186d7 --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/clip_weights_impl.py @@ -0,0 +1,80 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities to clip weights. + +This is useful in the original formulation of the Wasserstein loss, which +requires that the discriminator be K-Lipschitz. See +https://arxiv.org/pdf/1701.07875 for more details. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.opt.python.training import variable_clipping_optimizer + + +__all__ = [ + 'clip_variables', + 'clip_discriminator_weights', +] + + +def clip_discriminator_weights(optimizer, model, weight_clip): + """Modifies an optimizer so it clips weights to a certain value. + + Args: + optimizer: An optimizer to perform variable weight clipping. + model: A GANModel namedtuple. + weight_clip: Positive python float to clip discriminator weights. Used to + enforce a K-lipschitz condition, which is useful for some GAN training + schemes (ex WGAN: https://arxiv.org/pdf/1701.07875). + + Returns: + An optimizer to perform weight clipping after updates. + + Raises: + ValueError: If `weight_clip` is less than 0. + """ + return clip_variables(optimizer, model.discriminator_variables, weight_clip) + + +def clip_variables(optimizer, variables, weight_clip): + """Modifies an optimizer so it clips weights to a certain value. + + Args: + optimizer: An optimizer to perform variable weight clipping. + variables: A list of TensorFlow variables. + weight_clip: Positive python float to clip discriminator weights. Used to + enforce a K-lipschitz condition, which is useful for some GAN training + schemes (ex WGAN: https://arxiv.org/pdf/1701.07875). + + Returns: + An optimizer to perform weight clipping after updates. + + Raises: + ValueError: If `weight_clip` is less than 0. + """ + if weight_clip < 0: + raise ValueError( + '`discriminator_weight_clip` must be positive. Instead, was %s', + weight_clip) + return variable_clipping_optimizer.VariableClippingOptimizer( + opt=optimizer, + # Do no reduction, so clipping happens per-value. + vars_to_clip_dims={var: [] for var in variables}, + max_norm=weight_clip, + use_locking=True, + colocate_clip_ops_with_vars=True) diff --git a/tensorflow/contrib/gan/python/features/python/clip_weights_test.py b/tensorflow/contrib/gan/python/features/python/clip_weights_test.py new file mode 100644 index 00000000000..030e37ec679 --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/clip_weights_test.py @@ -0,0 +1,81 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tfgan.python.features.clip_weights.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.contrib.gan.python.features.python import clip_weights_impl as clip_weights + +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import training + + +class ClipWeightsTest(test.TestCase): + """Tests for `discriminator_weight_clip`.""" + + def setUp(self): + self.variables = [variables.Variable(2.0)] + self.tuple = collections.namedtuple( + 'VarTuple', ['discriminator_variables'])(self.variables) + + def _test_weight_clipping_helper(self, use_tuple): + loss = self.variables[0] * 2.0 + opt = training.GradientDescentOptimizer(1.0) + if use_tuple: + opt_clip = clip_weights.weight_clip(opt, self.variables, 0.1) + else: + opt_clip = clip_weights.discriminator_weight_clip(opt, self.tuple, 0.1) + + train_op1 = opt.minimize(loss, var_list=self.variables) + train_op2 = opt_clip.minimize(loss, var_list=self.variables) + + with self.test_session(use_gpu=True) as sess: + sess.run(variables.global_variables_initializer()) + self.assertEqual(2.0, self.variables[0].eval()) + sess.run(train_op1) + self.assertLess(0.1, self.variables[0].eval()) + + with self.test_session(use_gpu=True) as sess: + sess.run(variables.global_variables_initializer()) + self.assertEqual(2.0, self.variables[0].eval()) + sess.run(train_op2) + self.assertNear(0.1, self.variables[0].eval(), 1e-7) + + def test_weight_clipping_argsonly(self): + self._test_weight_clipping_helper(False) + + def test_weight_clipping_ganmodel(self): + self._test_weight_clipping_helper(True) + + def _test_incorrect_weight_clip_value_helper(self, use_tuple): + opt = training.GradientDescentOptimizer(1.0) + + if use_tuple: + with self.assertRaisesRegexp(ValueError, 'must be positive'): + clip_weights.clip_discriminator_weights(opt, self.tuple, weight_clip=-1) + else: + with self.assertRaisesRegexp(ValueError, 'must be positive'): + clip_weights.clip_weights(opt, self.variables, weight_clip=-1) + + def test_incorrect_weight_clip_value_argsonly(self): + self._test_incorrect_weight_clip_value_helper(False) + + def test_incorrect_weight_clip_value_tuple(self): + self._test_incorrect_weight_clip_value_helper(True) diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils.py new file mode 100644 index 00000000000..df71187fbd9 --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Miscellanous utilities for TFGAN code and examples.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.features.python import conditioning_utils_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.features.python.conditioning_utils_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = conditioning_utils_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py new file mode 100644 index 00000000000..cd31c62667f --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils_impl.py @@ -0,0 +1,112 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Miscellanous utilities for TFGAN code and examples. + +Includes: +1) Conditioning the value of a Tensor, based on techniques from + https://arxiv.org/abs/1609.03499. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.layers.python.layers import layers +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope + + +__all__ = [ + 'condition_tensor', + 'condition_tensor_from_onehot', +] + + +def _get_shape(tensor): + tensor_shape = array_ops.shape(tensor) + static_tensor_shape = tensor_util.constant_value(tensor_shape) + return (static_tensor_shape if static_tensor_shape is not None else + tensor_shape) + + +def condition_tensor(tensor, conditioning): + """Condition the value of a tensor. + + Conditioning scheme based on https://arxiv.org/abs/1609.03499. + + Args: + tensor: A minibatch tensor to be conditioned. + conditioning: A minibatch Tensor of to condition on. Must be 2D, with first + dimension the same as `tensor`. + + Returns: + `tensor` conditioned on `conditioning`. + + Raises: + ValueError: If the non-batch dimensions of `tensor` aren't fully defined. + ValueError: If `conditioning` isn't at least 2D. + ValueError: If the batch dimension for the input Tensors don't match. + """ + tensor.shape[1:].assert_is_fully_defined() + num_features = tensor.shape[1:].num_elements() + + mapped_conditioning = layers.linear( + layers.flatten(conditioning), num_features) + if not mapped_conditioning.shape.is_compatible_with(tensor.shape): + mapped_conditioning = array_ops.reshape( + mapped_conditioning, _get_shape(tensor)) + return tensor + mapped_conditioning + + +def _one_hot_to_embedding(one_hot, embedding_size): + """Get a dense embedding vector from a one-hot encoding.""" + num_tokens = one_hot.shape[1] + label_id = math_ops.argmax(one_hot, axis=1) + embedding = variable_scope.get_variable( + 'embedding', [num_tokens, embedding_size]) + return embedding_ops.embedding_lookup( + embedding, label_id, name='token_to_embedding') + + +def _validate_onehot(one_hot_labels): + one_hot_labels.shape.assert_has_rank(2) + one_hot_labels.shape[1:].assert_is_fully_defined() + + +def condition_tensor_from_onehot(tensor, one_hot_labels, embedding_size=256): + """Condition a tensor based on a one-hot tensor. + + Conditioning scheme based on https://arxiv.org/abs/1609.03499. + + Args: + tensor: Tensor to be conditioned. + one_hot_labels: A Tensor of one-hot labels. Shape is + [batch_size, num_classes]. + embedding_size: The size of the class embedding. + + Returns: + `tensor` conditioned on `one_hot_labels`. + + Raises: + ValueError: `one_hot_labels` isn't 2D, if non-batch dimensions aren't + fully defined, or if batch sizes don't match. + """ + _validate_onehot(one_hot_labels) + + conditioning = _one_hot_to_embedding(one_hot_labels, embedding_size) + return condition_tensor(tensor, conditioning) diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py new file mode 100644 index 00000000000..0898fd3113d --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils_test.py @@ -0,0 +1,76 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tfgan.python.features.conditioning_utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.features.python import conditioning_utils_impl as conditioning_utils + +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ConditioningUtilsTest(test.TestCase): + + def test_condition_tensor_multiple_shapes(self): + for tensor_shape in [(4, 1), (4, 2), (4, 2, 6), (None, 5, 3)]: + for conditioning_shape in [(4, 1), (4, 8), (4, 5, 3)]: + conditioning_utils.condition_tensor( + array_ops.placeholder(dtypes.float32, tensor_shape), + array_ops.placeholder(dtypes.float32, conditioning_shape)) + + def test_condition_tensor_asserts(self): + with self.assertRaisesRegexp(ValueError, 'Cannot reshape'): + conditioning_utils.condition_tensor( + array_ops.placeholder(dtypes.float32, (4, 1)), + array_ops.placeholder(dtypes.float32, (5, 1))) + + with self.assertRaisesRegexp(ValueError, 'Shape .* is not fully defined'): + conditioning_utils.condition_tensor( + array_ops.placeholder(dtypes.float32, (5, None)), + array_ops.placeholder(dtypes.float32, (5, 1))) + + with self.assertRaisesRegexp(ValueError, 'must have a least 2 dimensions.'): + conditioning_utils.condition_tensor( + array_ops.placeholder(dtypes.float32, (5, 2)), + array_ops.placeholder(dtypes.float32, (5))) + + def test_condition_tensor_from_onehot(self): + conditioning_utils.condition_tensor_from_onehot( + array_ops.placeholder(dtypes.float32, (5, 4, 1)), + array_ops.placeholder(dtypes.float32, (5, 10))) + + def test_condition_tensor_from_onehot_asserts(self): + with self.assertRaisesRegexp(ValueError, 'Shape .* must have rank 2'): + conditioning_utils.condition_tensor_from_onehot( + array_ops.placeholder(dtypes.float32, (5, 1)), + array_ops.placeholder(dtypes.float32, (5))) + + with self.assertRaisesRegexp(ValueError, 'Shape .* is not fully defined'): + conditioning_utils.condition_tensor_from_onehot( + array_ops.placeholder(dtypes.float32, (5, 1)), + array_ops.placeholder(dtypes.float32, (5, None))) + + with self.assertRaisesRegexp(ValueError, 'Cannot reshape a tensor'): + conditioning_utils.condition_tensor_from_onehot( + array_ops.placeholder(dtypes.float32, (5, 1)), + array_ops.placeholder(dtypes.float32, (4, 6))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm.py new file mode 100644 index 00000000000..ea54ac01cee --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm.py @@ -0,0 +1,27 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Virtual batch normalization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.features.python import virtual_batchnorm_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.features.python.virtual_batchnorm_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = virtual_batchnorm_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py new file mode 100644 index 00000000000..f8b372546b6 --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py @@ -0,0 +1,306 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Virtual batch normalization. + +This technique was first introduced in `Improved Techniques for Training GANs` +(Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch +normalization on a minibatch, it fixes a reference subset of the data to use for +calculating normalization statistics. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import variable_scope + +__all__ = [ + 'VBN', +] + + +def _static_or_dynamic_batch_size(tensor, batch_axis): + """Returns the static or dynamic batch size.""" + batch_size = array_ops.shape(tensor)[batch_axis] + static_batch_size = tensor_util.constant_value(batch_size) + return static_batch_size or batch_size + + +def _statistics(x, axes): + """Calculate the mean and mean square of `x`. + + Modified from the implementation of `tf.nn.moments`. + + Args: + x: A `Tensor`. + axes: Array of ints. Axes along which to compute mean and + variance. + + Returns: + Two `Tensor` objects: `mean` and `square mean`. + """ + # The dynamic range of fp16 is too limited to support the collection of + # sufficient statistics. As a workaround we simply perform the operations + # on 32-bit floats before converting the mean and variance back to fp16 + y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x + + # Compute true mean while keeping the dims for proper broadcasting. + shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keep_dims=True)) + + shifted_mean = math_ops.reduce_mean(y - shift, axes, keep_dims=True) + mean = shifted_mean + shift + mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keep_dims=True) + + mean = array_ops.squeeze(mean, axes) + mean_squared = array_ops.squeeze(mean_squared, axes) + if x.dtype == dtypes.float16: + return (math_ops.cast(mean, dtypes.float16), + math_ops.cast(mean_squared, dtypes.float16)) + else: + return (mean, mean_squared) + + +def _validate_init_input_and_get_axis(reference_batch, axis): + """Validate input and return the used axis value.""" + if reference_batch.shape.ndims is None: + raise ValueError('`reference_batch` has unknown dimensions.') + + ndims = reference_batch.shape.ndims + if axis < 0: + used_axis = ndims + axis + else: + used_axis = axis + if used_axis < 0 or used_axis >= ndims: + raise ValueError('Value of `axis` argument ' + str(used_axis) + + ' is out of range for input with rank ' + str(ndims)) + return used_axis + + +def _validate_call_input(tensor_list, batch_dim): + """Verifies that tensor shapes are compatible, except for `batch_dim`.""" + def _get_shape(tensor): + shape = tensor.shape.as_list() + del shape[batch_dim] + return shape + base_shape = tensor_shape.TensorShape(_get_shape(tensor_list[0])) + for tensor in tensor_list: + base_shape.assert_is_compatible_with(_get_shape(tensor)) + + +class VBN(object): + """A class to perform virtual batch normalization. + + This technique was first introduced in `Improved Techniques for Training GANs` + (Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch + normalization on a minibatch, it fixes a reference subset of the data to use + for calculating normalization statistics. + + To do this, we calculate the reference batch mean and mean square, and modify + those statistics for each example. We use mean square instead of variance, + since it is linear. + + Note that if `center` or `scale` variables are created, they are shared + between all calls to this object. + + The `__init__` API is intended to mimic `tf.layers.batch_normalization` as + closely as possible. + """ + + def __init__(self, + reference_batch, + axis=-1, + epsilon=1e-3, + center=True, + scale=True, + beta_initializer=init_ops.zeros_initializer(), + gamma_initializer=init_ops.ones_initializer(), + beta_regularizer=None, + gamma_regularizer=None, + trainable=True, + name=None, + batch_axis=0): + """Initialize virtual batch normalization object. + + We precompute the 'mean' and 'mean squared' of the reference batch, so that + `__call__` is efficient. This means that the axis must be supplied when the + object is created, not when it is called. + + We precompute 'square mean' instead of 'variance', because the square mean + can be easily adjusted on a per-example basis. + + Args: + reference_batch: A minibatch tensors. This will form the reference data + from which the normalization statistics are calculated. See + https://arxiv.org/abs/1606.03498 for more details. + axis: Integer, the axis that should be normalized (typically the features + axis). For instance, after a `Convolution2D` layer with + `data_format="channels_first"`, set `axis=1` in `BatchNormalization`. + epsilon: Small float added to variance to avoid dividing by zero. + center: If True, add offset of `beta` to normalized tensor. If False, + `beta` is ignored. + scale: If True, multiply by `gamma`. If False, `gamma` is + not used. When the next layer is linear (also e.g. `nn.relu`), this can + be disabled since the scaling can be done by the next layer. + beta_initializer: Initializer for the beta weight. + gamma_initializer: Initializer for the gamma weight. + beta_regularizer: Optional regularizer for the beta weight. + gamma_regularizer: Optional regularizer for the gamma weight. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). + name: String, the name of the ops. + batch_axis: The axis of the batch dimension. This dimension is treated + differently in `virtual batch normalization` vs `batch normalization`. + + Raises: + ValueError: If `reference_batch` has unknown dimensions at graph + construction. + ValueError: If `batch_axis` is the same as `axis`. + """ + axis = _validate_init_input_and_get_axis(reference_batch, axis) + self._epsilon = epsilon + self._beta = 0 + self._gamma = 1 + self._batch_axis = _validate_init_input_and_get_axis( + reference_batch, batch_axis) + + if axis == self._batch_axis: + raise ValueError('`axis` and `batch_axis` cannot be the same.') + + with variable_scope.variable_scope(name, 'VBN', + values=[reference_batch]) as self._vs: + self._reference_batch = reference_batch + + # Calculate important shapes: + # 1) Reduction axes for the reference batch + # 2) Broadcast shape, if necessary + # 3) Reduction axes for the virtual batchnormed batch + # 4) Shape for optional parameters + input_shape = self._reference_batch.shape + ndims = input_shape.ndims + reduction_axes = list(range(ndims)) + del reduction_axes[axis] + + self._broadcast_shape = [1] * len(input_shape) + self._broadcast_shape[axis] = input_shape[axis].value + + self._example_reduction_axes = list(range(ndims)) + del self._example_reduction_axes[max(axis, self._batch_axis)] + del self._example_reduction_axes[min(axis, self._batch_axis)] + + params_shape = self._reference_batch.shape[axis] + + # Determines whether broadcasting is needed. This is slightly different + # than in the `nn.batch_normalization` case, due to `batch_dim`. + self._needs_broadcasting = ( + sorted(self._example_reduction_axes) != list(range(ndims))[:-2]) + + # Calculate the sufficient statistics for the reference batch in a way + # that can be easily modified by additional examples. + self._ref_mean, self._ref_mean_squares = _statistics( + self._reference_batch, reduction_axes) + self._ref_variance = (self._ref_mean_squares - + math_ops.square(self._ref_mean)) + + # Virtual batch normalization uses a weighted average between example + # statistics and the reference batch statistics. + ref_batch_size = _static_or_dynamic_batch_size( + self._reference_batch, self._batch_axis) + self._example_weight = 1. / (math_ops.to_float(ref_batch_size) + 1.) + self._ref_weight = 1. - self._example_weight + + # Make the variables, if necessary. + if center: + self._beta = variable_scope.get_variable( + name='beta', + shape=(params_shape,), + initializer=beta_initializer, + regularizer=beta_regularizer, + trainable=trainable) + if scale: + self._gamma = variable_scope.get_variable( + name='gamma', + shape=(params_shape,), + initializer=gamma_initializer, + regularizer=gamma_regularizer, + trainable=trainable) + + def _virtual_statistics(self, inputs, reduction_axes): + """Compute the statistics needed for virtual batch normalization.""" + cur_mean, cur_mean_sq = _statistics(inputs, reduction_axes) + vb_mean = (self._example_weight * cur_mean + + self._ref_weight * self._ref_mean) + vb_mean_sq = (self._example_weight * cur_mean_sq + + self._ref_weight * self._ref_mean_squares) + return (vb_mean, vb_mean_sq) + + def _broadcast(self, v, broadcast_shape=None): + # The exact broadcast shape depends on the current batch, not the reference + # batch, unless we're calculating the batch normalization of the reference + # batch. + b_shape = broadcast_shape or self._broadcast_shape + if self._needs_broadcasting and v is not None: + return array_ops.reshape(v, b_shape) + return v + + def reference_batch_normalization(self): + """Return the reference batch, but batch normalized.""" + with ops.name_scope(self._vs.name): + return nn.batch_normalization(self._reference_batch, + self._broadcast(self._ref_mean), + self._broadcast(self._ref_variance), + self._broadcast(self._beta), + self._broadcast(self._gamma), + self._epsilon) + + def __call__(self, inputs): + """Run virtual batch normalization on inputs. + + Args: + inputs: Tensor input. + + Returns: + A virtual batch normalized version of `inputs`. + + Raises: + ValueError: If `inputs` shape isn't compatible with the reference batch. + """ + _validate_call_input([inputs, self._reference_batch], self._batch_axis) + + with ops.name_scope(self._vs.name, values=[inputs, self._reference_batch]): + # Calculate the statistics on the current input on a per-example basis. + vb_mean, vb_mean_sq = self._virtual_statistics( + inputs, self._example_reduction_axes) + vb_variance = vb_mean_sq - math_ops.square(vb_mean) + + # The exact broadcast shape of the input statistic Tensors depends on the + # current batch, not the reference batch. The parameter broadcast shape + # is independent of the shape of the input statistic Tensor dimensions. + b_shape = self._broadcast_shape[:] # deep copy + b_shape[self._batch_axis] = _static_or_dynamic_batch_size( + inputs, self._batch_axis) + return nn.batch_normalization( + inputs, + self._broadcast(vb_mean, b_shape), + self._broadcast(vb_variance, b_shape), + self._broadcast(self._beta, self._broadcast_shape), + self._broadcast(self._gamma, self._broadcast_shape), + self._epsilon) diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py new file mode 100644 index 00000000000..845f89827b6 --- /dev/null +++ b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_test.py @@ -0,0 +1,267 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tfgan.python.features.virtual_batchnorm.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.framework.python.ops import variables as contrib_variables_lib +from tensorflow.contrib.gan.python.features.python import virtual_batchnorm_impl as virtual_batchnorm +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import random_seed +from tensorflow.python.layers import normalization +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.platform import test + + +class VirtualBatchnormTest(test.TestCase): + + def test_syntax(self): + reference_batch = array_ops.zeros([5, 3, 16, 9, 15]) + vbn = virtual_batchnorm.VBN(reference_batch, batch_axis=1) + vbn(array_ops.ones([5, 7, 16, 9, 15])) + + def test_no_broadcast_needed(self): + """When `axis` and `batch_axis` are at the end, no broadcast is needed.""" + reference_batch = array_ops.zeros([5, 3, 16, 9, 15]) + minibatch = array_ops.zeros([5, 3, 16, 3, 15]) + vbn = virtual_batchnorm.VBN(reference_batch, axis=-1, batch_axis=-2) + vbn(minibatch) + + def test_statistics(self): + """Check that `_statistics` gives the same result as `nn.moments`.""" + random_seed.set_random_seed(1234) + + tensors = random_ops.random_normal([4, 5, 7, 3]) + for axes in [(3), (0, 2), (1, 2, 3)]: + vb_mean, mean_sq = virtual_batchnorm._statistics(tensors, axes) + mom_mean, mom_var = nn.moments(tensors, axes) + vb_var = mean_sq - math_ops.square(vb_mean) + + with self.test_session(use_gpu=True) as sess: + vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run([ + vb_mean, vb_var, mom_mean, mom_var]) + + self.assertAllClose(mom_mean_np, vb_mean_np) + self.assertAllClose(mom_var_np, vb_var_np) + + def test_virtual_statistics(self): + """Check that `_virtual_statistics` gives same result as `nn.moments`.""" + random_seed.set_random_seed(1234) + + batch_axis = 0 + partial_batch = random_ops.random_normal([4, 5, 7, 3]) + single_example = random_ops.random_normal([1, 5, 7, 3]) + full_batch = array_ops.concat([partial_batch, single_example], axis=0) + + for reduction_axis in range(1, 4): + # Get `nn.moments` on the full batch. + reduction_axes = list(range(4)) + del reduction_axes[reduction_axis] + mom_mean, mom_variance = nn.moments(full_batch, reduction_axes) + + # Get virtual batch statistics. + vb_reduction_axes = list(range(4)) + del vb_reduction_axes[reduction_axis] + del vb_reduction_axes[batch_axis] + vbn = virtual_batchnorm.VBN(partial_batch, reduction_axis) + vb_mean, mean_sq = vbn._virtual_statistics( + single_example, vb_reduction_axes) + vb_variance = mean_sq - math_ops.square(vb_mean) + # Remove singleton batch dim for easy comparisons. + vb_mean = array_ops.squeeze(vb_mean, batch_axis) + vb_variance = array_ops.squeeze(vb_variance, batch_axis) + + with self.test_session(use_gpu=True) as sess: + vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run([ + vb_mean, vb_variance, mom_mean, mom_variance]) + + self.assertAllClose(mom_mean_np, vb_mean_np) + self.assertAllClose(mom_var_np, vb_var_np) + + def test_reference_batch_normalization(self): + """Check that batch norm from VBN agrees with opensource implementation.""" + random_seed.set_random_seed(1234) + + batch = random_ops.random_normal([6, 5, 7, 3, 3]) + + for axis in range(5): + # Get `layers` batchnorm result. + bn_normalized = normalization.batch_normalization( + batch, axis, training=True) + + # Get VBN's batch normalization on reference batch. + batch_axis = 0 if axis is not 0 else 1 # axis and batch_axis can't same + vbn = virtual_batchnorm.VBN(batch, axis, batch_axis=batch_axis) + vbn_normalized = vbn.reference_batch_normalization() + + with self.test_session(use_gpu=True) as sess: + variables_lib.global_variables_initializer().run() + + bn_normalized_np, vbn_normalized_np = sess.run( + [bn_normalized, vbn_normalized]) + self.assertAllClose(bn_normalized_np, vbn_normalized_np) + + def test_same_as_batchnorm(self): + """Check that batch norm on set X is the same as ref of X / y on `y`.""" + random_seed.set_random_seed(1234) + + num_examples = 4 + examples = [random_ops.random_normal([5, 7, 3]) for _ in + range(num_examples)] + + # Get the result of the opensource batch normalization. + batch_normalized = normalization.batch_normalization( + array_ops.stack(examples), training=True) + + for i in range(num_examples): + examples_except_i = array_ops.stack(examples[:i] + examples[i+1:]) + # Get the result of VBN's batch normalization. + vbn = virtual_batchnorm.VBN(examples_except_i) + vb_normed = array_ops.squeeze( + vbn(array_ops.expand_dims(examples[i], [0])), [0]) + + with self.test_session(use_gpu=True) as sess: + variables_lib.global_variables_initializer().run() + bn_np, vb_np = sess.run([batch_normalized, vb_normed]) + self.assertAllClose(bn_np[i, ...], vb_np) + + def test_minibatch_independent(self): + """Test that virtual batch normalized exampels are independent. + + Unlike batch normalization, virtual batch normalization has the property + that the virtual batch normalized value of an example is independent of the + other examples in the minibatch. In this test, we verify this property. + """ + random_seed.set_random_seed(1234) + + # These can be random, but must be the same for all session calls. + reference_batch = constant_op.constant( + np.random.normal(size=[4, 7, 3]), dtype=dtypes.float32) + fixed_example = constant_op.constant(np.random.normal(size=[7, 3]), + dtype=dtypes.float32) + + # Get the VBN object and the virtual batch normalized value for + # `fixed_example`. + vbn = virtual_batchnorm.VBN(reference_batch) + vbn_fixed_example = array_ops.squeeze( + vbn(array_ops.expand_dims(fixed_example, 0)), 0) + with self.test_session(use_gpu=True): + variables_lib.global_variables_initializer().run() + vbn_fixed_example_np = vbn_fixed_example.eval() + + # Check that the value is the same for different minibatches, and different + # sized minibatches. + for minibatch_size in range(1, 6): + examples = [random_ops.random_normal([7, 3]) for _ in + range(minibatch_size)] + + minibatch = array_ops.stack([fixed_example] + examples) + vbn_minibatch = vbn(minibatch) + cur_vbn_fixed_example = vbn_minibatch[0, ...] + with self.test_session(use_gpu=True): + variables_lib.global_variables_initializer().run() + cur_vbn_fixed_example_np = cur_vbn_fixed_example.eval() + self.assertAllClose(vbn_fixed_example_np, cur_vbn_fixed_example_np) + + def test_variable_reuse(self): + """Test that variable scopes work and inference on a real-ish case.""" + tensor1_ref = array_ops.zeros([6, 5, 7, 3, 3]) + tensor1_examples = array_ops.zeros([4, 5, 7, 3, 3]) + tensor2_ref = array_ops.zeros([4, 2, 3]) + tensor2_examples = array_ops.zeros([2, 2, 3]) + + with variable_scope.variable_scope('dummy_scope', reuse=True): + with self.assertRaisesRegexp( + ValueError, 'does not exist, or was not created with ' + 'tf.get_variable()'): + virtual_batchnorm.VBN(tensor1_ref) + + vbn1 = virtual_batchnorm.VBN(tensor1_ref, name='vbn1') + vbn2 = virtual_batchnorm.VBN(tensor2_ref, name='vbn2') + + # Fetch reference and examples after virtual batch normalization. Also + # fetch in variable reuse case. + to_fetch = [] + + to_fetch.append(vbn1.reference_batch_normalization()) + to_fetch.append(vbn2.reference_batch_normalization()) + to_fetch.append(vbn1(tensor1_examples)) + to_fetch.append(vbn2(tensor2_examples)) + + variable_scope.get_variable_scope().reuse_variables() + + to_fetch.append(vbn1.reference_batch_normalization()) + to_fetch.append(vbn2.reference_batch_normalization()) + to_fetch.append(vbn1(tensor1_examples)) + to_fetch.append(vbn2(tensor2_examples)) + + self.assertEqual(4, len(contrib_variables_lib.get_variables())) + + with self.test_session(use_gpu=True) as sess: + variables_lib.global_variables_initializer().run() + sess.run(to_fetch) + + def test_invalid_input(self): + # Reference batch has unknown dimensions. + with self.assertRaisesRegexp( + ValueError, '`reference_batch` has unknown dimensions.'): + virtual_batchnorm.VBN(array_ops.placeholder(dtypes.float32), name='vbn1') + + # Axis too negative. + with self.assertRaisesRegexp( + ValueError, 'Value of `axis` argument .* is out of range'): + virtual_batchnorm.VBN(array_ops.zeros([1, 2]), axis=-3, name='vbn2') + + # Axis too large. + with self.assertRaisesRegexp( + ValueError, 'Value of `axis` argument .* is out of range'): + virtual_batchnorm.VBN(array_ops.zeros([1, 2]), axis=2, name='vbn3') + + # Batch axis too negative. + with self.assertRaisesRegexp( + ValueError, 'Value of `axis` argument .* is out of range'): + virtual_batchnorm.VBN(array_ops.zeros([1, 2]), name='vbn4', batch_axis=-3) + + # Batch axis too large. + with self.assertRaisesRegexp( + ValueError, 'Value of `axis` argument .* is out of range'): + virtual_batchnorm.VBN(array_ops.zeros([1, 2]), name='vbn5', batch_axis=2) + + # Axis and batch axis are the same. + with self.assertRaisesRegexp( + ValueError, '`axis` and `batch_axis` cannot be the same.'): + virtual_batchnorm.VBN(array_ops.zeros( + [1, 2]), axis=1, name='vbn6', batch_axis=1) + + # Reference Tensor and example Tensor have incompatible shapes. + tensor_ref = array_ops.zeros([5, 2, 3]) + tensor_examples = array_ops.zeros([3, 2, 3]) + vbn = virtual_batchnorm.VBN(tensor_ref, name='vbn7', batch_axis=1) + with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'): + vbn(tensor_examples) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/losses/__init__.py b/tensorflow/contrib/gan/python/losses/__init__.py new file mode 100644 index 00000000000..290ff867a1e --- /dev/null +++ b/tensorflow/contrib/gan/python/losses/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2017 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TFGAN grouped API. Please see README.md for details and usage.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Collapse losses into a single namespace. +from tensorflow.contrib.gan.python.losses.python import losses_wargs as wargs +from tensorflow.contrib.gan.python.losses.python import tuple_losses + +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.losses.python.tuple_losses import * +# pylint: enable=wildcard-import + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ['wargs'] + tuple_losses.__all__ +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py new file mode 100644 index 00000000000..3f9d87f54ed --- /dev/null +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -0,0 +1,887 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Losses that are useful for training GANs. + +The losses belong to two main groups, but there are others that do not: +1) xxxxx_generator_loss +2) xxxxx_discriminator_loss + +Example: +1) wasserstein_generator_loss +2) wasserstein_discriminator_loss + +Other example: +wasserstein_gradient_penalty + +All losses must be able to accept 1D or 2D Tensors, so as to be compatible with +patchGAN style losses (https://arxiv.org/abs/1611.07004). + +To make these losses usable in the TFGAN framework, please create a tuple +version of the losses with `losses_utils.py`. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.framework.python.ops import variables as contrib_variables_lib +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.distributions import distribution as ds +from tensorflow.python.ops.losses import losses +from tensorflow.python.ops.losses import util +from tensorflow.python.summary import summary + + +__all__ = [ + 'acgan_discriminator_loss', + 'acgan_generator_loss', + 'least_squares_discriminator_loss', + 'least_squares_generator_loss', + 'modified_discriminator_loss', + 'modified_generator_loss', + 'minimax_discriminator_loss', + 'minimax_generator_loss', + 'wasserstein_discriminator_loss', + 'wasserstein_generator_loss', + 'wasserstein_gradient_penalty', + 'mutual_information_penalty', + 'combine_adversarial_loss', +] + + +# Wasserstein losses from `Wasserstein GAN` (https://arxiv.org/abs/1701.07875). +def wasserstein_generator_loss( + discriminator_gen_outputs, + weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False): + """Wasserstein generator loss for GANs. + + See `Wasserstein GAN` (https://arxiv.org/abs/1701.07875) for more details. + + Args: + discriminator_gen_outputs: Discriminator output on generated data. Expected + to be in the range of (-inf, inf). + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `losses` dimension). + scope: The scope for the operations performed in computing the loss. + loss_collection: collection to which this loss will be added. + reduction: A `tf.losses.Reduction` to apply to loss. + add_summaries: Whether or not to add detailed summaries for the loss. + + Returns: + A loss Tensor. The shape depends on `reduction`. + """ + with ops.name_scope(scope, 'generator_wasserstein_loss', ( + discriminator_gen_outputs, weights)) as scope: + discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs) + + loss = - discriminator_gen_outputs + loss = losses.compute_weighted_loss( + loss, weights, scope, loss_collection, reduction) + + if add_summaries: + summary.scalar('generator_wass_loss', loss) + + return loss + + +def wasserstein_discriminator_loss( + discriminator_real_outputs, + discriminator_gen_outputs, + real_weights=1.0, + generated_weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False): + """Wasserstein discriminator loss for GANs. + + See `Wasserstein GAN` (https://arxiv.org/abs/1701.07875) for more details. + + Args: + discriminator_real_outputs: Discriminator output on real data. + discriminator_gen_outputs: Discriminator output on generated data. Expected + to be in the range of (-inf, inf). + real_weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale + the real loss. + generated_weights: A scalar or a `Tensor` of size [batch_size, K] used to + rescale the generated loss. + scope: The scope for the operations performed in computing the loss. + loss_collection: collection to which this loss will be added. + reduction: A `tf.losses.Reduction` to apply to loss. + add_summaries: Whether or not to add summaries for the loss. + + Returns: + A loss Tensor. The shape depends on `reduction`. + """ + with ops.name_scope(scope, 'discriminator_wasserstein_loss', ( + discriminator_real_outputs, discriminator_gen_outputs, real_weights, + generated_weights)) as scope: + discriminator_real_outputs = math_ops.to_float(discriminator_real_outputs) + discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs) + discriminator_real_outputs.shape.assert_is_compatible_with( + discriminator_gen_outputs.shape) + + loss_on_generated = losses.compute_weighted_loss( + discriminator_gen_outputs, generated_weights, scope, + loss_collection=None, reduction=reduction) + loss_on_real = losses.compute_weighted_loss( + discriminator_real_outputs, real_weights, scope, loss_collection=None, + reduction=reduction) + loss = loss_on_generated - loss_on_real + util.add_loss(loss, loss_collection) + + if add_summaries: + summary.scalar('discriminator_gen_wass_loss', loss_on_generated) + summary.scalar('discriminator_real_wass_loss', loss_on_real) + summary.scalar('discriminator_wass_loss', loss) + + return loss + + +# ACGAN losses from `Conditional Image Synthesis With Auxiliary Classifier GANs` +# (https://arxiv.org/abs/1610.09585). +def acgan_discriminator_loss( + discriminator_gen_classification_logits, + discriminator_real_classification_logits, + one_hot_labels, + label_smoothing=0.0, + real_weights=1.0, + generated_weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False): + """ACGAN loss for the discriminator. + + The ACGAN loss adds a classification loss to the conditional discriminator. + Therefore, the discriminator must output a tuple consisting of + (1) the real/fake prediction and + (2) the logits for the classification (usually the last conv layer, + flattened). + + For more details: + ACGAN: https://arxiv.org/abs/1610.09585 + + Args: + discriminator_gen_classification_logits: Classification logits for generated + data. + discriminator_real_classification_logits: Classification logits for real + data. + one_hot_labels: A Tensor holding one-hot labels for the batch. + label_smoothing: A float in [0, 1]. If greater than 0, smooth the labels for + "discriminator on real data" as suggested in + https://arxiv.org/pdf/1701.00160 + real_weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale + the real loss. + generated_weights: A scalar or a `Tensor` of size [batch_size, K] used to + rescale the generated loss. + scope: The scope for the operations performed in computing the loss. + loss_collection: collection to which this loss will be added. + reduction: A `tf.losses.Reduction` to apply to loss. + add_summaries: Whether or not to add summaries for the loss. + + Returns: + A loss Tensor. Shape depends on `reduction`. + + Raises: + TypeError: If the discriminator does not output a tuple. + """ + loss_on_generated = losses.softmax_cross_entropy( + one_hot_labels, discriminator_gen_classification_logits, + weights=generated_weights, scope=scope, loss_collection=None, + reduction=reduction) + loss_on_real = losses.softmax_cross_entropy( + one_hot_labels, discriminator_real_classification_logits, + weights=real_weights, label_smoothing=label_smoothing, scope=scope, + loss_collection=None, reduction=reduction) + loss = loss_on_generated + loss_on_real + util.add_loss(loss, loss_collection) + + if add_summaries: + summary.scalar('discriminator_gen_ac_loss', loss_on_generated) + summary.scalar('discriminator_real_ac_loss', loss_on_real) + summary.scalar('discriminator_ac_loss', loss) + + return loss + + +def acgan_generator_loss( + discriminator_gen_classification_logits, + one_hot_labels, + weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False): + """ACGAN loss for the generator. + + The ACGAN loss adds a classification loss to the conditional discriminator. + Therefore, the discriminator must output a tuple consisting of + (1) the real/fake prediction and + (2) the logits for the classification (usually the last conv layer, + flattened). + + For more details: + ACGAN: https://arxiv.org/abs/1610.09585 + + Args: + discriminator_gen_classification_logits: Classification logits for generated + data. + one_hot_labels: A Tensor holding one-hot labels for the batch. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `losses` dimension). + scope: The scope for the operations performed in computing the loss. + loss_collection: collection to which this loss will be added. + reduction: A `tf.losses.Reduction` to apply to loss. + add_summaries: Whether or not to add summaries for the loss. + + Returns: + A loss Tensor. Shape depends on `reduction`. + + Raises: + ValueError: if arg module not either `generator` or `discriminator` + TypeError: if the discriminator does not output a tuple. + """ + loss = losses.softmax_cross_entropy( + one_hot_labels, discriminator_gen_classification_logits, weights=weights, + scope=scope, loss_collection=loss_collection, reduction=reduction) + + if add_summaries: + summary.scalar('generator_ac_loss', loss) + + return loss + + +# Wasserstein Gradient Penalty losses from `Improved Training of Wasserstein +# GANs` (https://arxiv.org/abs/1704.00028). + + +# TODO(joelshor): Figure out why this function can't be inside a name scope. +def wasserstein_gradient_penalty( + generated_data, + real_data, + generator_inputs, + discriminator_fn, + discriminator_scope, + epsilon=1e-10, + weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False): + """The gradient penalty for the Wasserstein discriminator loss. + + See `Improved Training of Wasserstein GANs` + (https://arxiv.org/abs/1704.00028) for more details. + + Args: + generated_data: Output of the generator. + real_data: Real data. + generator_inputs: Exact argument to pass to the generator, which is used + as optional conditioning to the discriminator. + discriminator_fn: A discriminator function that conforms to TFGAN API. + discriminator_scope: If not `None`, reuse discriminators from this scope. + epsilon: A small positive number added for numerical stability when + computing the gradient norm. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `losses` dimension). + scope: The scope for the operations performed in computing the loss. + loss_collection: collection to which this loss will be added. + reduction: A `tf.losses.Reduction` to apply to loss. + add_summaries: Whether or not to add summaries for the loss. + + Returns: + A loss Tensor. The shape depends on `reduction`. + + Raises: + ValueError: If the rank of data Tensors is unknown. + """ + if generated_data.shape.ndims is None: + raise ValueError('`generated_data` can\'t have unknown rank.') + if real_data.shape.ndims is None: + raise ValueError('`real_data` can\'t have unknown rank.') + + differences = generated_data - real_data + batch_size = differences.shape[0].value or array_ops.shape(differences)[0] + alpha_shape = [batch_size] + [1] * (differences.shape.ndims - 1) + alpha = random_ops.random_uniform(shape=alpha_shape) + interpolates = real_data + (alpha * differences) + + # Reuse variables if a discriminator scope already exists. + reuse = False if discriminator_scope is None else True + with variable_scope.variable_scope(discriminator_scope, 'gpenalty_dscope', + reuse=reuse): + disc_interpolates = discriminator_fn(interpolates, generator_inputs) + + if isinstance(disc_interpolates, tuple): + # ACGAN case: disc outputs more than one tensor + disc_interpolates = disc_interpolates[0] + + gradients = gradients_impl.gradients(disc_interpolates, interpolates)[0] + gradient_squares = math_ops.reduce_sum( + math_ops.square(gradients), axis=list(range(1, gradients.shape.ndims))) + # Propagate shape information, if possible. + if isinstance(batch_size, int): + gradient_squares.set_shape([ + batch_size] + gradient_squares.shape.as_list()[1:]) + # For numerical stability, add epsilon to the sum before taking the square + # root. Note tf.norm does not add epsilon. + slopes = math_ops.sqrt(gradient_squares + epsilon) + penalties = math_ops.square(slopes - 1.0) + penalty = losses.compute_weighted_loss( + penalties, weights, scope=scope, loss_collection=loss_collection, + reduction=reduction) + + if add_summaries: + summary.scalar('gradient_penalty_loss', penalty) + + return penalty + + +# Original losses from `Generative Adversarial Nets` +# (https://arxiv.org/abs/1406.2661). + + +def minimax_discriminator_loss( + discriminator_real_outputs, + discriminator_gen_outputs, + label_smoothing=0.25, + real_weights=1.0, + generated_weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False): + """Original minimax discriminator loss for GANs, with label smoothing. + + Note that the authors don't recommend using this loss. A more practically + useful loss is `modified_discriminator_loss`. + + L = - real_weights * log(sigmoid(D(x))) + - generated_weights * log(1 - sigmoid(D(G(z)))) + + See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more + details. + + Args: + discriminator_real_outputs: Discriminator output on real data. + discriminator_gen_outputs: Discriminator output on generated data. Expected + to be in the range of (-inf, inf). + label_smoothing: The amount of smoothing for positive labels. This technique + is taken from `Improved Techniques for Training GANs` + (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing. + real_weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale + the real loss. + generated_weights: A scalar or a `Tensor` of size [batch_size, K] used to + rescale the generated loss. + scope: The scope for the operations performed in computing the loss. + loss_collection: collection to which this loss will be added. + reduction: A `tf.losses.Reduction` to apply to loss. + add_summaries: Whether or not to add summaries for the loss. + + Returns: + A loss Tensor. The shape depends on `reduction`. + """ + with ops.name_scope(scope, 'discriminator_minimax_loss', ( + discriminator_real_outputs, discriminator_gen_outputs, real_weights, + generated_weights, label_smoothing)) as scope: + + # -log((1 - label_smoothing) - sigmoid(D(x))) + loss_on_real = losses.sigmoid_cross_entropy( + array_ops.ones_like(discriminator_real_outputs), + discriminator_real_outputs, real_weights, label_smoothing, scope, + loss_collection=None, reduction=reduction) + # -log(- sigmoid(D(G(x)))) + loss_on_generated = losses.sigmoid_cross_entropy( + array_ops.zeros_like(discriminator_gen_outputs), + discriminator_gen_outputs, generated_weights, scope=scope, + loss_collection=None, reduction=reduction) + + loss = loss_on_real + loss_on_generated + util.add_loss(loss, loss_collection) + + if add_summaries: + summary.scalar('discriminator_gen_minimax_loss', loss_on_generated) + summary.scalar('discriminator_real_minimax_loss', loss_on_real) + summary.scalar('discriminator_minimax_loss', loss) + + return loss + + +def minimax_generator_loss( + discriminator_gen_outputs, + label_smoothing=0.0, + weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False): + """Original minimax generator loss for GANs. + + Note that the authors don't recommend using this loss. A more practically + useful loss is `modified_generator_loss`. + + L = log(sigmoid(D(x))) + log(1 - sigmoid(D(G(z)))) + + See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more + details. + + Args: + discriminator_gen_outputs: Discriminator output on generated data. Expected + to be in the range of (-inf, inf). + label_smoothing: The amount of smoothing for positive labels. This technique + is taken from `Improved Techniques for Training GANs` + (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing. + weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale + the loss. + scope: The scope for the operations performed in computing the loss. + loss_collection: collection to which this loss will be added. + reduction: A `tf.losses.Reduction` to apply to loss. + add_summaries: Whether or not to add summaries for the loss. + + Returns: + A loss Tensor. The shape depends on `reduction`. + """ + with ops.name_scope(scope, 'generator_minimax_loss') as scope: + loss = - minimax_discriminator_loss( + array_ops.ones_like(discriminator_gen_outputs), + discriminator_gen_outputs, label_smoothing, weights, weights, scope, + loss_collection, reduction, add_summaries=False) + + if add_summaries: + summary.scalar('generator_minimax_loss', loss) + + return loss + + +def modified_discriminator_loss( + discriminator_real_outputs, + discriminator_gen_outputs, + label_smoothing=0.25, + real_weights=1.0, + generated_weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False): + """Same as minimax discriminator loss. + + See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more + details. + + Args: + discriminator_real_outputs: Discriminator output on real data. + discriminator_gen_outputs: Discriminator output on generated data. Expected + to be in the range of (-inf, inf). + label_smoothing: The amount of smoothing for positive labels. This technique + is taken from `Improved Techniques for Training GANs` + (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing. + real_weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale + the real loss. + generated_weights: A scalar or a `Tensor` of size [batch_size, K] used to + rescale the generated loss. + scope: The scope for the operations performed in computing the loss. + loss_collection: collection to which this loss will be added. + reduction: A `tf.losses.Reduction` to apply to loss. + add_summaries: Whether or not to add summaries for the loss. + + Returns: + A loss Tensor. The shape depends on `reduction`. + """ + return minimax_discriminator_loss( + discriminator_real_outputs, + discriminator_gen_outputs, + label_smoothing, + real_weights, + generated_weights, + scope or 'discriminator_modified_loss', + loss_collection, + reduction, + add_summaries) + + +def modified_generator_loss( + discriminator_gen_outputs, + label_smoothing=0.0, + weights=1.0, + scope='generator_modified_loss', + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False): + """Modified generator loss for GANs. + + L = -log(sigmoid(D(G(z)))) + + This is the trick used in the original paper to avoid vanishing gradients + early in training. See `Generative Adversarial Nets` + (https://arxiv.org/abs/1406.2661) for more details. + + Args: + discriminator_gen_outputs: Discriminator output on generated data. Expected + to be in the range of (-inf, inf). + label_smoothing: The amount of smoothing for positive labels. This technique + is taken from `Improved Techniques for Training GANs` + (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `losses` dimension). + scope: The scope for the operations performed in computing the loss. + loss_collection: collection to which this loss will be added. + reduction: A `tf.losses.Reduction` to apply to loss. + add_summaries: Whether or not to add summaries for the loss. + + Returns: + A loss Tensor. The shape depends on `reduction`. + """ + loss = losses.sigmoid_cross_entropy( + array_ops.ones_like(discriminator_gen_outputs), discriminator_gen_outputs, + weights, label_smoothing, scope, loss_collection, reduction) + + if add_summaries: + summary.scalar('generator_modified_loss', loss) + + return loss + + +# Least Squares loss from `Least Squares Generative Adversarial Networks` +# (https://arxiv.org/abs/1611.04076). + + +def least_squares_generator_loss( + discriminator_gen_outputs, + real_label=1, + weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False): + """Least squares generator loss. + + This loss comes from `Least Squares Generative Adversarial Networks` + (https://arxiv.org/abs/1611.04076). + + L = 1/2 * (D(G(z)) - `real_label`) ** 2 + + where D(y) are discriminator logits. + + Args: + discriminator_gen_outputs: Discriminator output on generated data. Expected + to be in the range of (-inf, inf). + real_label: The value that the generator is trying to get the discriminator + to output on generated data. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `losses` dimension). + scope: The scope for the operations performed in computing the loss. + loss_collection: collection to which this loss will be added. + reduction: A `tf.losses.Reduction` to apply to loss. + add_summaries: Whether or not to add summaries for the loss. + + Returns: + A loss Tensor. The shape depends on `reduction`. + """ + with ops.name_scope(scope, 'lsq_generator_loss', + (discriminator_gen_outputs, real_label)) as scope: + discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs) + loss = math_ops.squared_difference( + discriminator_gen_outputs, real_label) / 2.0 + loss = losses.compute_weighted_loss( + loss, weights, scope, loss_collection, reduction) + + if add_summaries: + summary.scalar('generator_lsq_loss', loss) + + return loss + + +def least_squares_discriminator_loss( + discriminator_real_outputs, + discriminator_gen_outputs, + real_label=1, + fake_label=0, + real_weights=1.0, + generated_weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False): + """Least squares generator loss. + + This loss comes from `Least Squares Generative Adversarial Networks` + (https://arxiv.org/abs/1611.04076). + + L = 1/2 * (D(x) - `real`) ** 2 + + 1/2 * (D(G(z)) - `fake_label`) ** 2 + + where D(y) are discriminator logits. + + Args: + discriminator_real_outputs: Discriminator output on real data. + discriminator_gen_outputs: Discriminator output on generated data. Expected + to be in the range of (-inf, inf). + real_label: The value that the discriminator tries to output for real data. + fake_label: The value that the discriminator tries to output for fake data. + real_weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale + the real loss. + generated_weights: A scalar or a `Tensor` of size [batch_size, K] used to + rescale the generated loss. + scope: The scope for the operations performed in computing the loss. + loss_collection: collection to which this loss will be added. + reduction: A `tf.losses.Reduction` to apply to loss. + add_summaries: Whether or not to add summaries for the loss. + + Returns: + A loss Tensor. The shape depends on `reduction`. + """ + with ops.name_scope(scope, 'lsq_discriminator_loss', + (discriminator_gen_outputs, real_label)) as scope: + discriminator_real_outputs = math_ops.to_float(discriminator_real_outputs) + discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs) + discriminator_real_outputs.shape.assert_is_compatible_with( + discriminator_gen_outputs.shape) + + real_losses = math_ops.squared_difference( + discriminator_real_outputs, real_label) / 2.0 + fake_losses = math_ops.squared_difference( + discriminator_gen_outputs, fake_label) / 2.0 + + loss_on_real = losses.compute_weighted_loss( + real_losses, real_weights, scope, loss_collection=None, + reduction=reduction) + loss_on_generated = losses.compute_weighted_loss( + fake_losses, generated_weights, scope, loss_collection=None, + reduction=reduction) + + loss = loss_on_real + loss_on_generated + util.add_loss(loss, loss_collection) + + if add_summaries: + summary.scalar('discriminator_gen_lsq_loss', loss_on_generated) + summary.scalar('discriminator_real_lsq_loss', loss_on_real) + summary.scalar('discriminator_lsq_loss', loss) + + return loss + + +# InfoGAN loss from `InfoGAN: Interpretable Representation Learning by +# `Information Maximizing Generative Adversarial Nets` +# https://arxiv.org/abs/1606.03657 + + +def _validate_distributions(distributions): + if not isinstance(distributions, (list, tuple)): + raise ValueError('`distributions` must be a list or tuple. Instead, ' + 'found %s.', type(distributions)) + for x in distributions: + if not isinstance(x, ds.Distribution): + raise ValueError('`distributions` must be a list of `Distributions`. ' + 'Instead, found %s.', type(x)) + + +def _validate_information_penalty_inputs( + structured_generator_inputs, predicted_distributions): + """Validate input to `mutual_information_penalty`.""" + _validate_distributions(predicted_distributions) + if len(structured_generator_inputs) != len(predicted_distributions): + raise ValueError('`structured_generator_inputs` length %i must be the same ' + 'as `predicted_distributions` length %i.' % ( + len(structured_generator_inputs), + len(predicted_distributions))) + + +def mutual_information_penalty( + structured_generator_inputs, + predicted_distributions, + weights=1.0, + scope='generator_modified_loss', + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False): + """Returns a penalty on the mutual information in an InfoGAN model. + + This loss comes from an InfoGAN paper https://arxiv.org/abs/1606.03657. + + Args: + structured_generator_inputs: A list of Tensors representing the random noise + that must have high mutual information with the generator output. List + length should match `predicted_distributions`. + predicted_distributions: A list of tf.Distributions. Predicted by the + recognizer, and used to evaluate the likelihood of the structured noise. + List length should match `structured_generator_inputs`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `losses` dimension). + scope: The scope for the operations performed in computing the loss. + loss_collection: collection to which this loss will be added. + reduction: A `tf.losses.Reduction` to apply to loss. + add_summaries: Whether or not to add summaries for the loss. + + Returns: + A scalar Tensor representing the mutual information loss. + """ + _validate_information_penalty_inputs( + structured_generator_inputs, predicted_distributions) + + # Calculate the negative log-likelihood of the reconstructed noise. + log_probs = [math_ops.reduce_mean(dist.log_prob(noise)) for dist, noise in + zip(predicted_distributions, structured_generator_inputs)] + loss = -1 * losses.compute_weighted_loss( + log_probs, weights, scope, loss_collection=loss_collection, + reduction=reduction) + + if add_summaries: + summary.scalar('mutual_information_penalty', loss) + + return loss + + +def _numerically_stable_global_norm(tensor_list): + """Compute the global norm of a list of Tensors, with improved stability. + + The global norm computation sometimes overflows due to the intermediate L2 + step. To avoid this, we divide by a cheap-to-compute max over the + matrix elements. + + Args: + tensor_list: A list of tensors, or `None`. + + Returns: + A scalar tensor with the global norm. + """ + if np.all([x is None for x in tensor_list]): + return 0.0 + + list_max = math_ops.reduce_max([math_ops.reduce_max(math_ops.abs(x)) for x in + tensor_list if x is not None]) + return list_max * clip_ops.global_norm([x / list_max for x in tensor_list + if x is not None]) + + +def _used_weight(weights_list): + for weight in weights_list: + if weight is not None: + return tensor_util.constant_value(ops.convert_to_tensor(weight)) + + +def _validate_args(losses_list, weight_factor, gradient_ratio): + for loss in losses_list: + loss.shape.assert_is_compatible_with([]) + if weight_factor is None and gradient_ratio is None: + raise ValueError( + '`weight_factor` and `gradient_ratio` cannot both be `None.`') + if weight_factor is not None and gradient_ratio is not None: + raise ValueError( + '`weight_factor` and `gradient_ratio` cannot both be specified.') + + +# TODO(joelshor): Add ability to pass in gradients, to avoid recomputing. +def combine_adversarial_loss(main_loss, + adversarial_loss, + weight_factor=None, + gradient_ratio=None, + gradient_ratio_epsilon=1e-6, + variables=None, + scalar_summaries=True, + gradient_summaries=True, + scope=None): + """Utility to combine main and adversarial losses. + + This utility combines the main and adversarial losses in one of two ways. + 1) Fixed coefficient on adversarial loss. Use `weight_factor` in this case. + 2) Fixed ratio of gradients. Use `gradient_ratio` in this case. This is often + used to make sure both losses affect weights roughly equally, as in + https://arxiv.org/pdf/1705.05823. + + One can optionally also visualize the scalar and gradient behavior of the + losses. + + Args: + main_loss: A floating scalar Tensor indicating the main loss. + adversarial_loss: A floating scalar Tensor indication the adversarial loss. + weight_factor: If not `None`, the coefficient by which to multiply the + adversarial loss. Exactly one of this and `gradient_ratio` must be + non-None. + gradient_ratio: If not `None`, the ratio of the magnitude of the gradients. + Specifically, + gradient_ratio = grad_mag(main_loss) / grad_mag(adversarial_loss) + Exactly one of this and `weight_factor` must be non-None. + gradient_ratio_epsilon: An epsilon to add to the adversarial loss + coefficient denominator, to avoid division-by-zero. + variables: List of variables to calculate gradients with respect to. If not + present, defaults to all trainable variables. + scalar_summaries: Create scalar summaries of losses. + gradient_summaries: Create gradient summaries of losses. + scope: Optional name scope. + + Returns: + A floating scalar Tensor indicating the desired combined loss. + + Raises: + ValueError: Malformed input. + """ + _validate_args([main_loss, adversarial_loss], weight_factor, gradient_ratio) + if variables is None: + variables = contrib_variables_lib.get_trainable_variables() + + with ops.name_scope(scope, 'adversarial_loss', + values=[main_loss, adversarial_loss]): + # Compute gradients if we will need them. + if gradient_summaries or gradient_ratio is not None: + main_loss_grad_mag = _numerically_stable_global_norm( + gradients_impl.gradients(main_loss, variables)) + adv_loss_grad_mag = _numerically_stable_global_norm( + gradients_impl.gradients(adversarial_loss, variables)) + + # Add summaries, if applicable. + if scalar_summaries: + summary.scalar('main_loss', main_loss) + summary.scalar('adversarial_loss', adversarial_loss) + if gradient_summaries: + summary.scalar('main_loss_gradients', main_loss_grad_mag) + summary.scalar('adversarial_loss_gradients', adv_loss_grad_mag) + + # Combine losses in the appropriate way. + # If `weight_factor` is always `0`, avoid computing the adversarial loss + # tensor entirely. + if _used_weight((weight_factor, gradient_ratio)) == 0: + final_loss = main_loss + elif weight_factor is not None: + final_loss = (main_loss + + array_ops.stop_gradient(weight_factor) * adversarial_loss) + elif gradient_ratio is not None: + grad_mag_ratio = main_loss_grad_mag / ( + adv_loss_grad_mag + gradient_ratio_epsilon) + adv_coeff = grad_mag_ratio / gradient_ratio + summary.scalar('adversarial_coefficient', adv_coeff) + final_loss = (main_loss + + array_ops.stop_gradient(adv_coeff) * adversarial_loss) + + return final_loss diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py new file mode 100644 index 00000000000..3e003dd0f80 --- /dev/null +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -0,0 +1,606 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TFGAN losses.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.losses.python import losses_impl as tfgan_losses +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.ops.distributions import categorical +from tensorflow.python.ops.distributions import normal +from tensorflow.python.ops.losses import losses as tf_losses +from tensorflow.python.platform import test + + +# TODO(joelshor): Use `parameterized` tests when opensourced. +class _LossesTest(object): + + def init_constants(self): + self._discriminator_real_outputs_np = [-5.0, 1.4, 12.5, 2.7] + self._discriminator_gen_outputs_np = [10.0, 4.4, -5.5, 3.6] + self._weights = 2.3 + self._discriminator_real_outputs = constant_op.constant( + self._discriminator_real_outputs_np, dtype=dtypes.float32) + self._discriminator_gen_outputs = constant_op.constant( + self._discriminator_gen_outputs_np, dtype=dtypes.float32) + + def test_generator_all_correct(self): + loss = self._g_loss_fn(self._discriminator_gen_outputs) + self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype) + self.assertEqual(self._generator_loss_name, loss.op.name) + with self.test_session(): + self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5) + + def test_discriminator_all_correct(self): + loss = self._d_loss_fn( + self._discriminator_real_outputs, self._discriminator_gen_outputs) + self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype) + self.assertEqual(self._discriminator_loss_name, loss.op.name) + with self.test_session(): + self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5) + + def test_generator_loss_collection(self): + self.assertEqual(0, len(ops.get_collection('collection'))) + self._g_loss_fn( + self._discriminator_gen_outputs, loss_collection='collection') + self.assertEqual(1, len(ops.get_collection('collection'))) + + def test_discriminator_loss_collection(self): + self.assertEqual(0, len(ops.get_collection('collection'))) + self._d_loss_fn( + self._discriminator_real_outputs, self._discriminator_gen_outputs, + loss_collection='collection') + self.assertEqual(1, len(ops.get_collection('collection'))) + + def test_generator_no_reduction(self): + loss = self._g_loss_fn( + self._discriminator_gen_outputs, reduction=tf_losses.Reduction.NONE) + self.assertAllEqual([4], loss.shape) + + def test_discriminator_no_reduction(self): + loss = self._d_loss_fn( + self._discriminator_real_outputs, self._discriminator_gen_outputs, + reduction=tf_losses.Reduction.NONE) + self.assertAllEqual([4], loss.shape) + + def test_generator_patch(self): + loss = self._g_loss_fn( + array_ops.reshape(self._discriminator_gen_outputs, [2, 2])) + self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype) + with self.test_session(): + self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5) + + def test_discriminator_patch(self): + loss = self._d_loss_fn( + array_ops.reshape(self._discriminator_real_outputs, [2, 2]), + array_ops.reshape(self._discriminator_gen_outputs, [2, 2])) + self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype) + with self.test_session(): + self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5) + + def test_generator_loss_with_placeholder_for_logits(self): + logits = array_ops.placeholder(dtypes.float32, shape=(None, 4)) + weights = array_ops.ones_like(logits, dtype=dtypes.float32) + + loss = self._g_loss_fn(logits, weights=weights) + self.assertEqual(logits.dtype, loss.dtype) + + with self.test_session() as sess: + loss = sess.run(loss, + feed_dict={ + logits: [[10.0, 4.4, -5.5, 3.6]], + }) + self.assertAlmostEqual(self._expected_g_loss, loss, 5) + + def test_discriminator_loss_with_placeholder_for_logits(self): + logits = array_ops.placeholder(dtypes.float32, shape=(None, 4)) + logits2 = array_ops.placeholder(dtypes.float32, shape=(None, 4)) + real_weights = array_ops.ones_like(logits, dtype=dtypes.float32) + generated_weights = array_ops.ones_like(logits, dtype=dtypes.float32) + + loss = self._d_loss_fn( + logits, logits2, real_weights=real_weights, + generated_weights=generated_weights) + + with self.test_session() as sess: + loss = sess.run(loss, + feed_dict={ + logits: [self._discriminator_real_outputs_np], + logits2: [self._discriminator_gen_outputs_np], + }) + self.assertAlmostEqual(self._expected_d_loss, loss, 5) + + def test_generator_with_python_scalar_weight(self): + loss = self._g_loss_fn( + self._discriminator_gen_outputs, weights=self._weights) + with self.test_session(): + self.assertAlmostEqual(self._expected_g_loss * self._weights, + loss.eval(), 4) + + def test_discriminator_with_python_scalar_weight(self): + loss = self._d_loss_fn( + self._discriminator_real_outputs, self._discriminator_gen_outputs, + real_weights=self._weights, generated_weights=self._weights) + with self.test_session(): + self.assertAlmostEqual(self._expected_d_loss * self._weights, + loss.eval(), 4) + + def test_generator_with_scalar_tensor_weight(self): + loss = self._g_loss_fn(self._discriminator_gen_outputs, + weights=constant_op.constant(self._weights)) + with self.test_session(): + self.assertAlmostEqual(self._expected_g_loss * self._weights, + loss.eval(), 4) + + def test_discriminator_with_scalar_tensor_weight(self): + weights = constant_op.constant(self._weights) + loss = self._d_loss_fn( + self._discriminator_real_outputs, self._discriminator_gen_outputs, + real_weights=weights, generated_weights=weights) + with self.test_session(): + self.assertAlmostEqual(self._expected_d_loss * self._weights, + loss.eval(), 4) + + def test_generator_add_summaries(self): + self.assertEqual(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + self._g_loss_fn(self._discriminator_gen_outputs, add_summaries=True) + self.assertLess(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + + def test_discriminator_add_summaries(self): + self.assertEqual(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + self._d_loss_fn( + self._discriminator_real_outputs, self._discriminator_gen_outputs, + add_summaries=True) + self.assertLess(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + + +class LeastSquaresLossTest(test.TestCase, _LossesTest): + """Tests for least_squares_xxx_loss.""" + + def setUp(self): + super(LeastSquaresLossTest, self).setUp() + self.init_constants() + self._expected_g_loss = 17.69625 + self._expected_d_loss = 41.73375 + self._generator_loss_name = 'lsq_generator_loss/value' + self._discriminator_loss_name = 'lsq_discriminator_loss/add' + self._g_loss_fn = tfgan_losses.least_squares_generator_loss + self._d_loss_fn = tfgan_losses.least_squares_discriminator_loss + + +class ModifiedLossTest(test.TestCase, _LossesTest): + """Tests for modified_xxx_loss.""" + + def setUp(self): + super(ModifiedLossTest, self).setUp() + self.init_constants() + self._expected_g_loss = 1.38582 + self._expected_d_loss = 6.19637 + self._generator_loss_name = 'generator_modified_loss/value' + self._discriminator_loss_name = 'discriminator_modified_loss/add_1' + self._g_loss_fn = tfgan_losses.modified_generator_loss + self._d_loss_fn = tfgan_losses.modified_discriminator_loss + + +class MinimaxLossTest(test.TestCase, _LossesTest): + """Tests for minimax_xxx_loss.""" + + def setUp(self): + super(MinimaxLossTest, self).setUp() + self.init_constants() + self._expected_g_loss = -4.82408 + self._expected_d_loss = 6.19637 + self._generator_loss_name = 'generator_minimax_loss/Neg' + self._discriminator_loss_name = 'discriminator_minimax_loss/add_1' + self._g_loss_fn = tfgan_losses.minimax_generator_loss + self._d_loss_fn = tfgan_losses.minimax_discriminator_loss + + +class WassersteinLossTest(test.TestCase, _LossesTest): + """Tests for wasserstein_xxx_loss.""" + + def setUp(self): + super(WassersteinLossTest, self).setUp() + self.init_constants() + self._expected_g_loss = -3.12500 + self._expected_d_loss = 0.22500 + self._generator_loss_name = 'generator_wasserstein_loss/value' + self._discriminator_loss_name = 'discriminator_wasserstein_loss/sub' + self._g_loss_fn = tfgan_losses.wasserstein_generator_loss + self._d_loss_fn = tfgan_losses.wasserstein_discriminator_loss + + +# TODO(joelshor): Use `parameterized` tests when opensourced. +# TODO(joelshor): Refactor this test to use the same code as the other losses. +class ACGANLossTest(test.TestCase): + """Tests for wasserstein_xxx_loss.""" + + def setUp(self): + super(ACGANLossTest, self).setUp() + self._g_loss_fn = tfgan_losses.acgan_generator_loss + self._d_loss_fn = tfgan_losses.acgan_discriminator_loss + self._discriminator_gen_classification_logits_np = [[10.0, 4.4, -5.5, 3.6], + [-4.0, 4.4, 5.2, 4.6], + [1.1, 2.4, -3.5, 5.6], + [1.1, 2.4, -3.5, 5.6]] + self._discriminator_real_classification_logits_np = [[-2.0, 0.4, 12.5, 2.7], + [-1.2, 1.9, 12.3, 2.6], + [-2.4, -1.7, 2.5, 2.7], + [1.1, 2.4, -3.5, 5.6]] + self._one_hot_labels_np = [[0, 1, 0, 0], + [0, 0, 1, 0], + [1, 0, 0, 0], + [1, 0, 0, 0]] + self._weights = 2.3 + + self._discriminator_gen_classification_logits = constant_op.constant( + self._discriminator_gen_classification_logits_np, dtype=dtypes.float32) + self._discriminator_real_classification_logits = constant_op.constant( + self._discriminator_real_classification_logits_np, dtype=dtypes.float32) + self._one_hot_labels = constant_op.constant( + self._one_hot_labels_np, dtype=dtypes.float32) + self._generator_kwargs = { + 'discriminator_gen_classification_logits': + self._discriminator_gen_classification_logits, + 'one_hot_labels': self._one_hot_labels, + } + self._discriminator_kwargs = { + 'discriminator_gen_classification_logits': + self._discriminator_gen_classification_logits, + 'discriminator_real_classification_logits': + self._discriminator_real_classification_logits, + 'one_hot_labels': self._one_hot_labels, + } + self._generator_loss_name = 'softmax_cross_entropy_loss/value' + self._discriminator_loss_name = 'add' + self._expected_g_loss = 3.84974 + self._expected_d_loss = 9.43950 + + def test_generator_all_correct(self): + loss = self._g_loss_fn(**self._generator_kwargs) + self.assertEqual( + self._discriminator_gen_classification_logits.dtype, loss.dtype) + self.assertEqual(self._generator_loss_name, loss.op.name) + with self.test_session(): + self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5) + + def test_discriminator_all_correct(self): + loss = self._d_loss_fn(**self._discriminator_kwargs) + self.assertEqual( + self._discriminator_gen_classification_logits.dtype, loss.dtype) + self.assertEqual(self._discriminator_loss_name, loss.op.name) + with self.test_session(): + self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5) + + def test_generator_loss_collection(self): + self.assertEqual(0, len(ops.get_collection('collection'))) + self._g_loss_fn(loss_collection='collection', **self._generator_kwargs) + self.assertEqual(1, len(ops.get_collection('collection'))) + + def test_discriminator_loss_collection(self): + self.assertEqual(0, len(ops.get_collection('collection'))) + self._d_loss_fn(loss_collection='collection', **self._discriminator_kwargs) + self.assertEqual(1, len(ops.get_collection('collection'))) + + def test_generator_no_reduction(self): + loss = self._g_loss_fn( + reduction=tf_losses.Reduction.NONE, **self._generator_kwargs) + self.assertAllEqual([4], loss.shape) + + def test_discriminator_no_reduction(self): + loss = self._d_loss_fn( + reduction=tf_losses.Reduction.NONE, **self._discriminator_kwargs) + self.assertAllEqual([4], loss.shape) + + def test_generator_patch(self): + patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in + self._generator_kwargs.items()} + loss = self._g_loss_fn(**patch_args) + with self.test_session(): + self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5) + + def test_discriminator_patch(self): + patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in + self._discriminator_kwargs.items()} + loss = self._d_loss_fn(**patch_args) + with self.test_session(): + self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5) + + def test_generator_loss_with_placeholder_for_logits(self): + gen_logits = array_ops.placeholder(dtypes.float32, shape=(None, 4)) + one_hot_labels = array_ops.placeholder(dtypes.int32, shape=(None, 4)) + + loss = self._g_loss_fn(gen_logits, one_hot_labels) + with self.test_session() as sess: + loss = sess.run( + loss, feed_dict={ + gen_logits: self._discriminator_gen_classification_logits_np, + one_hot_labels: self._one_hot_labels_np, + }) + self.assertAlmostEqual(self._expected_g_loss, loss, 5) + + def test_discriminator_loss_with_placeholder_for_logits_and_weights(self): + gen_logits = array_ops.placeholder(dtypes.float32, shape=(None, 4)) + real_logits = array_ops.placeholder(dtypes.float32, shape=(None, 4)) + one_hot_labels = array_ops.placeholder(dtypes.int32, shape=(None, 4)) + + loss = self._d_loss_fn(gen_logits, real_logits, one_hot_labels) + + with self.test_session() as sess: + loss = sess.run( + loss, feed_dict={ + gen_logits: self._discriminator_gen_classification_logits_np, + real_logits: self._discriminator_real_classification_logits_np, + one_hot_labels: self._one_hot_labels_np, + }) + self.assertAlmostEqual(self._expected_d_loss, loss, 5) + + def test_generator_with_python_scalar_weight(self): + loss = self._g_loss_fn(weights=self._weights, **self._generator_kwargs) + with self.test_session(): + self.assertAlmostEqual(self._expected_g_loss * self._weights, + loss.eval(), 4) + + def test_discriminator_with_python_scalar_weight(self): + loss = self._d_loss_fn( + real_weights=self._weights, generated_weights=self._weights, + **self._discriminator_kwargs) + with self.test_session(): + self.assertAlmostEqual(self._expected_d_loss * self._weights, + loss.eval(), 4) + + def test_generator_with_scalar_tensor_weight(self): + loss = self._g_loss_fn( + weights=constant_op.constant(self._weights), **self._generator_kwargs) + with self.test_session(): + self.assertAlmostEqual(self._expected_g_loss * self._weights, + loss.eval(), 4) + + def test_discriminator_with_scalar_tensor_weight(self): + weights = constant_op.constant(self._weights) + loss = self._d_loss_fn(real_weights=weights, generated_weights=weights, + **self._discriminator_kwargs) + with self.test_session(): + self.assertAlmostEqual(self._expected_d_loss * self._weights, + loss.eval(), 4) + + def test_generator_add_summaries(self): + self.assertEqual(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + self._g_loss_fn(add_summaries=True, **self._generator_kwargs) + self.assertLess(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + + def test_discriminator_add_summaries(self): + self.assertEqual(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + self._d_loss_fn(add_summaries=True, **self._discriminator_kwargs) + self.assertLess(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + + +class _PenaltyTest(object): + + def test_all_correct(self): + loss = self._penalty_fn(**self._kwargs) + self.assertEqual(self._expected_dtype, loss.dtype) + self.assertEqual(self._expected_op_name, loss.op.name) + with self.test_session(): + variables.global_variables_initializer().run() + self.assertAlmostEqual(self._expected_loss, loss.eval(), 6) + + def test_loss_collection(self): + self.assertEqual(0, len(ops.get_collection('collection'))) + self._penalty_fn(loss_collection='collection', **self._kwargs) + self.assertEqual(1, len(ops.get_collection('collection'))) + + def test_no_reduction(self): + loss = self._penalty_fn(reduction=tf_losses.Reduction.NONE, **self._kwargs) + self.assertAllEqual([self._batch_size], loss.shape) + + def test_python_scalar_weight(self): + loss = self._penalty_fn(weights=2.3, **self._kwargs) + with self.test_session(): + variables.global_variables_initializer().run() + self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3) + + def test_scalar_tensor_weight(self): + loss = self._penalty_fn(weights=constant_op.constant(2.3), **self._kwargs) + with self.test_session(): + variables.global_variables_initializer().run() + self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3) + + +class GradientPenaltyTest(test.TestCase, _PenaltyTest): + """Tests for wasserstein_gradient_penalty.""" + + def setUp(self): + super(GradientPenaltyTest, self).setUp() + self._penalty_fn = tfgan_losses.wasserstein_gradient_penalty + self._generated_data_np = [[3.1, 2.3, -12.3, 32.1]] + self._real_data_np = [[-12.3, 23.2, 16.3, -43.2]] + self._expected_dtype = dtypes.float32 + + with variable_scope.variable_scope('fake_scope') as self._scope: + self._discriminator_fn(0.0, 0.0) + + self._kwargs = { + 'generated_data': constant_op.constant( + self._generated_data_np, dtype=self._expected_dtype), + 'real_data': constant_op.constant( + self._real_data_np, dtype=self._expected_dtype), + 'generator_inputs': None, + 'discriminator_fn': self._discriminator_fn, + 'discriminator_scope': self._scope, + } + self._expected_loss = 9.00000 + self._expected_op_name = 'weighted_loss/value' + self._batch_size = 1 + + def _discriminator_fn(self, inputs, _): + return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs + + def test_loss_with_placeholder(self): + generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) + real_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) + + loss = tfgan_losses.wasserstein_gradient_penalty( + generated_data, + real_data, + self._kwargs['generator_inputs'], + self._kwargs['discriminator_fn'], + self._kwargs['discriminator_scope']) + self.assertEqual(generated_data.dtype, loss.dtype) + + with self.test_session() as sess: + variables.global_variables_initializer().run() + loss = sess.run(loss, + feed_dict={ + generated_data: self._generated_data_np, + real_data: self._real_data_np, + }) + self.assertAlmostEqual(self._expected_loss, loss, 5) + + def test_reuses_scope(self): + """Test that gradient penalty reuses discriminator scope.""" + num_vars = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) + tfgan_losses.wasserstein_gradient_penalty(**self._kwargs) + self.assertEqual( + num_vars, len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))) + + +class MutualInformationPenaltyTest(test.TestCase, _PenaltyTest): + """Tests for mutual_information_penalty.""" + + def setUp(self): + super(MutualInformationPenaltyTest, self).setUp() + self._penalty_fn = tfgan_losses.mutual_information_penalty + self._structured_generator_inputs = [1.0, 2.0] + self._predicted_distributions = [categorical.Categorical(logits=[1.0, 2.0]), + normal.Normal([0.0], [1.0])] + self._expected_dtype = dtypes.float32 + + self._kwargs = { + 'structured_generator_inputs': self._structured_generator_inputs, + 'predicted_distributions': self._predicted_distributions, + } + self._expected_loss = 1.61610 + self._expected_op_name = 'mul' + self._batch_size = 2 + + +class CombineAdversarialLossTest(test.TestCase): + """Tests for combine_adversarial_loss.""" + + def setUp(self): + super(CombineAdversarialLossTest, self).setUp() + self._generated_data_np = [[3.1, 2.3, -12.3, 32.1]] + self._real_data_np = [[-12.3, 23.2, 16.3, -43.2]] + self._generated_data = constant_op.constant( + self._generated_data_np, dtype=dtypes.float32) + self._real_data = constant_op.constant( + self._real_data_np, dtype=dtypes.float32) + self._generated_inputs = None + self._expected_loss = 9.00000 + + def _test_correct_helper(self, use_weight_factor): + variable_list = [variables.Variable(1.0)] + main_loss = variable_list[0] * 2 + adversarial_loss = variable_list[0] * 3 + gradient_ratio_epsilon = 1e-6 + if use_weight_factor: + weight_factor = constant_op.constant(2.0) + gradient_ratio = None + adv_coeff = 2.0 + expected_loss = 1.0 * 2 + adv_coeff * 1.0 * 3 + else: + weight_factor = None + gradient_ratio = constant_op.constant(0.5) + adv_coeff = 2.0 / (3 * 0.5 + gradient_ratio_epsilon) + expected_loss = 1.0 * 2 + adv_coeff * 1.0 * 3 + combined_loss = tfgan_losses.combine_adversarial_loss( + main_loss, + adversarial_loss, + weight_factor=weight_factor, + gradient_ratio=gradient_ratio, + gradient_ratio_epsilon=gradient_ratio_epsilon, + variables=variable_list) + + with self.test_session(use_gpu=True): + variables.global_variables_initializer().run() + self.assertNear(expected_loss, combined_loss.eval(), 1e-5) + + def test_correct_useweightfactor(self): + self._test_correct_helper(True) + + def test_correct_nouseweightfactor(self): + self._test_correct_helper(False) + + def _test_no_weight_skips_adversarial_loss_helper(self, use_weight_factor): + """Test the 0 adversarial weight or grad ratio skips adversarial loss.""" + main_loss = constant_op.constant(1.0) + adversarial_loss = constant_op.constant(1.0) + + weight_factor = 0.0 if use_weight_factor else None + gradient_ratio = None if use_weight_factor else 0.0 + + combined_loss = tfgan_losses.combine_adversarial_loss( + main_loss, + adversarial_loss, + weight_factor=weight_factor, + gradient_ratio=gradient_ratio, + gradient_summaries=False) + + with self.test_session(use_gpu=True): + self.assertEqual(1.0, combined_loss.eval()) + + def test_no_weight_skips_adversarial_loss_useweightfactor(self): + self._test_no_weight_skips_adversarial_loss_helper(True) + + def test_no_weight_skips_adversarial_loss_nouseweightfactor(self): + self._test_no_weight_skips_adversarial_loss_helper(False) + + def test_stable_global_norm_avoids_overflow(self): + tensors = [array_ops.ones([4]), array_ops.ones([4, 4]) * 1e19, None] + gnorm_is_inf = math_ops.is_inf(clip_ops.global_norm(tensors)) + stable_gnorm_is_inf = math_ops.is_inf( + tfgan_losses._numerically_stable_global_norm(tensors)) + + with self.test_session(use_gpu=True): + self.assertTrue(gnorm_is_inf.eval()) + self.assertFalse(stable_gnorm_is_inf.eval()) + + def test_stable_global_norm_unchanged(self): + """Test that preconditioning doesn't change global norm value.""" + random_seed.set_random_seed(1234) + tensors = [random_ops.random_uniform([3]*i, -10.0, 10.0) for i in range(6)] + gnorm = clip_ops.global_norm(tensors) + precond_gnorm = tfgan_losses._numerically_stable_global_norm(tensors) + + with self.test_session(use_gpu=True) as sess: + for _ in range(10): # spot check closeness on more than one sample. + gnorm_np, precond_gnorm_np = sess.run([gnorm, precond_gnorm]) + self.assertNear(gnorm_np, precond_gnorm_np, 1e-5) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/losses/python/losses_wargs.py b/tensorflow/contrib/gan/python/losses/python/losses_wargs.py new file mode 100644 index 00000000000..f212bdcf30b --- /dev/null +++ b/tensorflow/contrib/gan/python/losses/python/losses_wargs.py @@ -0,0 +1,27 @@ +# Copyright 2017 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TFGAN grouped API. Please see README.md for details and usage.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.losses.python import losses_impl +from tensorflow.contrib.gan.python.losses.python.losses_impl import * +# pylint: enable=wildcard-import + +from tensorflow.python.util.all_util import remove_undocumented + +remove_undocumented(__name__, losses_impl.__all__) diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses.py new file mode 100644 index 00000000000..1a50b3f5880 --- /dev/null +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses.py @@ -0,0 +1,27 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TFGAN utilities for loss functions that accept GANModel namedtuples.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.losses.python import tuple_losses_impl +from tensorflow.contrib.gan.python.losses.python.tuple_losses_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = tuple_losses_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py new file mode 100644 index 00000000000..8805633deeb --- /dev/null +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py @@ -0,0 +1,203 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TFGAN utilities for loss functions that accept GANModel namedtuples. + +Example: + ```python + # `tfgan.losses.args` losses take individual arguments. + w_loss = tfgan.losses.args.wasserstein_discriminator_loss( + discriminator_real_outputs, + discriminator_gen_outputs) + + # `tfgan.losses` losses take GANModel namedtuples. + w_loss2 = tfgan.losses.wasserstein_discriminator_loss(gan_model) + ``` +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.losses.python import losses_impl +from tensorflow.python.util import tf_inspect + + +__all__ = [ + 'acgan_discriminator_loss', + 'acgan_generator_loss', + 'least_squares_discriminator_loss', + 'least_squares_generator_loss', + 'modified_discriminator_loss', + 'modified_generator_loss', + 'minimax_discriminator_loss', + 'minimax_generator_loss', + 'wasserstein_discriminator_loss', + 'wasserstein_generator_loss', + 'wasserstein_gradient_penalty', + 'mutual_information_penalty', + 'combine_adversarial_loss', +] + + +def _args_to_gan_model(loss_fn): + """Converts a loss taking individual args to one taking a GANModel namedtuple. + + The new function has the same name as the original one. + + Args: + loss_fn: A python function taking a `GANModel` object and returning a loss + Tensor calculated from that object. The shape of the loss depends on + `reduction`. + + Returns: + A new function that takes a GANModel namedtuples and returns the same loss. + """ + # Match arguments in `loss_fn` to elements of `namedtuple`. + # TODO(joelshor): Properly handle `varargs` and `keywords`. + argspec = tf_inspect.getargspec(loss_fn) + defaults = argspec.defaults or [] + + required_args = set(argspec.args[:-len(defaults)]) + args_with_defaults = argspec.args[-len(defaults):] + default_args_dict = dict(zip(args_with_defaults, defaults)) + + def new_loss_fn(gan_model, **kwargs): # pylint:disable=missing-docstring + gan_model_dict = gan_model._asdict() + + # Make sure non-tuple required args are supplied. + args_from_tuple = set(argspec.args).intersection(set(gan_model._fields)) + required_args_not_from_tuple = required_args - args_from_tuple + for arg in required_args_not_from_tuple: + if arg not in kwargs: + raise ValueError('`%s` must be supplied to %s loss function.' % ( + arg, loss_fn.__name__)) + + # Make sure tuple args aren't also supplied as keyword args. + ambiguous_args = set(gan_model._fields).intersection(set(kwargs.keys())) + if ambiguous_args: + raise ValueError( + 'The following args are present in both the tuple and keyword args ' + 'for %s: %s' % (loss_fn.__name__, ambiguous_args)) + + # Add required args to arg dictionary. + required_args_from_tuple = required_args.intersection(args_from_tuple) + for arg in required_args_from_tuple: + assert arg not in kwargs + kwargs[arg] = gan_model_dict[arg] + + # Add arguments that have defaults. + for arg in default_args_dict: + val_from_tuple = gan_model_dict[arg] if arg in gan_model_dict else None + val_from_kwargs = kwargs[arg] if arg in kwargs else None + assert not (val_from_tuple is not None and val_from_kwargs is not None) + kwargs[arg] = (val_from_tuple if val_from_tuple is not None else + val_from_kwargs if val_from_kwargs is not None else + default_args_dict[arg]) + + return loss_fn(**kwargs) + + new_docstring = """The gan_model version of %s.""" % loss_fn.__name__ + new_loss_fn.__docstring__ = new_docstring + new_loss_fn.__name__ = loss_fn.__name__ + new_loss_fn.__module__ = loss_fn.__module__ + return new_loss_fn + + +# Wasserstein losses from `Wasserstein GAN` (https://arxiv.org/abs/1701.07875). +wasserstein_generator_loss = _args_to_gan_model( + losses_impl.wasserstein_generator_loss) +wasserstein_discriminator_loss = _args_to_gan_model( + losses_impl.wasserstein_discriminator_loss) +wasserstein_gradient_penalty = _args_to_gan_model( + losses_impl.wasserstein_gradient_penalty) + +# ACGAN losses from `Conditional Image Synthesis With Auxiliary Classifier GANs` +# (https://arxiv.org/abs/1610.09585). +acgan_discriminator_loss = _args_to_gan_model( + losses_impl.acgan_discriminator_loss) +acgan_generator_loss = _args_to_gan_model( + losses_impl.acgan_generator_loss) + + +# Original losses from `Generative Adversarial Nets` +# (https://arxiv.org/abs/1406.2661). +minimax_discriminator_loss = _args_to_gan_model( + losses_impl.minimax_discriminator_loss) +minimax_generator_loss = _args_to_gan_model( + losses_impl.minimax_generator_loss) +modified_discriminator_loss = _args_to_gan_model( + losses_impl.modified_discriminator_loss) +modified_generator_loss = _args_to_gan_model( + losses_impl.modified_generator_loss) + + +# Least Squares loss from `Least Squares Generative Adversarial Networks` +# (https://arxiv.org/abs/1611.04076). +least_squares_generator_loss = _args_to_gan_model( + losses_impl.least_squares_generator_loss) +least_squares_discriminator_loss = _args_to_gan_model( + losses_impl.least_squares_discriminator_loss) + + +# InfoGAN loss from `InfoGAN: Interpretable Representation Learning by +# `Information Maximizing Generative Adversarial Nets` +# https://arxiv.org/abs/1606.03657 +mutual_information_penalty = _args_to_gan_model( + losses_impl.mutual_information_penalty) + + +def combine_adversarial_loss(gan_loss, + gan_model, + non_adversarial_loss, + weight_factor=None, + gradient_ratio=None, + gradient_ratio_epsilon=1e-6, + scalar_summaries=True, + gradient_summaries=True): + """Combine adversarial loss and main loss. + + Uses `combine_adversarial_loss` to combine the losses, and returns + a modified GANLoss namedtuple. + + Args: + gan_loss: A GANLoss namedtuple. Assume the GANLoss.generator_loss is the + adversarial loss. + gan_model: A GANModel namedtuple. Used to access the generator's variables. + non_adversarial_loss: Same as `main_loss` from + `combine_adversarial_loss`. + weight_factor: Same as `weight_factor` from + `combine_adversarial_loss`. + gradient_ratio: Same as `gradient_ratio` from + `combine_adversarial_loss`. + gradient_ratio_epsilon: Same as `gradient_ratio_epsilon` from + `combine_adversarial_loss`. + scalar_summaries: Same as `scalar_summaries` from + `combine_adversarial_loss`. + gradient_summaries: Same as `gradient_summaries` from + `combine_adversarial_loss`. + + Returns: + A modified GANLoss namedtuple, with `non_adversarial_loss` included + appropriately. + """ + combined_loss = losses_impl.combine_adversarial_loss( + non_adversarial_loss, + gan_loss.generator_loss, + weight_factor, + gradient_ratio, + gradient_ratio_epsilon, + gan_model.generator_variables, + scalar_summaries, + gradient_summaries) + return gan_loss._replace(generator_loss=combined_loss) diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py new file mode 100644 index 00000000000..f65b20d0b57 --- /dev/null +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py @@ -0,0 +1,134 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for contrib.gan.python.losses.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import numpy as np + +from tensorflow.contrib.gan.python.losses.python import tuple_losses_impl as tfgan_losses + +from tensorflow.python.platform import test + + +class ArgsToGanModelTest(test.TestCase): + + def test_args_to_gan_model(self): + """Test `_args_to_gan_model`.""" + tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg3']) + + def args_loss(arg1, arg2, arg3=3, arg4=4): + return arg1 + arg2 + arg3 + arg4 + + gan_model_loss = tfgan_losses._args_to_gan_model(args_loss) + + # Value is correct. + self.assertEqual(1 + 2 + 5 + 6, + gan_model_loss(tuple_type(1, 2), arg2=5, arg4=6)) + + # Uses tuple argument with defaults. + self.assertEqual(1 + 5 + 3 + 7, + gan_model_loss(tuple_type(1, None), arg2=5, arg4=7)) + + # Uses non-tuple argument with defaults. + self.assertEqual(1 + 5 + 2 + 4, + gan_model_loss(tuple_type(1, 2), arg2=5)) + + # Requires non-tuple, non-default arguments. + with self.assertRaisesRegexp(ValueError, '`arg2` must be supplied'): + gan_model_loss(tuple_type(1, 2)) + + # Can't pass tuple argument outside tuple. + with self.assertRaisesRegexp( + ValueError, 'present in both the tuple and keyword args'): + gan_model_loss(tuple_type(1, 2), arg2=1, arg3=5) + + def test_args_to_gan_model_name(self): + """Test that `_args_to_gan_model` produces correctly named functions.""" + def loss_fn(x): + return x + new_loss_fn = tfgan_losses._args_to_gan_model(loss_fn) + self.assertEqual('loss_fn', new_loss_fn.__name__) + self.assertTrue('The gan_model version of' in new_loss_fn.__docstring__) + + def test_tuple_respects_optional_args(self): + """Test that optional args can be changed with tuple losses.""" + tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg2']) + def args_loss(arg1, arg2, arg3=3): + return arg1 + 2 * arg2 + 3 * arg3 + + loss_fn = tfgan_losses._args_to_gan_model(args_loss) + loss = loss_fn(tuple_type(arg1=-1, arg2=2), arg3=4) + + # If `arg3` were not set properly, this value would be different. + self.assertEqual(-1 + 2 * 2 + 3 * 4, loss) + + +class ConsistentLossesTest(test.TestCase): + + pass + + +def _tuple_from_dict(args_dict): + return collections.namedtuple('Tuple', args_dict.keys())(**args_dict) + + +def add_loss_consistency_test(test_class, loss_name_str, loss_args): + tuple_loss = getattr(tfgan_losses, loss_name_str) + arg_loss = getattr(tfgan_losses.losses_impl, loss_name_str) + + def consistency_test(self): + self.assertEqual(arg_loss.__name__, tuple_loss.__name__) + with self.test_session(): + self.assertEqual(arg_loss(**loss_args).eval(), + tuple_loss(_tuple_from_dict(loss_args)).eval()) + + test_name = 'test_loss_consistency_%s' % loss_name_str + setattr(test_class, test_name, consistency_test) + + +# A list of consistency tests which need to be manually written. +manual_tests = [ + 'acgan_discriminator_loss', + 'acgan_generator_loss', + 'combine_adversarial_loss', + 'mutual_information_penalty', + 'wasserstein_gradient_penalty', +] + +discriminator_keyword_args = { + 'discriminator_real_outputs': np.array([[3.4, 2.3, -2.3], + [6.3, -2.1, 0.2]]), + 'discriminator_gen_outputs': np.array([[6.2, -1.5, 2.3], + [-2.9, -5.1, 0.1]]), +} +generator_keyword_args = { + 'discriminator_gen_outputs': np.array([[6.2, -1.5, 2.3], + [-2.9, -5.1, 0.1]]), +} + + +if __name__ == '__main__': + for loss_name in tfgan_losses.__all__: + if loss_name in manual_tests: continue + keyword_args = (generator_keyword_args if 'generator' in loss_name else + discriminator_keyword_args) + add_loss_consistency_test(ConsistentLossesTest, loss_name, keyword_args) + + test.main() diff --git a/tensorflow/contrib/linear_optimizer/BUILD b/tensorflow/contrib/linear_optimizer/BUILD index 22398d22556..fe2f183ac97 100644 --- a/tensorflow/contrib/linear_optimizer/BUILD +++ b/tensorflow/contrib/linear_optimizer/BUILD @@ -127,7 +127,6 @@ py_test( name = "sdca_estimator_test", srcs = ["python/sdca_estimator_test.py"], srcs_version = "PY2AND3", - tags = ["notsan"], deps = [ ":sdca_estimator_py", "//tensorflow/contrib/layers:layers_py", diff --git a/tensorflow/contrib/nearest_neighbor/BUILD b/tensorflow/contrib/nearest_neighbor/BUILD index 4c507aafb65..84d59cc4be8 100644 --- a/tensorflow/contrib/nearest_neighbor/BUILD +++ b/tensorflow/contrib/nearest_neighbor/BUILD @@ -61,6 +61,7 @@ tf_kernel_library( srcs = ["kernels/hyperplane_lsh_probes.cc"], deps = [ ":hyperplane_lsh_probes", + ":nearest_neighbor_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//third_party/eigen3", diff --git a/tensorflow/contrib/timeseries/__init__.py b/tensorflow/contrib/timeseries/__init__.py index 3cd9366f1f4..11db56b1b7a 100644 --- a/tensorflow/contrib/timeseries/__init__.py +++ b/tensorflow/contrib/timeseries/__init__.py @@ -20,6 +20,7 @@ @@ARModel @@CSVReader +@@NumpyReader @@RandomWindowInputFn @@WholeDatasetInputFn @@predict_continuation_input_fn diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 03cf77e0399..a52c1daf1fb 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1788,6 +1788,7 @@ tf_cuda_library( "common_runtime/process_util.cc", "common_runtime/renamed_device.cc", "common_runtime/rendezvous_mgr.cc", + "common_runtime/rendezvous_util.cc", "common_runtime/resource_variable_read_optimizer.cc", "common_runtime/session.cc", "common_runtime/session_factory.cc", @@ -1831,6 +1832,7 @@ tf_cuda_library( "common_runtime/profile_handler.h", "common_runtime/renamed_device.h", "common_runtime/rendezvous_mgr.h", + "common_runtime/rendezvous_util.h", "common_runtime/session_factory.h", "common_runtime/graph_execution_state.h", "common_runtime/placer.h", @@ -2675,29 +2677,29 @@ tf_cc_test( srcs = ["common_runtime/process_function_library_runtime_test.cc"], linkstatic = tf_kernel_tests_linkstatic(), deps = [ - ":core", ":core_cpu", ":core_cpu_internal", - ":direct_session_internal", ":framework", - ":framework_internal", - ":lib", - ":lib_internal", - ":ops", - ":protos_all_cc", ":test", ":test_main", ":testlib", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:function_ops", - "//tensorflow/cc:functional_ops", "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:function_ops", - "//tensorflow/core/kernels:matmul_op", - "//tensorflow/core/kernels:shape_ops", - "//third_party/eigen3", + ], +) + +tf_cc_test( + name = "common_runtime_rendezvous_util_test", + size = "small", + srcs = ["common_runtime/rendezvous_util_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":core_cpu_internal", + ":lib", + ":test", + ":test_main", ], ) diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 4b239606a84..4aeacc6d612 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -213,6 +213,9 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { FunctionBody** g_body); bool IsLocalTarget(const AttrSlice& attrs); AttrValueMap FixAttrs(const AttrSlice& attrs); + void RunRemote(const Options& opts, Handle handle, + gtl::ArraySlice args, std::vector* rets, + Executor::Args* exec_args, Item* item, DoneCallback done); TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl); }; @@ -557,52 +560,130 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) { return Status::OK(); } +void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle, + gtl::ArraySlice args, + std::vector* rets, + Executor::Args* exec_args, + Item* item, DoneCallback done) { + FunctionCallFrame* frame = exec_args->call_frame; + string target_device = parent_->GetDeviceName(handle); + string source_device = opts.source_device; + Rendezvous* rendezvous = opts.rendezvous; + // TODO(rohanj): Handle alloc_attrs in Rendezvous::Args. + Rendezvous::Args rendez_args; + Status s = + parent_->GetDeviceContext(target_device, &rendez_args.device_context); + if (!s.ok()) { + delete frame; + delete exec_args; + done(s); + return; + } + + // The ProcFLR sends the arguments to the function from the source_device to + // the target_device. So here we receive those arguments. Similarly, when the + // computation is done and stored in *rets, we send the return values back + // to the source_device (caller) so that the ProcFLR can receive them later. + std::vector* remote_args = new std::vector; + ProcessFunctionLibraryRuntime::ReceiveTensorsAsync( + source_device, target_device, "arg_", args.size(), rendez_args, + rendezvous, remote_args, + [frame, remote_args, item, source_device, target_device, rendezvous, + rendez_args, rets, done, exec_args](const Status& status) { + Status s = status; + s = frame->SetArgs(*remote_args); + if (!s.ok()) { + delete frame; + delete remote_args; + delete exec_args; + done(s); + return; + } + item->exec->RunAsync( + *exec_args, + [item, frame, rets, done, source_device, target_device, rendezvous, + rendez_args, remote_args, exec_args](const Status& status) { + item->Unref(); + Status s = status; + if (s.ok()) { + s = frame->ConsumeRetvals(rets); + } + delete frame; + if (!s.ok()) { + delete remote_args; + delete exec_args; + done(s); + return; + } + s = ProcessFunctionLibraryRuntime::SendTensors( + target_device, source_device, "ret_", *rets, rendez_args, + rendezvous); + delete remote_args; + delete exec_args; + done(s); + }); + }); +} + void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, gtl::ArraySlice args, std::vector* rets, DoneCallback done) { if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) { - return done(errors::Cancelled("")); + done(errors::Cancelled("")); + return; } if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) { - return parent_->Run(opts, handle, args, rets, done); + parent_->Run(opts, handle, args, rets, done); + return; } const FunctionBody* fbody = GetFunctionBody(handle); FunctionCallFrame* frame = new FunctionCallFrame(fbody->arg_types, fbody->ret_types); - Status s = frame->SetArgs(args); - if (!s.ok()) { - delete frame; - return done(s); - } + Item* item = nullptr; - s = GetOrCreateItem(handle, &item); + Status s = GetOrCreateItem(handle, &item); if (!s.ok()) { delete frame; - return done(s); + done(s); + return; } DCHECK(opts.runner != nullptr); - Executor::Args exec_args; + Executor::Args* exec_args = new Executor::Args; // Inherit the step_id from the caller. - exec_args.step_id = opts.step_id; - exec_args.rendezvous = opts.rendezvous; - exec_args.stats_collector = opts.stats_collector; - exec_args.call_frame = frame; - exec_args.cancellation_manager = opts.cancellation_manager; - exec_args.step_container = opts.step_container; - exec_args.runner = *opts.runner; + exec_args->step_id = opts.step_id; + exec_args->rendezvous = opts.rendezvous; + exec_args->stats_collector = opts.stats_collector; + exec_args->call_frame = frame; + exec_args->cancellation_manager = opts.cancellation_manager; + exec_args->step_container = opts.step_container; + exec_args->runner = *opts.runner; + + if (opts.remote_execution) { + RunRemote(opts, handle, args, rets, exec_args, item, done); + return; + } + + s = frame->SetArgs(args); + if (!s.ok()) { + delete frame; + delete exec_args; + done(s); + return; + } item->exec->RunAsync( // Executor args - exec_args, + *exec_args, // Done callback. - [item, frame, rets, done](const Status& status) { + [item, frame, rets, done, exec_args](const Status& status) { item->Unref(); Status s = status; if (s.ok()) { - s = frame->GetRetvals(rets); + s = frame->ConsumeRetvals(rets); } delete frame; + delete exec_args; done(s); }); } diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index a9f06c4df03..7eac1674e71 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/executor.h" #include "tensorflow/core/common_runtime/function_testlib.h" +#include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/op.h" @@ -155,6 +156,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { } Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle, + FunctionLibraryRuntime::Options opts, const std::vector& args, std::vector rets) { std::atomic call_count(0); std::function)> runner = @@ -164,7 +166,6 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { }; Notification done; - FunctionLibraryRuntime::Options opts; opts.runner = &runner; std::vector out; Status status; @@ -205,7 +206,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test { if (!status.ok()) { return status; } - return Run(flr, handle, args, std::move(rets)); + FunctionLibraryRuntime::Options opts; + return Run(flr, handle, opts, args, std::move(rets)); } std::unique_ptr GetFuncBody(FunctionLibraryRuntime* flr, @@ -963,15 +965,21 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) { {{"_target", "/job:localhost/replica:0/task:0/cpu:1"}}, &handle)); Tensor y; + FunctionLibraryRuntime::Options opts; + opts.rendezvous = new IntraProcessRendezvous(device_mgr_.get()); + opts.source_device = "/device:CPU:1"; // Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1. - TF_CHECK_OK(Run(flr1_, handle, {}, {&y})); + TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y})); test::ExpectTensorEqual( y, test::AsTensor({"/job:localhost/replica:0/task:0/cpu:1"}, TensorShape({}))); - TF_CHECK_OK(Run(flr2_, handle, {}, {&y})); + opts.remote_execution = true; + opts.source_device = "/job:localhost/replica:0/task:0/cpu:2"; + TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y})); test::ExpectTensorEqual( y, test::AsTensor({"/job:localhost/replica:0/task:0/cpu:1"}, TensorShape({}))); + opts.rendezvous->Unref(); } namespace { diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index 0caec036252..c39bab2348e 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/rendezvous_util.h" #include "tensorflow/core/lib/gtl/map_util.h" namespace tensorflow { @@ -57,6 +58,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime( } } +/* static */ string ProcessFunctionLibraryRuntime::ObtainFunctionTarget( const AttrSlice& attrs) { const AttrValue* value; @@ -66,6 +68,63 @@ string ProcessFunctionLibraryRuntime::ObtainFunctionTarget( return value->s(); } +/* static */ +Status ProcessFunctionLibraryRuntime::SendTensors( + const string& source_device, const string& target_device, + const string& key_prefix, gtl::ArraySlice tensors_to_send, + const Rendezvous::Args& args, Rendezvous* rendezvous) { + std::vector keys; + for (int i = 0; i < tensors_to_send.size(); ++i) { + string name = strings::StrCat(key_prefix, i); + string key = Rendezvous::CreateKey(source_device, i, target_device, name, + FrameAndIter(0, 0)); + keys.push_back(key); + } + TF_RETURN_IF_ERROR( + SendTensorsToRendezvous(rendezvous, args, keys, tensors_to_send)); + return Status::OK(); +} + +/* static */ +void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync( + const string& source_device, const string& target_device, + const string& key_prefix, int64 num_tensors, const Rendezvous::Args& args, + Rendezvous* rendezvous, std::vector* received_tensors, + const StatusCallback& done) { + std::vector keys; + for (int64 i = 0; i < num_tensors; ++i) { + string name = strings::StrCat(key_prefix, i); + string key = Rendezvous::CreateKey(source_device, i, target_device, name, + FrameAndIter(0, 0)); + keys.push_back(key); + } + RecvOutputsFromRendezvousAsync( + rendezvous, args, keys, received_tensors, + [done](const Status& status) { done(status); }); +} + +Status ProcessFunctionLibraryRuntime::GetDeviceContext( + const string& device_name, DeviceContext** device_context) { + *device_context = nullptr; + FunctionLibraryRuntime* flr = GetFLR(device_name); + if (flr == nullptr) { + return errors::InvalidArgument("Device name: ", device_name, " not found."); + } + Device* device = flr->device(); + string device_type = device->parsed_name().type; + if (device_type == "CPU") return Status::OK(); + if (device_type == "GPU") { + auto* dev_info = flr->device()->tensorflow_gpu_device_info(); + if (dev_info) { + *device_context = dev_info->default_context; + return Status::OK(); + } + } + return errors::Internal("Device type: ", device_type, + " is currently unsupported for remote ", + "function executions"); +} + FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR( const string& device_name) { if (flr_map_.find(device_name) == flr_map_.end()) { @@ -105,6 +164,7 @@ FunctionLibraryRuntime::LocalHandle ProcessFunctionLibraryRuntime::GetHandleOnDevice( const string& device_name, FunctionLibraryRuntime::Handle handle) { mutex_lock l(mu_); + CHECK_LE(handle, function_data_.size()); std::pair p = function_data_[handle]; if (p.first != device_name) { @@ -113,6 +173,15 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice( return p.second; } +string ProcessFunctionLibraryRuntime::GetDeviceName( + FunctionLibraryRuntime::Handle handle) { + mutex_lock l(mu_); + CHECK_LE(handle, function_data_.size()); + std::pair p = + function_data_[handle]; + return p.first; +} + Status ProcessFunctionLibraryRuntime::Instantiate( const string& function_name, AttrSlice attrs, FunctionLibraryRuntime::Handle* handle) { @@ -129,15 +198,58 @@ void ProcessFunctionLibraryRuntime::Run( const FunctionLibraryRuntime::Options& opts, FunctionLibraryRuntime::Handle handle, gtl::ArraySlice args, std::vector* rets, FunctionLibraryRuntime::DoneCallback done) { + if (!opts.remote_execution) { + done(errors::InvalidArgument( + "ProcessFunctionLibraryRuntime::Run should only be called when there ", + "is a remote execution.")); + return; + } + FunctionLibraryRuntime* flr = nullptr; + string target_device; { mutex_lock l(mu_); + CHECK_LE(handle, function_data_.size()); std::pair p = function_data_[handle]; + target_device = p.first; flr = GetFLR(p.first); } if (flr != nullptr) { - return flr->Run(opts, handle, args, rets, std::move(done)); + auto rendezvous = opts.rendezvous; + string source_device = opts.source_device; + Rendezvous::Args rendez_args; + Status s = GetDeviceContext(source_device, &rendez_args.device_context); + if (!s.ok()) { + done(s); + return; + } + // Send the args over to the target device. + s = SendTensors(source_device, target_device, "arg_", args, rendez_args, + rendezvous); + if (!s.ok()) { + done(s); + return; + } + std::vector* remote_rets = new std::vector; + flr->Run(opts, handle, args, remote_rets, + [source_device, target_device, rendezvous, remote_rets, rets, done, + rendez_args](const Status& status) { + if (!status.ok()) { + delete remote_rets; + done(status); + return; + } + int64 num_returns = remote_rets->size(); + delete remote_rets; + // Now receive the return values from the target. + ReceiveTensorsAsync(target_device, source_device, "ret_", + num_returns, rendez_args, rendezvous, rets, + done); + }); + } else { + done(errors::Internal("Could not find device")); + return; } } diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h index 2259997005e..2e97bae4b4f 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.h +++ b/tensorflow/core/common_runtime/process_function_library_runtime.h @@ -45,6 +45,31 @@ class ProcessFunctionLibraryRuntime { // attribute, returns "". Canonicalizes the device name. static string ObtainFunctionTarget(const AttrSlice& attrs); + // Sends `tensors_to_send` from `source_device` to `target_device` using + // `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the + // Rendezvous. Method takes references on each of the `tensors_to_send`. + // Method doesn't block. + static Status SendTensors(const string& source_device, + const string& target_device, + const string& key_prefix, + gtl::ArraySlice tensors_to_send, + const Rendezvous::Args& args, + Rendezvous* rendezvous); + + typedef std::function StatusCallback; + + // Receives `received_tensors` from `target_device` (originally sent from + // `source_device`) using `rendezvous`. Uses `key_prefix` to construct the + // keys to be retrieved. Method doesn't block and calls `done` when + // `num_tensors` are fetched. + static void ReceiveTensorsAsync(const string& source_device, + const string& target_device, + const string& key_prefix, int64 num_tensors, + const Rendezvous::Args& args, + Rendezvous* rendezvous, + std::vector* received_tensors, + const StatusCallback& done); + static const char kDefaultFLRDevice[]; // Returns the FunctionLibraryRuntime for the corresponding device_name. FunctionLibraryRuntime* GetFLR(const string& device_name); @@ -85,6 +110,17 @@ class ProcessFunctionLibraryRuntime { FunctionLibraryRuntime::DoneCallback done); private: + // For a given device_name, returns a DeviceContext for copying + // tensors to/from the device. + Status GetDeviceContext(const string& device_name, + DeviceContext** device_context); + + // Looks up the information for the given `handle` and returns the name + // of the device where the function is registered. + string GetDeviceName(FunctionLibraryRuntime::Handle handle); + + friend class FunctionLibraryRuntimeImpl; + mutable mutex mu_; // Holds all the function invocations here. diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc index 1536aedde58..fdbab46f547 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/function_testlib.h" +#include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/platform/test.h" @@ -43,10 +44,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { proc_flr_.reset(new ProcessFunctionLibraryRuntime( device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(), opts)); + rendezvous_ = new IntraProcessRendezvous(device_mgr_.get()); } - Status Run(const string& name, test::function::Attrs attrs, - const std::vector& args, std::vector rets) { + Status Run(const string& name, FunctionLibraryRuntime::Options opts, + test::function::Attrs attrs, const std::vector& args, + std::vector rets) { FunctionLibraryRuntime::Handle handle; Status status = proc_flr_->Instantiate(name, attrs, &handle); if (!status.ok()) { @@ -61,7 +64,6 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { }; Notification done; - FunctionLibraryRuntime::Options opts; opts.runner = &runner; std::vector out; proc_flr_->Run(opts, handle, args, &out, [&status, &done](const Status& s) { @@ -86,6 +88,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test { std::unique_ptr device_mgr_; std::unique_ptr lib_def_; std::unique_ptr proc_flr_; + IntraProcessRendezvous* rendezvous_; }; TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) { @@ -99,6 +102,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) { EXPECT_EQ(flr->device(), devices_[1]); flr = proc_flr_->GetFLR("abc"); EXPECT_EQ(flr, nullptr); + rendezvous_->Unref(); } TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) { @@ -118,69 +122,94 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) { TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) { Init({test::function::XTimesTwo()}); + FunctionLibraryRuntime::Options opts; + opts.source_device = "/job:a/replica:0/task:0/cpu:0"; + opts.rendezvous = rendezvous_; + opts.remote_execution = true; auto x = test::AsTensor({1, 2, 3, 4}); Tensor y; TF_CHECK_OK( - Run("XTimesTwo", + Run("XTimesTwo", opts, {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, {&y})); test::ExpectTensorEqual(y, test::AsTensor({2, 4, 6, 8})); + rendezvous_->Unref(); } TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) { Init({test::function::FindDevice()}); + FunctionLibraryRuntime::Options opts; + opts.source_device = "/job:a/replica:0/task:0/cpu:0"; + opts.rendezvous = rendezvous_; + opts.remote_execution = true; Tensor y; - TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, - {}, {&y})); + TF_CHECK_OK(Run("FindDevice", opts, + {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y})); test::ExpectTensorEqual( y, test::AsTensor({"/job:a/replica:0/task:0/cpu:0"}, TensorShape({}))); + rendezvous_->Unref(); } TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) { Init({test::function::XTimesTwo(), test::function::XTimesFour()}); auto x = test::AsTensor({1, 2, 3, 4}); + FunctionLibraryRuntime::Options opts; + opts.source_device = "/job:a/replica:0/task:0/cpu:0"; + opts.rendezvous = rendezvous_; + opts.remote_execution = true; Tensor y; TF_CHECK_OK( - Run("XTimesTwo", + Run("XTimesTwo", opts, {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, {&y})); test::ExpectTensorEqual(y, test::AsTensor({2, 4, 6, 8})); TF_CHECK_OK( - Run("XTimesFour", + Run("XTimesFour", opts, {{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x}, {&y})); test::ExpectTensorEqual(y, test::AsTensor({4, 8, 12, 16})); + rendezvous_->Unref(); } TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) { Init({test::function::FindDevice()}); + FunctionLibraryRuntime::Options opts; + opts.source_device = "/job:a/replica:0/task:0/cpu:0"; + opts.rendezvous = rendezvous_; + opts.remote_execution = true; Tensor y; - TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, - {}, {&y})); + TF_CHECK_OK(Run("FindDevice", opts, + {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y})); test::ExpectTensorEqual( y, test::AsTensor({"/job:a/replica:0/task:0/cpu:1"}, TensorShape({}))); - TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, - {}, {&y})); + TF_CHECK_OK(Run("FindDevice", opts, + {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y})); test::ExpectTensorEqual( y, test::AsTensor({"/job:a/replica:0/task:0/cpu:1"}, TensorShape({}))); + rendezvous_->Unref(); } TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) { Init({test::function::FindDevice()}); + FunctionLibraryRuntime::Options opts; + opts.source_device = "/job:a/replica:0/task:0/cpu:0"; + opts.rendezvous = rendezvous_; + opts.remote_execution = true; Tensor y; - TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, - {}, {&y})); + TF_CHECK_OK(Run("FindDevice", opts, + {{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y})); test::ExpectTensorEqual( y, test::AsTensor({"/job:a/replica:0/task:0/cpu:0"}, TensorShape({}))); - TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, - {}, {&y})); + TF_CHECK_OK(Run("FindDevice", opts, + {{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y})); test::ExpectTensorEqual( y, test::AsTensor({"/job:a/replica:0/task:0/cpu:1"}, TensorShape({}))); + rendezvous_->Unref(); } } // anonymous namespace diff --git a/tensorflow/core/common_runtime/rendezvous_util.cc b/tensorflow/core/common_runtime/rendezvous_util.cc new file mode 100644 index 00000000000..a0d409e7735 --- /dev/null +++ b/tensorflow/core/common_runtime/rendezvous_util.cc @@ -0,0 +1,119 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/common_runtime/rendezvous_util.h" + +namespace tensorflow { + +Status SendTensorsToRendezvous(Rendezvous* rendezvous, + const Rendezvous::Args& args, + const std::vector& keys, + gtl::ArraySlice tensors_to_send) { + if (keys.size() != tensors_to_send.size()) { + return errors::InvalidArgument( + "keys and tensors_to_send are not the same size. keys.size() = ", + keys.size(), "; tensors_to_send.size() = ", tensors_to_send.size()); + } + Rendezvous::ParsedKey parsed; + for (int i = 0; i < keys.size(); ++i) { + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(keys[i], &parsed)); + TF_RETURN_IF_ERROR( + rendezvous->Send(parsed, args, tensors_to_send[i], false)); + } + return Status::OK(); +} + +void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, + const Rendezvous::Args& args, + const std::vector& keys, + std::vector* received_tensors, + const StatusCallback& done) { + if (keys.empty()) { + done(Status::OK()); + return; + } + received_tensors->reserve(keys.size()); + std::vector> arguments; + for (int i = 0; i < keys.size(); ++i) { + Rendezvous::ParsedKey parsed; + Status s = Rendezvous::ParseKey(keys[i], &parsed); + received_tensors->push_back(Tensor()); + if (!s.ok()) { + done(s); + return; + } + arguments.push_back( + std::make_tuple(keys[i], &((*received_tensors)[i]), parsed)); + } + + typedef struct { + mutex mu; + int64 done_counter; + Status shared_status = Status::OK(); + } CallState; + CallState* call_state = new CallState; + call_state->done_counter = keys.size(); + for (auto& p : arguments) { + const string& key = std::get<0>(p); + Tensor* val = std::get<1>(p); + Rendezvous::ParsedKey parsed = std::get<2>(p); + rendezvous->RecvAsync( + parsed, args, + [val, done, key, call_state](const Status& s, + const Rendezvous::Args& send_args, + const Rendezvous::Args& recv_args, + const Tensor& v, const bool is_dead) { + Status status = s; + if (status.ok()) { + *val = v; + if (is_dead) { + status = errors::InvalidArgument("The tensor returned for ", key, + " was not valid."); + } + } + call_state->mu.lock(); + call_state->shared_status.Update(status); + call_state->done_counter--; + // If we are the last async call to return, call the done callback. + if (call_state->done_counter == 0) { + const Status& final_status = call_state->shared_status; + call_state->mu.unlock(); + done(final_status); + delete call_state; + return; + } + call_state->mu.unlock(); + }); + } +} + +Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out, + const Rendezvous::Args& args) { + // Receives values requested by the caller. + Rendezvous::ParsedKey parsed; + for (auto& p : *out) { + const string& key = p.first; + Tensor* val = &p.second; + bool is_dead = false; + TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed)); + TF_RETURN_IF_ERROR(rendezvous->Recv(parsed, args, val, &is_dead)); + if (is_dead) { + return errors::InvalidArgument("The tensor returned for ", key, + " was not valid."); + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/rendezvous_util.h b/tensorflow/core/common_runtime/rendezvous_util.h new file mode 100644 index 00000000000..a54f8c3f948 --- /dev/null +++ b/tensorflow/core/common_runtime/rendezvous_util.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_ + +#include + +#include "tensorflow/core/framework/rendezvous.h" + +namespace tensorflow { + +typedef std::map NamedTensors; +typedef std::function StatusCallback; + +// Uses `rendezvous` to send tensors in `in`. +Status SendTensorsToRendezvous(Rendezvous* rendezvous, + const Rendezvous::Args& args, + const std::vector& keys, + gtl::ArraySlice tensors_to_send); + +void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, + const Rendezvous::Args& args, + const std::vector& keys, + std::vector* received_tensors, + const StatusCallback& done); + +Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out, + const Rendezvous::Args& args); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_ diff --git a/tensorflow/core/common_runtime/rendezvous_util_test.cc b/tensorflow/core/common_runtime/rendezvous_util_test.cc new file mode 100644 index 00000000000..8ee9f4d5226 --- /dev/null +++ b/tensorflow/core/common_runtime/rendezvous_util_test.cc @@ -0,0 +1,94 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/common_runtime/rendezvous_util.h" + +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +class RendezvousUtilTest : public ::testing::Test { + public: + RendezvousUtilTest() { rendez_ = NewLocalRendezvous(); } + + ~RendezvousUtilTest() override { rendez_->Unref(); } + + Rendezvous* rendez_; +}; + +// string -> Tensor +Tensor V(const string& content) { + Tensor tensor(DT_STRING, TensorShape({})); + tensor.scalar()() = content; + return tensor; +} + +// Tensor -> string +string V(const Tensor& tensor) { + CHECK_EQ(tensor.dtype(), DT_STRING); + CHECK(TensorShapeUtils::IsScalar(tensor.shape())); + return tensor.scalar()(); +} + +string MakeStringKey(const string& name) { + return Rendezvous::CreateKey( + "/job:localhost/replica:0/task:0/device:CPU:0", 0, + "/job:localhost/replica:0/task:0/device:GPU:0", name, FrameAndIter(0, 0)); +} + +TEST_F(RendezvousUtilTest, SendBeforeRecv) { + // Fire off sends before receive the tensors. + Rendezvous::Args args; + TF_ASSERT_OK(SendTensorsToRendezvous( + rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")}, + {V("hello1"), V("hello2")})); + + Notification n; + std::vector received_keys; + RecvOutputsFromRendezvousAsync( + rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")}, + &received_keys, [&n](const Status& status) { n.Notify(); }); + n.WaitForNotification(); + + EXPECT_EQ(2, received_keys.size()); + EXPECT_EQ("hello1", V(received_keys[0])); + EXPECT_EQ("hello2", V(received_keys[1])); +} + +TEST_F(RendezvousUtilTest, RecvBeforeSend) { + // Fire off recvs, wait for a notification in the callback. + Rendezvous::Args args; + + Notification n; + std::vector received_keys; + RecvOutputsFromRendezvousAsync( + rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")}, + &received_keys, [&n](const Status& status) { n.Notify(); }); + + TF_ASSERT_OK(SendTensorsToRendezvous( + rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")}, + {V("hello1"), V("hello2")})); + + n.WaitForNotification(); + + EXPECT_EQ(2, received_keys.size()); + EXPECT_EQ("hello1", V(received_keys[0])); + EXPECT_EQ("hello2", V(received_keys[1])); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc index 4f52cb0b4d7..c9f2c247326 100644 --- a/tensorflow/core/debug/debug_io_utils.cc +++ b/tensorflow/core/debug/debug_io_utils.cc @@ -26,12 +26,13 @@ limitations under the License. #include "grpc++/create_channel.h" #else // winsock2.h is used in grpc, so Ws2_32.lib is needed -#pragma comment(lib,"Ws2_32.lib") +#pragma comment(lib, "Ws2_32.lib") #endif // #ifndef PLATFORM_WINDOWS #include "tensorflow/core/debug/debugger_event_metadata.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/summary.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/io/path.h" diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 1169b86c9db..411b6d861b7 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/memory_types.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/common_runtime/rendezvous_util.h" #include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h" #include "tensorflow/core/framework/cancellation.h" @@ -321,116 +322,25 @@ Status GraphMgr::DeregisterAll() { return Status::OK(); } -Status GraphMgr::SendInputsToRendezvous(Rendezvous* rendezvous, - const NamedTensors& in) { - Rendezvous::ParsedKey parsed; - for (const auto& p : in) { - const string& key = p.first; - const Tensor& val = p.second; - - Status s = Rendezvous::ParseKey(key, &parsed); - if (s.ok()) { - s = rendezvous->Send(parsed, Rendezvous::Args(), val, false); - } - if (!s.ok()) { - return s; - } - } - return Status::OK(); -} - -Status GraphMgr::RecvOutputsFromRendezvous(Rendezvous* rendezvous, - NamedTensors* out) { - // Receives values requested by the caller. - Rendezvous::ParsedKey parsed; - for (auto& p : *out) { - const string& key = p.first; - Tensor* val = &p.second; - bool is_dead = false; - Status s = Rendezvous::ParseKey(key, &parsed); - if (s.ok()) { - s = rendezvous->Recv(parsed, Rendezvous::Args(), val, &is_dead); - } - if (is_dead) { - s = errors::InvalidArgument("The tensor returned for ", key, - " was not valid."); - } - if (!s.ok()) return s; - } - return Status::OK(); -} - -void GraphMgr::RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, - NamedTensors* out, - const StatusCallback& done) { - if (out->empty()) { - done(Status::OK()); - return; - } - // We compute the args before calling RecvAsync because we need to ensure that - // out isn't being iterated over after done is called, since done deletes out. - std::vector> args; - for (auto& p : *out) { - Rendezvous::ParsedKey parsed; - Status s = Rendezvous::ParseKey(p.first, &parsed); - if (!s.ok()) { - done(s); - return; - } - args.push_back(std::make_tuple(p.first, &p.second, parsed)); - } - - typedef struct { - mutex mu; - int done_counter; - Status shared_status = Status::OK(); - } CallState; - CallState* call_state = new CallState; - call_state->done_counter = out->size(); - for (auto& p : args) { - const string& key = std::get<0>(p); - Tensor* val = std::get<1>(p); - Rendezvous::ParsedKey parsed = std::get<2>(p); - rendezvous->RecvAsync( - parsed, Rendezvous::Args(), - [val, done, key, call_state](const Status& s, - const Rendezvous::Args& send_args, - const Rendezvous::Args& recv_args, - const Tensor& v, const bool is_dead) { - Status status = s; - if (status.ok()) { - *val = v; - if (is_dead) { - status = errors::InvalidArgument("The tensor returned for ", key, - " was not valid."); - } - } - call_state->mu.lock(); - call_state->shared_status.Update(status); - call_state->done_counter--; - // If we are the last async call to return, call the done callback. - if (call_state->done_counter == 0) { - const Status& final_status = call_state->shared_status; - call_state->mu.unlock(); - done(final_status); - delete call_state; - return; - } - call_state->mu.unlock(); - }); - } -} - Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) { Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); - Status s = SendInputsToRendezvous(rendezvous, in); + std::vector keys; + std::vector tensors_to_send; + keys.reserve(in.size()); + tensors_to_send.reserve(in.size()); + for (const auto& p : in) { + keys.push_back(p.first); + tensors_to_send.push_back(p.second); + } + Status s = SendTensorsToRendezvous(rendezvous, Rendezvous::Args(), keys, + tensors_to_send); rendezvous->Unref(); return s; } Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) { Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); - Status s = RecvOutputsFromRendezvous(rendezvous, out); + Status s = RecvOutputsFromRendezvous(rendezvous, out, Rendezvous::Args()); rendezvous->Unref(); return s; } @@ -438,11 +348,24 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) { void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out, StatusCallback done) { Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id); - RecvOutputsFromRendezvousAsync(rendezvous, out, - [done, rendezvous](const Status s) { - rendezvous->Unref(); - done(s); - }); + std::vector keys; + std::vector* received_keys = new std::vector; + keys.reserve(out->size()); + received_keys->reserve(out->size()); + for (const auto& p : *out) { + keys.push_back(p.first); + received_keys->push_back(p.second); + } + RecvOutputsFromRendezvousAsync( + rendezvous, Rendezvous::Args(), keys, received_keys, + [done, rendezvous, received_keys, out, keys](const Status s) { + rendezvous->Unref(); + for (int i = 0; i < keys.size(); ++i) { + (*out)[keys[i]] = (*received_keys)[i]; + } + delete received_keys; + done(s); + }); } void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, @@ -484,7 +407,16 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id, // Sends values specified by the caller. if (s.ok()) { - s = SendInputsToRendezvous(rendezvous, in); + std::vector keys; + std::vector tensors_to_send; + keys.reserve(in.size()); + tensors_to_send.reserve(in.size()); + for (auto& p : in) { + keys.push_back(p.first); + tensors_to_send.push_back(p.second); + } + s = SendTensorsToRendezvous(rendezvous, Rendezvous::Args(), keys, + tensors_to_send); } if (!s.ok()) { diff --git a/tensorflow/core/distributed_runtime/graph_mgr.h b/tensorflow/core/distributed_runtime/graph_mgr.h index d719dd4ec6b..c6f55e4ef9c 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.h +++ b/tensorflow/core/distributed_runtime/graph_mgr.h @@ -169,11 +169,6 @@ class GraphMgr { void BuildCostModel(Item* item, StepStatsCollector* collector, CostGraphDef* cost_graph); - Status SendInputsToRendezvous(Rendezvous* rendezvous, const NamedTensors& in); - Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out); - void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, NamedTensors* out, - const StatusCallback& done); - Status InitItem(const string& session, const GraphDef& gdef, const GraphOptions& graph_options, const DebugOptions& debug_options, Item* item); diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD index 6512f5bf5dd..ce8323303fe 100644 --- a/tensorflow/core/distributed_runtime/rpc/BUILD +++ b/tensorflow/core/distributed_runtime/rpc/BUILD @@ -465,7 +465,6 @@ tf_cuda_cc_test( linkstatic = tf_kernel_tests_linkstatic(), tags = tf_cuda_tests_tags() + [ "no_oss", # b/62956105: port conflicts. - "noguitar", # b/64805119 ], deps = [ ":grpc_channel", diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 317707644b3..e3842ea58d3 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -426,6 +426,10 @@ class FunctionLibraryRuntime { StepStatsCollector* stats_collector = nullptr; std::function)>* runner = nullptr; + + // Parameters for remote function execution. + bool remote_execution = false; + string source_device = ""; // Fully specified device name. }; typedef std::function DoneCallback; virtual void Run(const Options& opts, Handle handle, diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc index 62b504691b2..962bc11ccbd 100644 --- a/tensorflow/core/framework/op_def_builder.cc +++ b/tensorflow/core/framework/op_def_builder.cc @@ -110,6 +110,37 @@ bool ConsumeAttrNumber(StringPiece* sp, int64* out) { } \ } while (false) +bool ConsumeCompoundAttrType(StringPiece* sp, StringPiece* out) { + auto capture_begin = sp->begin(); + if (sp->Consume("numbertype") || sp->Consume("numerictype") || + sp->Consume("quantizedtype") || sp->Consume("realnumbertype") || + sp->Consume("realnumberictype")) { + *out = StringPiece(capture_begin, sp->begin() - capture_begin); + return true; + } + return false; +} + +bool ProcessCompoundType(const StringPiece type_string, AttrValue* allowed) { + if (type_string == "numbertype" || type_string == "numerictype") { + for (DataType dt : NumberTypes()) { + allowed->mutable_list()->add_type(dt); + } + } else if (type_string == "quantizedtype") { + for (DataType dt : QuantizedTypes()) { + allowed->mutable_list()->add_type(dt); + } + } else if (type_string == "realnumbertype" || + type_string == "realnumerictype") { + for (DataType dt : RealNumberTypes()) { + allowed->mutable_list()->add_type(dt); + } + } else { + return false; + } + return true; +} + void FinalizeAttr(StringPiece spec, OpDef* op_def, std::vector* errors) { OpDef::AttrDef* attr = op_def->add_attr(); @@ -123,6 +154,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, // Read "" or "list()". bool is_list = ConsumeListPrefix(&spec); string type; + StringPiece type_string; // Used if type == "type" if (spec.Consume("string")) { type = "string"; } else if (spec.Consume("int")) { @@ -139,29 +171,15 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, type = "tensor"; } else if (spec.Consume("func")) { type = "func"; - } else if (spec.Consume("numbertype") || spec.Consume("numerictype")) { + } else if (ConsumeCompoundAttrType(&spec, &type_string)) { type = "type"; AttrValue* allowed = attr->mutable_allowed_values(); - for (DataType dt : NumberTypes()) { - allowed->mutable_list()->add_type(dt); - } - } else if (spec.Consume("quantizedtype")) { - type = "type"; - AttrValue* allowed = attr->mutable_allowed_values(); - for (DataType dt : QuantizedTypes()) { - allowed->mutable_list()->add_type(dt); - } - } else if (spec.Consume("realnumbertype") || - spec.Consume("realnumerictype")) { - type = "type"; - AttrValue* allowed = attr->mutable_allowed_values(); - for (DataType dt : RealNumberTypes()) { - allowed->mutable_list()->add_type(dt); - } + VERIFY(ProcessCompoundType(type_string, allowed), + "Expected to see a compound type, saw: ", type_string); } else if (spec.Consume("{")) { // e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }" - str_util::RemoveLeadingWhitespace(&spec); AttrValue* allowed = attr->mutable_allowed_values(); + str_util::RemoveLeadingWhitespace(&spec); if (spec.starts_with("\"") || spec.starts_with("'")) { type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }" while (true) { @@ -172,8 +190,8 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, string unescaped; string error; VERIFY(str_util::CUnescape(escaped_string, &unescaped, &error), - "Trouble unescaping \"", escaped_string, "\", got error: ", - error); + "Trouble unescaping \"", escaped_string, + "\", got error: ", error); allowed->mutable_list()->add_s(unescaped); if (spec.Consume(",")) { str_util::RemoveLeadingWhitespace(&spec); @@ -184,16 +202,19 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, break; } } - } else { // "{ int32, float, bool }" + } else { // "{ bool, numbertype, string }" type = "type"; while (true) { - StringPiece type_string; VERIFY(ConsumeAttrType(&spec, &type_string), "Trouble parsing type string at '", spec, "'"); - DataType dt; - VERIFY(DataTypeFromString(type_string, &dt), - "Unrecognized type string '", type_string, "'"); - allowed->mutable_list()->add_type(dt); + if (ProcessCompoundType(type_string, allowed)) { + // Processed a compound type. + } else { + DataType dt; + VERIFY(DataTypeFromString(type_string, &dt), + "Unrecognized type string '", type_string, "'"); + allowed->mutable_list()->add_type(dt); + } if (spec.Consume(",")) { str_util::RemoveLeadingWhitespace(&spec); if (spec.Consume("}")) break; // Allow ending with ", }". @@ -204,7 +225,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def, } } } - } else { + } else { // if spec.Consume("{") VERIFY(false, "Trouble parsing type string at '", spec, "'"); } str_util::RemoveLeadingWhitespace(&spec); diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h index 0d492208d47..0c91d271b74 100644 --- a/tensorflow/core/framework/op_def_builder.h +++ b/tensorflow/core/framework/op_def_builder.h @@ -57,8 +57,10 @@ class OpDefBuilder { // (by convention only using capital letters for attrs that can be inferred) // can be: // "string", "int", "float", "bool", "type", "shape", or "tensor" - // "numbertype", "realnumbertype", "quantizedtype", "{int32,int64}" + // "numbertype", "realnumbertype", "quantizedtype" // (meaning "type" with a restriction on valid values) + // "{int32,int64}" or {realnumbertype,quantizedtype,string}" + // (meaning "type" with a restriction containing unions of value types) // "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}" // (meaning "string" with a restriction on valid values) // "list(string)", ..., "list(tensor)", "list(numbertype)", ... diff --git a/tensorflow/core/framework/op_def_builder_test.cc b/tensorflow/core/framework/op_def_builder_test.cc index d545db5e091..efedb221e70 100644 --- a/tensorflow/core/framework/op_def_builder_test.cc +++ b/tensorflow/core/framework/op_def_builder_test.cc @@ -125,13 +125,27 @@ TEST_F(OpDefBuilderTest, AttrWithRestrictions) { "[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, " "DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, " "DT_QINT32] } } }"); + ExpectSuccess( + b().Attr("a:{numbertype, variant}"), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, " + "DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, " + "DT_QINT32, DT_VARIANT] } } }"); ExpectSuccess(b().Attr("a:realnumbertype"), "attr: { name: 'a' type: 'type' allowed_values { list { type: " "[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, " "DT_INT16, DT_UINT16, DT_INT8] } } }"); + ExpectSuccess(b().Attr("a:{realnumbertype, variant , string, }"), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, " + "DT_INT16, DT_UINT16, DT_INT8, DT_VARIANT, DT_STRING] } } }"); ExpectSuccess(b().Attr("a:quantizedtype"), "attr: { name: 'a' type: 'type' allowed_values { list { type: " "[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16]} } }"); + ExpectSuccess(b().Attr("a:{quantizedtype ,string}"), + "attr: { name: 'a' type: 'type' allowed_values { list { type: " + "[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16, " + "DT_STRING]} } }"); ExpectSuccess(b().Attr("a:{string,int32}"), "attr: { name: 'a' type: 'type' allowed_values { list { type: " "[DT_STRING, DT_INT32] } } }"); @@ -202,6 +216,11 @@ TEST_F(OpDefBuilderTest, AttrListOfRestricted) { "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: " "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, " "DT_UINT16, DT_INT8, DT_HALF] } } }"); + ExpectSuccess( + b().Attr("a:list({realnumbertype, variant})"), + "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: " + "[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, " + "DT_UINT16, DT_INT8, DT_HALF, DT_VARIANT] } } }"); ExpectSuccess( b().Attr("a:list(quantizedtype)"), "attr: { name: 'a' type: 'list(type)' allowed_values { list { type: " diff --git a/tensorflow/core/framework/variant_op_registry.cc b/tensorflow/core/framework/variant_op_registry.cc index 9cc7530459e..22a0b4ca01f 100644 --- a/tensorflow/core/framework/variant_op_registry.cc +++ b/tensorflow/core/framework/variant_op_registry.cc @@ -24,6 +24,12 @@ limitations under the License. namespace tensorflow { +std::unordered_set* UnaryVariantOpRegistry::PersistentStringStorage() { + static std::unordered_set* string_storage = + new std::unordered_set(); + return string_storage; +} + // static UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() { static UnaryVariantOpRegistry* global_unary_variant_op_registry = @@ -32,7 +38,7 @@ UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() { } UnaryVariantOpRegistry::VariantShapeFn* UnaryVariantOpRegistry::GetShapeFn( - const string& type_name) { + StringPiece type_name) { auto found = shape_fns.find(type_name); if (found == shape_fns.end()) return nullptr; return &found->second; @@ -45,7 +51,8 @@ void UnaryVariantOpRegistry::RegisterShapeFn(const string& type_name, CHECK_EQ(existing, nullptr) << "Unary VariantShapeFn for type_name: " << type_name << " already registered"; - shape_fns.insert(std::pair(type_name, shape_fn)); + shape_fns.insert(std::pair( + GetPersistentStringPiece(type_name), shape_fn)); } Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) { @@ -65,8 +72,29 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) { return (*shape_fn)(v, shape); } +// Add some basic registrations for use by others, e.g., for testing. +namespace { +template +Status ScalarShape(const T&, TensorShape* shape) { + *shape = TensorShape({}); + return Status::OK(); +} +} // namespace + +#define REGISTER_VARIANT_SHAPE_TYPE(T) \ + REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, TF_STR(T), ScalarShape); + +// No encode/shape registered for std::complex<> and Eigen::half +// objects yet. +REGISTER_VARIANT_SHAPE_TYPE(int); +REGISTER_VARIANT_SHAPE_TYPE(float); +REGISTER_VARIANT_SHAPE_TYPE(bool); +REGISTER_VARIANT_SHAPE_TYPE(double); + +#undef REGISTER_VARIANT_SHAPE_TYPE + UnaryVariantOpRegistry::VariantDecodeFn* UnaryVariantOpRegistry::GetDecodeFn( - const string& type_name) { + StringPiece type_name) { auto found = decode_fns.find(type_name); if (found == decode_fns.end()) return nullptr; return &found->second; @@ -79,7 +107,8 @@ void UnaryVariantOpRegistry::RegisterDecodeFn( CHECK_EQ(existing, nullptr) << "Unary VariantDecodeFn for type_name: " << type_name << " already registered"; - decode_fns.insert(std::pair(type_name, decode_fn)); + decode_fns.insert(std::pair( + GetPersistentStringPiece(type_name), decode_fn)); } bool DecodeUnaryVariant(Variant* variant) { @@ -103,13 +132,6 @@ bool DecodeUnaryVariant(Variant* variant) { // Add some basic registrations for use by others, e.g., for testing. -namespace { -string MaybeRemoveTFPrefix(const StringPiece& str) { - return str.starts_with("::tensorflow::") ? str.substr(14).ToString() - : str.ToString(); -} -} // namespace - #define REGISTER_VARIANT_DECODE_TYPE(T) \ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, TF_STR(T)); @@ -122,30 +144,31 @@ REGISTER_VARIANT_DECODE_TYPE(double); #undef REGISTER_VARIANT_DECODE_TYPE -// Special casing ZerosLikeFn per device. -UnaryVariantOpRegistry::VariantZerosLikeFn* -UnaryVariantOpRegistry::GetZerosLikeFn(const string& device, - const string& type_name) { - auto found = zeros_like_fns.find(std::make_pair(device, type_name)); - if (found == zeros_like_fns.end()) return nullptr; +// Special casing UnaryOpFn per op and per device. +UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn( + VariantUnaryOp op, StringPiece device, StringPiece type_name) { + auto found = unary_op_fns.find(std::make_tuple(op, device, type_name)); + if (found == unary_op_fns.end()) return nullptr; return &found->second; } -void UnaryVariantOpRegistry::RegisterZerosLikeFn( - const string& device, const string& type_name, - const VariantZerosLikeFn& zeros_like_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantZerosLike"; - VariantZerosLikeFn* existing = GetZerosLikeFn(device, type_name); +void UnaryVariantOpRegistry::RegisterUnaryOpFn( + VariantUnaryOp op, const string& device, const string& type_name, + const VariantUnaryOpFn& unary_op_fn) { + CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantUnaryOp"; + VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_name); CHECK_EQ(existing, nullptr) - << "Unary VariantZerosLikeFn for type_name: " << type_name + << "Unary VariantUnaryOpFn for type_name: " << type_name << " already registered for device type: " << device; - zeros_like_fns.insert( - std::pair, VariantZerosLikeFn>( - std::make_pair(device, type_name), zeros_like_fn)); + unary_op_fns.insert( + std::pair, + VariantUnaryOpFn>( + std::make_tuple(op, GetPersistentStringPiece(device), + GetPersistentStringPiece(type_name)), + unary_op_fn)); } namespace { - template Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t, T* t_out) { @@ -154,9 +177,10 @@ Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t, } } // namespace -#define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \ - REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION( \ - DEVICE_CPU, T, TF_STR(T), ZerosLikeVariantPrimitiveType); +#define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \ + REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \ + DEVICE_CPU, T, TF_STR(T), \ + ZerosLikeVariantPrimitiveType); // No zeros_like registered for std::complex<> or Eigen::half objects yet. REGISTER_VARIANT_ZEROS_LIKE_TYPE(int); @@ -166,4 +190,51 @@ REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool); #undef REGISTER_VARIANT_ZEROS_LIKE_TYPE +// Special casing BinaryOpFn per op and per device. +UnaryVariantOpRegistry::VariantBinaryOpFn* +UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device, + StringPiece type_name) { + auto found = binary_op_fns.find(std::make_tuple(op, device, type_name)); + if (found == binary_op_fns.end()) return nullptr; + return &found->second; +} + +void UnaryVariantOpRegistry::RegisterBinaryOpFn( + VariantBinaryOp op, const string& device, const string& type_name, + const VariantBinaryOpFn& add_fn) { + CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantBinaryOp"; + VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_name); + CHECK_EQ(existing, nullptr) + << "Unary VariantBinaryOpFn for type_name: " << type_name + << " already registered for device type: " << device; + binary_op_fns.insert( + std::pair, + VariantBinaryOpFn>( + std::make_tuple(op, GetPersistentStringPiece(device), + GetPersistentStringPiece(type_name)), + add_fn)); +} + +namespace { +template +Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b, + T* out) { + *out = a + b; + return Status::OK(); +} +} // namespace + +#define REGISTER_VARIANT_ADD_TYPE(T) \ + REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \ + T, TF_STR(T), \ + AddVariantPrimitiveType); + +// No add registered for std::complex<> or Eigen::half objects yet. +REGISTER_VARIANT_ADD_TYPE(int); +REGISTER_VARIANT_ADD_TYPE(float); +REGISTER_VARIANT_ADD_TYPE(double); +REGISTER_VARIANT_ADD_TYPE(bool); + +#undef REGISTER_VARIANT_ADD_TYPE + } // namespace tensorflow diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h index 37e54f82c0f..2e9f2243ad1 100644 --- a/tensorflow/core/framework/variant_op_registry.h +++ b/tensorflow/core/framework/variant_op_registry.h @@ -17,11 +17,13 @@ limitations under the License. #define TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_ #include +#include #include #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/lib/hash/hash.h" namespace tensorflow { @@ -30,49 +32,110 @@ class OpKernelContext; // for different variant types. To be used by ShapeOp, RankOp, and // SizeOp, decoding, etc. +enum VariantUnaryOp { + INVALID_VARIANT_UNARY_OP = 0, + ZEROS_LIKE_VARIANT_UNARY_OP = 1, +}; + +enum VariantBinaryOp { + INVALID_VARIANT_BINARY_OP = 0, + ADD_VARIANT_BINARY_OP = 1, +}; + class UnaryVariantOpRegistry { public: typedef std::function VariantShapeFn; typedef std::function VariantDecodeFn; typedef std::function - VariantZerosLikeFn; + VariantUnaryOpFn; + typedef std::function + VariantBinaryOpFn; // Add a shape lookup function to the registry. void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn); // Returns nullptr if no shape function was found for the given TypeName. - VariantShapeFn* GetShapeFn(const string& type_name); + VariantShapeFn* GetShapeFn(StringPiece type_name); // Add a decode function to the registry. void RegisterDecodeFn(const string& type_name, const VariantDecodeFn& decode_fn); // Returns nullptr if no decode function was found for the given TypeName. - VariantDecodeFn* GetDecodeFn(const string& type_name); + VariantDecodeFn* GetDecodeFn(StringPiece type_name); - // Add a zeros-like function to the registry. - void RegisterZerosLikeFn(const string& device, const string& type_name, - const VariantZerosLikeFn& zeros_like_fn); + // Add a unary op function to the registry. + void RegisterUnaryOpFn(VariantUnaryOp op, const string& device, + const string& type_name, + const VariantUnaryOpFn& unary_op_fn); - // Returns nullptr if no zeros-like function was found for the given - // device and TypeName. - VariantZerosLikeFn* GetZerosLikeFn(const string& device, - const string& type_name); + // Returns nullptr if no unary op function was found for the given + // op, device, and TypeName. + VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device, + StringPiece type_name); + // Add a binary op function to the registry. + void RegisterBinaryOpFn(VariantBinaryOp op, const string& device, + const string& type_name, + const VariantBinaryOpFn& add_fn); + + // Returns nullptr if no binary op function was found for the given + // op, device and TypeName. + VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device, + StringPiece type_name); + + // Get a pointer to a global UnaryVariantOpRegistry object static UnaryVariantOpRegistry* Global(); + // Get a pointer to a global persistent string storage object. + // ISO/IEC C++ working draft N4296 clarifies that insertion into an + // std::unordered_set does not invalidate memory locations of + // *values* inside the set (though it may invalidate existing + // iterators). In other words, one may safely point a StringPiece to + // a value in the set without that StringPiece being invalidated by + // future insertions. + static std::unordered_set* PersistentStringStorage(); + private: - std::unordered_map shape_fns; - std::unordered_map decode_fns; - // Map std::pair to function. - struct PairHash { - template - std::size_t operator()(const std::pair& x) const { - return std::hash()(x.first) ^ std::hash()(x.second); + std::unordered_map + shape_fns; + std::unordered_map + decode_fns; + + // Map std::tuple to function. + struct TupleHash { + template + std::size_t operator()( + const std::tuple& x) const { + // The hash of an enum is just its value as a std::size_t. + std::size_t ret = static_cast(std::get<0>(x)); + StringPiece::Hasher sp_hasher; + ret = Hash64Combine(ret, sp_hasher(std::get<1>(x))); + ret = Hash64Combine(ret, sp_hasher(std::get<2>(x))); + return ret; } }; - std::unordered_map, VariantZerosLikeFn, PairHash> - zeros_like_fns; + std::unordered_map, + VariantUnaryOpFn, TupleHash> + unary_op_fns; + std::unordered_map, + VariantBinaryOpFn, TupleHash> + binary_op_fns; + + // Find or insert a string into a persistent string storage + // container; return the StringPiece pointing to the permanent + // string location. + static StringPiece GetPersistentStringPiece(const string& str) { + const auto string_storage = PersistentStringStorage(); + auto found = string_storage->find(str); + if (found == string_storage->end()) { + auto inserted = string_storage->insert(str); + return StringPiece(*inserted.first); + } else { + return StringPiece(*found); + } + } }; // Gets a TensorShape from a Tensor containing a scalar Variant. @@ -94,26 +157,57 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape); // bool DecodeUnaryVariant(Variant* variant); -// Sets *z_out = zeros_like(v). The variant v must have a registered -// ZerosLike function for the given Device. Returns an Internal error -// if v does not have a registered zeros_like function for this device, or if -// ZerosLike fails. +// Sets *v_out = unary_op(v). The variant v must have a registered +// UnaryOp function for the given Device. Returns an Internal error +// if v does not have a registered unary_op function for this device, or if +// UnaryOp fails. // // REQUIRES: // v_out is not null. // template -Status CreateZerosLikeVariant(OpKernelContext* ctx, const Variant& v, - Variant* v_out) { +Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v, + Variant* v_out) { const string& device = DeviceName::value; - UnaryVariantOpRegistry::VariantZerosLikeFn* zeros_like_fn = - UnaryVariantOpRegistry::Global()->GetZerosLikeFn(device, v.TypeName()); - if (zeros_like_fn == nullptr) { + UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn = + UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeName()); + if (unary_op_fn == nullptr) { return errors::Internal( - "No unary variant zeros_like function found for Variant type_name: ", - v.TypeName(), " for device type: ", device); + "No unary variant unary_op function found for unary variant op enum: ", + op, " Variant type_name: ", v.TypeName(), " for device type: ", device); } - return (*zeros_like_fn)(ctx, v, v_out); + return (*unary_op_fn)(ctx, v, v_out); +} + +// Sets *out = binary_op(a, b). The variants a and b must be the same type +// and have a registered binary_op function for the given Device. Returns an +// Internal error if a and b are not the same type_name or if +// if a does not have a registered op function for this device, or if +// BinaryOp fails. +// +// REQUIRES: +// out is not null. +// +template +Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op, + const Variant& a, const Variant& b, Variant* out) { + if (a.TypeName() != b.TypeName()) { + return errors::Internal( + "BianryOpVariants: Variants a and b have different " + "type names: '", + a.TypeName(), "' vs. '", b.TypeName(), "'"); + } + const string& device = DeviceName::value; + UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn = + UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeName()); + if (binary_op_fn == nullptr) { + return errors::Internal( + "No unary variant binary_op function found for binary variant op " + "enum: ", + op, " Variant type_name: '", a.TypeName(), + "' for device type: ", device); + } + return (*binary_op_fn)(ctx, a, b, out); } namespace variant_op_registry_fn_registration { @@ -165,30 +259,65 @@ class UnaryVariantDecodeRegistration { }; template -class UnaryVariantZerosLikeRegistration { +class UnaryVariantUnaryOpRegistration { typedef std::function - LocalVariantZerosLikeFn; + LocalVariantUnaryOpFn; public: - UnaryVariantZerosLikeRegistration( - const string& device, const string& type_name, - const LocalVariantZerosLikeFn& zeros_like_fn) { - auto wrapped_fn = [type_name, zeros_like_fn](OpKernelContext* ctx, - const Variant& v, - Variant* v_out) -> Status { + UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device, + const string& type_name, + const LocalVariantUnaryOpFn& unary_op_fn) { + auto wrapped_fn = [type_name, unary_op_fn](OpKernelContext* ctx, + const Variant& v, + Variant* v_out) -> Status { CHECK_NOTNULL(v_out); *v_out = T(); if (v.get() == nullptr) { return errors::Internal( - "VariantZerosLikeFn: Could not access object, type_name: ", + "VariantUnaryOpFn: Could not access object, type_name: ", type_name); } const T& t = *v.get(); T* t_out = v_out->get(); - return zeros_like_fn(ctx, t, t_out); + return unary_op_fn(ctx, t, t_out); }; - UnaryVariantOpRegistry::Global()->RegisterZerosLikeFn(device, type_name, - wrapped_fn); + UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(op, device, type_name, + wrapped_fn); + } +}; + +template +class UnaryVariantBinaryOpRegistration { + typedef std::function + LocalVariantBinaryOpFn; + + public: + UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device, + const string& type_name, + const LocalVariantBinaryOpFn& binary_op_fn) { + auto wrapped_fn = [type_name, binary_op_fn]( + OpKernelContext* ctx, const Variant& a, + const Variant& b, Variant* out) -> Status { + CHECK_NOTNULL(out); + *out = T(); + if (a.get() == nullptr) { + return errors::Internal( + "VariantBinaryOpFn: Could not access object 'a', type_name: ", + type_name); + } + if (b.get() == nullptr) { + return errors::Internal( + "VariantBinaryOpFn: Could not access object 'b', type_name: ", + type_name); + } + const T& t_a = *a.get(); + const T& t_b = *b.get(); + T* t_out = out->get(); + return binary_op_fn(ctx, t_a, t_b, t_out); + }; + UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(op, device, type_name, + wrapped_fn); } }; @@ -223,25 +352,47 @@ class UnaryVariantZerosLikeRegistration { T> \ register_unary_variant_op_decoder_fn_##ctr(type_name) -// Register a unary zeros_like variant function with the signature: -// Status ZerosLikeFn(OpKernelContext* ctx, const T& t, T* t_out); -// to Variants having TypeName type_name, for device string device. -#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION(device, T, type_name, \ - zeros_like_function) \ - REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ_HELPER( \ - __COUNTER__, device, T, type_name, zeros_like_function) +// Register a unary unary_op variant function with the signature: +// Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out); +// to Variants having TypeName type_name, for device string device, +// for UnaryVariantOp enum op. +#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, type_name, \ + unary_op_function) \ + REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, op, device, T, type_name, unary_op_function) -#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ_HELPER( \ - ctr, device, T, type_name, zeros_like_function) \ - REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ(ctr, device, T, type_name, \ - zeros_like_function) +#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ + ctr, op, device, T, type_name, unary_op_function) \ + REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, type_name, \ + unary_op_function) -#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ( \ - ctr, device, T, type_name, zeros_like_function) \ - static variant_op_registry_fn_registration:: \ - UnaryVariantZerosLikeRegistration \ - register_unary_variant_op_decoder_fn_##ctr(device, type_name, \ - zeros_like_function) +#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \ + ctr, op, device, T, type_name, unary_op_function) \ + static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \ + T> \ + register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \ + unary_op_function) + +// Register a binary_op variant function with the signature: +// Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out); +// to Variants having TypeName type_name, for device string device, +// for BinaryVariantOp enum OP. +#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, type_name, \ + binary_op_function) \ + REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, op, device, T, type_name, binary_op_function) + +#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ + ctr, op, device, T, type_name, binary_op_function) \ + REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ + ctr, op, device, T, type_name, binary_op_function) + +#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ + ctr, op, device, T, type_name, binary_op_function) \ + static variant_op_registry_fn_registration:: \ + UnaryVariantBinaryOpRegistration \ + register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \ + binary_op_function) } // end namespace tensorflow diff --git a/tensorflow/core/framework/variant_op_registry_test.cc b/tensorflow/core/framework/variant_op_registry_test.cc index 4e79180217a..8102f1e18be 100644 --- a/tensorflow/core/framework/variant_op_registry_test.cc +++ b/tensorflow/core/framework/variant_op_registry_test.cc @@ -50,7 +50,7 @@ struct VariantValue { if (v.early_exit) { return errors::InvalidArgument("early exit zeros_like!"); } - v_out->zeros_like_set = 1; // CPU + v_out->value = 1; // CPU return Status::OK(); } static Status GPUZerosLikeFn(OpKernelContext* ctx, const VariantValue& v, @@ -58,11 +58,27 @@ struct VariantValue { if (v.early_exit) { return errors::InvalidArgument("early exit zeros_like!"); } - v_out->zeros_like_set = 2; // GPU + v_out->value = 2; // GPU + return Status::OK(); + } + static Status CPUAddFn(OpKernelContext* ctx, const VariantValue& a, + const VariantValue& b, VariantValue* out) { + if (a.early_exit) { + return errors::InvalidArgument("early exit add!"); + } + out->value = a.value + b.value; // CPU + return Status::OK(); + } + static Status GPUAddFn(OpKernelContext* ctx, const VariantValue& a, + const VariantValue& b, VariantValue* out) { + if (a.early_exit) { + return errors::InvalidArgument("early exit add!"); + } + out->value = -(a.value + b.value); // GPU return Status::OK(); } bool early_exit; - int zeros_like_set; + int value; }; REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue", @@ -70,13 +86,23 @@ REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue", REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue"); -REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION(DEVICE_CPU, VariantValue, - "TEST VariantValue", - VariantValue::CPUZerosLikeFn); +REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, + DEVICE_CPU, VariantValue, + "TEST VariantValue", + VariantValue::CPUZerosLikeFn); -REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION(DEVICE_GPU, VariantValue, - "TEST VariantValue", - VariantValue::GPUZerosLikeFn); +REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, + DEVICE_GPU, VariantValue, + "TEST VariantValue", + VariantValue::GPUZerosLikeFn); + +REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, + VariantValue, "TEST VariantValue", + VariantValue::CPUAddFn); + +REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU, + VariantValue, "TEST VariantValue", + VariantValue::GPUAddFn); } // namespace @@ -104,8 +130,9 @@ TEST(VariantOpShapeRegistryTest, TestBasic) { TEST(VariantOpShapeRegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::VariantShapeFn f; - registry.RegisterShapeFn("fjfjfj", f); - EXPECT_DEATH(registry.RegisterShapeFn("fjfjfj", f), + string kTypeName = "fjfjfj"; + registry.RegisterShapeFn(kTypeName, f); + EXPECT_DEATH(registry.RegisterShapeFn(kTypeName, f), "fjfjfj already registered"); } @@ -133,71 +160,146 @@ TEST(VariantOpDecodeRegistryTest, TestBasic) { TEST(VariantOpDecodeRegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::VariantDecodeFn f; - registry.RegisterDecodeFn("fjfjfj", f); - EXPECT_DEATH(registry.RegisterDecodeFn("fjfjfj", f), + string kTypeName = "fjfjfj"; + registry.RegisterDecodeFn(kTypeName, f); + EXPECT_DEATH(registry.RegisterDecodeFn(kTypeName, f), "fjfjfj already registered"); } TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) { - EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetZerosLikeFn( - DEVICE_CPU, "YOU SHALL NOT PASS"), + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn( + ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"), nullptr); - VariantValue vv_early_exit{true /* early_exit */, 0 /* zeros_like_set */}; + VariantValue vv_early_exit{true /* early_exit */, 0 /* value */}; Variant v = vv_early_exit; Variant v_out = VariantValue(); OpKernelContext* null_context_pointer = nullptr; - Status s0 = - CreateZerosLikeVariant(null_context_pointer, v, &v_out); + Status s0 = UnaryOpVariant(null_context_pointer, + ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out); EXPECT_FALSE(s0.ok()); EXPECT_TRUE( StringPiece(s0.error_message()).contains("early exit zeros_like")); - VariantValue vv_ok{false /* early_exit */, 0 /* zeros_like_set */}; + VariantValue vv_ok{false /* early_exit */, 0 /* value */}; v = vv_ok; - TF_EXPECT_OK( - CreateZerosLikeVariant(null_context_pointer, v, &v_out)); + TF_EXPECT_OK(UnaryOpVariant( + null_context_pointer, ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out)); VariantValue* vv_out = CHECK_NOTNULL(v_out.get()); - EXPECT_EQ(vv_out->zeros_like_set, 1); // CPU + EXPECT_EQ(vv_out->value, 1); // CPU } #if GOOGLE_CUDA -TEST(VariantOpZerosLikeRegistryTest, TestBasicGPU) { - EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetZerosLikeFn( - DEVICE_GPU, "YOU SHALL NOT PASS"), +TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) { + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn( + ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"), nullptr); - VariantValue vv_early_exit{true /* early_exit */, 0 /* zeros_like_set */}; + VariantValue vv_early_exit{true /* early_exit */, 0 /* value */}; Variant v = vv_early_exit; Variant v_out = VariantValue(); OpKernelContext* null_context_pointer = nullptr; - Status s0 = - CreateZerosLikeVariant(null_context_pointer, v, &v_out); + Status s0 = UnaryOpVariant(null_context_pointer, + ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out); EXPECT_FALSE(s0.ok()); EXPECT_TRUE( StringPiece(s0.error_message()).contains("early exit zeros_like")); - VariantValue vv_ok{false /* early_exit */, 0 /* zeros_like_set */}; + VariantValue vv_ok{false /* early_exit */, 0 /* value */}; v = vv_ok; - TF_EXPECT_OK( - CreateZerosLikeVariant(null_context_pointer, v, &v_out)); + TF_EXPECT_OK(UnaryOpVariant( + null_context_pointer, ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out)); VariantValue* vv_out = CHECK_NOTNULL(v_out.get()); - EXPECT_EQ(vv_out->zeros_like_set, 2); // GPU + EXPECT_EQ(vv_out->value, 2); // GPU } #endif // GOOGLE_CUDA -TEST(VariantOpZerosLikeRegistryTest, TestDuplicate) { +TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; - UnaryVariantOpRegistry::VariantZerosLikeFn f; + UnaryVariantOpRegistry::VariantUnaryOpFn f; + string kTypeName = "fjfjfj"; - registry.RegisterZerosLikeFn(DEVICE_CPU, "fjfjfj", f); - EXPECT_DEATH(registry.RegisterZerosLikeFn(DEVICE_CPU, "fjfjfj", f), + registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, kTypeName, + f); + EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, + DEVICE_CPU, kTypeName, f), "fjfjfj already registered"); - registry.RegisterZerosLikeFn(DEVICE_GPU, "fjfjfj", f); - EXPECT_DEATH(registry.RegisterZerosLikeFn(DEVICE_GPU, "fjfjfj", f), + registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, kTypeName, + f); + EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, + DEVICE_GPU, kTypeName, f), + "fjfjfj already registered"); +} + +TEST(VariantOpAddRegistryTest, TestBasicCPU) { + return; + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn( + ADD_VARIANT_BINARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"), + nullptr); + + VariantValue vv_early_exit{true /* early_exit */, 3 /* value */}; + VariantValue vv_other{true /* early_exit */, 4 /* value */}; + Variant v_a = vv_early_exit; + Variant v_b = vv_other; + Variant v_out = VariantValue(); + + OpKernelContext* null_context_pointer = nullptr; + Status s0 = BinaryOpVariants( + null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out); + EXPECT_FALSE(s0.ok()); + EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit add")); + + VariantValue vv_ok{false /* early_exit */, 3 /* value */}; + v_a = vv_ok; + TF_EXPECT_OK(BinaryOpVariants( + null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out)); + VariantValue* vv_out = CHECK_NOTNULL(v_out.get()); + EXPECT_EQ(vv_out->value, 7); // CPU +} + +#if GOOGLE_CUDA +TEST(VariantOpAddRegistryTest, TestBasicGPU) { + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn( + ADD_VARIANT_BINARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"), + nullptr); + + VariantValue vv_early_exit{true /* early_exit */, 3 /* value */}; + VariantValue vv_other{true /* early_exit */, 4 /* value */}; + Variant v_a = vv_early_exit; + Variant v_b = vv_other; + Variant v_out = VariantValue(); + + OpKernelContext* null_context_pointer = nullptr; + Status s0 = BinaryOpVariants( + null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out); + EXPECT_FALSE(s0.ok()); + EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit add")); + + VariantValue vv_ok{false /* early_exit */, 3 /* value */}; + v_a = vv_ok; + TF_EXPECT_OK(BinaryOpVariants( + null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out)); + VariantValue* vv_out = CHECK_NOTNULL(v_out.get()); + EXPECT_EQ(vv_out->value, -7); // GPU +} +#endif // GOOGLE_CUDA + +TEST(VariantOpAddRegistryTest, TestDuplicate) { + UnaryVariantOpRegistry registry; + UnaryVariantOpRegistry::VariantBinaryOpFn f; + string kTypeName = "fjfjfj"; + + registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeName, f); + EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, + kTypeName, f), + "fjfjfj already registered"); + + registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeName, f); + EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, + kTypeName, f), "fjfjfj already registered"); } diff --git a/tensorflow/core/graph/tensor_id.cc b/tensorflow/core/graph/tensor_id.cc index 985654d027c..089ea5e527a 100644 --- a/tensorflow/core/graph/tensor_id.cc +++ b/tensorflow/core/graph/tensor_id.cc @@ -34,8 +34,8 @@ TensorId ParseTensorName(StringPiece name) { // whole name string forms the first part of the tensor name. const char* base = name.data(); const char* p = base + name.size() - 1; - int index = 0; - int mul = 1; + unsigned int index = 0; + unsigned int mul = 1; while (p > base && (*p >= '0' && *p <= '9')) { index += ((*p - '0') * mul); mul *= 10; diff --git a/tensorflow/core/kernels/aggregate_ops.cc b/tensorflow/core/kernels/aggregate_ops.cc index 0aa65729de2..0099984f69c 100644 --- a/tensorflow/core/kernels/aggregate_ops.cc +++ b/tensorflow/core/kernels/aggregate_ops.cc @@ -24,6 +24,9 @@ limitations under the License. #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/platform/logging.h" @@ -33,7 +36,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL template class AddNOp : public OpKernel { @@ -150,6 +153,65 @@ class AddNOp : public OpKernel { } }; +template +class AddNOp : public OpKernel { + public: + explicit AddNOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* ctx) override { + if (!ctx->ValidateInputsAreSameShape(this)) return; + + const Tensor& input0 = ctx->input(0); + const int num = ctx->num_inputs(); + + if (num == 1) { + ctx->set_output(0, input0); + return; + } + + for (int i = 0; i < num; ++i) { + // Step 1: ensure unary variants. + OP_REQUIRES( + ctx, ctx->input(i).dims() == 0, + errors::InvalidArgument( + "AddN of non-scalar Tensor with dtype=DT_VARIANT is not " + "supported; inputs[", + i, " has shape: ", ctx->input(i).shape().DebugString(), ".")); + } + + TensorShape common_shape; + OP_REQUIRES_OK(ctx, GetUnaryVariantShape(ctx->input(0), &common_shape)); + // Step 2: access all variants and ensure shapes match. + for (int i = 1; i < num; ++i) { + TensorShape check_shape; + OP_REQUIRES_OK(ctx, GetUnaryVariantShape(ctx->input(i), &check_shape)); + OP_REQUIRES(ctx, common_shape == check_shape, + errors::InvalidArgument( + "AddN of Variants of differing shapes; inputs[0] shape: ", + common_shape.DebugString(), ", inputs[", i, + "] shape: ", check_shape.DebugString())); + } + + // Step 3: attempt to add using + // BinaryOpVariants(ADD_VARIANT_BINARY_OP, ...) + // For the output create a default-constructed variant object. + // TODO(ebrevdo): Perform summation in a tree-structure. + Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({})); + Variant* v_out = &(out.scalar()()); + OP_REQUIRES_OK( + ctx, BinaryOpVariants( + ctx, ADD_VARIANT_BINARY_OP, ctx->input(0).scalar()(), + ctx->input(1).scalar()(), v_out)); + for (int i = 2; i < num; ++i) { + const Variant tmp = std::move(*v_out); + const Variant& inp = ctx->input(i).scalar()(); + OP_REQUIRES_OK(ctx, BinaryOpVariants(ctx, ADD_VARIANT_BINARY_OP, + inp, tmp, v_out)); + } + ctx->set_output(0, out); + } +}; + #define REGISTER_ADDN(type, dev) \ REGISTER_KERNEL_BUILDER( \ Name("AddN").Device(DEVICE_##dev).TypeConstraint("T"), \ @@ -158,6 +220,8 @@ class AddNOp : public OpKernel { #define REGISTER_ADDN_CPU(type) REGISTER_ADDN(type, CPU) TF_CALL_NUMBER_TYPES(REGISTER_ADDN_CPU); +REGISTER_ADDN_CPU(Variant); + #undef REGISTER_ADDN_CPU #if GOOGLE_CUDA @@ -176,6 +240,16 @@ REGISTER_KERNEL_BUILDER(Name("AddN") .HostMemory("inputs") .HostMemory("sum"), AddNOp); + +// TODO(ebrevdo): Once rendezvous has been properly set up for +// Variants, we'll no longer need a HostMemory attribute for this case. +REGISTER_KERNEL_BUILDER(Name("AddN") + .Device(DEVICE_GPU) + .TypeConstraint("T") + .HostMemory("inputs") + .HostMemory("sum"), + AddNOp); + #endif // GOOGLE_CUDA #ifdef TENSORFLOW_USE_SYCL @@ -191,7 +265,7 @@ REGISTER_KERNEL_BUILDER(Name("AddN") .HostMemory("inputs") .HostMemory("sum"), AddNOp); -#endif // TENSORFLOW_USE_SYCL +#endif // TENSORFLOW_USE_SYCL #undef REGISTER_ADDN diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index cdc11452827..6c9c48d41bc 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -279,13 +279,15 @@ class ZerosLikeOp : public OpKernel { const Tensor& input = ctx->input(0); const Device& d = ctx->eigen_device(); if (std::is_same::value) { - OP_REQUIRES(ctx, input.dims() == 0, - errors::InvalidArgument( - "ZerosLike of non-unary Variant not supported.")); + OP_REQUIRES( + ctx, input.dims() == 0, + errors::InvalidArgument("ZerosLike non-scalar Tensor with " + "dtype=DT_VARIANT is not supported.")); const Variant& v = input.scalar()(); Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({})); Variant* out_v = &(out.scalar()()); - OP_REQUIRES_OK(ctx, CreateZerosLikeVariant(ctx, v, out_v)); + OP_REQUIRES_OK(ctx, UnaryOpVariant( + ctx, ZEROS_LIKE_VARIANT_UNARY_OP, v, out_v)); ctx->set_output(0, out); } else { Tensor* out = nullptr; diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index a1dfd4c3d31..629e29958f6 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -292,7 +292,8 @@ class RemoteCallOp : public AsyncOpKernel { OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done); AttrValueMap attr_values = func_->attr(); AttrValue v; - v.set_s(target->scalar()()); + const string& target_device = target->scalar()(); + v.set_s(target_device); AddAttr("_target", v, &attr_values); FunctionLibraryRuntime* lib = ctx->function_library(); @@ -310,6 +311,11 @@ class RemoteCallOp : public AsyncOpKernel { FunctionLibraryRuntime::Options opts; opts.step_id = ctx->step_id(); opts.runner = ctx->runner(); + opts.source_device = lib->device()->name(); + if (opts.source_device != target_device) { + opts.remote_execution = true; + } + opts.rendezvous = ctx->rendezvous(); std::vector args; args.reserve(arguments.size()); for (const Tensor& argument : arguments) { @@ -334,10 +340,13 @@ class RemoteCallOp : public AsyncOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp); }; -REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_CPU), RemoteCallOp); -REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_GPU), RemoteCallOp); +REGISTER_KERNEL_BUILDER( + Name("RemoteCall").Device(DEVICE_CPU).HostMemory("target"), RemoteCallOp); +REGISTER_KERNEL_BUILDER( + Name("RemoteCall").Device(DEVICE_GPU).HostMemory("target"), RemoteCallOp); #if TENSORFLOW_USE_SYCL -REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_SYCL), RemoteCallOp); +REGISTER_KERNEL_BUILDER( + Name("RemoteCall").Device(DEVICE_SYCL).HostMemory("target"), RemoteCallOp); #endif // TENSORFLOW_USE_SYCL } // namespace tensorflow diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc index 8d825c13d76..60ed1263a23 100644 --- a/tensorflow/core/kernels/maxpooling_op.cc +++ b/tensorflow/core/kernels/maxpooling_op.cc @@ -920,6 +920,13 @@ class MaxPoolingGradWithArgmaxOp : public OpKernel { public: explicit MaxPoolingGradWithArgmaxOp(OpKernelConstruction* context) : OpKernel(context) { + string data_format_str; + auto status = context->GetAttr("data_format", &data_format_str); + if (status.ok()) { + OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + } + OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_)); OP_REQUIRES(context, ksize_.size() == 4, errors::InvalidArgument("Sliding window ksize field must " @@ -959,6 +966,7 @@ class MaxPoolingGradWithArgmaxOp : public OpKernel { std::vector ksize_; std::vector stride_; Padding padding_; + TensorFormat data_format_; }; template @@ -1051,17 +1059,36 @@ class MaxPoolingNoMaskOp : public OpKernel { TensorShape out_shape = ShapeFromFormat(data_format_, params.tensor_in_batch, params.out_height, params.out_width, params.depth); - if (use_dnn_ && data_format_ == FORMAT_NCHW) { + + // Assuming qint8 <--> NCHW_VECT_C (int8x4) here. + constexpr bool is_int8x4 = std::is_same::value; + OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)), + errors::InvalidArgument( + "qint8 should be used with data_format NCHW_VECT_C.")); + + // These is_int8x4 checks avoid linker errors for missing qint8 kernels. + if (!is_int8x4 && use_dnn_ && data_format_ == FORMAT_NCHW) { DnnPoolingOp::Compute( context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize_, stride_, padding_, data_format_, tensor_in, out_shape); } else { - CHECK(data_format_ == FORMAT_NHWC) - << "Non-Cudnn MaxPool only supports NHWC format"; Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); - LaunchMaxPoolingNoMask::launch(context, params, tensor_in, - output); + if (is_int8x4) { + LaunchMaxPoolingNoMask_NCHW_VECT_C::launch(context, params, + tensor_in, output); + } else if (data_format_ == FORMAT_NHWC) { + LaunchMaxPoolingNoMask::launch(context, params, tensor_in, + output); + } else { + LOG(FATAL) << "MaxPool currently only supports the following (layout, " + "type) combinations: (NHWC, non-qint8), " + "(NCHW, non-qint8) or (NCHW_VECT_C, qint8). The " + "requested combination (" + << ToString(data_format_) << ", " + << DataTypeString(DataTypeToEnum::v()) + << ") is not supported."; + } } } @@ -1346,6 +1373,26 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_MAX_POOL_KERNELS); .TypeConstraint("Targmax"), \ MaxPoolingGradGradWithArgmaxOp); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_ONLY_POOL_KERNELS); + +REGISTER_KERNEL_BUILDER( + Name("MaxPool").Device(DEVICE_GPU).TypeConstraint("T"), + MaxPoolingNoMaskOp); + +REGISTER_KERNEL_BUILDER(Name("MaxPoolV2") + .Device(DEVICE_GPU) + .HostMemory("ksize") + .HostMemory("strides") + .TypeConstraint("T"), + MaxPoolingV2Op); + +REGISTER_KERNEL_BUILDER(Name("MaxPoolV2") + .Device(DEVICE_GPU) + .HostMemory("ksize") + .HostMemory("strides") + .TypeConstraint("T") + .Label("eigen_tensor"), + MaxPoolingV2Op); + #undef REGISTER_GPU_ONLY_POOL_KERNELS #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/maxpooling_op.h b/tensorflow/core/kernels/maxpooling_op.h index 1670c1b26d8..f82e57d44c2 100644 --- a/tensorflow/core/kernels/maxpooling_op.h +++ b/tensorflow/core/kernels/maxpooling_op.h @@ -17,7 +17,9 @@ limitations under the License. #define TENSORFLOW_KERNELS_MAXPOOLING_OP_H_ // Functor definition for MaxPoolingOp, must be compilable by nvcc. +#include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/type_traits.h" #include "tensorflow/core/kernels/eigen_pooling.h" #include "tensorflow/core/platform/types.h" @@ -37,6 +39,14 @@ struct SpatialMaxPooling { } }; +template +struct SpatialMaxPooling { + void operator()(const Device& d, typename TTypes::Tensor output, + typename TTypes::ConstTensor input, int window_rows, + int window_cols, int row_stride, int col_stride, + const Eigen::PaddingType& padding) {} +}; + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc index e3a57d2f28a..26f52748045 100644 --- a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc +++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/type_traits.h" #include "tensorflow/core/kernels/maxpooling_op.h" #include "tensorflow/core/kernels/maxpooling_op_gpu.h" #include "tensorflow/core/util/cuda_kernel_helper.h" @@ -89,6 +90,42 @@ __global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data, } } +// The parameters for MaxPoolForwardNoMaskKernel_NCHW_VECT_C are the same as for +// MaxPoolForwardNCHW above, except that mask is not supported, and each +// element of the input and output contains 4 adjacent channel values for +// the same X, y coordinate. +// (so channels = outer_channels, output_size = real output size / 4). +__global__ void MaxPoolForwardNoMaskKernel_NCHW_VECT_C( + const int nthreads, const int32* bottom_data, const int height, + const int width, const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_t, const int pad_l, + int32* top_data) { + // TODO(pauldonnelly): Implement a better optimized version of this kernel. + const int32 kMinINT8X4 = 0x80808080; + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + int hstart = ph * stride_h - pad_t; + int wstart = pw * stride_w - pad_l; + int hend = min(hstart + kernel_h, height); + int wend = min(wstart + kernel_w, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + int32 maxval = kMinINT8X4; + const int32* bottom_data_n = bottom_data + n * channels * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int idx = (c * height + h) * width + w; + maxval = __vmaxs4(maxval, bottom_data_n[idx]); + } + } + top_data[index] = maxval; + } +} + template __global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data, const int height, const int width, @@ -328,6 +365,25 @@ __global__ void MaxPoolGradBackward(const int nthreads, const dtype* top_diff, namespace functor { +// Note: channels is the outer channels (dim 1) which has already been +// divided by 4. +bool MaxPoolForwardNoMask_NCHW_VECT_C::operator()( + const int32* bottom_data, const int batch, const int height, + const int width, int channels, const int pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_t, const int pad_l, + int32* top_data, const Eigen::GpuDevice& d) { + const int kThreadsPerBlock = 1024; + const int output_size = batch * channels * pooled_height * pooled_width; + MaxPoolForwardNoMaskKernel_NCHW_VECT_C<<< + (output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, + 0, d.stream()>>>(output_size, bottom_data, height, width, channels, + pooled_height, pooled_width, kernel_h, kernel_w, + stride_h, stride_w, pad_t, pad_l, top_data); + d.synchronize(); + return d.ok(); +} + template bool MaxPoolForwardWithOptionalArgmax::operator()( const T* bottom_data, const int batch, const int height, const int width, diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.h b/tensorflow/core/kernels/maxpooling_op_gpu.h index d2029f5719a..34203797cf0 100644 --- a/tensorflow/core/kernels/maxpooling_op_gpu.h +++ b/tensorflow/core/kernels/maxpooling_op_gpu.h @@ -42,6 +42,15 @@ struct MaxPoolForwardWithOptionalArgmax { const Eigen::GpuDevice& d); }; +struct MaxPoolForwardNoMask_NCHW_VECT_C { + bool operator()(const int32* bottom_data, const int batch, const int height, + const int width, int channels, const int pooled_height, + const int pooled_width, const int kernel_h, + const int kernel_w, const int stride_h, const int stride_w, + const int pad_t, const int pad_l, int32* top_data, + const Eigen::GpuDevice& d); +}; + template struct MaxPoolBackwardWithArgmax { bool operator()(const int output_size, const int input_size, diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc index 37747a31999..7dee751c4f3 100644 --- a/tensorflow/core/kernels/pooling_ops_common.cc +++ b/tensorflow/core/kernels/pooling_ops_common.cc @@ -22,7 +22,6 @@ limitations under the License. #if GOOGLE_CUDA #include "tensorflow/core/kernels/conv_2d.h" -#include "tensorflow/core/kernels/maxpooling_op_gpu.h" #include "tensorflow/core/kernels/pooling_ops_common_gpu.h" #include "tensorflow/core/platform/stream_executor.h" #endif // GOOGLE_CUDA @@ -34,12 +33,18 @@ PoolParameters::PoolParameters(OpKernelContext* context, const std::vector& stride, Padding padding, TensorFormat data_format, const TensorShape& tensor_in_shape) { - // For maxpooling, tensor_in should have 4 dimensions. - OP_REQUIRES(context, tensor_in_shape.dims() == 4, - errors::InvalidArgument("tensor_in must be 4-dimensional")); + // For maxpooling, tensor_in should have 2 spatial dimensions. + // Note: the total number of dimensions could be 4 for NHWC, NCHW, + // or 5 for NCHW_VECT_C. + OP_REQUIRES(context, + GetTensorSpatialDims(tensor_in_shape.dims(), data_format) == 2, + errors::InvalidArgument( + "tensor_in_shape must have 2 spatial dimensions. ", + tensor_in_shape.dims(), " ", data_format)); this->data_format = data_format; - depth = GetTensorDim(tensor_in_shape, data_format, 'C'); + depth = GetTensorDim(tensor_in_shape, data_format, 'C') * + (data_format == FORMAT_NCHW_VECT_C ? 4 : 1); tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, 'W'); tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, 'H'); tensor_in_batch = GetTensorDim(tensor_in_shape, data_format, 'N'); diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h index 1b59c18df79..75a6fc371b4 100644 --- a/tensorflow/core/kernels/pooling_ops_common.h +++ b/tensorflow/core/kernels/pooling_ops_common.h @@ -29,6 +29,10 @@ limitations under the License. #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/work_sharder.h" +#if GOOGLE_CUDA +#include "tensorflow/core/kernels/maxpooling_op_gpu.h" +#endif // GOOGLE_CUDA + namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; @@ -256,6 +260,30 @@ class MaxPoolingOp : public OpKernel { TensorFormat data_format_; }; +template +struct LaunchMaxPoolingNoMask_NCHW_VECT_C; + +#ifdef GOOGLE_CUDA +template <> +struct LaunchMaxPoolingNoMask_NCHW_VECT_C { + static void launch(OpKernelContext* context, const PoolParameters& params, + const Tensor& input, Tensor* output) { + bool status = functor::MaxPoolForwardNoMask_NCHW_VECT_C()( + reinterpret_cast(input.flat().data()), + params.tensor_in_batch, params.tensor_in_rows, params.tensor_in_cols, + params.depth, params.out_height, params.out_width, params.window_rows, + params.window_cols, params.row_stride, params.col_stride, + params.pad_rows, params.pad_cols, + reinterpret_cast(output->flat().data()), + context->eigen_gpu_device()); + if (!status) { + context->SetStatus(errors::Internal( + "Failed launching LaunchMaxPoolingNoMask_NCHW_VECT_C")); + } + } +}; +#endif + template class MaxPoolingV2Op : public OpKernel { public: @@ -266,8 +294,11 @@ class MaxPoolingV2Op : public OpKernel { OP_REQUIRES(context, FormatFromString(data_format, &data_format_), errors::InvalidArgument("Invalid data format")); OP_REQUIRES( - context, data_format_ == FORMAT_NHWC, - errors::InvalidArgument("Default MaxPoolingOp only supports NHWC.")); + context, + data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW_VECT_C, + errors::InvalidArgument( + "MaxPoolingV2Op only supports NHWC or NCHW_VECT_C. Got: ", + data_format)); } else { data_format_ = FORMAT_NHWC; } @@ -315,8 +346,8 @@ class MaxPoolingV2Op : public OpKernel { errors::Unimplemented( "Pooling is not yet supported on the batch dimension.")); - PoolParameters params{context, ksize, stride, - padding_, FORMAT_NHWC, tensor_in.shape()}; + PoolParameters params{context, ksize, stride, + padding_, data_format_, tensor_in.shape()}; if (!context->status().ok()) { return; } @@ -368,13 +399,21 @@ class MaxPoolingV2Op : public OpKernel { // Spatial MaxPooling implementation. // // TODO(vrv): Remove this once we no longer need it. +#ifdef GOOGLE_CUDA if (std::is_same::value) { Eigen::PaddingType pt = BrainPadding2EigenPadding(padding); - functor::SpatialMaxPooling()( - context->eigen_device(), output->tensor(), - tensor_in.tensor(), params.window_rows, params.window_cols, - params.row_stride, params.col_stride, pt); - } else { + if (std::is_same::value) { + LaunchMaxPoolingNoMask_NCHW_VECT_C::launch( + context, params, tensor_in, output); + } else { + functor::SpatialMaxPooling()( + context->eigen_device(), output->tensor(), + tensor_in.tensor(), params.window_rows, params.window_cols, + params.row_stride, params.col_stride, pt); + } + } else +#endif + { typedef Eigen::Map> ConstEigenMatrixMap; typedef Eigen::Map> diff --git a/tensorflow/core/kernels/sql/sqlite_query_connection.cc b/tensorflow/core/kernels/sql/sqlite_query_connection.cc index b39b38b4b8d..118c9f55458 100644 --- a/tensorflow/core/kernels/sql/sqlite_query_connection.cc +++ b/tensorflow/core/kernels/sql/sqlite_query_connection.cc @@ -82,8 +82,6 @@ Status SqliteQueryConnection::GetNext(std::vector* out_tensors, int rc = sqlite3_step(stmt_); if (rc == SQLITE_ROW) { for (int i = 0; i < column_count_; i++) { - // TODO(b/64276939) Support other tensorflow types. Interpret columns as - // the types that the client specifies. DataType dt = output_types_[i]; Tensor tensor(cpu_allocator(), dt, {}); FillTensorWithResultSetEntry(dt, i, &tensor); @@ -125,11 +123,46 @@ void SqliteQueryConnection::FillTensorWithResultSetEntry( tensor->scalar()() = value; break; } + case DT_INT8: { + int8 value = sqlite3_column_int(stmt_, column_index); + tensor->scalar()() = value; + break; + } + case DT_INT16: { + int16 value = sqlite3_column_int(stmt_, column_index); + tensor->scalar()() = value; + break; + } case DT_INT32: { int32 value = sqlite3_column_int(stmt_, column_index); tensor->scalar()() = value; break; } + case DT_INT64: { + int64 value = sqlite3_column_int64(stmt_, column_index); + tensor->scalar()() = value; + break; + } + case DT_UINT8: { + uint8 value = sqlite3_column_int(stmt_, column_index); + tensor->scalar()() = value; + break; + } + case DT_UINT16: { + uint16 value = sqlite3_column_int(stmt_, column_index); + tensor->scalar()() = value; + break; + } + case DT_BOOL: { + int value = sqlite3_column_int(stmt_, column_index); + tensor->scalar()() = value ? true : false; + break; + } + case DT_DOUBLE: { + double value = sqlite3_column_double(stmt_, column_index); + tensor->scalar()() = value; + break; + } // Error preemptively thrown by SqlDatasetOp::MakeDataset in this case. default: { LOG(FATAL) diff --git a/tensorflow/core/kernels/sql_dataset_ops.cc b/tensorflow/core/kernels/sql_dataset_ops.cc index c8713f7996d..23846d65bb8 100644 --- a/tensorflow/core/kernels/sql_dataset_ops.cc +++ b/tensorflow/core/kernels/sql_dataset_ops.cc @@ -34,13 +34,15 @@ class SqlDatasetOp : public DatasetOpKernel { explicit SqlDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); - // TODO(b/64276939) Remove this check when we add support for other - // tensorflow types. for (const DataType& dt : output_types_) { - OP_REQUIRES( - ctx, dt == DT_STRING || dt == DT_INT32, - errors::InvalidArgument( - "Each element of `output_types_` must be DT_STRING or DT_INT32")); + OP_REQUIRES(ctx, + dt == DT_STRING || dt == DT_INT8 || dt == DT_INT16 || + dt == DT_INT32 || dt == DT_INT64 || dt == DT_UINT8 || + dt == DT_UINT16 || dt == DT_BOOL || dt == DT_DOUBLE, + errors::InvalidArgument( + "Each element of `output_types_` must be one of: " + "DT_STRING, DT_INT8, DT_INT16, DT_INT32, DT_INT64, " + "DT_UINT8, DT_UINT16, DT_BOOL, DT_DOUBLE ")); } for (const PartialTensorShape& pts : output_shapes_) { OP_REQUIRES(ctx, pts.dims() == 0, diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index a8338620d69..4919266d8ee 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -303,6 +303,49 @@ op { is_aggregate: true is_commutative: true } +op { + name: "AddN" + input_arg { + name: "inputs" + type_attr: "T" + number_attr: "N" + } + output_arg { + name: "sum" + type_attr: "T" + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT64 + type: DT_INT32 + type: DT_UINT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_HALF + type: DT_VARIANT + } + } + } + is_aggregate: true + is_commutative: true +} op { name: "AddSparseToTensorsMap" input_arg { @@ -13473,6 +13516,74 @@ op { } } } +op { + name: "MaxPool" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_UINT16 + type: DT_HALF + type: DT_QINT8 + } + } + } + attr { + name: "ksize" + type: "list(int)" + has_minimum: true + minimum: 4 + } + attr { + name: "strides" + type: "list(int)" + has_minimum: true + minimum: 4 + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } + attr { + name: "data_format" + type: "string" + default_value { + s: "NHWC" + } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + s: "NCHW_VECT_C" + } + } + } +} op { name: "MaxPool3D" input_arg { @@ -14435,6 +14546,70 @@ op { } } } +op { + name: "MaxPoolV2" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "ksize" + type: DT_INT32 + } + input_arg { + name: "strides" + type: DT_INT32 + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_UINT16 + type: DT_HALF + type: DT_QINT8 + } + } + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } + attr { + name: "data_format" + type: "string" + default_value { + s: "NHWC" + } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + s: "NCHW_VECT_C" + } + } + } +} op { name: "MaxPoolWithArgmax" input_arg { diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index c21b9a7977a..ef4737cafe4 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -28,7 +28,7 @@ REGISTER_OP("AddN") .Input("inputs: N * T") .Output("sum: T") .Attr("N: int >= 1") - .Attr("T: numbertype") + .Attr("T: {numbertype, variant}") .SetIsCommutative() .SetIsAggregate() .SetShapeFn([](InferenceContext* c) { diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 6651ad41e9a..22afa4db9aa 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -1344,11 +1344,13 @@ output: The gradients for LRN. // -------------------------------------------------------------------------- REGISTER_OP("MaxPool") - .Attr("T: realnumbertype = DT_FLOAT") + .Attr( + "T: {float, double, int32, int64, uint8, int16, int8, uint16, " + "half, qint8} = DT_FLOAT") .Attr("ksize: list(int) >= 4") .Attr("strides: list(int) >= 4") .Attr(GetPaddingAttrString()) - .Attr(GetConvnetDataFormatAttrString()) + .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") .Input("input: T") .Output("output: T") .SetShapeFn(shape_inference::MaxPoolShape) @@ -1369,9 +1371,11 @@ output: The max pooled output tensor. )doc"); REGISTER_OP("MaxPoolV2") - .Attr("T: realnumbertype = DT_FLOAT") + .Attr( + "T: {float, double, int32, int64, uint8, int16, int8, uint16, " + "half, qint8} = DT_FLOAT") .Attr(GetPaddingAttrString()) - .Attr(GetConvnetDataFormatAttrString()) + .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") .Input("input: T") .Input("ksize: int32") .Input("strides: int32") diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index cfd3869d059..3f78c72e038 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -334,6 +334,7 @@ op { type: DT_QUINT8 type: DT_QINT32 type: DT_HALF + type: DT_VARIANT } } } @@ -12628,6 +12629,7 @@ op { type: DT_INT8 type: DT_UINT16 type: DT_HALF + type: DT_QINT8 } } } @@ -12667,6 +12669,7 @@ op { list { s: "NHWC" s: "NCHW" + s: "NCHW_VECT_C" } } } @@ -13401,6 +13404,7 @@ op { type: DT_INT8 type: DT_UINT16 type: DT_HALF + type: DT_QINT8 } } } @@ -13426,6 +13430,7 @@ op { list { s: "NHWC" s: "NCHW" + s: "NCHW_VECT_C" } } } diff --git a/tensorflow/core/profiler/README.md b/tensorflow/core/profiler/README.md index 2ddac0a79b7..78557994172 100644 --- a/tensorflow/core/profiler/README.md +++ b/tensorflow/core/profiler/README.md @@ -216,7 +216,7 @@ seq2seq_attention_model.py:363:build_graph:self._add_train_o..., cpu: 1.28sec, a ```shell # The following example generates a timeline. -tfprof> graph -step 0 -max_depth 100000 -output timeline:outfile= +tfprof> graph -step -1 -max_depth 100000 -output timeline:outfile= generating trace file. diff --git a/tensorflow/core/profiler/g3doc/command_line.md b/tensorflow/core/profiler/g3doc/command_line.md index fb4207c7841..d41ac7290db 100644 --- a/tensorflow/core/profiler/g3doc/command_line.md +++ b/tensorflow/core/profiler/g3doc/command_line.md @@ -14,7 +14,12 @@ ### Command Line Inputs -tfprof command line tool uses the following inputs: +tfprof command line tool uses the following input: + +--profile_path: A ProfileProto binary proto file. +See QuickStart on generating the file. + +THE OLD WAY BELOW IS DEPRECATED: --graph_path: GraphDef proto file (required). Used to build in-memory data structure of the model. For example, graph.pbtxt written by tf.Supervisor diff --git a/tensorflow/core/profiler/internal/print_model_analysis.cc b/tensorflow/core/profiler/internal/print_model_analysis.cc index ddf3c7f1f28..575ae182ee8 100644 --- a/tensorflow/core/profiler/internal/print_model_analysis.cc +++ b/tensorflow/core/profiler/internal/print_model_analysis.cc @@ -84,7 +84,6 @@ string RunProfile(const string& command, const string& options, } // namespace bool NewProfiler(const string* graph, const string* op_log) { - CHECK(!tf_stat) << "Currently only 1 living tfprof profiler is allowed"; CHECK(graph) << "graph mustn't be null"; std::unique_ptr graph_ptr(new GraphDef()); if (!graph_ptr->ParseFromString(*graph)) { diff --git a/tensorflow/core/profiler/internal/tfprof_node.h b/tensorflow/core/profiler/internal/tfprof_node.h index 55d53f39237..95d199e5b90 100644 --- a/tensorflow/core/profiler/internal/tfprof_node.h +++ b/tensorflow/core/profiler/internal/tfprof_node.h @@ -175,22 +175,22 @@ class ExecStep { std::map> output_memory_; }; -#define GRAPH_NODE_BYTES(type) \ - do { \ - if (execs_.empty()) { \ - return 0; \ - } \ - if (step >= 0) { \ - auto exec = execs_.find(step); \ - CHECK(exec != execs_.end()) << "unknown step " << step; \ - return exec->second.type##_bytes(); \ - } \ - \ - int64 bytes = 0; \ - for (const auto& exec : execs_) { \ - bytes += exec.second.type##_bytes(); \ - } \ - return bytes / execs_.size(); \ +#define GRAPH_NODE_BYTES(type) \ + do { \ + if (execs_.empty()) { \ + return 0; \ + } \ + if (step >= 0) { \ + auto exec = execs_.find(step); \ + if (exec == execs_.end()) return 0; \ + return exec->second.type##_bytes(); \ + } \ + \ + int64 bytes = 0; \ + for (const auto& exec : execs_) { \ + bytes += exec.second.type##_bytes(); \ + } \ + return bytes / execs_.size(); \ } while (0) class TFGraphNode { @@ -372,7 +372,9 @@ class TFGraphNode { } if (step >= 0) { auto exec = execs_.find(step); - CHECK(exec != execs_.end()); + if (exec == execs_.end()) { + return 0; + } return exec->second.run_count(); } int64 total_run_count = 0; @@ -390,7 +392,9 @@ class TFGraphNode { } if (step >= 0) { auto exec = execs_.find(step); - CHECK(exec != execs_.end()); + if (exec == execs_.end()) { + return 0; + } return exec->second.exec_micros(); } @@ -410,7 +414,9 @@ class TFGraphNode { } if (step >= 0) { auto exec = execs_.find(step); - CHECK(exec != execs_.end()); + if (exec == execs_.end()) { + return 0; + } return exec->second.accelerator_exec_micros(); } @@ -430,7 +436,9 @@ class TFGraphNode { } if (step >= 0) { auto exec = execs_.find(step); - CHECK(exec != execs_.end()); + if (exec == execs_.end()) { + return 0; + } return exec->second.cpu_exec_micros(); } @@ -448,20 +456,26 @@ class TFGraphNode { int64 all_start_micros(int64 step) const { auto exec = execs_.find(step); - CHECK(exec != execs_.end()) << "unknown step " << step; + if (exec == execs_.end()) { + return 0; + } return exec->second.all_start_micros(); } int64 latest_end_micros(int64 step) const { auto exec = execs_.find(step); - CHECK(exec != execs_.end()) << "unknown step " << step; + if (exec == execs_.end()) { + return 0; + } return exec->second.latest_end_micros(); } const std::map>>& op_execs( int64 step) const { auto exec = execs_.find(step); - CHECK(exec != execs_.end()) << "unknown step " << step; + if (exec == execs_.end()) { + return empty_op_execs_; + } return exec->second.op_execs(); } @@ -469,33 +483,45 @@ class TFGraphNode { int64 accelerator_temp_bytes(int64 step) const { auto exec = execs_.find(step); - CHECK(exec != execs_.end()) << "unknown step " << step; + if (exec == execs_.end()) { + return 0; + } return exec->second.accelerator_temp_bytes(); } int64 host_temp_bytes(int64 step) const { auto exec = execs_.find(step); - CHECK(exec != execs_.end()) << "unknown step " << step; + if (exec == execs_.end()) { + return 0; + } return exec->second.host_temp_bytes(); } int64 accelerator_persistent_bytes(int64 step) const { auto exec = execs_.find(step); - CHECK(exec != execs_.end()) << "unknown step " << step; + if (exec == execs_.end()) { + return 0; + } return exec->second.accelerator_persistent_bytes(); } int64 host_persistent_bytes(int64 step) const { auto exec = execs_.find(step); - CHECK(exec != execs_.end()) << "unknown step " << step; + if (exec == execs_.end()) { + return 0; + } return exec->second.host_persistent_bytes(); } const std::map>& output_memory( int64 step) const { auto exec = execs_.find(step); - CHECK(exec != execs_.end()) << "unknown step " << step; + if (exec == execs_.end()) { + return empty_output_memory_; + } return exec->second.output_memory(); } int64 allocator_bytes_in_use(int64 step) const { auto exec = execs_.find(step); - CHECK(exec != execs_.end()) << "unknown step " << step; + if (exec == execs_.end()) { + return 0; + } return exec->second.allocator_bytes_in_use(); } @@ -566,6 +592,9 @@ class TFGraphNode { std::set op_types_; std::map execs_; + + std::map> empty_output_memory_; + std::map>> empty_op_execs_; }; class TFMultiGraphNode { diff --git a/tensorflow/core/profiler/internal/tfprof_stats.cc b/tensorflow/core/profiler/internal/tfprof_stats.cc index 5b549583446..eb84bada135 100644 --- a/tensorflow/core/profiler/internal/tfprof_stats.cc +++ b/tensorflow/core/profiler/internal/tfprof_stats.cc @@ -88,6 +88,9 @@ TFStats::TFStats(const string& filename, node_pb.second.name(), std::move(node))); } has_code_traces_ = profile.has_trace(); + for (int64 s : profile.steps()) { + steps_.insert(s); + } } void TFStats::BuildView(const string& cmd) { @@ -136,6 +139,14 @@ const GraphNodeProto& TFStats::ShowGraphNode(const string& cmd, if (cmd == kCmds[0]) { return scope_view_->Show(opts); } else if (cmd == kCmds[1]) { + if (opts.step < 0 && opts.output_type == kOutput[0]) { + for (int64 step : steps_) { + Options nopts = opts; + nopts.step = step; + graph_view_->Show(nopts); + } + return empty_graph_node_; + } return graph_view_->Show(opts); } else { fprintf(stderr, "Unknown command: %s\n", cmd.c_str()); @@ -148,7 +159,11 @@ const MultiGraphNodeProto& TFStats::ShowMultiGraphNode( if (!Validate(opts)) { return empty_multi_graph_node_; } - if (cmd == kCmds[2] && has_code_traces()) { + if (cmd == kCmds[2]) { + if (!has_code_traces()) { + fprintf(stderr, "No code trace information\n"); + return empty_multi_graph_node_; + } return code_view_->Show(opts); } else if (cmd == kCmds[3]) { return op_view_->Show(opts); @@ -212,7 +227,9 @@ void TFStats::AddOpLogProto(std::unique_ptr op_log) { } if (entry.has_code_def()) { has_code_traces_ = true; - node->second->AddCode(entry.code_def()); + if (node->second->code().traces_size() == 0) { + node->second->AddCode(entry.code_def()); + } } } } @@ -258,9 +275,11 @@ void TFStats::WriteProfile(const string& filename) { } (*profile.mutable_nodes())[it->second->id()].MergeFrom( it->second->ToProto(nodes_map_)); - if (it->second->code().traces_size() > 0) { - profile.set_has_trace(true); - } + } + + profile.set_has_trace(has_code_traces_); + for (int64 s : steps_) { + profile.add_steps(s); } Status s = WriteStringToFile(Env::Default(), filename, profile.SerializeAsString()); @@ -271,7 +290,12 @@ void TFStats::WriteProfile(const string& filename) { bool TFStats::Validate(const Options& opts) const { if (opts.step >= 0 && steps_.find(opts.step) == steps_.end()) { - fprintf(stderr, "Options -step=%lld not found\n", opts.step); + fprintf(stderr, + "Options -step=%lld not found.\nAvailable steps: ", opts.step); + for (int64 s : steps_) { + fprintf(stderr, "%lld ", s); + } + fprintf(stderr, "\n"); return false; } return true; diff --git a/tensorflow/core/profiler/internal/tfprof_timeline.cc b/tensorflow/core/profiler/internal/tfprof_timeline.cc index f3934860d9a..1732574cc41 100644 --- a/tensorflow/core/profiler/internal/tfprof_timeline.cc +++ b/tensorflow/core/profiler/internal/tfprof_timeline.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/profiler/internal/tfprof_utils.h" namespace tensorflow { @@ -303,11 +304,12 @@ void Timeline::GenerateCodeTimeline(const CodeNode* node) { } void Timeline::OutputTimeline() { + string outfile = strings::Printf("%s_%lld", outfile_.c_str(), step()); Status s = - WriteStringToFile(Env::Default(), outfile_, chrome_formatter_.Format()); + WriteStringToFile(Env::Default(), outfile, chrome_formatter_.Format()); if (!s.ok()) { fprintf(stderr, "Failed to write timeline file: %s\nError: %s\n", - outfile_.c_str(), s.ToString().c_str()); + outfile.c_str(), s.ToString().c_str()); return; } fprintf(stdout, "\n******************************************************\n"); @@ -315,7 +317,7 @@ void Timeline::OutputTimeline() { "Timeline file is written to %s.\n" "Open a Chrome browser, enter URL chrome://tracing and " "load the timeline file.", - outfile_.c_str()); + outfile.c_str()); fprintf(stdout, "\n******************************************************\n"); fflush(stdout); } diff --git a/tensorflow/core/profiler/internal/tfprof_timeline_test.cc b/tensorflow/core/profiler/internal/tfprof_timeline_test.cc index 5dd440e9a26..babae395bad 100644 --- a/tensorflow/core/profiler/internal/tfprof_timeline_test.cc +++ b/tensorflow/core/profiler/internal/tfprof_timeline_test.cc @@ -70,7 +70,7 @@ TEST_F(TFProfTimelineTest, GraphView) { tf_stats_->ShowGraphNode("graph", opts); string dump_str; - TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file, &dump_str)); + TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file + "_0", &dump_str)); EXPECT_EQ(1754536562981488144ull, Hash64(dump_str)); } @@ -84,7 +84,7 @@ TEST_F(TFProfTimelineTest, ScopeView) { tf_stats_->ShowGraphNode("scope", opts); string dump_str; - TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file, &dump_str)); + TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file + "_0", &dump_str)); EXPECT_EQ(17545174915963890413ull, Hash64(dump_str)); } diff --git a/tensorflow/core/profiler/tfprof_log.proto b/tensorflow/core/profiler/tfprof_log.proto index 1ce5a5eecf7..ae571e2540b 100644 --- a/tensorflow/core/profiler/tfprof_log.proto +++ b/tensorflow/core/profiler/tfprof_log.proto @@ -42,6 +42,8 @@ message ProfileProto { map nodes = 1; // Whether or not has code traces. bool has_trace = 2; + // Traced steps. + repeated int64 steps = 3; } message ProfileNode { diff --git a/tensorflow/docs_src/extend/adding_an_op.md b/tensorflow/docs_src/extend/adding_an_op.md index 424648d54ab..cd086d80130 100644 --- a/tensorflow/docs_src/extend/adding_an_op.md +++ b/tensorflow/docs_src/extend/adding_an_op.md @@ -632,6 +632,22 @@ define an attr with constraints, you can use the following ``s: tf.number_type(t=tf.bool) # Invalid ``` + Lists can be combined with other lists and single types. The following + op allows attr `t` to be any of the numberic types, or the bool type: + + ```c++ + REGISTER_OP("NumberOrBooleanType") + .Attr("t: {numbertype, bool}"); + ``` + + For this op: + + ```python + tf.number_or_boolean_type(t=tf.int32) # Valid + tf.number_or_boolean_type(t=tf.bool) # Valid + tf.number_or_boolean_type(t=tf.string) # Invalid + ``` + * `int >= `: The value must be an int whose value is greater than or equal to ``, where `` is a natural number. diff --git a/tensorflow/docs_src/install/index.md b/tensorflow/docs_src/install/index.md index 3df16139fb8..eddbfe9e31e 100644 --- a/tensorflow/docs_src/install/index.md +++ b/tensorflow/docs_src/install/index.md @@ -1,5 +1,13 @@ # Installing TensorFlow +We've built and tested TensorFlow on the following 64-bit laptop/desktop +operating systems: + * MacOS X 10.11 (El Capitan) or later. + * Ubuntu 14.04 or later + * Windows 7 or later. +Although you might be able to install TensorFlow on other laptop or desktop +systems, we only support (and only fix issues in) the preceding configurations. + The following guides explain how to install a version of TensorFlow that enables you to write applications in Python: diff --git a/tensorflow/docs_src/performance/performance_guide.md b/tensorflow/docs_src/performance/performance_guide.md index 2448ffac8bb..bf69b7e6fc2 100644 --- a/tensorflow/docs_src/performance/performance_guide.md +++ b/tensorflow/docs_src/performance/performance_guide.md @@ -1,43 +1,182 @@ # Performance Guide -This guide contains a collection of best practices for optimizing your -TensorFlow code. The best practices apply to both new and experienced -Tensorflow users. As a complement to the best practices in this document, the -@{$performance_models$High-Performance Models} document links to example code -and details for creating models that scale on a variety of hardware. +This guide contains a collection of best practices for optimizing TensorFlow +code. The guide is divided into a few sections: -## Best Practices -While optimizing implementations of different types of models can be different, -the topics below cover best practices to get the most performance from -TensorFlow. Although these suggestions focus on image-based models, we will -regularly add tips for all kinds of models. The following list highlights key -best practices: +* [General best practices](#general_best_practices) covers topics that are + common across a variety of model types and hardware. +* [Optimizing for GPU](#optimizing_for_gpu) details tips specifically relevant + to GPUs. +* [Optimizing for CPU](#optimizing_for_cpu) details CPU specific information. -* Build and install from source -* Utilize queues for reading data -* Preprocessing on the CPU -* Use `NCHW` image data format -* Place shared parameters on the GPU -* Use fused batch norm +## General best practices -The following sections detail the preceding suggestions. +The sections below cover best practices that are relevant to a variety of +hardware and models. The best practices section is broken down into the +following sections: -### Build and install from source +* [Input pipeline optimizations](#input-pipeline-optimization) +* [Data formats](#data-formats) +* [Common fused Ops](#common-fused-ops) +* [Building and installing from source](#building-and-installing-from-source) -To install the most optimized version of TensorFlow, build and install -TensorFlow from source by following [Installing TensorFlow from Source](../install/install_sources). -Building from source with compiler optimizations for the target hardware and -ensuring the latest CUDA platform and cuDNN libraries are installed results in -the highest performing installs. +### Input pipeline optimization -For the most stable experience, build from the [latest release](https://github.com/tensorflow/tensorflow/releases) -branch. To get the latest performance changes and accept some stability risk, -build from [master](https://github.com/tensorflow/tensorflow). +Typical models retrieve data from disk and preprocess it before sending the data +through the network. For example, models that process JPEG images will follow +this flow: load image from disk, decode JPEG into a tensor, crop and pad, +possibly flip and distort, and then batch. This flow is referred to as the input +pipeline. As GPUs and other hardware accelerators get faster, preprocessing of +data can be a bottleneck. -If there is a need to build TensorFlow on a platform that has different hardware -than the target, then cross-compile with the highest optimizations for the target -platform. The following command is an example of telling `bazel` to compile for -a specific platform: +Determining if the input pipeline is the bottleneck can be complicated. One of +the most straightforward methods is to reduce the model to a single operation +(trivial model) after the input pipeline and measure the examples per second. If +the difference in examples per second for the full model and the trivial model +is minimal then the input pipeline is likely a bottleneck. Below are some other +approaches to identifying issues: + +* Check if a GPU is underutilized by running `watch -n 2 nvidia-smi`. If GPU + utilization is not approaching 80-100%, then the input pipeline may be the + bottleneck. +* Generate a timeline and look for large blocks of white space (waiting). An + example of generating a timeline exists as part of the @{$jit$XLA JIT} + tutorial. +* Check CPU usage. It is possible to have an optimized input pipeline and lack + the CPU cycles to process the pipeline. +* Estimate the throughput needed and verify the disk used is capable of that + level of throughput. Some cloud solutions have network attached disks that + start as low as 50 MB/sec, which is slower than spinning disks (150 MB/sec), + SATA SSDs (500 MB/sec), and PCIe SSDs (2,000+ MB/sec). + +#### Preprocessing on the CPU + +Placing input pipeline operations on the CPU can significantly improve +performance. Utilizing the CPU for the input pipeline frees the GPU to focus on +training. To ensure preprocessing is on the CPU, wrap the preprocessing +operations as shown below: + +```python +with tf.device('/cpu:0'): + # function to get and process images or data. + distorted_inputs = load_and_distort_images() +``` + +If using `tf.estimator.Estimator` the input function is automatically placed on +the CPU. + +#### Using the Dataset API + +The @{$datasets$Dataset API} is replacing `queue_runner` as the recommended API +for building input pipelines. The API was added to contrib as part of TensorFlow +1.2 and will move to core in the near future. This +[ResNet example](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10_estimator/cifar10_main.py) +([arXiv:1512.03385](https://arxiv.org/abs/1512.03385)) +training CIFAR-10 illustrates the use of the Dataset API along with +`tf.estimator.Estimator`. The Dataset API utilizes C++ multi-threading and has a +much lower overhead than the Python-based `queue_runner` that is limited by +Python's multi-threading performance. + +While feeding data using a `feed_dict` offers a high level of flexibility, in +most instances using `feed_dict` does not scale optimally. However, in instances +where only a single GPU is being used the difference can be negligible. Using +the Dataset API is still strongly recommended. Try to avoid the following: + +```python +# feed_dict often results in suboptimal performance when using large inputs. +sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) +``` + +#### Use large files + +Reading large numbers of small files significantly impacts I/O performance. +One approach to get maximum I/O throughput is to preprocess input data into +larger (~100MB) `TFRecord` files. For smaller data sets (200MB-1GB), the best +approach is often to load the entire data set into memory. The document +[Downloading and converting to TFRecord format](https://github.com/tensorflow/models/tree/master/slim#Data) +includes information and scripts for creating `TFRecords` and this +[script](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10_estimator/generate_cifar10_tfrecords.py) +converts the CIFAR-10 data set into `TFRecords`. + +### Data formats + +Data formats refers to the structure of the Tensor passed to a given Op. The +discussion below is specifically about 4D Tensors representing images. In +TensorFlow the parts of the 4D tensor are often referred to by the following +letters: + +* N refers to the number of images in a batch. +* H refers to the number of pixels in the vertical (height) dimension. +* W refers to the number of pixels in the horizontal (width) dimension. +* C refers to the channels. For example, 1 for black and white or grayscale + and 3 for RGB. + +Within TensorFlow there are two naming conventions representing the two most +common data formats: + +* `NCHW` or `channels_first` +* `NHWC` or `channels_last` + +`NHWC` is the TensorFlow default and `NCHW` is the optimal format to use when +training on NVIDIA GPUs using [cuDNN](https://developer.nvidia.com/cudnn). + +The best practice is to build models that work with both data formats. This +simplifies training on GPUs and then running inference on CPUs. If TensorFlow is +compiled with the [Intel MKL](#tensorflow_with_intel_mkl-dnn) optimizations, +many operations, especially those related to CNN based models, will be optimized +and support `NCHW`. If not using the MKL, some operations are not supported on +CPU when using `NCHW`. + +The brief history of these two formats is that TensorFlow started by using +`NHWC` because it was a little faster on CPUs. In the long term, we are working +on tools to auto rewrite graphs to make switching between the formats +transparent and take advantages of micro optimizations where a GPU Op may be +faster using `NHWC` than the normally most efficient `NCHW`. + +### Common fused Ops + +Fused Ops combine multiple operations into a single kernel for improved +performance. There are many fused Ops within TensorFlow and @{$xla$XLA} will +create fused Ops when possible to automatically improve performance. Collected +below are select fused Ops that can greatly improve performance and may be +overlooked. + +#### Fused batch norm + +Fused batch norm combines the multiple operations needed to do batch +normalization into a single kernel. Batch norm is an expensive process that for +some models makes up a large percentage of the operation time. Using fused batch +norm can result in a 12%-30% speedup. + +There are two commonly used batch norms and both support fusing. The core +@{tf.layers.batch_normalization} added fused starting in TensorFlow 1.3. + +```python +bn = tf.layers.batch_normalization( + input_layer, fused=True, data_format='NCHW') +``` + +The contrib @{tf.contrib.layers.batch_norm} method has had fused as an option +since before TensorFlow 1.0. + +```python +bn = tf.contrib.layers.batch_norm(input_layer, fused=True, data_format='NCHW') +``` + +### Building and installing from source + +The default TensorFlow binaries target the broadest range of hardware to make +TensorFlow accessible to everyone. If using CPUs for training or inference, it +is recommended to compile TensorFlow with all of the optimizations available for +the CPU in use. Speedups for training and inference on CPU are documented below +in [Comparing compiler optimizations](#comparing-compiler-optimizations). + +To install the most optimized version of TensorFlow, +@{$install_sources$build and install} from source. If there is a need to build +TensorFlow on a platform that has different hardware than the target, then +cross-compile with the highest optimizations for the target platform. The +following command is an example of using `bazel` to compile for a specific +platform: ```python # This command optimizes for Intel’s Broadwell processor @@ -47,106 +186,467 @@ bazel build -c opt --copt=-march="broadwell" --config=cuda //tensorflow/tools/pi #### Environment, build, and install tips -* Compile with the highest level of compute the [GPU - supports](http://developer.nvidia.com/cuda-gpus), e.g. P100: 6.0, Titan X - (pascal): 6.2, Titan X (maxwell): 5.2, and K80: 3.7. -* Install the latest CUDA platform and cuDNN libraries. -* Make sure to use a version of gcc that supports all of the optimizations of - the target CPU. The recommended minimum gcc version is 4.8.3. On OS X upgrade - to the latest Xcode version and use the version of clang that comes with Xcode. -* TensorFlow checks on startup whether it has been compiled with the - optimizations available on the CPU. If the optimizations are not included, - TensorFlow will emit warnings, e.g. AVX, AVX2, and FMA instructions not - included. +* `./configure` asks which compute capability to include in the build. This + does not impact overall performance but does impact initial startup. After + running TensorFlow once, the compiled kernels are cached by CUDA. If using + a docker container, the data is not cached and the penalty is paid each time + TensorFlow starts. The best practice is to include the + [compute capabilities](http://developer.nvidia.com/cuda-gpus) + of the GPUs that will be used, e.g. P100: 6.0, Titan X (Pascal): 6.1, Titan + X (Maxwell): 5.2, and K80: 3.7. +* Use a version of gcc that supports all of the optimizations of the target + CPU. The recommended minimum gcc version is 4.8.3. On OS X, upgrade to the + latest Xcode version and use the version of clang that comes with Xcode. +* Install the latest stable CUDA platform and cuDNN libraries supported by + TensorFlow. -### Utilize queues for reading data +## Optimizing for GPU -One common cause of poor performance is underutilizing GPUs, or essentially -"starving" them of data by not setting up an efficient pipeline. Make sure to -set up an input pipeline to utilize queues and stream data effectively. Review -the @{$reading_data#reading_from_files$Reading Data guide} for implementation -details. One way to identify a "starved" GPU is to generate and review -timelines. A detailed tutorial for timelines does not exist, but a quick example -of generating a timeline exists as part of the @{$jit$XLA JIT} tutorial. Another -simple way to check if a GPU is underutilized is to run `watch nvidia-smi`, and -if GPU utilization is not approaching 100% then the GPU is not getting data fast -enough. +This section contains GPU-specific tips that are not covered in the +[General best practices](#general-best-practices). Obtaining optimal performance +on multi-GPUs is a challenge. A common approach is to use data parallelism. +Scaling through the use of data parallelism involves making multiple copies of +the model, which are referred to as "towers", and then placing one tower on each +of the GPUs. Each tower operates on a different mini-batch of data and then +updates variables, also known as parameters, that need to be shared between +each of the towers. How each tower gets the updated variables and how the +gradients are applied has an impact on the performance, scaling, and convergence +of the model. The rest of this section provides an overview of variable +placement and the towering of a model on multiple GPUs. +@{$performance_models$High-Performance Models} gets into more details regarding +more complex methods that can be used to share and update variables between +towers. -Unless for a special circumstance or for example code, do not feed data -into the session from Python variables, e.g. `dictionary`. +The best approach to handling variable updates depends on the model, hardware, +and even how the hardware has been configured. An example of this, is that two +systems can be built with NVIDIA Tesla P100s but one may be using PCIe and the +other [NVLink](http://www.nvidia.com/object/nvlink.html). In that scenario, the +optimal solution for each system may be different. For real world examples, read +the @{$benchmarks$benchmark} page which details the settings that were optimal +for a variety of platforms. Below is a summary of what was learned from +benchmarking various platforms and configurations: + +* **Tesla K80**: If the GPUs are on the same PCI Express root complex and are + able to use [NVIDIA GPUDirect](https://developer.nvidia.com/gpudirect) Peer + to Peer, then placing the variables equally across the GPUs used for + training is the best approach. If the GPUs cannot use GPUDirect, then + placing the variables on the CPU is the best option. + +* **Titan X (Maxwell and Pascal), M40, P100, and similar**: For models like + ResNet and InceptionV3, placing variables on the CPU is the optimal setting, + but for models with a lot of variables like AlexNet and VGG, using GPUs with + `NCCL` is better. + +A common approach to managing where variables are placed, is to create a method +to determine where each Op is to be placed and use that method in place of a +specific device name when calling `with tf.device():`. Consider a scenario where +a model is being trained on 2 GPUs and the variables are to be placed on the +CPU. There would be a loop for creating and placing the "towers" on each of the +2 GPUs. A custom device placement method would be created that watches for Ops +of type `Variable`, `VariableV2`, and `VarHandleOp` and indicates that they are +to be placed on the CPU. All other Ops would be placed on the target GPU. +The building of the graph would proceed as follows: + +* On the first loop a "tower" of the model would be created for `gpu:0`. + During the placement of the Ops, the custom device placement method would + indicate that variables are to be placed on `cpu:0` and all other Ops on + `gpu:0`. + +* On the second loop, `reuse` is set to `True` to indicate that variables are + to be reused and then the "tower" is created on `gpu:1`. During the + placement of the Ops associated with the "tower", the variables that were + placed on `cpu:0` are reused and all other Ops are created and placed on + `gpu:1`. + +The final result is all of the variables are placed on the CPU with each GPU +having a copy of all of the computational Ops associated with the model. + +The code snippet below illustrates two different approaches for variable +placement: one is placing variables on the CPU; the other is placing variables +equally across the GPUs. ```python -# Using feed_dict often results in suboptimal performance when using large inputs. -sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) + +class GpuParamServerDeviceSetter(object): + """Used with tf.device() to place variables on the least loaded GPU. + + A common use for this class is to pass a list of GPU devices, e.g. ['gpu:0', + 'gpu:1','gpu:2'], as ps_devices. When each variable is placed, it will be + placed on the least loaded gpu. All other Ops, which will be the computation + Ops, will be placed on the worker_device. + """ + + def __init__(self, worker_device, ps_devices): + """Initializer for GpuParamServerDeviceSetter. + Args: + worker_device: the device to use for computation Ops. + ps_devices: a list of devices to use for Variable Ops. Each variable is + assigned to the least loaded device. + """ + self.ps_devices = ps_devices + self.worker_device = worker_device + self.ps_sizes = [0] * len(self.ps_devices) + + def __call__(self, op): + if op.device: + return op.device + if op.type not in ['Variable', 'VariableV2', 'VarHandleOp']: + return self.worker_device + + # Gets the least loaded ps_device + device_index, _ = min(enumerate(self.ps_sizes), key=operator.itemgetter(1)) + device_name = self.ps_devices[device_index] + var_size = op.outputs[0].get_shape().num_elements() + self.ps_sizes[device_index] += var_size + + return device_name + +def _create_device_setter(is_cpu_ps, worker, num_gpus): + """Create device setter object.""" + if is_cpu_ps: + # tf.train.replica_device_setter supports placing variables on the CPU, all + # on one GPU, or on ps_servers defined in a cluster_spec. + return tf.train.replica_device_setter( + worker_device=worker, ps_device='/cpu:0', ps_tasks=1) + else: + gpus = ['/gpu:%d' % i for i in range(num_gpus)] + return ParamServerDeviceSetter(worker, gpus) + +# The method below is a modified snippet from the full example. +def _resnet_model_fn(): + # When set to False, variables are placed on the least loaded GPU. If set + # to True, the variables will be placed on the CPU. + is_cpu_ps = False + + # Loops over the number of GPUs and creates a copy ("tower") of the model on + # each GPU. + for i in range(num_gpus): + worker = '/gpu:%d' % i + # Creates a device setter used to determine where Ops are to be placed. + device_setter = _create_device_setter(is_cpu_ps, worker, FLAGS.num_gpus) + # Creates variables on the first loop. On subsequent loops reuse is set + # to True, which results in the "towers" sharing variables. + with tf.variable_scope('resnet', reuse=bool(i != 0)): + with tf.name_scope('tower_%d' % i) as name_scope: + # tf.device calls the device_setter for each Op that is created. + # device_setter returns the device the Op is to be placed on. + with tf.device(device_setter): + # Creates the "tower". + _tower_fn(is_training, weight_decay, tower_features[i], + tower_labels[i], tower_losses, tower_gradvars, + tower_preds, False) + ``` -### Preprocessing on the CPU +In the near future the above code will be for illustration purposes only as +there will be easy to use high level methods to support a wide range of popular +approaches. This +[example](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10_estimator) +will continue to get updated as the API expands and evolves to address multi-GPU +scenarios. -Placing preprocessing operations on the CPU can significantly improve -performance. When preprocessing occurs on the GPU the flow of data is -CPU -> GPU (preprocessing) -> CPU -> GPU (training). The data is bounced back -and forth between the CPU and GPU. When preprocessing is placed on the CPU, -the data flow is CPU (preprocessing) -> GPU (training). Another benefit is -preprocessing on the CPU frees GPU time to focus on training. +## Optimizing for CPU -Placing preprocessing on the CPU can result in a 6X+ increase in samples/sec -processed, which could lead to training in 1/6th of the time. To ensure -preprocessing is on the CPU, wrap the preprocessing operations as shown below: +CPUs, which includes Intel® Xeon Phi™, achieve optimal performance when +TensorFlow is @{$install_sources$built from source} with all of the instructions +supported by the target CPU. + +Beyond using the latest instruction sets, Intel® has added support for the +Intel® Math Kernel Library for Deep Neural Networks (Intel® MKL-DNN) to +TensorFlow. While the name is not completely accurate, these optimizations are +often simply referred to as 'MKL' or 'TensorFlow with MKL'. [TensorFlow +with Intel® MKL-DNN](#tensorflow_with_intel_mkl_dnn) contains details on the +MKL optimizations. + +The two configurations listed below are used to optimize CPU performance by +adjusting the thread pools. + +* `intra_op_parallelism_threads`: Nodes that can use multiple threads to + parallelize their execution will schedule the individual pieces into this + pool. +* `inter_op_parallelism_threads`: All ready nodes are scheduled in this pool. + +These configurations are set via the `tf.ConfigProto` and passed to `tf.Session` +in the `config` attribute as shown in the snippet below. For both configuration +options, if they are unset or set to 0, will default to the number of logical +CPU cores. Testing has shown that the default is effective for systems ranging +from one CPU with 4 cores to multiple CPUs with 70+ combined logical cores. +A common alternative optimization is to set the number of threads in both pools +equal to the number of physical cores rather than logical cores. ```python -with tf.device('/cpu:0'): - # function to get and process images or data. - distorted_inputs = load_and_distort_images() + + config = tf.ConfigProto() + config.intra_op_parallelism_threads = 44 + config.inter_op_parallelism_threads = 44 + tf.session(config=config) + ``` -### Use large files +The [Comparing compiler optimizations](#comparing-compiler-optimizations) +section contains the results of tests that used different compiler +optimizations. -Under some circumstances, both the CPU and GPU can be starved for data by the -I/O system. If you are using many small files to form your input data set, you -may be limited by the speed of your filesystem. If your training loop runs -faster when using SSDs vs HDDs for storing your input data, you could be -I/O bottlenecked. +### TensorFlow with Intel® MKL DNN -If this is the case, you should pre-process your input data, creating a few -large TFRecord files. +Intel® has added optimizations to TensorFlow for Intel® Xeon® and Intel® Xeon +Phi™ though the use of Intel® Math Kernel Library for Deep Neural Networks +(Intel® MKL-DNN) optimized primitives. The optimizations also provide speedups +for the consumer line of processors, e.g. i5 and i7 Intel processors. The Intel +published paper +[TensorFlow* Optimizations on Modern Intel® Architecture](https://software.intel.com/en-us/articles/tensorflow-optimizations-on-modern-intel-architecture) +contains additional details on the implementation. -### Use NCHW image data format +> Note: MKL was added as of TensorFlow 1.2 and currently only works on Linux. It +> also does not work when also using `--config=cuda`. -Image data format refers to the representation of batches of images. TensorFlow -supports `NHWC` (TensorFlow default) and `NCHW` (cuDNN default). N refers to the -number of images in a batch, H refers to the number of pixels in the vertical -dimension, W refers to the number of pixels in the horizontal dimension, and C -refers to the channels (e.g. 1 for black and white, 3 for RGB, etc.) Although -cuDNN can operate on both formats, it is faster to operate in its default -format. +In addition to providing significant performance improvements for training CNN +based models, compiling with the MKL creates a binary that is optimized for AVX +and AVX2. The result is a single binary that is optimized and compatible with +most modern (post-2011) processors. -The best practice is to build models that work with both `NCHW` and `NHWC` as it -is common to train using `NCHW` on GPU, and then do inference with `NHWC` on CPU. +TensorFlow can be compiled with the MKL optimizations using the following +commands that depending on the version of the TensorFlow source used. -There are edge cases where `NCHW` can be slower on GPU than `NHWC`. One -[case](https://github.com/tensorflow/tensorflow/issues/7551#issuecomment-280421351) -is using non-fused batch norm on WRN-16-4 without dropout. In that case using -fused batch norm, which is also recommended, is the optimal solution. +For TensorFlow source versions after 1.3.0: -The very brief history of these two formats is that TensorFlow started by using -`NHWC` because it was a little faster on CPUs. Then the TensorFlow team -discovered that `NCHW` performs better when using the NVIDIA cuDNN library. The -current recommendation is that users support both formats in their models. In -the long term, we plan to rewrite graphs to make switching between the formats -transparent. +```bash +./configure +# Pick the desired options +bazel build --config=mkl -c opt //tensorflow/tools/pip_package:build_pip_package -### Use fused batch norm +``` -When using batch norm -@{tf.contrib.layers.batch_norm} set the attribute `fused=True`: +For TensorFlow versions 1.2.0 through 1.3.0: + +```bash +./configure +Do you wish to build TensorFlow with MKL support? [y/N] Y +Do you wish to download MKL LIB from the web? [Y/n] Y +# Select the defaults for the rest of the options. + +bazel build --config=mkl --copt="-DEIGEN_USE_VML" -c opt //tensorflow/tools/pip_package:build_pip_package + +``` + +#### Tuning MKL for the best performance + +This section details the different configurations and environment variables that +can be used to tune the MKL to get optimal performance. Before tweaking various +environment variables make sure the model is using the `NCHW` (`channels_first`) +[data format](#data-formats). The MKL is optimized for `NCHW` and Intel is +working to get near performance parity when using `NHWC`. + +MKL uses the following environment variables to tune performance: + +* KMP_BLOCKTIME - Sets the time, in milliseconds, that a thread should wait, + after completing the execution of a parallel region, before sleeping. +* KMP_AFFINITY - Enables the run-time library to bind threads to physical + processing units. +* KMP_SETTINGS - Enables (true) or disables (false) the printing of OpenMP* + run-time library environment variables during program execution. +* OMP_NUM_THREADS - Specifies the number of threads to use. + +More details on the KMP variables are on +[Intel's](https://software.intel.com/en-us/node/522775) site and the OMP +variables on +[gnu.org](https://gcc.gnu.org/onlinedocs/libgomp/Environment-Variables.html) + +While there can be substantial gains from adjusting the environment variables, +which is discussed below, the simplified advice is to set the +`inter_op_parallelism_threads` equal to the number of physical CPUs and to set +the following environment variables: + +* KMP_BLOCKTIME=0 +* KMP_AFFINITY=granularity=fine,verbose,compact,1,0 + +Example setting MKL variables with command-line arguments: + +```bash +KMP_BLOCKTIME=0 KMP_AFFINITY=granularity=fine,verbose,compact,1,0 \ +KMP_SETTINGS=1 python your_python_script.py +``` + +Example setting MKL variables with python `os.environ`: ```python -bn = tf.contrib.layers.batch_norm( - input_layer, fused=True, data_format='NCHW' - scope=scope, **kwargs) +os.environ["KMP_BLOCKTIME"] = str(FLAGS.kmp_blocktime) +os.environ["KMP_SETTINGS"] = str(FLAGS.kmp_settings) +os.environ["KMP_AFFINITY"]= FLAGS.kmp_affinity +if FLAGS.num_intra_threads > 0: + os.environ["OMP_NUM_THREADS"]= str(FLAGS.num_intra_threads) + ``` -The non-fused batch norm does computations using several individual Ops. Fused -batch norm combines the individual operations into a single kernel, which runs -faster. +There are models and hardware platforms that benefit from different settings. +Each variable that impacts performance is discussed below. +* **KMP_BLOCKTIME**: The MKL default is 200ms, which was not optimal in our + testing. 0 (0ms) was a good default for CNN based models that were tested. + The best performance for AlexNex was achieved at 30ms and both GoogleNet and + VGG11 performed best set at 1ms. + +* **KMP_AFFINITY**: The recommended setting is + `granularity=fine,verbose,compact,1,0`. + +* **OMP_NUM_THREADS**: This defaults to the number of physical cores. + Adjusting this parameter beyond matching the number of cores can have an + impact when using Intel® Xeon Phi™ (Knights Landing) for some models. See + [TensorFlow* Optimizations on Modern Intel® Architecture](https://software.intel.com/en-us/articles/tensorflow-optimizations-on-modern-intel-architecture) + for optimal settings. + +* **intra_op_parallelism_threads**: Setting this equal to the number of + physical cores is recommended. Setting the value to 0, which is the default + and will result in the value being set to the number of logical cores, is an + option to try for some architectures. This value and `OMP_NUM_THREADS` + should be equal. + +* **inter_op_parallelism_threads**: Setting this equal to the number of + sockets is recommended. Setting the value to 0, which is the default, + results in the value being set to the number of logical cores. + +### Comparing compiler optimizations + +Collected below are performance results running training and inference on +different types of CPUs on different platforms with various compiler +optimizations. The models used were ResNet-50 +([arXiv:1512.03385](https://arxiv.org/abs/1512.03385)) and +InceptionV3 ([arXiv:1512.00567](https://arxiv.org/abs/1512.00567)). + +For each test, when the MKL optimization was used the environment variable +KMP_BLOCKTIME was set to 0 (0ms) and KMP_AFFINITY to +`granularity=fine,verbose,compact,1,0`. + +#### Inference InceptionV3 + +**Environment** + +* Instance Type: AWS EC2 m4.xlarge +* CPU: Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz (Broadwell) +* Dataset: ImageNet +* TensorFlow Version: 1.2.0 RC2 +* Test Script: [tf_cnn_benchmarks.py](https://github.com/tensorflow/benchmarks/blob/mkl_experiment/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py) + +**Batch Size: 1** + +Command executed for the MKL test: + +```bash +python tf_cnn_benchmarks.py --forward_only=True --device=cpu --mkl=True \ +--kmp_blocktime=0 --nodistortions --model=inception3 --data_format=NCHW \ +--batch_size=1 --num_inter_threads=1 --num_intra_threads=4 \ +--data_dir= +``` + +| Optimization | Data Format | Images/Sec | Intra threads | Inter Threads | +: : : (step time) : : : +| ------------ | ----------- | ------------ | ------------- | ------------- | +| AVX2 | NHWC | 6.8 (147ms) | 4 | 0 | +| MKL | NCHW | 6.6 (151ms) | 4 | 1 | +| MKL | NHWC | 5.95 (168ms) | 4 | 1 | +| AVX | NHWC | 4.7 (211ms) | 4 | 0 | +| SSE3 | NHWC | 2.7 (370ms) | 4 | 0 | + +**Batch Size: 32** + +Command executed for the MKL test: + +```bash +python tf_cnn_benchmarks.py --forward_only=True --device=cpu --mkl=True \ +--kmp_blocktime=0 --nodistortions --model=inception3 --data_format=NCHW \ +--batch_size=32 --num_inter_threads=1 --num_intra_threads=4 \ +--data_dir= +``` + +| Optimization | Data Format | Images/Sec | Intra threads | Inter Threads | +: : : (step time) : : : +| ------------ | ----------- | ------------- | ------------- | ------------- | +| MKL | NCHW | 10.24 | 4 | 1 | +: : : (3125ms) : : : +| MKL | NHWC | 8.9 (3595ms) | 4 | 1 | +| AVX2 | NHWC | 7.3 (4383ms) | 4 | 0 | +| AVX | NHWC | 5.1 (6275ms) | 4 | 0 | +| SSE3 | NHWC | 2.8 (11428ms) | 4 | 0 | + +#### Inference ResNet-50 + +**Environment** + +* Instance Type: AWS EC2 m4.xlarge +* CPU: Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz (Broadwell) +* Dataset: ImageNet +* TensorFlow Version: 1.2.0 RC2 +* Test Script: [tf_cnn_benchmarks.py](https://github.com/tensorflow/benchmarks/blob/mkl_experiment/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py) + +**Batch Size: 1** + +Command executed for the MKL test: + +```bash +python tf_cnn_benchmarks.py --forward_only=True --device=cpu --mkl=True \ +--kmp_blocktime=0 --nodistortions --model=resnet50 --data_format=NCHW \ +--batch_size=1 --num_inter_threads=1 --num_intra_threads=4 \ +--data_dir= +``` + +| Optimization | Data Format | Images/Sec | Intra threads | Inter Threads | +: : : (step time) : : : +| ------------ | ----------- | ------------ | ------------- | ------------- | +| AVX2 | NHWC | 6.8 (147ms) | 4 | 0 | +| MKL | NCHW | 6.6 (151ms) | 4 | 1 | +| MKL | NHWC | 5.95 (168ms) | 4 | 1 | +| AVX | NHWC | 4.7 (211ms) | 4 | 0 | +| SSE3 | NHWC | 2.7 (370ms) | 4 | 0 | + +**Batch Size: 32** + +Command executed for the MKL test: + +```bash +python tf_cnn_benchmarks.py --forward_only=True --device=cpu --mkl=True \ +--kmp_blocktime=0 --nodistortions --model=resnet50 --data_format=NCHW \ +--batch_size=32 --num_inter_threads=1 --num_intra_threads=4 \ +--data_dir= +``` + +| Optimization | Data Format | Images/Sec | Intra threads | Inter Threads | +: : : (step time) : : : +| ------------ | ----------- | ------------- | ------------- | ------------- | +| MKL | NCHW | 10.24 | 4 | 1 | +: : : (3125ms) : : : +| MKL | NHWC | 8.9 (3595ms) | 4 | 1 | +| AVX2 | NHWC | 7.3 (4383ms) | 4 | 0 | +| AVX | NHWC | 5.1 (6275ms) | 4 | 0 | +| SSE3 | NHWC | 2.8 (11428ms) | 4 | 0 | + +#### Training InceptionV3 + +**Environment** + +* Instance Type: Dedicated AWS EC2 r4.16xlarge (Broadwell) +* CPU: Intel Xeon E5-2686 v4 (Broadwell) Processors +* Dataset: ImageNet +* TensorFlow Version: 1.2.0 RC2 +* Test Script: [tf_cnn_benchmarks.py](https://github.com/tensorflow/benchmarks/blob/mkl_experiment/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py) + +Command executed for MKL test: + +```bash +python tf_cnn_benchmarks.py --device=cpu --mkl=True --kmp_blocktime=0 \ +--nodistortions --model=resnet50 --data_format=NCHW --batch_size=32 \ +--num_inter_threads=2 --num_intra_threads=36 \ +--data_dir= +``` + +Optimization | Data Format | Images/Sec | Intra threads | Inter Threads +------------ | ----------- | ---------- | ------------- | ------------- +MKL | NCHW | 20.8 | 36 | 2 +AVX2 | NHWC | 6.2 | 36 | 0 +AVX | NHWC | 5.7 | 36 | 0 +SSE3 | NHWC | 4.3 | 36 | 0 + +ResNet and [AlexNet](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf) +were also run on this configuration but in an ad hoc manner. There were not +enough runs executed to publish a coherent table of results. The incomplete +results strongly indicated the final result would be similar to the table above +with MKL providing significant 3x+ gains over AVX2. diff --git a/tensorflow/python/client/session_list_devices_test.py b/tensorflow/python/client/session_list_devices_test.py index c1e9e5e48fc..584b1abe55c 100644 --- a/tensorflow/python/client/session_list_devices_test.py +++ b/tensorflow/python/client/session_list_devices_test.py @@ -21,7 +21,10 @@ from __future__ import print_function from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import config_pb2 +from tensorflow.python import pywrap_tensorflow as tf_session from tensorflow.python.client import session +from tensorflow.python.framework import c_api_util +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.platform import googletest @@ -38,6 +41,24 @@ class SessionListDevicesTestMethods(object): [d.name for d in devices]), devices) self.assertGreaterEqual(1, len(devices), devices) + def testInvalidDeviceNumber(self): + opts = tf_session.TF_NewSessionOptions() + with errors.raise_exception_on_not_ok_status() as status: + c_session = tf_session.TF_NewSession( + ops.get_default_graph()._c_graph, opts, status) + raw_device_list = tf_session.TF_SessionListDevices( + c_session, status) + size = tf_session.TF_DeviceListCount(raw_device_list) + # Test that invalid device numbers return -1 rather than a Swig-wrapped + # pointer. + status_no_exception = c_api_util.ScopedTFStatus() + memory = tf_session.TF_DeviceListMemoryBytes( + raw_device_list, size, status_no_exception) + self.assertEqual(memory, -1) + tf_session.TF_DeleteDeviceList(raw_device_list) + with errors.raise_exception_on_not_ok_status() as status: + tf_session.TF_CloseSession(c_session, status) + def testListDevicesGrpcSession(self): server = server_lib.Server.create_local_server() with session.Session(server.target) as sess: diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index fa49e66e87b..9c2ffe1e5cb 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -75,6 +75,11 @@ tensorflow::ImportNumpy(); $result = PyUnicode_FromString($1); } +// Convert TF_DeviceListMemoryBytes and TF_Dim int64_t output to Python integers +%typemap(out) int64_t { + $result = PyInt_FromLong($1); +} + // We use TF_OperationGetControlInputs_wrapper instead of // TF_OperationGetControlInputs %ignore TF_OperationGetControlInputs; diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD index 4bb810a6c8f..c848ee96bc7 100644 --- a/tensorflow/python/eager/BUILD +++ b/tensorflow/python/eager/BUILD @@ -297,7 +297,6 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:gradients", "//tensorflow/python:graph_to_function_def", - "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:util", "//tensorflow/python/eager:context", "//tensorflow/python/eager:core", diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index c6af7df176f..79374d2bb55 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -53,6 +53,7 @@ class _EagerContext(threading.local): self.mode = _default_mode self.scope_name = "" self.recording_summaries = False + self.scalar_cache = {} # TODO(agarwal): rename to EagerContext / EagerRuntime ? @@ -157,6 +158,10 @@ class Context(object): """Returns True if current thread is in EAGER mode.""" return self._eager_context.mode == EAGER_MODE + def scalar_cache(self): + """Per-device cache for scalars.""" + return self._eager_context.scalar_cache + @property def scope_name(self): """Returns scope name for the current thread.""" @@ -245,6 +250,23 @@ class Context(object): # TODO(ashankar): Use TF_DeviceListType to count GPU devices. return len(self._devices) - 1 + def add_function_def(self, fdef): + """Add a function definition to the context. + + Once added, the function (identified by its name) can be executed like any + other operation. + + Args: + fdef: A FunctionDef protocol buffer message. + """ + fdef_string = fdef.SerializeToString() + with errors.raise_exception_on_not_ok_status() as status: + pywrap_tensorflow.TFE_ContextAddFunctionDef( + self._handle, # pylint: disable=protected-access + fdef_string, + len(fdef_string), + status) + def add_post_execution_callback(self, callback): """Add a post-execution callback to the context. @@ -389,3 +411,21 @@ def enable_eager_execution(): global _default_mode assert _default_mode == GRAPH_MODE _default_mode = EAGER_MODE + + +def list_devices(): + """List the names of the available devices. + + Returns: + Names of the available devices, as a `list`. + """ + return context().devices() + + +def num_gpus(): + """Get the number of available GPU devices. + + Returns: + The number of available GPU devices. + """ + return context().num_gpus() diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index 5de396f62c3..670bb5c73cc 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -33,7 +33,7 @@ from tensorflow.python.framework import test_util def truncated_normal(shape): return execute.execute( - 'TruncatedNormal', + b'TruncatedNormal', 1, inputs=[shape], attrs=('dtype', dtypes.float32.as_datatype_enum, 'T', @@ -118,7 +118,7 @@ class TFETest(test_util.TensorFlowTestCase): y = tensor.Tensor(2.) # Add would fail if t2 were not on GPU result = execute.execute( - 'Add', 1, inputs=[x, y], + b'Add', 1, inputs=[x, y], attrs=('T', x.dtype.as_datatype_enum))[0].as_cpu_tensor().numpy() self.assertEqual(3, result) @@ -161,7 +161,7 @@ class TFETest(test_util.TensorFlowTestCase): three = tensor.Tensor(3) five = tensor.Tensor(5) product = execute.execute( - 'Mul', + b'Mul', num_outputs=1, inputs=[three, five], attrs=('T', three.dtype.as_datatype_enum))[0] @@ -171,7 +171,7 @@ class TFETest(test_util.TensorFlowTestCase): # num_outputs provided is 50, but only one output is produced. # That should be okay. product = execute.execute( - 'Mul', + b'Mul', num_outputs=50, inputs=[tensor.Tensor(3), tensor.Tensor(5)], attrs=('T', dtypes.int32.as_datatype_enum))[0] @@ -183,7 +183,7 @@ class TFETest(test_util.TensorFlowTestCase): three = tensor.Tensor([[3.]]).as_gpu_tensor() five = tensor.Tensor([[5.]]).as_gpu_tensor() product = execute.execute( - 'MatMul', + b'MatMul', num_outputs=1, inputs=[three, five], attrs=('transpose_a', False, 'transpose_b', False, 'T', @@ -192,7 +192,7 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteStringAttr(self): checked_three = execute.execute( - 'CheckNumerics', + b'CheckNumerics', num_outputs=1, inputs=[tensor.Tensor(3.)], attrs=('message', 'just checking', 'T', @@ -202,14 +202,14 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteStringAttrBadValue(self): with self.assertRaises(errors.InvalidArgumentError): _ = execute.execute( - 'CheckNumerics', + b'CheckNumerics', num_outputs=1, inputs=[tensor.Tensor(3.)], attrs=('message', 1, 'T', dtypes.float32.as_datatype_enum)) def testExecuteFloatAttr(self): almost_equal = execute.execute( - 'ApproximateEqual', + b'ApproximateEqual', num_outputs=1, inputs=[tensor.Tensor(3.0), tensor.Tensor(2.9)], attrs=('tolerance', 0.3, 'T', dtypes.float32.as_datatype_enum))[0] @@ -218,14 +218,14 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteFloatAttrBadValue(self): with self.assertRaises(errors.InvalidArgumentError): _ = execute.execute( - 'ApproximateEqual', + b'ApproximateEqual', num_outputs=1, inputs=[tensor.Tensor(3.0), tensor.Tensor(2.9)], attrs=('tolerance', '0.3', 'T', dtypes.float32.as_datatype_enum)) def testExecuteIntAttr(self): total = execute.execute( - 'AddN', + b'AddN', num_outputs=1, inputs=[tensor.Tensor(3), tensor.Tensor(4)], attrs=('T', dtypes.int32.as_datatype_enum, 'N', 2))[0] @@ -234,7 +234,7 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteIntAttrBadValue(self): with self.assertRaises(errors.InvalidArgumentError): _ = execute.execute( - 'AddN', + b'AddN', num_outputs=1, inputs=[tensor.Tensor(3), tensor.Tensor(4)], attrs=('T', dtypes.int32.as_datatype_enum, 'N', '2')) @@ -242,7 +242,7 @@ class TFETest(test_util.TensorFlowTestCase): # Looks like we don't have an existing op with list(bool) attrs. def testExecuteBoolAttr(self): product = execute.execute( - 'MatMul', + b'MatMul', num_outputs=1, inputs=[tensor.Tensor([[3]]), tensor.Tensor([[5]])], @@ -252,7 +252,7 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteShapeAttr(self): execute.execute( - 'VarHandleOp', + b'VarHandleOp', num_outputs=1, inputs=[], attrs=('shape', [1, 2], 'dtype', dtypes.int32.as_datatype_enum, @@ -261,7 +261,7 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteShapeAttrBadValue(self): with self.assertRaises(errors.InvalidArgumentError): execute.execute( - 'VarHandleOp', + b'VarHandleOp', num_outputs=1, inputs=[], attrs=('shape', 1, 'dtype', dtypes.int32.as_datatype_enum, @@ -269,7 +269,7 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteListStringAttr(self): execute.execute( - 'TensorSummary', + b'TensorSummary', num_outputs=1, inputs=[tensor.Tensor(3.0)], attrs=('T', dtypes.float32.as_datatype_enum, 'description', @@ -279,7 +279,7 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteListStringAttrBadValue(self): with self.assertRaises(errors.InvalidArgumentError): execute.execute( - 'TensorSummary', + b'TensorSummary', num_outputs=1, inputs=[tensor.Tensor(3.0)], attrs=('T', dtypes.float32.as_datatype_enum, 'description', '', @@ -288,7 +288,7 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteListStringAttrBadListValue(self): with self.assertRaises(errors.InvalidArgumentError): execute.execute( - 'TensorSummary', + b'TensorSummary', num_outputs=1, inputs=[tensor.Tensor(3.0)], attrs=('T', dtypes.float32.as_datatype_enum, 'description', '', @@ -296,7 +296,7 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteListFloatAttr(self): b = execute.execute( - 'Bucketize', + b'Bucketize', num_outputs=1, inputs=[tensor.Tensor([3.0, 5.0, 7.0])], attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', [4.0, @@ -306,7 +306,7 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteListFloatAttrBadValue(self): with self.assertRaises(errors.InvalidArgumentError): execute.execute( - 'Bucketize', + b'Bucketize', num_outputs=1, inputs=[tensor.Tensor([3.0, 5.0, 7.0])], attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', 4.0)) @@ -314,7 +314,7 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteListFloatAttrBadListValue(self): with self.assertRaises(errors.InvalidArgumentError): execute.execute( - 'Bucketize', + b'Bucketize', num_outputs=1, inputs=[tensor.Tensor([3.0, 5.0, 7.0])], attrs=('T', dtypes.float32.as_datatype_enum, 'boundaries', @@ -322,7 +322,7 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteListIntAttr(self): b = execute.execute( - 'Squeeze', + b'Squeeze', num_outputs=1, inputs=[tensor.Tensor([[[3.0]]])], attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', [0, 2]))[0] @@ -331,7 +331,7 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteListIntAttrBadValue(self): with self.assertRaises(errors.InvalidArgumentError): execute.execute( - 'Squeeze', + b'Squeeze', num_outputs=1, inputs=[tensor.Tensor([[[3.0]]])], attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', 0)) @@ -339,7 +339,7 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteListIntAttrBadListValue(self): with self.assertRaises(errors.InvalidArgumentError): execute.execute( - 'Squeeze', + b'Squeeze', num_outputs=1, inputs=[tensor.Tensor([[[3.0]]])], attrs=('T', dtypes.float32.as_datatype_enum, 'squeeze_dims', @@ -347,7 +347,7 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteListTypeListShapeAttr(self): execute.execute( - 'Barrier', + b'Barrier', num_outputs=1, inputs=[], attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes', @@ -356,7 +356,7 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteListTypeAttrBadValue(self): with self.assertRaises(errors.InvalidArgumentError): execute.execute( - 'Barrier', + b'Barrier', num_outputs=1, inputs=[], attrs=('component_types', dtypes.float64.as_datatype_enum, 'shapes', @@ -365,7 +365,7 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteListTypeAttrBadListValue(self): with self.assertRaises(errors.InvalidArgumentError): execute.execute( - 'Barrier', + b'Barrier', num_outputs=1, inputs=[], attrs=('component_types', '1', 'shapes', [[1, 2]], 'capacity', -1, @@ -374,7 +374,7 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteListShapeAttrBadValue(self): with self.assertRaises(errors.InvalidArgumentError): execute.execute( - 'Barrier', + b'Barrier', num_outputs=1, inputs=[], attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes', @@ -383,7 +383,7 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteListShapeAttrBadListValue(self): with self.assertRaises(errors.InvalidArgumentError): execute.execute( - 'Barrier', + b'Barrier', num_outputs=1, inputs=[], attrs=('component_types', [dtypes.float64.as_datatype_enum], 'shapes', @@ -393,7 +393,7 @@ class TFETest(test_util.TensorFlowTestCase): split_dim = 1 value = [[0, 1, 2], [3, 4, 5]] x1, x2, x3 = execute.execute( - 'Split', + b'Split', num_outputs=3, inputs=[tensor.Tensor(split_dim), tensor.Tensor(value)], @@ -405,18 +405,18 @@ class TFETest(test_util.TensorFlowTestCase): def testExecuteBadNumOutputsArgument(self): with self.assertRaises(TypeError): execute.execute( - 'Relu', [], + b'Relu', [], inputs=[tensor.Tensor(3.0)], attrs=('T', dtypes.float32.as_datatype_enum)) def testExecuteUnknownOp(self): with self.assertRaises(errors.NotFoundError): - execute.execute('BlahBlahBlah', num_outputs=1, inputs=[], attrs=None) + execute.execute(b'BlahBlahBlah', num_outputs=1, inputs=[], attrs=None) def testExecuteUnknownAttr(self): with self.assertRaises(errors.InvalidArgumentError): execute.execute( - 'Identity', + b'Identity', num_outputs=1, inputs=[tensor.Tensor(3)], attrs=('T', dtypes.int32.as_datatype_enum, 'unknown_attr', 'blah')) @@ -425,7 +425,7 @@ class TFETest(test_util.TensorFlowTestCase): def add(x, y): return execute.execute( - 'Add', + b'Add', num_outputs=1, inputs=[x, y], attrs=('T', dtypes.int32.as_datatype_enum))[0] @@ -447,7 +447,7 @@ class TFETest(test_util.TensorFlowTestCase): y = truncated_normal(shape) # Add would fail if x and y were not on the same device. execute.execute( - 'Add', 1, inputs=[x, y], attrs=('T', x.dtype.as_datatype_enum)) + b'Add', 1, inputs=[x, y], attrs=('T', x.dtype.as_datatype_enum)) def testInvalidDevice(self): with self.assertRaises(ValueError): diff --git a/tensorflow/python/eager/execute.py b/tensorflow/python/eager/execute.py index 2b5a76ca121..722094ad7f4 100644 --- a/tensorflow/python/eager/execute.py +++ b/tensorflow/python/eager/execute.py @@ -63,15 +63,14 @@ def execute(op_name, num_outputs, inputs, attrs=None, name=None): device_name = ctx.device_name try: outh = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name, - str(op_name), input_handles, attrs, + op_name, input_handles, attrs, num_outputs) - # pylint: enable=protected-access - except core._NotOkStatusException as e: # pylint: disable=protected-access + except core._NotOkStatusException as e: if name is not None: message = e.message + " name: " + name else: message = e.message - raise core._status_to_exception(e.code, message) # pylint: disable=protected-access + raise core._status_to_exception(e.code, message) # pylint: enable=protected-access tensors = [tensor._tensor_from_handle(x) for x in outh] # pylint: disable=protected-access diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 227520eea8a..480fb19195f 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -26,14 +26,13 @@ import threading from autograd import core as ag_core import numpy as np -from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import execute from tensorflow.python.eager import tape from tensorflow.python.eager import tensor from tensorflow.python.eager.graph_only_ops import graph_placeholder +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors from tensorflow.python.framework import graph_to_function_def from tensorflow.python.framework import ops from tensorflow.python.ops import gradients_impl @@ -59,10 +58,8 @@ def capture_tensors(captures): _scoped_captures.tensors = old -def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False): - """Captures a tfe Tensor while building a graph mode function. - - Creates a placeholder to pass the tensor as an argument. +def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False): + """Captures a Tensor while building a graph mode function. Arguments: value: A tfe.Tensor object @@ -71,19 +68,17 @@ def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False): as_ref: Ignored (required by register_tensor_conversion_function). Returns: - A placeholder which will, at runtime, have the value of this tensor. - - Raises: - ValueError: if called outside a defun context. + Returns a constant (the current value of the tensor) if capturing + is not enabled. A placeholder which will have the value of the + tensor at runtime otherwise. """ if context.in_eager_mode(): return value _ = as_ref tensor_map = _scoped_captures.tensors if tensor_map is None: - raise ValueError( - "Trying to use tfe.Tensor objects in a graph outside graph mode. " - "To build a graph use tfe.defun or tfe.make_template.") + # Capturing is not enabled. + return constant_op.constant(value.numpy()) captured_value = tensor_map.get(ops.tensor_id(value), None) if captured_value is None: captured_value = graph_placeholder( @@ -100,7 +95,7 @@ def _convert_to_graph_constant(value, dtype=None, name=None, as_ref=False): # Note that we register this at a higher priority than ops.Tensor since we want # to handle subclass specific conversion before a superclass conversion. ops.register_tensor_conversion_function( - tensor.Tensor, _convert_to_graph_constant, priority=-1) + tensor.Tensor, _convert_to_graph_tensor, priority=-1) class _CapturingContext(object): @@ -261,7 +256,7 @@ class _GraphModeFunction(object): outputs[i].set_shape(s) else: outputs = execute.execute( - signature.name, + str(signature.name), num_outputs=len(signature.output_arg), inputs=all_args) real_outputs = outputs[:len(self._returns)] @@ -321,7 +316,7 @@ class _GraphModeFunction(object): for x in tensor_inputs ] result = execute.execute( - self._func_name, + str(self._func_name), num_outputs=self._num_outputs, inputs=tensor_inputs + self._extra_inputs) @@ -438,20 +433,10 @@ def _cache_key(x): return x -def register_function_def(fdef): - fdef_string = fdef.SerializeToString() - with errors.raise_exception_on_not_ok_status() as status: - pywrap_tensorflow.TFE_ContextAddFunctionDef( - context.get_default_context()._handle, # pylint: disable=protected-access - fdef_string, - len(fdef_string), - status) - - def _register_with_name(name, fdef): """Registers the function `fdef` with the name `name`.""" fdef.signature.name = name - register_function_def(fdef) + context.context().add_function_def(fdef) # TODO(apassos): better error messages for non-hashable arguments. diff --git a/tensorflow/python/eager/python_eager_op_gen.cc b/tensorflow/python/eager/python_eager_op_gen.cc index a526856794d..c7eb405f9c8 100644 --- a/tensorflow/python/eager/python_eager_op_gen.cc +++ b/tensorflow/python/eager/python_eager_op_gen.cc @@ -650,7 +650,7 @@ void GenEagerPythonOp::AddEagerAttrs() { void GenEagerPythonOp::AddEagerExecute(const string& num_outputs_expr) { const string return_prefix = " _result = _execute.execute("; const string return_args = - strings::StrCat("\"", op_def_.name(), "\", ", num_outputs_expr, + strings::StrCat("b\"", op_def_.name(), "\", ", num_outputs_expr, ", inputs=_inputs_flat, attrs=_attrs, name=name)"); strings::StrAppend(&result_, // Wrap the arguments, and indent to the (. diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD index 167f9b10543..d060faa4efd 100644 --- a/tensorflow/python/estimator/BUILD +++ b/tensorflow/python/estimator/BUILD @@ -22,6 +22,7 @@ py_library( ":model_fn", ":parsing_utils", ":run_config", + ":training", "//tensorflow/python:util", ], ) @@ -70,6 +71,27 @@ py_test( ], ) +py_library( + name = "training", + srcs = ["training.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:training", + "@six_archive//:six", + ], +) + +py_test( + name = "training_test", + size = "small", + srcs = ["training_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":training", + "//tensorflow/python:client_testlib", + ], +) + py_library( name = "run_config", srcs = ["run_config.py"], diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 1554c271541..b15c89aaae2 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -211,17 +211,17 @@ class Estimator(object): hooks: List of `SessionRunHook` subclass instances. Used for callbacks inside the training loop. steps: Number of steps for which to train model. If `None`, train forever - or train until input_fn generates the `OutOfRange` or `StopIteration` - error. 'steps' works incrementally. If you call two times - train(steps=10) then training occurs in total 20 steps. If `OutOfRange` - or `StopIteration` error occurs in the middle, training stops before 20 - steps. If you don't want to have incremental behavior please set - `max_steps` instead. If set, `max_steps` must be `None`. + or train until input_fn generates the `OutOfRange` error or + `StopIteration` exception. 'steps' works incrementally. If you call two + times train(steps=10) then training occurs in total 20 steps. If + `OutOfRange` or `StopIteration` occurs in the middle, training stops + before 20 steps. If you don't want to have incremental behavior please + set `max_steps` instead. If set, `max_steps` must be `None`. max_steps: Number of total steps for which to train model. If `None`, - train forever or train until input_fn generates the `OutOfRange` or - `StopIteration` error. If set, `steps` must be `None`. If `OutOfRange` - or `StopIteration` error occurs in the middle, training stops before - `max_steps` steps. + train forever or train until input_fn generates the `OutOfRange` error + or `StopIteration` exception. If set, `steps` must be `None`. If + `OutOfRange` or `StopIteration` occurs in the middle, training stops + before `max_steps` steps. Two calls to `train(steps=100)` means 200 training iterations. On the other hand, two calls to `train(max_steps=100)` means diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py new file mode 100644 index 00000000000..9a8a0db66ee --- /dev/null +++ b/tensorflow/python/estimator/training.py @@ -0,0 +1,179 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Classes and functions related to train_and_evaluate.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import six + +from tensorflow.python.training import session_run_hook + + +def _validate_input_fn(input_fn): + """Validates the `input_fn`.""" + if not callable(input_fn): + raise TypeError( + '`input_fn` must be callable, given: {}'.format(input_fn)) + + +def _validate_hooks(hooks): + """Validates the `hooks`.""" + hooks = tuple(hooks or []) + for hook in hooks: + if not isinstance(hook, session_run_hook.SessionRunHook): + raise TypeError( + 'All hooks must be `SessionRunHook` instances, given: {}'.format( + hook)) + return hooks + + +class TrainSpec( + collections.namedtuple('TrainSpec', ['input_fn', 'max_steps', 'hooks'])): + """Objects passed to `train_and_evaluate`. + + `TrainSpec` fully defines the objects to be run by `Estimator.train`. + """ + + def __new__(cls, + input_fn, + max_steps=None, + hooks=None): + """Creates a validated `TrainSpec` instance. + + Args: + input_fn: Training input function returning a tuple of: + features - `Tensor` or dictionary of string feature name to `Tensor`. + labels - `Tensor` or dictionary of `Tensor` with labels. + max_steps: Int. Number of total steps for which to train model. If `None`, + train forever or train until `input_fn` generates the `OutOfRange` error + or `StopIteration` exception. See `Estimator.train` for details. + hooks: Iterable of `tf.train.SessionRunHook` objects to run + on all workers (including chief) during training. + + Returns: + A validated `TrainSpec` object. + + Raises: + ValueError: If validation fails. + TypeError: If any of the arguments is not the expected type. + """ + # Validate input_fn. + _validate_input_fn(input_fn) + + # Validate max_steps. + if max_steps is not None and max_steps <= 0: + raise ValueError( + 'Must specify max_steps > 0, given: {}'.format(max_steps)) + + # Validate hooks. + hooks = _validate_hooks(hooks) + + return super(TrainSpec, cls).__new__( + cls, + input_fn=input_fn, + max_steps=max_steps, + hooks=hooks) + + +class EvalSpec( + collections.namedtuple('EvalSpec', [ + 'input_fn', 'steps', 'name', 'hooks', 'export_strategies', + 'delay_secs', 'throttle_secs' + ])): + """Objects passed to `train_and_evaluate`. + + `EvalSpec` fully defines the objects to be run by `Estimator.evaluate` and + `Estimator.export_savedmodel`. + """ + + def __new__(cls, + input_fn, + steps=100, + name=None, + hooks=None, + export_strategies=None, + delay_secs=120, + throttle_secs=600): + """Creates a validated `EvalSpec` instance. + + Args: + input_fn: Training input function returning a tuple of: + features - `Tensor` or dictionary of string feature name to `Tensor`. + labels - `Tensor` or dictionary of `Tensor` with labels. + steps: Int. Number of total steps for which to train model. If `None`, + train forever or train until `input_fn` generates the `OutOfRange` error + or `StopIteration` exception. See `Estimator.train` for details. + name: String. Name of the evaluation if user needs to run multiple + evaluations on different data sets. Metrics for different evaluations + are saved in separate folders, and appear separately in tensorboard. + hooks: Iterable of `tf.train.SessionRunHook` objects to run + on all workers (including chief) during training. + export_strategies: Iterable of `ExportStrategy`s, or a single one, or + `None`. `export_strategies` will be invoked after each evaluation. + delay_secs: Int. Start evaluating after waiting for this many seconds. + throttle_secs: Int. Do not re-evaluate unless the last evaluation was + started at least this many seconds ago. Of course, evaluation does not + occur if no new checkpoint is available, hence, this is the minimum. + + Returns: + A validated `TrainSpec` object. + + Raises: + ValueError: If validation fails. + TypeError: If any of the arguments is not the expected type. + """ + # Validate input_fn. + _validate_input_fn(input_fn) + + # Validate steps. + if steps is not None and steps <= 0: + raise ValueError('Must specify steps > 0, given: {}'.format(steps)) + + # Validate name. + if name is not None and not isinstance(name, six.string_types): + raise TypeError('`name` must be string, given: {}'.format(name)) + + # Validate hooks. + hooks = _validate_hooks(hooks) + + # Validate export_strategies. + export_strategies = tuple(export_strategies or []) + # TODO(b/65169058): Validate export_strategies once `ExportStratey` defined. + + # Validate delay_secs. + if delay_secs < 0: + raise ValueError( + 'Must specify delay_secs >= 0, given: {}'.format(delay_secs)) + + # Validate throttle_secs. + if throttle_secs < 0: + raise ValueError( + 'Must specify throttle_secs >= 0, given: {}'.format(throttle_secs)) + + return super(EvalSpec, cls).__new__( + cls, + input_fn=input_fn, + steps=steps, + name=name, + hooks=hooks, + export_strategies=export_strategies, + delay_secs=delay_secs, + throttle_secs=throttle_secs) + diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py new file mode 100644 index 00000000000..654a1659b29 --- /dev/null +++ b/tensorflow/python/estimator/training_test.py @@ -0,0 +1,133 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for training.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.estimator import training +from tensorflow.python.platform import test +from tensorflow.python.training import session_run_hook + +_DEFAULT_EVAL_STEPS = 100 +_DEFAULT_EVAL_DELAY_SECS = 120 +_DEFAULT_EVAL_THROTTLE_SECS = 600 +_INVALID_INPUT_FN_MSG = '`input_fn` must be callable' +_INVALID_HOOK_MSG = 'All hooks must be `SessionRunHook` instances' +_INVALID_MAX_STEPS_MSG = 'Must specify max_steps > 0' +_INVALID_STEPS_MSG = 'Must specify steps > 0' +_INVALID_NAME_MSG = '`name` must be string' +_INVALID_EVAL_DELAY_SECS_MSG = 'Must specify delay_secs >= 0' +_INVALID_EVAL_THROTTLE_SECS_MSG = 'Must specify throttle_secs >= 0' + + +class _FakeHook(session_run_hook.SessionRunHook): + """Fake implementation of `SessionRunHook`.""" + + +class _InvalidHook(object): + """Invalid hook (not a subclass of `SessionRunHook`).""" + + +class TrainSpecTest(test.TestCase): + """Tests TrainSpec.""" + + def testRequiredArgumentsSet(self): + """Tests that no errors are raised when all required arguments are set.""" + spec = training.TrainSpec(input_fn=lambda: 1) + self.assertEqual(1, spec.input_fn()) + self.assertIsNone(spec.max_steps) + self.assertEqual(0, len(spec.hooks)) + + def testAllArgumentsSet(self): + """Tests that no errors are raised when all arguments are set.""" + hooks = [_FakeHook()] + spec = training.TrainSpec(input_fn=lambda: 1, max_steps=2, hooks=hooks) + self.assertEqual(1, spec.input_fn()) + self.assertEqual(2, spec.max_steps) + self.assertEqual(tuple(hooks), spec.hooks) + + def testInvalidInputFn(self): + with self.assertRaisesRegexp(TypeError, _INVALID_INPUT_FN_MSG): + training.TrainSpec(input_fn='invalid') + + def testInvalidMaxStep(self): + with self.assertRaisesRegexp(ValueError, _INVALID_MAX_STEPS_MSG): + training.TrainSpec(input_fn=lambda: 1, max_steps=0) + + def testInvalidHook(self): + with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG): + training.TrainSpec(input_fn=lambda: 1, hooks=[_InvalidHook()]) + + +class EvalSpecTest(test.TestCase): + """Tests EvalSpec.""" + + def testRequiredArgumentsSet(self): + """Tests that no errors are raised when all required arguments are set.""" + spec = training.EvalSpec(input_fn=lambda: 1) + self.assertEqual(1, spec.input_fn()) + self.assertEqual(_DEFAULT_EVAL_STEPS, spec.steps) + self.assertIsNone(spec.name) + self.assertEqual(0, len(spec.hooks)) + self.assertEqual(0, len(spec.export_strategies)) + self.assertEqual(_DEFAULT_EVAL_DELAY_SECS, spec.delay_secs) + self.assertEqual(_DEFAULT_EVAL_THROTTLE_SECS, spec.throttle_secs) + + def testAllArgumentsSet(self): + """Tests that no errors are raised when all arguments are set.""" + hooks = [_FakeHook()] + + # TODO(b/65169058): Replace the export_strategies with valid instances. + spec = training.EvalSpec(input_fn=lambda: 1, steps=2, name='name', + hooks=hooks, export_strategies=hooks, + delay_secs=3, throttle_secs=4) + self.assertEqual(1, spec.input_fn()) + self.assertEqual(2, spec.steps) + self.assertEqual('name', spec.name) + self.assertEqual(tuple(hooks), spec.hooks) + self.assertEqual(tuple(hooks), spec.export_strategies) + self.assertEqual(3, spec.delay_secs) + self.assertEqual(4, spec.throttle_secs) + + def testInvalidInputFn(self): + with self.assertRaisesRegexp(TypeError, _INVALID_INPUT_FN_MSG): + training.EvalSpec(input_fn='invalid') + + def testInvalidMaxStep(self): + with self.assertRaisesRegexp(ValueError, _INVALID_STEPS_MSG): + training.EvalSpec(input_fn=lambda: 1, steps=0) + + def testInvalidName(self): + with self.assertRaisesRegexp(TypeError, _INVALID_NAME_MSG): + training.EvalSpec(input_fn=lambda: 1, name=123) + + def testInvalidHook(self): + with self.assertRaisesRegexp(TypeError, _INVALID_HOOK_MSG): + training.EvalSpec(input_fn=lambda: 1, hooks=[_InvalidHook()]) + + def testInvalidDelaySecs(self): + with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_DELAY_SECS_MSG): + training.EvalSpec(input_fn=lambda: 1, delay_secs=-1) + + def testInvalidThrottleSecs(self): + with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_THROTTLE_SECS_MSG): + training.EvalSpec(input_fn=lambda: 1, throttle_secs=-1) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/framework/constant_op.py b/tensorflow/python/framework/constant_op.py index 9176720e864..a859645950d 100644 --- a/tensorflow/python/framework/constant_op.py +++ b/tensorflow/python/framework/constant_op.py @@ -60,7 +60,7 @@ def _eager_reshape(tensor, shape): attr_tshape = attr_tshape.as_datatype_enum inputs_flat = [tensor, shape] attrs = ("T", attr_t, "Tshape", attr_tshape) - result, = execute.execute("Reshape", 1, inputs=inputs_flat, attrs=attrs) + result, = execute.execute(b"Reshape", 1, inputs=inputs_flat, attrs=attrs) return result @@ -70,38 +70,29 @@ def _eager_fill(dims, value): dims = convert_to_eager_tensor(dims, dtypes.int32) inputs_flat = [dims, value] attrs = ("T", attr_t) - result, = execute.execute("Fill", 1, inputs=inputs_flat, attrs=attrs) + result, = execute.execute(b"Fill", 1, inputs=inputs_flat, attrs=attrs) return result -# Rely on the GIL for thread-safety. -_scalar_cache = {} - - def convert_to_eager_tensor(t, dtype=None): """Converts the given `value` to an `EagerTensor`.""" if isinstance(ag_core.getval(t), ops.EagerTensor): if dtype is not None and t.dtype != dtype: raise TypeError("Expected tensor with type %r not %r" % (dtype, t.dtype)) return t - # Handle converting ResourceVariable to Tensor. - # TODO(josh11b): get rid of this explicit ugly conversion once we have a more - # general scheme in place. - try: - return t._dense_var_to_tensor(dtype=dtype, as_ref=False) # pylint: disable=protected-access - except AttributeError: - pass if isinstance(t, (int, float)): # Use a scalar cache. This will put each scalar of each type only once on # each device. Scalars don't use much device memory but copying scalars can # trigger memcpys which are slow. - device = context.context().device_name + ctx = context.context() + device = ctx.device_name cache_key = device, t, dtype, type(t) - tensor = _scalar_cache.get(cache_key, None) + scalar_cache = ctx.scalar_cache() + tensor = scalar_cache.get(cache_key, None) if tensor is not None: return tensor value = ops.EagerTensor(t, dtype=dtype) - _scalar_cache[cache_key] = value + scalar_cache[cache_key] = value return value return ops.EagerTensor(t, dtype=dtype) diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index 9cf222a63ab..c6a7d0833e5 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -208,6 +208,71 @@ def NHWCToNCHW(input_tensor): return [input_tensor[a] for a in new_axes[ndims]] +def NHWCToNCHW_VECT_C(input_shape_or_tensor): + """Transforms the input from the NHWC layout to NCHW_VECT_C layout. + + Note: Does not include quantization or type conversion steps, which should + be applied afterwards. + + Args: + input_shape_or_tensor: a 4- or 5-D tensor, or an array representing shape + + Returns: + tensor or shape array transformed into NCHW_VECT_C + + Raises: + ValueError: if last dimension of `input_shape_or_tensor` is not evenly + divisible by 4. + """ + permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]} + is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) + temp_shape = (input_shape_or_tensor.shape.as_list() + if is_tensor else input_shape_or_tensor) + if temp_shape[-1] % 4 != 0: + raise ValueError( + "Last dimension of input must be evenly divisible by 4 to convert to " + "NCHW_VECT_C.") + temp_shape[-1] //= 4 + temp_shape.append(4) + permutation = permutations[len(temp_shape)] + if is_tensor: + t = array_ops.reshape(input_shape_or_tensor, temp_shape) + return array_ops.transpose(t, permutation) + else: + return [temp_shape[a] for a in permutation] + + +def NCHW_VECT_CToNHWC(input_shape_or_tensor): + """Transforms the input from the NCHW_VECT_C layout to NHWC layout. + + Note: Does not include de-quantization or type conversion steps, which should + be applied beforehand. + + Args: + input_shape_or_tensor: a 5- or 6-D tensor, or an array representing shape + + Returns: + tensor or shape array transformed into NHWC + + Raises: + ValueError: if last dimension of `input_shape_or_tensor` is not 4. + """ + permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]} + is_tensor = isinstance(input_shape_or_tensor, ops.Tensor) + input_shape = (input_shape_or_tensor.shape.as_list() + if is_tensor else input_shape_or_tensor) + if input_shape[-1] != 4: + raise ValueError("Last dimension of NCHW_VECT_C must be 4.") + permutation = permutations[len(input_shape)] + nhwc_shape = [input_shape[a] for a in permutation[:-1]] + nhwc_shape[-1] *= input_shape[-1] + if is_tensor: + t = array_ops.transpose(input_shape_or_tensor, permutation) + return array_ops.reshape(t, nhwc_shape) + else: + return nhwc_shape + + def NCHWToNHWC(input_tensor): """Converts the input from the NCHW format to NHWC. @@ -392,7 +457,7 @@ class TensorFlowTestCase(googletest.TestCase): self._cached_session = None def setUp(self): - logging.info("SET UP: %s" % str(self)) + logging.info("SET UP: %s", str(self)) self._ClearCachedSession() random.seed(random_seed.DEFAULT_GRAPH_SEED) np.random.seed(random_seed.DEFAULT_GRAPH_SEED) @@ -407,7 +472,7 @@ class TensorFlowTestCase(googletest.TestCase): ops.get_default_graph().seed = random_seed.DEFAULT_GRAPH_SEED def tearDown(self): - logging.info("TEAR DOWN: %s" % str(self)) + logging.info("TEAR DOWN: %s", str(self)) for thread in self._threads: self.assertFalse(thread.is_alive(), "A checkedThread did not terminate") diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index d9c5f3bce99..501382901c7 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -528,7 +528,6 @@ tf_py_test( "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", ], - tags = ["nomsan"], # fails in msan from numpy calls ) tf_py_test( @@ -2677,7 +2676,6 @@ cuda_py_test( "//tensorflow/python:math_ops", ], shard_count = 20, - tags = ["nomsan"], # fails in msan from numpy calls ) cuda_py_test( @@ -2693,7 +2691,6 @@ cuda_py_test( "//tensorflow/python:math_ops", ], shard_count = 20, - tags = ["nomsan"], # fails in msan from numpy calls ) cuda_py_test( diff --git a/tensorflow/python/kernel_tests/aggregate_ops_test.py b/tensorflow/python/kernel_tests/aggregate_ops_test.py index f56917f7e9b..0a08c01dad3 100644 --- a/tensorflow/python/kernel_tests/aggregate_ops_test.py +++ b/tensorflow/python/kernel_tests/aggregate_ops_test.py @@ -20,8 +20,12 @@ from __future__ import print_function import numpy as np +from tensorflow.core.framework import tensor_pb2 +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops import logging_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -74,6 +78,42 @@ class AddNTest(test.TestCase): tol = 5e-3 if dtype == dtypes.float16 else 5e-7 self.assertAllClose(expected, actual, rtol=tol, atol=tol) + def testVariant(self): + + def create_constant_variant(value): + return constant_op.constant( + tensor_pb2.TensorProto( + dtype=dtypes.variant.as_datatype_enum, + tensor_shape=tensor_shape.TensorShape([]).as_proto(), + variant_val=[ + tensor_pb2.VariantTensorDataProto( + # Match registration in variant_op_registry.cc + type_name=b"int", + metadata=np.array(value, dtype=np.int32).tobytes()) + ])) + + # TODO(ebrevdo): Re-enable use_gpu=True once non-DMA Variant + # copying between CPU and GPU is supported. + with self.test_session(use_gpu=False): + variant_const_3 = create_constant_variant(3) + variant_const_4 = create_constant_variant(4) + variant_const_5 = create_constant_variant(5) + # 3 + 3 + 5 + 4 = 15. + result = math_ops.add_n((variant_const_3, variant_const_3, + variant_const_5, variant_const_4)) + + # Smoke test -- ensure this executes without trouble. + # Right now, non-numpy-compatible objects cannot be returned from a + # session.run call; similarly, objects that can't be converted to + # native numpy types cannot be passed to ops.convert_to_tensor. + # For now, run the test and examine the output to see that the result is + # equal to 15. + result_op = logging_ops.Print( + result, [variant_const_3, variant_const_4, variant_const_5, result], + message=("Variants stored an int: c(3), c(4), c(5), " + "add_n(c(3), c(3), c(5), c(4)): ")).op + result_op.run() + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py index a7bedc7199c..9ee7c0c5611 100644 --- a/tensorflow/python/kernel_tests/functional_ops_test.py +++ b/tensorflow/python/kernel_tests/functional_ops_test.py @@ -500,6 +500,54 @@ class FunctionalOpsTest(test.TestCase): mul = sess.run(remote_op) self.assertEqual(mul, [6]) + def testRemoteFunctionCPUGPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + @function.Defun(dtypes.float32, dtypes.float32) + def _remote_fn(a, b): + return math_ops.multiply(a, b) + + with ops.device("/job:localhost/replica:0/task:0/cpu:0"): + a = variables.Variable(2, dtype=dtypes.float32) + b = variables.Variable(3, dtype=dtypes.float32) + + with ops.device("/job:localhost/replica:0/task:0/cpu:0"): + remote_op = functional_ops.remote_call( + args=[a, b], + Tout=[dtypes.float32], + f=_remote_fn, + target="/job:localhost/replica:0/task:0/device:GPU:0")[0] + 3.0 + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + mul = sess.run(remote_op) + self.assertEqual(mul, 9.0) + + def testRemoteFunctionGPUCPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + @function.Defun(dtypes.float32, dtypes.float32) + def _remote_fn(a, b): + return math_ops.multiply(a, b) + + with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): + a = variables.Variable(2, dtype=dtypes.float32) + b = variables.Variable(3, dtype=dtypes.float32) + + with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): + remote_op = functional_ops.remote_call( + args=[a, b], + Tout=[dtypes.float32], + f=_remote_fn, + target="/job:localhost/replica:0/task:0/cpu:0")[0] + 3.0 + + with self.test_session() as sess: + sess.run(variables.global_variables_initializer()) + mul = sess.run(remote_op) + self.assertEqual(mul, 9.0) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py index 9eb1fea8037..c699d50c02d 100644 --- a/tensorflow/python/kernel_tests/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/pooling_ops_test.py @@ -25,25 +25,40 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import nn_ops -from tensorflow.python.framework import ops import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging -def GetTestConfigs(): +def GetTestConfigs(include_nchw_vect_c=False): """Get all the valid tests configs to run. + Args: + include_nchw_vect_c: Whether to include NCHW_VECT_C in the test configs. + Returns: all the valid test configs as tuples of data_format and use_gpu. """ test_configs = [("NHWC", False), ("NHWC", True)] - if test.is_gpu_available(cuda_only=True): - # "NCHW" format is currently supported exclusively on CUDA GPUs. - test_configs += [("NCHW", True)] + if not test.is_gpu_available(cuda_only=True): + tf_logging.info("NCHW and NCHW_VECT_C tests skipped because not run with " + "--config=cuda or no GPUs available.") + return test_configs + # "NCHW" format is currently supported exclusively on CUDA GPUs. + test_configs += [("NCHW", True)] + if include_nchw_vect_c: + if test.is_gpu_available( + cuda_only=True, min_cuda_compute_capability=(6, 1)): + test_configs += [("NCHW_VECT_C", True)] + else: + tf_logging.info("NCHW_VECT_C test skipped because no GPUs with " + "compute capability >= 6.1 are available.") + return test_configs @@ -95,16 +110,32 @@ class PoolingTest(test.TestCase): total_size = 1 for s in input_sizes: total_size *= s + if v2 and data_format != "NHWC": + tf_logging.info("v2 not supported for %s", data_format) + return + if data_format == "NCHW_VECT_C": + if data_type != dtypes.float32: + tf_logging.info("quantization to qint8 not implemented for %r", + data_type) + return + if input_sizes[-1] % 4 != 0: + tf_logging.info("Skipping test for depth %d", input_sizes[-1]) + return + tf_logging.info("Running %s test. %r %r %d %r %r %r", data_format, v2, + input_sizes, total_size, pool_func, ksize, strides) # Initializes the input tensor with array containing incrementing - # numbers from 1. - x = [f * 1.0 for f in range(1, total_size + 1)] + # numbers from 1, wrapping round to -127 after 127 to support int8. + x = [((f + 128) % 255) - 127 for f in range(total_size)] with self.test_session(use_gpu=use_gpu): t = constant_op.constant(x, shape=input_sizes, dtype=data_type) - if data_format == "NCHW": - t = test_util.NHWCToNCHW(t) + if data_format in ("NCHW", "NCHW_VECT_C"): + if data_format == "NCHW_VECT_C": + t = test_util.NHWCToNCHW_VECT_C(t) + t, _, _ = gen_array_ops.quantize_v2(t, -128.0, 127.0, dtypes.qint8) + else: + t = test_util.NHWCToNCHW(t) ksize = test_util.NHWCToNCHW(ksize) strides = test_util.NHWCToNCHW(strides) - v2 = v2 and data_format != "NCHW" ksize_placeholder = array_ops.placeholder(dtypes.int32, shape=[4]) strides_placeholder = array_ops.placeholder(dtypes.int32, shape=[4]) if v2: @@ -121,7 +152,10 @@ class PoolingTest(test.TestCase): strides=strides, padding=padding, data_format=data_format) - if data_format == "NCHW": + if data_format == "NCHW_VECT_C": + t = gen_array_ops.dequantize(t, -128, 127) + t = test_util.NCHW_VECT_CToNHWC(t) + elif data_format == "NCHW": t = test_util.NCHWToNHWC(t) if v2: actual = t.eval(feed_dict={ksize_placeholder: ksize, @@ -146,6 +180,13 @@ class PoolingTest(test.TestCase): expected: An array containing the expected operation outputs. use_gpu: Whether we are running on GPU. """ + if data_format == "NCHW_VECT_C": + avg_pool_func = nn_ops.avg_pool + tf_logging.info("pool_func=%s", pool_func) + if pool_func == avg_pool_func: + tf_logging.info("NCHW_VECT_C not yet implemented for avg_pool") + return + self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding, data_format, dtypes.float32, expected, use_gpu, v2) @@ -167,7 +208,7 @@ class PoolingTest(test.TestCase): expected: An array containing the expected operation outputs. use_gpu: Whether we are running on GPU. """ - for (data_format, use_gpu_2) in GetTestConfigs(): + for (data_format, use_gpu_2) in GetTestConfigs(True): if use_gpu_2 == use_gpu: self._VerifyOneTest(pool_func, input_sizes, ksize, strides, padding, data_format, expected, use_gpu, v2) @@ -296,20 +337,20 @@ class PoolingTest(test.TestCase): def _testAvgPoolSamePaddingPacket8(self, use_gpu): expected_output = [ - 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 80.0, 89.0, 90.0, 91.0, 92.0, - 93.0, 94.0, 95.0, 96.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0, - 112.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, 123.0, 124.0, 201.0, - 202.0, 203.0, 204.0, 205.0, 206.0, 207.0, 208.0, 217.0, 218.0, 219.0, - 220.0, 221.0, 222.0, 223.0, 224.0, 233.0, 234.0, 235.0, 236.0, 237.0, - 238.0, 239.0, 240.0, 245.0, 246.0, 247.0, 248.0, 249.0, 250.0, 251.0, - 252.0, 329.0, 330.0, 331.0, 332.0, 333.0, 334.0, 335.0, 336.0, 345.0, - 346.0, 347.0, 348.0, 349.0, 350.0, 351.0, 352.0, 361.0, 362.0, 363.0, - 364.0, 365.0, 366.0, 367.0, 368.0, 373.0, 374.0, 375.0, 376.0, 377.0, - 378.0, 379.0, 380.0, 425.0, 426.0, 427.0, 428.0, 429.0, 430.0, 431.0, - 432.0, 441.0, 442.0, 443.0, 444.0, 445.0, 446.0, 447.0, 448.0, 457.0, - 458.0, 459.0, 460.0, 461.0, 462.0, 463.0, 464.0, 469.0, 470.0, 471.0, - 472.0, 473.0, 474.0, 475.0, 476.0 + -12.0, -11.0, -10.0, -9.0, -8.0, -7.0, -6.0, -5.0, 4.0, 5.0, 6.0, 7.0, + 8.0, 9.0, 10.0, 11.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, + 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, -3.5, -54.0, -53.0, -52.0, + -51.0, -50.0, -49.0, -48.0, -47.0, -38.0, -37.0, -36.0, -35.0, -34.0, + -33.0, -32.0, -31.0, -22.0, -21.0, -20.0, -19.0, -18.0, -17.0, -16.0, + -15.0, -10.0, -9.0, -8.0, -7.0, -6.0, -5.0, -4.0, -3.0, -11.0, -10.0, + -9.0, -8.0, -7.0, -6.0, -5.0, -4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, + 12.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 33.0, 34.0, 35.0, + 36.0, 37.0, 38.0, -3.5, -2.5, -85.0, -84.0, -83.0, -82.0, -81.0, -80.0, + -79.0, -78.0, -69.0, -68.0, -67.0, -66.0, -65.0, -64.0, -63.0, -62.0, + -53.0, -52.0, -51.0, -50.0, -49.0, -48.0, -47.0, -46.0, -41.0, -40.0, + -39.0, -38.0, -37.0, -36.0, -35.0, -34.0 ] + self._VerifyValues( nn_ops.avg_pool, input_sizes=[1, 8, 8, 8], @@ -468,19 +509,18 @@ class PoolingTest(test.TestCase): def _testMaxPoolSamePaddingPacket8(self, use_gpu): expected_output = [ - 145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, 152.0, 161.0, 162.0, - 163.0, 164.0, 165.0, 166.0, 167.0, 168.0, 177.0, 178.0, 179.0, 180.0, - 181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0, - 191.0, 192.0, 273.0, 274.0, 275.0, 276.0, 277.0, 278.0, 279.0, 280.0, - 289.0, 290.0, 291.0, 292.0, 293.0, 294.0, 295.0, 296.0, 305.0, 306.0, - 307.0, 308.0, 309.0, 310.0, 311.0, 312.0, 313.0, 314.0, 315.0, 316.0, - 317.0, 318.0, 319.0, 320.0, 401.0, 402.0, 403.0, 404.0, 405.0, 406.0, - 407.0, 408.0, 417.0, 418.0, 419.0, 420.0, 421.0, 422.0, 423.0, 424.0, - 433.0, 434.0, 435.0, 436.0, 437.0, 438.0, 439.0, 440.0, 441.0, 442.0, - 443.0, 444.0, 445.0, 446.0, 447.0, 448.0, 465.0, 466.0, 467.0, 468.0, - 469.0, 470.0, 471.0, 472.0, 481.0, 482.0, 483.0, 484.0, 485.0, 486.0, - 487.0, 488.0, 497.0, 498.0, 499.0, 500.0, 501.0, 502.0, 503.0, 504.0, - 505.0, 506.0, 507.0, 508.0, 509.0, 510.0, 511.0, 512.0 + 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 97.0, 98.0, 99.0, 100.0, + 101.0, 102.0, 103.0, 104.0, 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, + 119.0, 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, 126.0, 127.0, 120.0, + 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 34.0, 35.0, 36.0, 37.0, + 38.0, 39.0, 40.0, 41.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, + 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 64.0, 65.0, 82.0, 83.0, 84.0, 85.0, + 86.0, 87.0, 88.0, 89.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, 104.0, + 105.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, + 123.0, 124.0, 125.0, 126.0, 127.0, 120.0, 121.0, -45.0, -44.0, -43.0, + -42.0, -41.0, -40.0, -39.0, -38.0, -29.0, -28.0, -27.0, -26.0, -25.0, + -24.0, -23.0, -22.0, -13.0, -12.0, -11.0, -10.0, -9.0, -8.0, -7.0, -6.0, + -5.0, -4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0 ] self._VerifyValues( nn_ops.max_pool, diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index 33ba5df7a6e..274eda4f643 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -1428,6 +1428,8 @@ def zeros(shape, dtype=dtypes.float32, name=None): zero = "" else: zero = 0 + if context.in_eager_mode(): + return fill(shape, constant(zero, dtype=dtype), name=name) try: shape = tensor_shape.as_shape(shape) output = constant(zero, shape=shape, dtype=dtype, name=name) @@ -1466,6 +1468,13 @@ def zeros_like(tensor, dtype=None, name=None, optimize=True): with ops.name_scope(name, "zeros_like", [tensor]) as name: tensor = ops.convert_to_tensor(tensor, name="tensor") + if context.in_eager_mode(): + if dtype is not None and dtype != tensor.dtype: + return zeros( + shape_internal(tensor, optimize=optimize), dtype=dtype, name=name) + with ops.device(tensor.device): + return gen_array_ops._zeros_like(tensor, name=name) + # For now, variant types must be created via zeros_like; as we need to # pass the input variant object to the proper zeros callback. diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index 6559929560c..3cd82d60417 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -736,10 +736,9 @@ def cast(x, dtype, name=None): values_cast = cast(x.values, base_type, name=name) return sparse_tensor.SparseTensor(x.indices, values_cast, x.dense_shape) else: - # TODO(touts): Handle what Josh said. - # - # Could return ops.convert_to_tensor(x, dtype=dtype, ...) here, but that - # allows some conversions that cast() can't do, e.g. casting numbers to + # TODO(josh11b): If x is not already a Tensor, we could return + # ops.convert_to_tensor(x, dtype=dtype, ...) here, but that + # allows some conversions that cast() can't do, e.g. casting numbers to # strings. x = ops.convert_to_tensor(x, name="x") if x.dtype.base_dtype == base_type: diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index d4b16635071..a2e75dd7f27 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -37,6 +37,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops.gen_nn_ops import * # pylint: enable=wildcard-import + # Aliases for some automatically-generated names. local_response_normalization = gen_nn_ops.lrn @@ -1750,19 +1751,19 @@ def max_pool(value, ksize, strides, padding, data_format="NHWC", name=None): """Performs the max pooling on the input. Args: - value: A 4-D `Tensor` with shape `[batch, height, width, channels]` and - type `tf.float32`. + value: A 4-D `Tensor` of the format specified by `data_format`. ksize: A 1-D int Tensor of 4 elements. The size of the window for each dimension of the input tensor. strides: A 1-D int Tensor of 4 elements. The stride of the sliding window for each dimension of the input tensor. padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See the @{tf.nn.convolution$comment here} - data_format: A string. 'NHWC' and 'NCHW' are supported. + data_format: A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported. name: Optional name for the operation. Returns: - A `Tensor` with type `tf.float32`. The max pooled output tensor. + A `Tensor` of format specified by `data_format`. + The max pooled output tensor. """ with ops.name_scope(name, "MaxPool", [value]) as name: value = ops.convert_to_tensor(value, name="input") diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index fdc8a5843fe..c735be06983 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -42,14 +42,14 @@ from tensorflow.python.ops.gen_resource_variable_ops import * from tensorflow.python.util import compat -def _eager_safe_variable_handle(shape, dtype, shared_name, name, +def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode, container=None): """Creates a variable handle with information to do shape inference.""" handle = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, shared_name=shared_name, name=name, container=container) - if context.in_graph_mode(): + if graph_mode: return handle with context.graph_mode(), ops.Graph().as_default(): h = gen_resource_variable_ops.var_handle_op(shape=shape, dtype=dtype, @@ -152,8 +152,8 @@ class ResourceVariable(variables.Variable): uniquified automatically. dtype: If set, initial_value will be converted to the given type. If None, either the datatype will be kept (if initial_value is - a Tensor) or float32 will be used (if it is a Python object convertible - to a Tensor). + a Tensor) or float32 will be used (if it is a Python object convertible + to a Tensor). variable_def: `VariableDef` protocol buffer. If not None, recreates the `ResourceVariable` object with its contents. `variable_def` and other arguments (except for import_scope) are mutually exclusive. @@ -172,7 +172,7 @@ class ResourceVariable(variables.Variable): shape and `validate_shape` is `True`. """ if variable_def: - if initial_value: + if initial_value is not None: raise ValueError("variable_def and initial_value are mutually " "exclusive.") if not context.in_graph_mode(): @@ -277,7 +277,8 @@ class ResourceVariable(variables.Variable): shape=initial_value.get_shape(), dtype=initial_value.dtype.base_dtype, shared_name=handle_name, - name=name) + name=name, + graph_mode=self._in_graph_mode) self._handle_device = ( self._handle.device if self._in_graph_mode else context.get_default_context().device_name) @@ -291,6 +292,7 @@ class ResourceVariable(variables.Variable): dtype=initial_value.dtype.base_dtype, shared_name=handle_name, name=name, + graph_mode=False, container="") self._handle_device = ( self._handle.device if self._in_graph_mode else @@ -316,6 +318,7 @@ class ResourceVariable(variables.Variable): dtype=initial_value.dtype.base_dtype, shared_name=handle_name, name=name, + graph_mode=self._in_graph_mode, container="") self._handle_device = (self._handle.device if self._in_graph_mode else context.get_default_context().device_name) @@ -372,6 +375,7 @@ class ResourceVariable(variables.Variable): """Initializes from `VariableDef` proto.""" # Note that init_from_proto is currently not supported in Eager mode. assert context.in_graph_mode() + self._in_graph_mode = True assert isinstance(variable_def, variable_pb2.VariableDef) if not variable_def.is_resource: raise ValueError("Trying to restore Variable as ResourceVariable.") @@ -434,7 +438,7 @@ class ResourceVariable(variables.Variable): @property def create(self): """The op responsible for initializing this variable.""" - if not context.in_graph_mode(): + if not self._in_graph_mode: raise RuntimeError("Calling create in EAGER mode not supported.") return self._initializer_op @@ -520,7 +524,7 @@ class ResourceVariable(variables.Variable): # In graph mode, ensure we read the variable in the same device as the # handle. In eager mode, however, this sometimes tries to read a GPU # variable in the CPU because the handle is host memory. For now, then, we - # need to skip the device block in eager. TODO(apassos) eager should have + # need to skip the device block in eager. TODO(apassos): eager should have # separate notions of device and memory, so handle.device can be GPU while # handle.memory_space is always CPU. if context.in_graph_mode(): diff --git a/tensorflow/python/profiler/internal/run_metadata_test.py b/tensorflow/python/profiler/internal/run_metadata_test.py index 1e26a9897eb..c0de08cad6b 100644 --- a/tensorflow/python/profiler/internal/run_metadata_test.py +++ b/tensorflow/python/profiler/internal/run_metadata_test.py @@ -70,6 +70,7 @@ def _run_model(): opts = builder.time_and_memory() opts['min_micros'] = 0 opts['min_bytes'] = 0 + opts['order_by'] = 'name' opts['output'] = 'none' _ = sess.run(y, options=config_pb2.RunOptions( @@ -95,6 +96,7 @@ def _run_loop_model(): run_metadata=run_meta) opts = builder.time_and_memory() + opts['order_by'] = 'name' opts['output'] = 'none' tfprof_node = model_analyzer.profile( diff --git a/tensorflow/python/profiler/model_analyzer.py b/tensorflow/python/profiler/model_analyzer.py index 98d3e58f2af..2071325c7bb 100644 --- a/tensorflow/python/profiler/model_analyzer.py +++ b/tensorflow/python/profiler/model_analyzer.py @@ -180,8 +180,7 @@ class Profiler(object): """ # pylint: disable=protected-access op_log = tfprof_logger._merge_default_with_oplog( - self._graph, run_meta=run_meta, add_trace=False, - add_trainable_var=False) + self._graph, run_meta=run_meta) # pylint: enable=protected-access # TODO(xpan): P1: Better to find the current graph. print_mdl.AddStep( diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py index dcdda1ffa25..494ba2e2a0d 100644 --- a/tensorflow/python/profiler/model_analyzer_test.py +++ b/tensorflow/python/profiler/model_analyzer_test.py @@ -305,7 +305,7 @@ class PrintModelAnalysisTest(test.TestCase): _ = model_analyzer.profile( sess.graph, run_meta, cmd='graph', options=opts) - with gfile.Open(outfile, 'r') as f: + with gfile.Open(outfile + '_0', 'r') as f: # Test that a json file is created. # TODO(xpan): tfprof Timeline isn't quite correct on Windows. # Investigate why. diff --git a/tensorflow/python/profiler/option_builder.py b/tensorflow/python/profiler/option_builder.py index 641895ffe54..13942ad6a2a 100644 --- a/tensorflow/python/profiler/option_builder.py +++ b/tensorflow/python/profiler/option_builder.py @@ -177,7 +177,7 @@ class ProfileOptionBuilder(object): 'min_params': 0, 'min_float_ops': 0, 'min_occurrence': 0, - 'order_by': 'name', + 'order_by': 'micros', 'account_type_regexes': ['.*'], 'start_name_regexes': ['.*'], 'trim_name_regexes': [], diff --git a/tensorflow/python/profiler/profile_context.py b/tensorflow/python/profiler/profile_context.py index 07adcb9c3f5..49fa22e3479 100644 --- a/tensorflow/python/profiler/profile_context.py +++ b/tensorflow/python/profiler/profile_context.py @@ -20,6 +20,7 @@ from __future__ import print_function import contextlib import os +import threading from tensorflow.core.protobuf import config_pb2 from tensorflow.python import pywrap_tensorflow as print_mdl @@ -163,6 +164,7 @@ class ProfileContext(object): self._traced_steps = 0 self._auto_profiles = [] self._profiler = None + self._lock = threading.Lock() def add_auto_profiling(self, cmd, options, profile_steps): """Traces and profiles at some session run steps. @@ -181,9 +183,10 @@ class ProfileContext(object): @property def profiler(self): """Returns the current profiler object.""" - if not self._profiler: - self._profiler = model_analyzer.Profiler(ops.get_default_graph()) - return self._profiler + with self._lock: + if not self._profiler: + self._profiler = model_analyzer.Profiler(ops.get_default_graph()) + return self._profiler def trace_next_step(self): """Enables tracing and add traces to profiler at next step.""" diff --git a/tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat b/tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat index 07ad70dd344..2a6db1015e6 100644 --- a/tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat +++ b/tensorflow/tools/ci_build/windows/cpu/cmake/run_build.bat @@ -34,7 +34,7 @@ SET CMAKE_DIR=%REPO_ROOT%\tensorflow\contrib\cmake SET MSBUILD_EXE="C:\Program Files (x86)\MSBuild\14.0\Bin\msbuild.exe" :: Run cmake to create Visual Studio Project files. -%CMAKE_EXE% %CMAKE_DIR% -A x64 -DSWIG_EXECUTABLE=%SWIG_EXE% -DPYTHON_EXECUTABLE=%PY_EXE% -DCMAKE_BUILD_TYPE=Release -DPYTHON_LIBRARIES=%PY_LIB% -Dtensorflow_BUILD_PYTHON_TESTS=%BUILD_PYTHON_TESTS% -Dtensorflow_BUILD_CC_TESTS=%BUILD_CC_TESTS% +%CMAKE_EXE% %CMAKE_DIR% -A x64 -DSWIG_EXECUTABLE=%SWIG_EXE% -DPYTHON_EXECUTABLE=%PY_EXE% -DCMAKE_BUILD_TYPE=Release -DPYTHON_LIBRARIES=%PY_LIB% -Dtensorflow_BUILD_PYTHON_TESTS=%BUILD_PYTHON_TESTS% -Dtensorflow_BUILD_CC_TESTS=%BUILD_CC_TESTS% -Dtensorflow_TF_NIGHTLY=%TF_NIGHTLY% :: Run msbuild in the resulting VS project files to build a pip package. -%MSBUILD_EXE% /p:Configuration=Release /maxcpucount:32 tf_python_build_pip_package.vcxproj +%MSBUILD_EXE% /p:Configuration=Release /maxcpucount:32 /verbosity:quiet tf_python_build_pip_package.vcxproj diff --git a/tensorflow/tools/ci_build/windows/cpu/cmake/run_py.bat b/tensorflow/tools/ci_build/windows/cpu/cmake/run_py.bat index 96fbadd1767..2f6d53e171c 100644 --- a/tensorflow/tools/ci_build/windows/cpu/cmake/run_py.bat +++ b/tensorflow/tools/ci_build/windows/cpu/cmake/run_py.bat @@ -19,8 +19,11 @@ MKDIR %BUILD_DIR% CD %BUILD_DIR% :: Set which tests to build -SET BUILD_CC_TESTS=OFF -SET BUILD_PYTHON_TESTS=ON +IF DEFINED BUILD_CC_TESTS (ECHO BUILD_CC_TESTS is set to %BUILD_CC_TESTS%) ELSE (SET BUILD_CC_TESTS=OFF) +IF DEFINED BUILD_PYTHON_TESTS (ECHO BUILD_PYTHON_TESTS is set to %BUILD_PYTHON_TESTS%) ELSE (SET BUILD_PYTHON_TESTS=ON) + +:: Set if this build is a nightly +IF DEFINED TF_NIGHTLY (ECHO TF_NIGHTLY is set to %TF_NIGHTLY%) ELSE (SET TF_NIGHTLY=OFF) :: Set pip binary location. Do not override if it is set already. IF DEFINED PIP_EXE (ECHO PIP_EXE is set to %PIP_EXE%) ELSE (SET PIP_EXE="C:\Program Files\Anaconda3\Scripts\pip.exe") diff --git a/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat b/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat index b4f9cc84762..cbb72b16a62 100644 --- a/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat +++ b/tensorflow/tools/ci_build/windows/gpu/cmake/run_build.bat @@ -35,7 +35,7 @@ SET CMAKE_DIR=%REPO_ROOT%\tensorflow\contrib\cmake SET MSBUILD_EXE="C:\Program Files (x86)\MSBuild\14.0\Bin\msbuild.exe" :: Run cmake to create Visual Studio Project files. -%CMAKE_EXE% %CMAKE_DIR% -A x64 -DSWIG_EXECUTABLE=%SWIG_EXE% -DPYTHON_EXECUTABLE=%PY_EXE% -DCMAKE_BUILD_TYPE=Release -DPYTHON_LIBRARIES=%PY_LIB% -Dtensorflow_BUILD_PYTHON_TESTS=%BUILD_PYTHON_TESTS% -Dtensorflow_BUILD_CC_TESTS=%BUILD_CC_TESTS% -Dtensorflow_ENABLE_GPU=ON -DCUDNN_HOME=%CUDNN_HOME% +%CMAKE_EXE% %CMAKE_DIR% -A x64 -DSWIG_EXECUTABLE=%SWIG_EXE% -DPYTHON_EXECUTABLE=%PY_EXE% -DCMAKE_BUILD_TYPE=Release -DPYTHON_LIBRARIES=%PY_LIB% -Dtensorflow_BUILD_PYTHON_TESTS=%BUILD_PYTHON_TESTS% -Dtensorflow_BUILD_CC_TESTS=%BUILD_CC_TESTS% -Dtensorflow_ENABLE_GPU=ON -DCUDNN_HOME=%CUDNN_HOME% -Dtensorflow_TF_NIGHTLY=%TF_NIGHTLY% :: Run msbuild in the resulting VS project files to build a pip package. -%MSBUILD_EXE% /p:Configuration=Release /maxcpucount:32 tf_python_build_pip_package.vcxproj +%MSBUILD_EXE% /p:Configuration=Release /maxcpucount:32 /verbosity:quiet tf_python_build_pip_package.vcxproj diff --git a/tensorflow/tools/ci_build/windows/gpu/cmake/run_py.bat b/tensorflow/tools/ci_build/windows/gpu/cmake/run_py.bat index e774a6e9168..02e24c85de1 100644 --- a/tensorflow/tools/ci_build/windows/gpu/cmake/run_py.bat +++ b/tensorflow/tools/ci_build/windows/gpu/cmake/run_py.bat @@ -19,9 +19,13 @@ MKDIR %BUILD_DIR% CD %BUILD_DIR% :: Set which tests to build -SET BUILD_CC_TESTS=OFF -SET BUILD_PYTHON_TESTS=ON +IF DEFINED BUILD_CC_TESTS (ECHO BUILD_CC_TESTS is set to %BUILD_CC_TESTS%) ELSE (SET BUILD_CC_TESTS=OFF) +IF DEFINED BUILD_PYTHON_TESTS (ECHO BUILD_PYTHON_TESTS is set to %BUILD_PYTHON_TESTS%) ELSE (SET BUILD_PYTHON_TESTS=ON) +:: Set if this build is a nightly +IF DEFINED TF_NIGHTLY (ECHO TF_NIGHTLY is set to %TF_NIGHTLY%) ELSE (SET TF_NIGHTLY=OFF) + +:: Set pip binary location. Do not override if it is set already. IF DEFINED PIP_EXE (ECHO PIP_EXE is set to %PIP_EXE%) ELSE (SET PIP_EXE="C:\Program Files\Anaconda3\Scripts\pip.exe") :: Run the CMAKE build to build the pip package. diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index d62316964f8..493d06a0d6e 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -152,6 +152,7 @@ sh_binary( "//tensorflow:tensorflow_py", "//tensorflow/contrib/boosted_trees:boosted_trees_pip", "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", + "//tensorflow/contrib/gan:gan", "//tensorflow/contrib/graph_editor:graph_editor_pip", "//tensorflow/contrib/keras:keras", "//tensorflow/contrib/labeled_tensor:labeled_tensor_pip",