Merge pull request #12956 from yifeif/branch_168186374
Branch 168186374
This commit is contained in:
commit
40eef4473b
@ -79,6 +79,8 @@ tf_cc_test(
|
|||||||
"//tensorflow/cc:ops",
|
"//tensorflow/cc:ops",
|
||||||
"//tensorflow/cc:scope",
|
"//tensorflow/cc:scope",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
],
|
],
|
||||||
|
@ -479,20 +479,16 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
|||||||
if (kernel == nullptr) {
|
if (kernel == nullptr) {
|
||||||
const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
|
const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
|
||||||
kernel = new tensorflow::KernelAndDevice(ctx->rendezvous);
|
kernel = new tensorflow::KernelAndDevice(ctx->rendezvous);
|
||||||
if (!op->is_function()) {
|
// Knowledge of the implementation of Init (and in-turn
|
||||||
status->status =
|
// FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
|
||||||
tensorflow::KernelAndDevice::InitOp(device, ndef, kernel);
|
// will be accessed, so grab on to the lock.
|
||||||
} else {
|
// See WARNING comment below - would be nice to rework to avoid this
|
||||||
// Knowledge of the implementation of InitFn (and in-turn
|
// subtlety.
|
||||||
// FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
|
tensorflow::tf_shared_lock l(ctx->functions_mu);
|
||||||
// will be accessed, so grab on to the lock.
|
status->status =
|
||||||
// See WARNING comment below - would be nice to rework to avoid this
|
tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel);
|
||||||
// subtlety.
|
|
||||||
tensorflow::mutex_lock l(ctx->functions_mu);
|
|
||||||
status->status = tensorflow::KernelAndDevice::InitFn(
|
|
||||||
ndef, ctx->func_lib(device), kernel);
|
|
||||||
}
|
|
||||||
if (!status->status.ok()) {
|
if (!status->status.ok()) {
|
||||||
|
delete kernel;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
|
tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
|
||||||
|
@ -238,9 +238,8 @@ Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// static
|
// static
|
||||||
Status KernelAndDevice::InitFn(const NodeDef& ndef,
|
Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
|
||||||
FunctionLibraryRuntime* flib,
|
KernelAndDevice* out) {
|
||||||
KernelAndDevice* out) {
|
|
||||||
OpKernel* k = nullptr;
|
OpKernel* k = nullptr;
|
||||||
Status s = flib->CreateKernel(ndef, &k);
|
Status s = flib->CreateKernel(ndef, &k);
|
||||||
out->device_ = flib->device();
|
out->device_ = flib->device();
|
||||||
|
@ -150,28 +150,19 @@ class KernelAndDevice {
|
|||||||
public:
|
public:
|
||||||
// Populates 'out' with a kernel appropriate for 'ndef'.
|
// Populates 'out' with a kernel appropriate for 'ndef'.
|
||||||
//
|
//
|
||||||
// Assumes that 'ndef' refers to a primitive op (as opposed to a function).
|
|
||||||
static Status InitOp(Device* device, const NodeDef& ndef,
|
|
||||||
KernelAndDevice* out);
|
|
||||||
|
|
||||||
// Like InitOp but for functions defined in flib (i.e., ndef.op() refers to a
|
|
||||||
// TensorFlow function in the FunctionLibraryRuntime).
|
|
||||||
//
|
|
||||||
// The provided FunctionLibraryRuntime MUST outlive all calls to
|
// The provided FunctionLibraryRuntime MUST outlive all calls to
|
||||||
// Run() on the returned KernelAndDevice.
|
// Run() on the returned KernelAndDevice.
|
||||||
//
|
//
|
||||||
// TODO(ashankar): There shouldn't be a need for a separate InitOp and InitFn.
|
// TODO(ashankar): Figure out thread-safety concerns around
|
||||||
// The implementation of InitFn should work for both because
|
// FunctionLibraryRuntime (in particular, how the underlying
|
||||||
// FunctionLibraryRuntime::CreateKernel will create a primitive op kernel if
|
// FunctionLibraryDefinition might be mutated by another thread as new
|
||||||
// appropriate. However, for now we keep them separate because I haven't
|
// functions are registered with it). Conservatively, thread-safe usage of
|
||||||
// figured out thread-safety concerns around FunctionLibraryRuntime (in
|
// the FunctionLibraryRuntime is pushed on to the caller (see locking in
|
||||||
// particular, how the underlying FunctionLibraryDefinition might be mutated
|
// c_api.cc).
|
||||||
// by another thread as new functions are registered with it).
|
static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
|
||||||
// Conservatively, thread-safe usage of the FunctionLibraryRuntime is pushed
|
KernelAndDevice* out);
|
||||||
// on to the caller (see locking in c_api.cc) for now. But I really should
|
// TODO(ashankar): Remove this
|
||||||
// dig into this so that both InitOp and InitFn can be collapsed to
|
static Status InitOp(Device* device, const NodeDef& ndef,
|
||||||
// FunctionLibraryRuntime::CreateKernel.
|
|
||||||
static Status InitFn(const NodeDef& ndef, FunctionLibraryRuntime* flib,
|
|
||||||
KernelAndDevice* out);
|
KernelAndDevice* out);
|
||||||
|
|
||||||
KernelAndDevice(tensorflow::Rendezvous* rendez)
|
KernelAndDevice(tensorflow::Rendezvous* rendez)
|
||||||
@ -184,10 +175,10 @@ class KernelAndDevice {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<OpKernel> kernel_;
|
std::unique_ptr<OpKernel> kernel_;
|
||||||
tensorflow::Device* device_;
|
Device* device_;
|
||||||
tensorflow::FunctionLibraryRuntime* flib_;
|
FunctionLibraryRuntime* flib_;
|
||||||
tensorflow::checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
|
checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
|
||||||
tensorflow::Rendezvous* rendez_;
|
Rendezvous* rendez_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -23,15 +23,36 @@ limitations under the License.
|
|||||||
#include "tensorflow/cc/framework/scope.h"
|
#include "tensorflow/cc/framework/scope.h"
|
||||||
#include "tensorflow/cc/ops/standard_ops.h"
|
#include "tensorflow/cc/ops/standard_ops.h"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||||
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
|
#include "tensorflow/core/platform/env.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
#include "tensorflow/core/platform/test_benchmark.h"
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
#include "tensorflow/core/public/version.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
Device* CPUDevice() {
|
class TestEnv {
|
||||||
return DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
|
public:
|
||||||
}
|
TestEnv() : flib_def_(OpRegistry::Global(), {}) {
|
||||||
|
Device* device =
|
||||||
|
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
|
||||||
|
device_mgr_.reset(new DeviceMgr({device}));
|
||||||
|
flib_runtime_ = NewFunctionLibraryRuntime(device_mgr_.get(), Env::Default(),
|
||||||
|
device, TF_GRAPH_DEF_VERSION,
|
||||||
|
&flib_def_, {}, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
FunctionLibraryRuntime* function_library_runtime() const {
|
||||||
|
return flib_runtime_.get();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
FunctionLibraryDefinition flib_def_;
|
||||||
|
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||||
|
std::unique_ptr<FunctionLibraryRuntime> flib_runtime_;
|
||||||
|
};
|
||||||
|
|
||||||
TEST(AttrTypeMap, Lookup) {
|
TEST(AttrTypeMap, Lookup) {
|
||||||
const AttrTypeMap* m = nullptr;
|
const AttrTypeMap* m = nullptr;
|
||||||
@ -69,9 +90,10 @@ TEST(KernelAndDevice, Run) {
|
|||||||
.Set("transpose_b", false)
|
.Set("transpose_b", false)
|
||||||
.NumInputs(inputs.size())
|
.NumInputs(inputs.size())
|
||||||
.BuildNodeDef());
|
.BuildNodeDef());
|
||||||
std::unique_ptr<Device> device(CPUDevice());
|
TestEnv env;
|
||||||
KernelAndDevice kernel(nullptr);
|
KernelAndDevice kernel(nullptr);
|
||||||
Status s = KernelAndDevice::InitOp(device.get(), ndef, &kernel);
|
Status s =
|
||||||
|
KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel);
|
||||||
ASSERT_TRUE(s.ok()) << s;
|
ASSERT_TRUE(s.ok()) << s;
|
||||||
std::vector<Tensor> outputs;
|
std::vector<Tensor> outputs;
|
||||||
s = kernel.Run(&inputs, &outputs);
|
s = kernel.Run(&inputs, &outputs);
|
||||||
@ -132,11 +154,12 @@ void BM_KernelAndDeviceInit(int iters) {
|
|||||||
.Set("transpose_b", false)
|
.Set("transpose_b", false)
|
||||||
.NumInputs(2)
|
.NumInputs(2)
|
||||||
.BuildNodeDef());
|
.BuildNodeDef());
|
||||||
std::unique_ptr<Device> device(CPUDevice());
|
TestEnv env;
|
||||||
KernelAndDevice k(nullptr);
|
KernelAndDevice k(nullptr);
|
||||||
tensorflow::testing::StartTiming();
|
tensorflow::testing::StartTiming();
|
||||||
for (int i = 0; i < iters; ++i) {
|
for (int i = 0; i < iters; ++i) {
|
||||||
TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &k));
|
TF_CHECK_OK(
|
||||||
|
KernelAndDevice::Init(ndef, env.function_library_runtime(), &k));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
BENCHMARK(BM_KernelAndDeviceInit);
|
BENCHMARK(BM_KernelAndDeviceInit);
|
||||||
@ -154,9 +177,10 @@ void BM_KernelAndDeviceRun(int iters) {
|
|||||||
.Set("transpose_b", false)
|
.Set("transpose_b", false)
|
||||||
.NumInputs(inputs.size())
|
.NumInputs(inputs.size())
|
||||||
.BuildNodeDef());
|
.BuildNodeDef());
|
||||||
std::unique_ptr<Device> device(CPUDevice());
|
TestEnv env;
|
||||||
KernelAndDevice kernel(nullptr);
|
KernelAndDevice kernel(nullptr);
|
||||||
TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &kernel));
|
TF_CHECK_OK(
|
||||||
|
KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel));
|
||||||
tensorflow::testing::StartTiming();
|
tensorflow::testing::StartTiming();
|
||||||
for (int i = 0; i < iters; ++i) {
|
for (int i = 0; i < iters; ++i) {
|
||||||
TF_CHECK_OK(kernel.Run(&inputs, &outputs));
|
TF_CHECK_OK(kernel.Run(&inputs, &outputs));
|
||||||
|
@ -286,6 +286,7 @@ cc_library(
|
|||||||
srcs = ["call_inliner.cc"],
|
srcs = ["call_inliner.cc"],
|
||||||
hdrs = ["call_inliner.h"],
|
hdrs = ["call_inliner.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":call_graph",
|
||||||
":hlo_pass",
|
":hlo_pass",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
@ -17,33 +17,11 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <deque>
|
#include <deque>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
namespace {
|
||||||
StatusOr<bool> CallInliner::Run(HloModule* module) {
|
|
||||||
std::deque<HloInstruction*> work_queue;
|
|
||||||
|
|
||||||
// Seed the work queue with call instructions from the main computation.
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
module->entry_computation()->Accept([&](HloInstruction* hlo) {
|
|
||||||
if (hlo->opcode() == HloOpcode::kCall) {
|
|
||||||
work_queue.push_back(hlo);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}));
|
|
||||||
|
|
||||||
VLOG(1) << "Work queue seeded with " << work_queue.size() << " entries.";
|
|
||||||
|
|
||||||
bool mutated = false;
|
|
||||||
while (!work_queue.empty()) {
|
|
||||||
mutated = true;
|
|
||||||
HloInstruction* call = work_queue.front();
|
|
||||||
work_queue.pop_front();
|
|
||||||
TF_RETURN_IF_ERROR(ReplaceWithInlinedBody(call, &work_queue));
|
|
||||||
}
|
|
||||||
return mutated;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Traverses the callee computation, inlining cloned nodes into the caller
|
// Traverses the callee computation, inlining cloned nodes into the caller
|
||||||
// computation and connecting them to producers/consumers appropriately.
|
// computation and connecting them to producers/consumers appropriately.
|
||||||
@ -52,14 +30,18 @@ StatusOr<bool> CallInliner::Run(HloModule* module) {
|
|||||||
// computation have been added to the work_queue.
|
// computation have been added to the work_queue.
|
||||||
class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
|
class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
|
||||||
public:
|
public:
|
||||||
SubcomputationInsertionVisitor(HloInstruction* call,
|
// call is the call operation -- it will be replaced with the body of the
|
||||||
std::deque<HloInstruction*>* work_queue)
|
// called computation.
|
||||||
: call_(call), outer_(call->parent()), work_queue_(work_queue) {}
|
explicit SubcomputationInsertionVisitor(HloInstruction* call)
|
||||||
|
: call_(call), outer_(call->parent()) {
|
||||||
|
CHECK_EQ(HloOpcode::kCall, call_->opcode());
|
||||||
|
}
|
||||||
|
|
||||||
// Resolves the operands to the HLO instruction in the inlined (caller) graph,
|
// Resolves the operands to the HLO instruction in the inlined (caller) graph,
|
||||||
// and clones the HLO instruction into that graph with the new operands.
|
// and clones the HLO instruction into that graph with the new operands.
|
||||||
// If the instruction is a call, it is added to the work queue.
|
// If the instruction is a call, it is added to the work queue.
|
||||||
Status DefaultAction(HloInstruction* hlo) override {
|
Status DefaultAction(HloInstruction* hlo) override {
|
||||||
|
TF_RET_CHECK(hlo->opcode() != HloOpcode::kCall);
|
||||||
std::vector<HloInstruction*> new_operands;
|
std::vector<HloInstruction*> new_operands;
|
||||||
for (HloInstruction* operand : hlo->operands()) {
|
for (HloInstruction* operand : hlo->operands()) {
|
||||||
TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, Resolve(operand));
|
TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, Resolve(operand));
|
||||||
@ -79,12 +61,6 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
|
|||||||
new_control_predecessor->AddControlDependencyTo(new_hlo_pointer));
|
new_control_predecessor->AddControlDependencyTo(new_hlo_pointer));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (new_hlo_pointer->opcode() == HloOpcode::kCall) {
|
|
||||||
VLOG(1) << "Adding new call HLO to work queue.";
|
|
||||||
// Call instructions we observe in the subcomputation are added to the
|
|
||||||
// inliner work queue.
|
|
||||||
work_queue_->push_back(new_hlo_pointer);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -141,16 +117,30 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
|
|||||||
std::deque<HloInstruction*>* work_queue_;
|
std::deque<HloInstruction*>* work_queue_;
|
||||||
};
|
};
|
||||||
|
|
||||||
Status CallInliner::ReplaceWithInlinedBody(
|
} // namespace
|
||||||
HloInstruction* call, std::deque<HloInstruction*>* work_queue) {
|
|
||||||
TF_RET_CHECK(call->opcode() == HloOpcode::kCall);
|
|
||||||
TF_RET_CHECK(call->called_computations().size() == 1);
|
|
||||||
HloComputation* called = call->called_computations()[0];
|
|
||||||
VLOG(1) << "Replacing call " << call->ToString() << " with inlined body of "
|
|
||||||
<< called->name();
|
|
||||||
|
|
||||||
SubcomputationInsertionVisitor visitor(call, work_queue);
|
StatusOr<bool> CallInliner::Run(HloModule* module) {
|
||||||
return called->Accept(&visitor);
|
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||||
|
// Because call graph nodes are visited in post-order (callees before callers)
|
||||||
|
// we'll always inline kCalls into their callers in the appropriate order.
|
||||||
|
bool did_mutate = false;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
call_graph->VisitNodes([&](const CallGraphNode& node) -> Status {
|
||||||
|
for (const CallSite& callsite : node.caller_callsites()) {
|
||||||
|
VLOG(1) << "Visiting callsite: " << callsite.ToString();
|
||||||
|
if (callsite.instruction()->opcode() == HloOpcode::kCall) {
|
||||||
|
did_mutate = true;
|
||||||
|
const auto& callees = callsite.called_computations();
|
||||||
|
TF_RET_CHECK(callees.size() == 1);
|
||||||
|
HloComputation* callee = callees[0];
|
||||||
|
// We visit the callee, cloning its body into its caller.
|
||||||
|
SubcomputationInsertionVisitor visitor(callsite.instruction());
|
||||||
|
TF_RETURN_IF_ERROR(callee->Accept(&visitor));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}));
|
||||||
|
return did_mutate;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -31,16 +31,6 @@ class CallInliner : public HloPassInterface {
|
|||||||
tensorflow::StringPiece name() const override { return "CallInliner"; }
|
tensorflow::StringPiece name() const override { return "CallInliner"; }
|
||||||
|
|
||||||
StatusOr<bool> Run(HloModule* module) override;
|
StatusOr<bool> Run(HloModule* module) override;
|
||||||
|
|
||||||
private:
|
|
||||||
// Replaces the given call operation -- which must be an operation inside the
|
|
||||||
// entry computation with opcode kCall -- with the called computation's body,
|
|
||||||
// such that the called computation is inline in the entry computation.
|
|
||||||
//
|
|
||||||
// On successful inlining, the inlined computation may have itself contained
|
|
||||||
// calls; if so, they are added to the work_queue.
|
|
||||||
Status ReplaceWithInlinedBody(HloInstruction* call,
|
|
||||||
std::deque<HloInstruction*>* work_queue);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -44,6 +44,8 @@ namespace {
|
|||||||
using CallInlinerTest = HloTestBase;
|
using CallInlinerTest = HloTestBase;
|
||||||
|
|
||||||
TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
|
TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
|
||||||
|
// "inner" computation just has a control dependency from the "zero" value to
|
||||||
|
// the "one" value.
|
||||||
HloComputation::Builder inner(TestName() + ".inner");
|
HloComputation::Builder inner(TestName() + ".inner");
|
||||||
HloInstruction* zero = inner.AddInstruction(
|
HloInstruction* zero = inner.AddInstruction(
|
||||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(24.0f)));
|
HloInstruction::CreateConstant(Literal::CreateR0<float>(24.0f)));
|
||||||
@ -54,6 +56,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
|
|||||||
HloComputation* inner_computation =
|
HloComputation* inner_computation =
|
||||||
module->AddEmbeddedComputation(inner.Build());
|
module->AddEmbeddedComputation(inner.Build());
|
||||||
|
|
||||||
|
// "outer" computation just calls the "inner" computation.
|
||||||
HloComputation::Builder outer(TestName() + ".outer");
|
HloComputation::Builder outer(TestName() + ".outer");
|
||||||
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
|
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
|
||||||
outer.AddInstruction(
|
outer.AddInstruction(
|
||||||
@ -73,5 +76,44 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
|
|||||||
EXPECT_EQ(prior->literal().GetFirstElement<float>(), 24);
|
EXPECT_EQ(prior->literal().GetFirstElement<float>(), 24);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tests for referential transparency (a function that calls a function that
|
||||||
|
// returns false should be identical to just returning false).
|
||||||
|
TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
|
||||||
|
const Shape pred = ShapeUtil::MakeShape(PRED, {});
|
||||||
|
auto module = CreateNewModule();
|
||||||
|
|
||||||
|
// Create a lambda that calls a function that returns the false predicate.
|
||||||
|
// Note we also use this lambda twice by reference, just to make the test a
|
||||||
|
// little trickier.
|
||||||
|
HloComputation::Builder just_false(TestName() + ".false");
|
||||||
|
just_false.AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||||
|
HloComputation* false_computation =
|
||||||
|
module->AddEmbeddedComputation(just_false.Build());
|
||||||
|
|
||||||
|
HloComputation::Builder call_false_builder(TestName() + ".call_false");
|
||||||
|
call_false_builder.AddInstruction(
|
||||||
|
HloInstruction::CreateCall(pred, {}, false_computation));
|
||||||
|
HloComputation* call_false =
|
||||||
|
module->AddEmbeddedComputation(call_false_builder.Build());
|
||||||
|
|
||||||
|
HloComputation::Builder outer(TestName() + ".outer");
|
||||||
|
HloInstruction* init_value = outer.AddInstruction(
|
||||||
|
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
|
||||||
|
outer.AddInstruction(
|
||||||
|
HloInstruction::CreateWhile(pred, call_false, call_false, init_value));
|
||||||
|
|
||||||
|
auto computation = module->AddEntryComputation(outer.Build());
|
||||||
|
|
||||||
|
CallInliner call_inliner;
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
|
||||||
|
ASSERT_TRUE(mutated);
|
||||||
|
EXPECT_THAT(
|
||||||
|
computation->root_instruction()->while_condition()->root_instruction(),
|
||||||
|
op::Constant());
|
||||||
|
EXPECT_THAT(computation->root_instruction()->while_body()->root_instruction(),
|
||||||
|
op::Constant());
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -216,8 +216,7 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault {
|
|||||||
Status HandleCall(HloInstruction* call) override {
|
Status HandleCall(HloInstruction* call) override {
|
||||||
TF_RETURN_IF_ERROR(DefaultAction(call));
|
TF_RETURN_IF_ERROR(DefaultAction(call));
|
||||||
CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_);
|
CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(call->to_apply()->Accept(&candidates_for_call));
|
||||||
call->to_apply()->root_instruction()->Accept(&candidates_for_call));
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,8 +45,7 @@ string HloExecutionProfile::ToString(
|
|||||||
const HloComputation& computation,
|
const HloComputation& computation,
|
||||||
const DeviceDescription& device_description,
|
const DeviceDescription& device_description,
|
||||||
HloCostAnalysis* cost_analysis) const {
|
HloCostAnalysis* cost_analysis) const {
|
||||||
tensorflow::Status analysis_status =
|
tensorflow::Status analysis_status = computation.Accept(cost_analysis);
|
||||||
computation.root_instruction()->Accept(cost_analysis);
|
|
||||||
if (!analysis_status.ok()) {
|
if (!analysis_status.ok()) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
@ -1179,8 +1179,7 @@ tensorflow::Status Service::GetComputationStats(
|
|||||||
HloCostAnalysis analysis(
|
HloCostAnalysis analysis(
|
||||||
execute_backend_->compiler()->ShapeSizeBytesFunction());
|
execute_backend_->compiler()->ShapeSizeBytesFunction());
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis));
|
||||||
module->entry_computation()->root_instruction()->Accept(&analysis));
|
|
||||||
|
|
||||||
ComputationStats stats;
|
ComputationStats stats;
|
||||||
stats.set_flop_count(analysis.flop_count());
|
stats.set_flop_count(analysis.flop_count());
|
||||||
|
@ -151,19 +151,6 @@ XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) {
|
|||||||
ComputeAndCompareR0<int32>(&builder, -3, {});
|
ComputeAndCompareR0<int32>(&builder, -3, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) {
|
|
||||||
ComputationBuilder builder(client_, TestName());
|
|
||||||
auto a = builder.Parameter(0, ShapeUtil::MakeShape(S64, {}), "a");
|
|
||||||
builder.ConvertElementType(a, F32);
|
|
||||||
|
|
||||||
int64 value = 3LL << 32;
|
|
||||||
std::unique_ptr<Literal> a_literal = Literal::CreateR0<int64>(value);
|
|
||||||
std::unique_ptr<GlobalData> a_data =
|
|
||||||
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
|
|
||||||
ComputeAndCompareR0<float>(&builder, static_cast<float>(value),
|
|
||||||
{a_data.get()});
|
|
||||||
}
|
|
||||||
|
|
||||||
XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) {
|
XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) {
|
||||||
ComputationBuilder builder(client_, TestName());
|
ComputationBuilder builder(client_, TestName());
|
||||||
builder.Mul(builder.Mul(builder.ConstantR0<float>(2.1f),
|
builder.Mul(builder.Mul(builder.ConstantR0<float>(2.1f),
|
||||||
|
@ -94,7 +94,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
|
|||||||
|
|
||||||
OperationDumper dumper(arg);
|
OperationDumper dumper(arg);
|
||||||
for (auto& computation : module.computations()) {
|
for (auto& computation : module.computations()) {
|
||||||
TF_CHECK_OK(computation->root_instruction()->Accept(&dumper));
|
TF_CHECK_OK(computation->Accept(&dumper));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -113,6 +113,7 @@ py_test(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = [
|
tags = [
|
||||||
"nomac", # b/63258195
|
"nomac", # b/63258195
|
||||||
|
"notsan", # b/62863147
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":gbdt_batch",
|
":gbdt_batch",
|
||||||
|
@ -19,6 +19,14 @@ set(GIFLIB_INCLUDES
|
|||||||
"lib/gif_lib.h"
|
"lib/gif_lib.h"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (WIN32)
|
||||||
|
# Suppress warnings to reduce build log size.
|
||||||
|
add_definitions(/wd4267 /wd4244 /wd4800 /wd4503 /wd4554 /wd4996 /wd4348 /wd4018)
|
||||||
|
add_definitions(/wd4099 /wd4146 /wd4267 /wd4305 /wd4307)
|
||||||
|
add_definitions(/wd4715 /wd4722 /wd4723 /wd4838 /wd4309 /wd4334)
|
||||||
|
add_definitions(/wd4003 /wd4244 /wd4267 /wd4503 /wd4506 /wd4800 /wd4996)
|
||||||
|
endif()
|
||||||
|
|
||||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/lib")
|
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/lib")
|
||||||
|
|
||||||
add_library(giflib ${GIFLIB_SRCS})
|
add_library(giflib ${GIFLIB_SRCS})
|
||||||
|
@ -62,6 +62,14 @@ set(LIBJPEG_INCLUDES
|
|||||||
"jversion.h"
|
"jversion.h"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (WIN32)
|
||||||
|
# Suppress warnings to reduce build log size.
|
||||||
|
add_definitions(/wd4267 /wd4244 /wd4800 /wd4503 /wd4554 /wd4996 /wd4348 /wd4018)
|
||||||
|
add_definitions(/wd4099 /wd4146 /wd4267 /wd4305 /wd4307)
|
||||||
|
add_definitions(/wd4715 /wd4722 /wd4723 /wd4838 /wd4309 /wd4334)
|
||||||
|
add_definitions(/wd4003 /wd4244 /wd4267 /wd4503 /wd4506 /wd4800 /wd4996)
|
||||||
|
endif()
|
||||||
|
|
||||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}")
|
include_directories("${CMAKE_CURRENT_SOURCE_DIR}")
|
||||||
|
|
||||||
add_library(libjpeg ${LIBJPEG_SRCS})
|
add_library(libjpeg ${LIBJPEG_SRCS})
|
||||||
|
@ -12,6 +12,14 @@ set(LIBLMDB_INCLUDES
|
|||||||
"libraries/liblmdb/midl.h"
|
"libraries/liblmdb/midl.h"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (WIN32)
|
||||||
|
# Suppress warnings to reduce build log size.
|
||||||
|
add_definitions(/wd4267 /wd4244 /wd4800 /wd4503 /wd4554 /wd4996 /wd4348 /wd4018)
|
||||||
|
add_definitions(/wd4099 /wd4146 /wd4267 /wd4305 /wd4307)
|
||||||
|
add_definitions(/wd4715 /wd4722 /wd4723 /wd4838 /wd4309 /wd4334)
|
||||||
|
add_definitions(/wd4003 /wd4244 /wd4267 /wd4503 /wd4506 /wd4800 /wd4996)
|
||||||
|
endif()
|
||||||
|
|
||||||
include_directories("${CMAKE_CURRENT_SOURCE_DIR}")
|
include_directories("${CMAKE_CURRENT_SOURCE_DIR}")
|
||||||
|
|
||||||
add_library(lmdb ${LIBLMDB_SRCS})
|
add_library(lmdb ${LIBLMDB_SRCS})
|
||||||
|
@ -361,6 +361,11 @@ add_python_module("tensorflow/contrib/framework/python")
|
|||||||
add_python_module("tensorflow/contrib/framework/python/framework")
|
add_python_module("tensorflow/contrib/framework/python/framework")
|
||||||
add_python_module("tensorflow/contrib/framework/python/ops")
|
add_python_module("tensorflow/contrib/framework/python/ops")
|
||||||
add_python_module("tensorflow/contrib/gan")
|
add_python_module("tensorflow/contrib/gan")
|
||||||
|
add_python_module("tensorflow/contrib/gan/python")
|
||||||
|
add_python_module("tensorflow/contrib/gan/python/features")
|
||||||
|
add_python_module("tensorflow/contrib/gan/python/features/python")
|
||||||
|
add_python_module("tensorflow/contrib/gan/python/losses")
|
||||||
|
add_python_module("tensorflow/contrib/gan/python/losses/python")
|
||||||
add_python_module("tensorflow/contrib/graph_editor")
|
add_python_module("tensorflow/contrib/graph_editor")
|
||||||
add_python_module("tensorflow/contrib/graph_editor/examples")
|
add_python_module("tensorflow/contrib/graph_editor/examples")
|
||||||
add_python_module("tensorflow/contrib/graph_editor/tests")
|
add_python_module("tensorflow/contrib/graph_editor/tests")
|
||||||
@ -1147,12 +1152,24 @@ add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
|
|||||||
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen
|
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen
|
||||||
${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/include/unsupported/Eigen)
|
${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/include/unsupported/Eigen)
|
||||||
|
|
||||||
if(${tensorflow_ENABLE_GPU})
|
if(${tensorflow_TF_NIGHTLY})
|
||||||
add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
|
if(${tensorflow_ENABLE_GPU})
|
||||||
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/setup.py bdist_wheel --project_name tensorflow_gpu
|
add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
|
||||||
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tf_python)
|
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/setup.py bdist_wheel --project_name tf_nightly_gpu
|
||||||
|
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tf_python)
|
||||||
|
else()
|
||||||
|
add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
|
||||||
|
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/setup.py bdist_wheel --project_name tf_nightly
|
||||||
|
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tf_python)
|
||||||
|
endif(${tensorflow_ENABLE_GPU})
|
||||||
else()
|
else()
|
||||||
add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
|
if(${tensorflow_ENABLE_GPU})
|
||||||
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/setup.py bdist_wheel
|
add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
|
||||||
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tf_python)
|
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/setup.py bdist_wheel --project_name tensorflow_gpu
|
||||||
endif(${tensorflow_ENABLE_GPU})
|
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tf_python)
|
||||||
|
else()
|
||||||
|
add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
|
||||||
|
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/setup.py bdist_wheel
|
||||||
|
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tf_python)
|
||||||
|
endif(${tensorflow_ENABLE_GPU})
|
||||||
|
endif(${tensorflow_TF_NIGHTLY})
|
||||||
|
@ -24,6 +24,8 @@ py_test(
|
|||||||
"//tensorflow/python:functional_ops",
|
"//tensorflow/python:functional_ops",
|
||||||
"//tensorflow/python:gradients",
|
"//tensorflow/python:gradients",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:parsing_ops",
|
||||||
|
"//tensorflow/python:script_ops",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
],
|
],
|
||||||
|
@ -27,10 +27,13 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import function
|
from tensorflow.python.framework import function
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import functional_ops
|
from tensorflow.python.ops import functional_ops
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import parsing_ops
|
||||||
|
from tensorflow.python.ops import script_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
from tensorflow.python.training import server_lib
|
from tensorflow.python.training import server_lib
|
||||||
|
|
||||||
@ -420,7 +423,7 @@ class IteratorTest(test.TestCase):
|
|||||||
|
|
||||||
def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
|
def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
|
||||||
worker_config = config_pb2.ConfigProto()
|
worker_config = config_pb2.ConfigProto()
|
||||||
worker_config.device_count["CPU"] = 2
|
worker_config.device_count["CPU"] = 3
|
||||||
|
|
||||||
with ops.device("/job:localhost/replica:0/task:0/cpu:1"):
|
with ops.device("/job:localhost/replica:0/task:0/cpu:1"):
|
||||||
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||||
@ -448,12 +451,12 @@ class IteratorTest(test.TestCase):
|
|||||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
|
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
|
||||||
})
|
})
|
||||||
self.assertEqual(elem, [1])
|
self.assertEqual(elem, [1])
|
||||||
# Fails when target is cpu:0 where the resource is not located.
|
# Fails when target is cpu:2 where the resource is not located.
|
||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
sess.run(
|
sess.run(
|
||||||
remote_op,
|
remote_op,
|
||||||
feed_dict={
|
feed_dict={
|
||||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
|
target_placeholder: "/job:localhost/replica:0/task:0/cpu:2"
|
||||||
})
|
})
|
||||||
elem = sess.run(
|
elem = sess.run(
|
||||||
remote_op,
|
remote_op,
|
||||||
@ -474,6 +477,61 @@ class IteratorTest(test.TestCase):
|
|||||||
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
|
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self):
|
||||||
|
if not test_util.is_gpu_available():
|
||||||
|
self.skipTest("No GPU available")
|
||||||
|
|
||||||
|
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
|
||||||
|
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
|
||||||
|
iterator_3 = dataset_3.make_one_shot_iterator()
|
||||||
|
iterator_3_handle = iterator_3.string_handle()
|
||||||
|
|
||||||
|
def _encode_raw(byte_array):
|
||||||
|
return "".join([chr(item) for item in byte_array])
|
||||||
|
|
||||||
|
@function.Defun(dtypes.uint8)
|
||||||
|
def _remote_fn(h):
|
||||||
|
handle = script_ops.py_func(_encode_raw, [h], dtypes.string)
|
||||||
|
remote_iterator = dataset_ops.Iterator.from_string_handle(
|
||||||
|
handle, dataset_3.output_types, dataset_3.output_shapes)
|
||||||
|
return remote_iterator.get_next()
|
||||||
|
|
||||||
|
with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
|
||||||
|
target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
|
||||||
|
iterator_3_handle_uint8 = parsing_ops.decode_raw(
|
||||||
|
bytes=iterator_3_handle, out_type=dtypes.uint8)
|
||||||
|
remote_op = functional_ops.remote_call(
|
||||||
|
args=[iterator_3_handle_uint8],
|
||||||
|
Tout=[dtypes.int32],
|
||||||
|
f=_remote_fn,
|
||||||
|
target=target_placeholder)
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
elem = sess.run(
|
||||||
|
remote_op,
|
||||||
|
feed_dict={
|
||||||
|
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
|
||||||
|
})
|
||||||
|
self.assertEqual(elem, [1])
|
||||||
|
elem = sess.run(
|
||||||
|
remote_op,
|
||||||
|
feed_dict={
|
||||||
|
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
|
||||||
|
})
|
||||||
|
self.assertEqual(elem, [2])
|
||||||
|
elem = sess.run(
|
||||||
|
remote_op,
|
||||||
|
feed_dict={
|
||||||
|
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
|
||||||
|
})
|
||||||
|
self.assertEqual(elem, [3])
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(
|
||||||
|
remote_op,
|
||||||
|
feed_dict={
|
||||||
|
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -235,7 +235,7 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
self.read_coordination_events[expected_element].acquire()
|
self.read_coordination_events[expected_element].acquire()
|
||||||
else:
|
else:
|
||||||
self.write_coordination_events[expected_element].set()
|
self.write_coordination_events[expected_element].set()
|
||||||
time.sleep(0.01) # Sleep to consistently "avoid" the race condition.
|
time.sleep(0.1) # Sleep to consistently "avoid" the race condition.
|
||||||
actual_element = sess.run(self.next_element)
|
actual_element = sess.run(self.next_element)
|
||||||
if not done_first_event:
|
if not done_first_event:
|
||||||
done_first_event = True
|
done_first_event = True
|
||||||
@ -300,7 +300,7 @@ class SloppyInterleaveDatasetTest(test.TestCase):
|
|||||||
self.read_coordination_events[expected_element].acquire()
|
self.read_coordination_events[expected_element].acquire()
|
||||||
else:
|
else:
|
||||||
self.write_coordination_events[expected_element].set()
|
self.write_coordination_events[expected_element].set()
|
||||||
time.sleep(0.01) # Sleep to consistently "avoid" the race condition.
|
time.sleep(0.1) # Sleep to consistently "avoid" the race condition.
|
||||||
actual_element = sess.run(self.next_element)
|
actual_element = sess.run(self.next_element)
|
||||||
if not done_first_event:
|
if not done_first_event:
|
||||||
done_first_event = True
|
done_first_event = True
|
||||||
|
@ -49,25 +49,46 @@ class SqlDatasetTest(test.TestCase):
|
|||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute("DROP TABLE IF EXISTS students")
|
c.execute("DROP TABLE IF EXISTS students")
|
||||||
c.execute("DROP TABLE IF EXISTS people")
|
c.execute("DROP TABLE IF EXISTS people")
|
||||||
|
c.execute("DROP TABLE IF EXISTS townspeople")
|
||||||
c.execute(
|
c.execute(
|
||||||
"CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY,"
|
"CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, "
|
||||||
" first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100),"
|
"first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), "
|
||||||
" school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), "
|
"school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), "
|
||||||
"grade_level INTEGER, income INTEGER, favorite_number INTEGER)")
|
"desk_number INTEGER, income INTEGER, favorite_number INTEGER, "
|
||||||
|
"favorite_big_number INTEGER, favorite_negative_number INTEGER, "
|
||||||
|
"favorite_medium_sized_number INTEGER, brownie_points INTEGER, "
|
||||||
|
"account_balance INTEGER, registration_complete INTEGER)")
|
||||||
c.executemany(
|
c.executemany(
|
||||||
"INSERT INTO students (first_name, last_name, motto, school_id, "
|
"INSERT INTO students (first_name, last_name, motto, school_id, "
|
||||||
"favorite_nonsense_word, grade_level, income, favorite_number) "
|
"favorite_nonsense_word, desk_number, income, favorite_number, "
|
||||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
|
"favorite_big_number, favorite_negative_number, "
|
||||||
[("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647),
|
"favorite_medium_sized_number, brownie_points, account_balance, "
|
||||||
("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 11, -20000,
|
"registration_complete) "
|
||||||
-2147483648)])
|
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
|
[("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647,
|
||||||
|
9223372036854775807, -2, 32767, 0, 0, 1),
|
||||||
|
("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000,
|
||||||
|
-2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)])
|
||||||
c.execute(
|
c.execute(
|
||||||
"CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, "
|
"CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, "
|
||||||
"first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))")
|
"first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))")
|
||||||
c.executemany(
|
c.executemany(
|
||||||
"INSERT INTO people (first_name, last_name, state) VALUES (?, ?, ?)",
|
"INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)",
|
||||||
[("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe",
|
[("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe",
|
||||||
"California")])
|
"California")])
|
||||||
|
c.execute(
|
||||||
|
"CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY "
|
||||||
|
"KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories "
|
||||||
|
"FLOAT, accolades FLOAT, triumphs FLOAT)")
|
||||||
|
c.executemany(
|
||||||
|
"INSERT INTO townspeople (first_name, last_name, victories, "
|
||||||
|
"accolades, triumphs) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
[("George", "Washington", 20.00,
|
||||||
|
1331241.321342132321324589798264627463827647382647382643874,
|
||||||
|
9007199254740991.0),
|
||||||
|
("John", "Adams", -19.95,
|
||||||
|
1331241321342132321324589798264627463827647382647382643874.0,
|
||||||
|
9007199254740992.0)])
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
|
|
||||||
@ -80,7 +101,6 @@ class SqlDatasetTest(test.TestCase):
|
|||||||
sess.run(
|
sess.run(
|
||||||
init_op,
|
init_op,
|
||||||
feed_dict={
|
feed_dict={
|
||||||
self.driver_name: "sqlite",
|
|
||||||
self.query: "SELECT first_name, last_name, motto FROM students "
|
self.query: "SELECT first_name, last_name, motto FROM students "
|
||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
@ -98,7 +118,6 @@ class SqlDatasetTest(test.TestCase):
|
|||||||
sess.run(
|
sess.run(
|
||||||
init_op,
|
init_op,
|
||||||
feed_dict={
|
feed_dict={
|
||||||
self.driver_name: "sqlite",
|
|
||||||
self.query:
|
self.query:
|
||||||
"SELECT students.first_name, state, motto FROM students "
|
"SELECT students.first_name, state, motto FROM students "
|
||||||
"INNER JOIN people "
|
"INNER JOIN people "
|
||||||
@ -118,7 +137,6 @@ class SqlDatasetTest(test.TestCase):
|
|||||||
sess.run(
|
sess.run(
|
||||||
init_op,
|
init_op,
|
||||||
feed_dict={
|
feed_dict={
|
||||||
self.driver_name: "sqlite",
|
|
||||||
self.query:
|
self.query:
|
||||||
"SELECT first_name, last_name, favorite_nonsense_word "
|
"SELECT first_name, last_name, favorite_nonsense_word "
|
||||||
"FROM students ORDER BY first_name DESC"
|
"FROM students ORDER BY first_name DESC"
|
||||||
@ -249,20 +267,124 @@ class SqlDatasetTest(test.TestCase):
|
|||||||
with self.assertRaises(errors.InvalidArgumentError):
|
with self.assertRaises(errors.InvalidArgumentError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read an integer from a SQLite database table and
|
||||||
|
# place it in an `int8` tensor.
|
||||||
|
def testReadResultSetInt8(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, desk_number FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||||
|
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||||
|
# SQLite database table and place it in an `int8` tensor.
|
||||||
|
def testReadResultSetInt8NegativeAndZero(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8,
|
||||||
|
dtypes.int8))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, income, favorite_negative_number "
|
||||||
|
"FROM students "
|
||||||
|
"WHERE first_name = 'John' ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"John", 0, -2), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a large (positive or negative) integer from
|
||||||
|
# a SQLite database table and place it in an `int8` tensor.
|
||||||
|
def testReadResultSetInt8MaxValues(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query:
|
||||||
|
"SELECT desk_number, favorite_negative_number FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((9, -2), sess.run(get_next))
|
||||||
|
# Max and min values of int8
|
||||||
|
self.assertEqual((127, -128), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read an integer from a SQLite database table and
|
||||||
|
# place it in an `int16` tensor.
|
||||||
|
def testReadResultSetInt16(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, desk_number FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||||
|
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||||
|
# SQLite database table and place it in an `int16` tensor.
|
||||||
|
def testReadResultSetInt16NegativeAndZero(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16,
|
||||||
|
dtypes.int16))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, income, favorite_negative_number "
|
||||||
|
"FROM students "
|
||||||
|
"WHERE first_name = 'John' ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"John", 0, -2), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a large (positive or negative) integer from
|
||||||
|
# a SQLite database table and place it in an `int16` tensor.
|
||||||
|
def testReadResultSetInt16MaxValues(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, favorite_medium_sized_number "
|
||||||
|
"FROM students ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
# Max value of int16
|
||||||
|
self.assertEqual((b"John", 32767), sess.run(get_next))
|
||||||
|
# Min value of int16
|
||||||
|
self.assertEqual((b"Jane", -32768), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read an integer from a SQLite database table and
|
||||||
|
# place it in an `int32` tensor.
|
||||||
def testReadResultSetInt32(self):
|
def testReadResultSetInt32(self):
|
||||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
sess.run(
|
sess.run(
|
||||||
init_op,
|
init_op,
|
||||||
feed_dict={
|
feed_dict={
|
||||||
self.query: "SELECT first_name, grade_level FROM students "
|
self.query: "SELECT first_name, desk_number FROM students "
|
||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
self.assertEqual((b"John", 9), sess.run(get_next))
|
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||||
self.assertEqual((b"Jane", 11), sess.run(get_next))
|
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
|
||||||
sess.run(get_next)
|
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||||
|
# SQLite database table and place it in an `int32` tensor.
|
||||||
def testReadResultSetInt32NegativeAndZero(self):
|
def testReadResultSetInt32NegativeAndZero(self):
|
||||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -277,6 +399,8 @@ class SqlDatasetTest(test.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a large (positive or negative) integer from
|
||||||
|
# a SQLite database table and place it in an `int32` tensor.
|
||||||
def testReadResultSetInt32MaxValues(self):
|
def testReadResultSetInt32MaxValues(self):
|
||||||
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
@ -286,7 +410,9 @@ class SqlDatasetTest(test.TestCase):
|
|||||||
self.query: "SELECT first_name, favorite_number FROM students "
|
self.query: "SELECT first_name, favorite_number FROM students "
|
||||||
"ORDER BY first_name DESC"
|
"ORDER BY first_name DESC"
|
||||||
})
|
})
|
||||||
|
# Max value of int32
|
||||||
self.assertEqual((b"John", 2147483647), sess.run(get_next))
|
self.assertEqual((b"John", 2147483647), sess.run(get_next))
|
||||||
|
# Min value of int32
|
||||||
self.assertEqual((b"Jane", -2147483648), sess.run(get_next))
|
self.assertEqual((b"Jane", -2147483648), sess.run(get_next))
|
||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
@ -307,6 +433,224 @@ class SqlDatasetTest(test.TestCase):
|
|||||||
with self.assertRaises(errors.OutOfRangeError):
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
sess.run(get_next)
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read an integer from a SQLite database table
|
||||||
|
# and place it in an `int64` tensor.
|
||||||
|
def testReadResultSetInt64(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, desk_number FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||||
|
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a negative or 0-valued integer from a
|
||||||
|
# SQLite database table and place it in an `int64` tensor.
|
||||||
|
def testReadResultSetInt64NegativeAndZero(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, income FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||||
|
self.assertEqual((b"Jane", -20000), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a large (positive or negative) integer from
|
||||||
|
# a SQLite database table and place it in an `int64` tensor.
|
||||||
|
def testReadResultSetInt64MaxValues(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query:
|
||||||
|
"SELECT first_name, favorite_big_number FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
# Max value of int64
|
||||||
|
self.assertEqual((b"John", 9223372036854775807), sess.run(get_next))
|
||||||
|
# Min value of int64
|
||||||
|
self.assertEqual((b"Jane", -9223372036854775808), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read an integer from a SQLite database table and
|
||||||
|
# place it in a `uint8` tensor.
|
||||||
|
def testReadResultSetUInt8(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, desk_number FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||||
|
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read the minimum and maximum uint8 values from a
|
||||||
|
# SQLite database table and place them in `uint8` tensors.
|
||||||
|
def testReadResultSetUInt8MinAndMaxValues(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, brownie_points FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
# Min value of uint8
|
||||||
|
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||||
|
# Max value of uint8
|
||||||
|
self.assertEqual((b"Jane", 255), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read an integer from a SQLite database table
|
||||||
|
# and place it in a `uint16` tensor.
|
||||||
|
def testReadResultSetUInt16(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, desk_number FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"John", 9), sess.run(get_next))
|
||||||
|
self.assertEqual((b"Jane", 127), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read the minimum and maximum uint16 values from a
|
||||||
|
# SQLite database table and place them in `uint16` tensors.
|
||||||
|
def testReadResultSetUInt16MinAndMaxValues(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, account_balance FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
# Min value of uint16
|
||||||
|
self.assertEqual((b"John", 0), sess.run(get_next))
|
||||||
|
# Max value of uint16
|
||||||
|
self.assertEqual((b"Jane", 65535), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a 0-valued and 1-valued integer from a
|
||||||
|
# SQLite database table and place them as `True` and `False` respectively
|
||||||
|
# in `bool` tensors.
|
||||||
|
def testReadResultSetBool(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query:
|
||||||
|
"SELECT first_name, registration_complete FROM students "
|
||||||
|
"ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"John", True), sess.run(get_next))
|
||||||
|
self.assertEqual((b"Jane", False), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read an integer that is not 0-valued or 1-valued
|
||||||
|
# from a SQLite database table and place it as `True` in a `bool` tensor.
|
||||||
|
def testReadResultSetBoolNotZeroOrOne(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query: "SELECT first_name, favorite_medium_sized_number "
|
||||||
|
"FROM students ORDER BY first_name DESC"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"John", True), sess.run(get_next))
|
||||||
|
self.assertEqual((b"Jane", True), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a float from a SQLite database table
|
||||||
|
# and place it in a `float64` tensor.
|
||||||
|
def testReadResultSetFloat64(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||||
|
dtypes.float64))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query:
|
||||||
|
"SELECT first_name, last_name, victories FROM townspeople "
|
||||||
|
"ORDER BY first_name"
|
||||||
|
})
|
||||||
|
self.assertEqual((b"George", b"Washington", 20.0), sess.run(get_next))
|
||||||
|
self.assertEqual((b"John", b"Adams", -19.95), sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a float from a SQLite database table beyond
|
||||||
|
# the precision of 64-bit IEEE, without throwing an error. Test that
|
||||||
|
# `SqlDataset` identifies such a value as equal to itself.
|
||||||
|
def testReadResultSetFloat64OverlyPrecise(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||||
|
dtypes.float64))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query:
|
||||||
|
"SELECT first_name, last_name, accolades FROM townspeople "
|
||||||
|
"ORDER BY first_name"
|
||||||
|
})
|
||||||
|
self.assertEqual(
|
||||||
|
(b"George", b"Washington",
|
||||||
|
1331241.321342132321324589798264627463827647382647382643874),
|
||||||
|
sess.run(get_next))
|
||||||
|
self.assertEqual(
|
||||||
|
(b"John", b"Adams",
|
||||||
|
1331241321342132321324589798264627463827647382647382643874.0),
|
||||||
|
sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
# Test that `SqlDataset` can read a float from a SQLite database table,
|
||||||
|
# representing the largest integer representable as a 64-bit IEEE float
|
||||||
|
# such that the previous integer is also representable as a 64-bit IEEE float.
|
||||||
|
# Test that `SqlDataset` can distinguish these two numbers.
|
||||||
|
def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self):
|
||||||
|
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
|
||||||
|
dtypes.float64))
|
||||||
|
with self.test_session() as sess:
|
||||||
|
sess.run(
|
||||||
|
init_op,
|
||||||
|
feed_dict={
|
||||||
|
self.query:
|
||||||
|
"SELECT first_name, last_name, triumphs FROM townspeople "
|
||||||
|
"ORDER BY first_name"
|
||||||
|
})
|
||||||
|
self.assertNotEqual((b"George", b"Washington", 9007199254740992.0),
|
||||||
|
sess.run(get_next))
|
||||||
|
self.assertNotEqual((b"John", b"Adams", 9007199254740991.0),
|
||||||
|
sess.run(get_next))
|
||||||
|
with self.assertRaises(errors.OutOfRangeError):
|
||||||
|
sess.run(get_next)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -2276,6 +2276,23 @@ class SqlDataset(Dataset):
|
|||||||
def __init__(self, driver_name, data_source_name, query, output_types):
|
def __init__(self, driver_name, data_source_name, query, output_types):
|
||||||
"""Creates a `SqlDataset`.
|
"""Creates a `SqlDataset`.
|
||||||
|
|
||||||
|
`SqlDataset` allows a user to read data from the result set of a SQL query.
|
||||||
|
For example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
dataset = tf.contrib.data.SqlDataset("sqlite", "/foo/bar.sqlite3",
|
||||||
|
"SELECT name, age FROM people",
|
||||||
|
(tf.string, tf.int32))
|
||||||
|
iterator = dataset.make_one_shot_iterator()
|
||||||
|
next_element = iterator.get_next()
|
||||||
|
# Prints the rows of the result set of the above query.
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
print(sess.run(next_element))
|
||||||
|
except tf.errors.OutOfRangeError:
|
||||||
|
break
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
driver_name: A 0-D `tf.string` tensor containing the database type.
|
driver_name: A 0-D `tf.string` tensor containing the database type.
|
||||||
Currently, the only supported value is 'sqlite'.
|
Currently, the only supported value is 'sqlite'.
|
||||||
|
@ -21,6 +21,16 @@ py_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cuda_py_test(
|
||||||
|
name = "tfe_test",
|
||||||
|
srcs = ["tfe_test.py"],
|
||||||
|
additional_deps = [
|
||||||
|
":tfe",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "saver",
|
name = "saver",
|
||||||
srcs = ["saver.py"],
|
srcs = ["saver.py"],
|
||||||
|
@ -18,9 +18,9 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice.
|
|||||||
|
|
||||||
To use, at program startup, call `tfe.enable_eager_execution()`.
|
To use, at program startup, call `tfe.enable_eager_execution()`.
|
||||||
|
|
||||||
@@list_devices
|
|
||||||
@@device
|
@@device
|
||||||
|
@@list_devices
|
||||||
|
@@num_gpus
|
||||||
|
|
||||||
@@defun
|
@@defun
|
||||||
@@implicit_gradients
|
@@implicit_gradients
|
||||||
@ -58,9 +58,10 @@ from tensorflow.python.util.all_util import remove_undocumented
|
|||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager.custom_gradient import custom_gradient
|
from tensorflow.python.eager.custom_gradient import custom_gradient
|
||||||
from tensorflow.python.eager import function
|
from tensorflow.python.eager import function
|
||||||
from tensorflow.python.eager.context import context
|
|
||||||
from tensorflow.python.eager.context import device
|
from tensorflow.python.eager.context import device
|
||||||
from tensorflow.python.eager.context import enable_eager_execution
|
from tensorflow.python.eager.context import enable_eager_execution
|
||||||
|
from tensorflow.python.eager.context import list_devices
|
||||||
|
from tensorflow.python.eager.context import num_gpus
|
||||||
from tensorflow.python.eager.context import run
|
from tensorflow.python.eager.context import run
|
||||||
from tensorflow.python.eager.core import enable_tracing
|
from tensorflow.python.eager.core import enable_tracing
|
||||||
from tensorflow.python.eager.execution_callbacks import add_execution_callback
|
from tensorflow.python.eager.execution_callbacks import add_execution_callback
|
||||||
@ -70,10 +71,6 @@ from tensorflow.python.eager.execution_callbacks import inf_nan_callback
|
|||||||
from tensorflow.python.eager.execution_callbacks import nan_callback
|
from tensorflow.python.eager.execution_callbacks import nan_callback
|
||||||
from tensorflow.python.eager.execution_callbacks import seterr
|
from tensorflow.python.eager.execution_callbacks import seterr
|
||||||
|
|
||||||
|
|
||||||
def list_devices():
|
|
||||||
return context().devices()
|
|
||||||
|
|
||||||
defun = function.defun
|
defun = function.defun
|
||||||
implicit_gradients = backprop.implicit_grad
|
implicit_gradients = backprop.implicit_grad
|
||||||
implicit_value_and_gradients = backprop.implicit_val_and_grad
|
implicit_value_and_gradients = backprop.implicit_val_and_grad
|
||||||
|
36
tensorflow/contrib/eager/python/tfe_test.py
Normal file
36
tensorflow/contrib/eager/python/tfe_test.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tfe.py."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.eager.python import tfe
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class TFETest(test.TestCase):
|
||||||
|
|
||||||
|
def testListDevices(self):
|
||||||
|
# Expect at least one device.
|
||||||
|
self.assertTrue(tfe.list_devices())
|
||||||
|
|
||||||
|
def testNumGPUs(self):
|
||||||
|
devices = tfe.list_devices()
|
||||||
|
self.assertEqual(len(devices) - 1, tfe.num_gpus())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test.main()
|
@ -26,6 +26,7 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
":extenders",
|
":extenders",
|
||||||
|
":head",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -59,3 +60,14 @@ py_test(
|
|||||||
"//third_party/py/numpy",
|
"//third_party/py/numpy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "head",
|
||||||
|
srcs = [
|
||||||
|
"python/estimator/head.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python/estimator:head",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@ -20,10 +20,16 @@ from __future__ import print_function
|
|||||||
|
|
||||||
# pylint: disable=unused-import,line-too-long,wildcard-import
|
# pylint: disable=unused-import,line-too-long,wildcard-import
|
||||||
from tensorflow.contrib.estimator.python.estimator.extenders import *
|
from tensorflow.contrib.estimator.python.estimator.extenders import *
|
||||||
|
from tensorflow.contrib.estimator.python.estimator.head import *
|
||||||
|
|
||||||
from tensorflow.python.util.all_util import remove_undocumented
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
# pylint: enable=unused-import,line-too-long,wildcard-import
|
# pylint: enable=unused-import,line-too-long,wildcard-import
|
||||||
|
|
||||||
_allowed_symbols = ['add_metrics']
|
_allowed_symbols = [
|
||||||
|
'add_metrics',
|
||||||
|
'binary_classification_head',
|
||||||
|
'multi_class_head',
|
||||||
|
'regression_head',
|
||||||
|
]
|
||||||
|
|
||||||
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
|
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
|
||||||
|
125
tensorflow/contrib/estimator/python/estimator/head.py
Normal file
125
tensorflow/contrib/estimator/python/estimator/head.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Abstractions for the head(s) of a model."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.estimator.canned import head as head_lib
|
||||||
|
|
||||||
|
|
||||||
|
def multi_class_head(n_classes,
|
||||||
|
weight_column=None,
|
||||||
|
label_vocabulary=None,
|
||||||
|
head_name=None):
|
||||||
|
"""Creates a `_Head` for multi class classification.
|
||||||
|
|
||||||
|
Uses `sparse_softmax_cross_entropy` loss.
|
||||||
|
|
||||||
|
This head expects to be fed integer labels specifying the class index.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_classes: Number of classes, must be greater than 2 (for 2 classes, use
|
||||||
|
`_BinaryLogisticHeadWithSigmoidCrossEntropyLoss`).
|
||||||
|
weight_column: A string or a `_NumericColumn` created by
|
||||||
|
`tf.feature_column.numeric_column` defining feature column representing
|
||||||
|
weights. It is used to down weight or boost examples during training. It
|
||||||
|
will be multiplied by the loss of the example.
|
||||||
|
label_vocabulary: A list of strings represents possible label values. If it
|
||||||
|
is not given, that means labels are already encoded as integer within
|
||||||
|
[0, n_classes). If given, labels must be string type and have any value in
|
||||||
|
`label_vocabulary`. Also there will be errors if vocabulary is not
|
||||||
|
provided and labels are string.
|
||||||
|
head_name: name of the head. If provided, summary and metrics keys will be
|
||||||
|
suffixed by `"/" + head_name`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of `_Head` for multi class classification.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `n_classes`, `metric_class_ids` or `label_keys` is invalid.
|
||||||
|
"""
|
||||||
|
return head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint:disable=protected-access
|
||||||
|
n_classes=n_classes,
|
||||||
|
weight_column=weight_column,
|
||||||
|
label_vocabulary=label_vocabulary,
|
||||||
|
head_name=head_name)
|
||||||
|
|
||||||
|
|
||||||
|
def binary_classification_head(
|
||||||
|
weight_column=None, thresholds=None, label_vocabulary=None, head_name=None):
|
||||||
|
"""Creates a `_Head` for single label binary classification.
|
||||||
|
|
||||||
|
This head uses `sigmoid_cross_entropy_with_logits` loss.
|
||||||
|
|
||||||
|
This head expects to be fed float labels of shape `(batch_size, 1)`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight_column: A string or a `_NumericColumn` created by
|
||||||
|
`tf.feature_column.numeric_column` defining feature column representing
|
||||||
|
weights. It is used to down weight or boost examples during training. It
|
||||||
|
will be multiplied by the loss of the example.
|
||||||
|
thresholds: Iterable of floats in the range `(0, 1)`. For binary
|
||||||
|
classification metrics such as precision and recall, an eval metric is
|
||||||
|
generated for each threshold value. This threshold is applied to the
|
||||||
|
logistic values to determine the binary classification (i.e., above the
|
||||||
|
threshold is `true`, below is `false`.
|
||||||
|
label_vocabulary: A list of strings represents possible label values. If it
|
||||||
|
is not given, that means labels are already encoded within [0, 1]. If
|
||||||
|
given, labels must be string type and have any value in
|
||||||
|
`label_vocabulary`. Also there will be errors if vocabulary is not
|
||||||
|
provided and labels are string.
|
||||||
|
head_name: name of the head. If provided, summary and metrics keys will be
|
||||||
|
suffixed by `"/" + head_name`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of `_Head` for binary classification.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `thresholds` contains a value outside of `(0, 1)`.
|
||||||
|
"""
|
||||||
|
return head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint:disable=protected-access
|
||||||
|
weight_column=weight_column,
|
||||||
|
thresholds=thresholds,
|
||||||
|
label_vocabulary=label_vocabulary,
|
||||||
|
head_name=head_name)
|
||||||
|
|
||||||
|
|
||||||
|
def regression_head(weight_column=None,
|
||||||
|
label_dimension=1,
|
||||||
|
head_name=None):
|
||||||
|
"""Creates a `_Head` for regression using the mean squared loss.
|
||||||
|
|
||||||
|
Uses `mean_squared_error` loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight_column: A string or a `_NumericColumn` created by
|
||||||
|
`tf.feature_column.numeric_column` defining feature column representing
|
||||||
|
weights. It is used to down weight or boost examples during training. It
|
||||||
|
will be multiplied by the loss of the example.
|
||||||
|
label_dimension: Number of regression labels per example. This is the size
|
||||||
|
of the last dimension of the labels `Tensor` (typically, this has shape
|
||||||
|
`[batch_size, label_dimension]`).
|
||||||
|
head_name: name of the head. If provided, summary and metrics keys will be
|
||||||
|
suffixed by `"/" + head_name`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of `_Head` for linear regression.
|
||||||
|
"""
|
||||||
|
return head_lib._regression_head_with_mean_squared_error_loss( # pylint:disable=protected-access
|
||||||
|
weight_column=weight_column,
|
||||||
|
label_dimension=label_dimension,
|
||||||
|
head_name=head_name)
|
@ -162,6 +162,7 @@ tf_py_test(
|
|||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python:variables",
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
|
tags = ["notsan"], # b/62863147
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
@ -1,9 +1,12 @@
|
|||||||
|
# Files for using TFGAN framework.
|
||||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
exports_files(["LICENSE"])
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "gan",
|
name = "gan",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -11,6 +14,192 @@ py_library(
|
|||||||
],
|
],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":features",
|
||||||
|
":losses",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "losses",
|
||||||
|
srcs = ["python/losses/__init__.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":losses_impl",
|
||||||
|
":tuple_losses",
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "features",
|
||||||
|
srcs = ["python/features/__init__.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":clip_weights",
|
||||||
|
":conditioning_utils",
|
||||||
|
":virtual_batchnorm",
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "losses_impl",
|
||||||
|
srcs = ["python/losses/python/losses_impl.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/contrib/framework:framework_py",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:clip_ops",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:gradients",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:random_ops",
|
||||||
|
"//tensorflow/python:summary",
|
||||||
|
"//tensorflow/python:tensor_util",
|
||||||
|
"//tensorflow/python:variable_scope",
|
||||||
|
"//tensorflow/python/ops/distributions",
|
||||||
|
"//tensorflow/python/ops/losses",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "losses_impl_test",
|
||||||
|
srcs = ["python/losses/python/losses_impl_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":losses_impl",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:clip_ops",
|
||||||
|
"//tensorflow/python:constant_op",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:random_ops",
|
||||||
|
"//tensorflow/python:random_seed",
|
||||||
|
"//tensorflow/python:variable_scope",
|
||||||
|
"//tensorflow/python:variables",
|
||||||
|
"//tensorflow/python/ops/distributions",
|
||||||
|
"//tensorflow/python/ops/losses",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "tuple_losses",
|
||||||
|
srcs = [
|
||||||
|
"python/losses/python/losses_wargs.py",
|
||||||
|
"python/losses/python/tuple_losses.py",
|
||||||
|
"python/losses/python/tuple_losses_impl.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":losses_impl",
|
||||||
|
"//tensorflow/python:util",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "tuple_losses_test",
|
||||||
|
srcs = ["python/losses/python/tuple_losses_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":tuple_losses",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "conditioning_utils",
|
||||||
|
srcs = [
|
||||||
|
"python/features/python/conditioning_utils.py",
|
||||||
|
"python/features/python/conditioning_utils_impl.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/contrib/layers:layers_py",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:embedding_ops",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:tensor_util",
|
||||||
|
"//tensorflow/python:variable_scope",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "conditioning_utils_test",
|
||||||
|
srcs = ["python/features/python/conditioning_utils_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":conditioning_utils",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "virtual_batchnorm",
|
||||||
|
srcs = [
|
||||||
|
"python/features/python/virtual_batchnorm.py",
|
||||||
|
"python/features/python/virtual_batchnorm_impl.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:framework_ops",
|
||||||
|
"//tensorflow/python:init_ops",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:nn",
|
||||||
|
"//tensorflow/python:tensor_shape",
|
||||||
|
"//tensorflow/python:tensor_util",
|
||||||
|
"//tensorflow/python:variable_scope",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "virtual_batchnorm_test",
|
||||||
|
srcs = ["python/features/python/virtual_batchnorm_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":virtual_batchnorm",
|
||||||
|
"//tensorflow/contrib/framework:framework_py",
|
||||||
|
"//tensorflow/python:array_ops",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:constant_op",
|
||||||
|
"//tensorflow/python:dtypes",
|
||||||
|
"//tensorflow/python:layers",
|
||||||
|
"//tensorflow/python:math_ops",
|
||||||
|
"//tensorflow/python:nn",
|
||||||
|
"//tensorflow/python:random_ops",
|
||||||
|
"//tensorflow/python:random_seed",
|
||||||
|
"//tensorflow/python:variable_scope",
|
||||||
|
"//tensorflow/python:variables",
|
||||||
|
"//third_party/py/numpy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "clip_weights",
|
||||||
|
srcs = [
|
||||||
|
"python/features/python/clip_weights.py",
|
||||||
|
"python/features/python/clip_weights_impl.py",
|
||||||
|
],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = ["//tensorflow/contrib/opt:opt_py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "clip_weights_test",
|
||||||
|
srcs = ["python/features/python/clip_weights_test.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = [
|
||||||
|
":clip_weights",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:training",
|
||||||
|
"//tensorflow/python:variables",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2017 Google Inc. All Rights Reserved.
|
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -12,8 +12,16 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""TFGAN grouped API."""
|
"""TFGAN grouped API. Please see README.md for details and usage."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# Collapse TFGAN into a tiered namespace.
|
||||||
|
from tensorflow.contrib.gan.python import features
|
||||||
|
from tensorflow.contrib.gan.python import losses
|
||||||
|
|
||||||
|
del absolute_import
|
||||||
|
del division
|
||||||
|
del print_function
|
||||||
|
37
tensorflow/contrib/gan/python/features/__init__.py
Normal file
37
tensorflow/contrib/gan/python/features/__init__.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
# Copyright 2017 Google Inc. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""TFGAN grouped API. Please see README.md for details and usage."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# Collapse features into a single namespace.
|
||||||
|
# pylint: disable=unused-import,wildcard-import
|
||||||
|
from tensorflow.contrib.gan.python.features.python import clip_weights
|
||||||
|
from tensorflow.contrib.gan.python.features.python import conditioning_utils
|
||||||
|
from tensorflow.contrib.gan.python.features.python import virtual_batchnorm
|
||||||
|
|
||||||
|
from tensorflow.contrib.gan.python.features.python.clip_weights import *
|
||||||
|
from tensorflow.contrib.gan.python.features.python.conditioning_utils import *
|
||||||
|
from tensorflow.contrib.gan.python.features.python.virtual_batchnorm import *
|
||||||
|
# pylint: enable=unused-import,wildcard-import
|
||||||
|
|
||||||
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
|
|
||||||
|
_allowed_symbols = clip_weights.__all__
|
||||||
|
_allowed_symbols += conditioning_utils.__all__
|
||||||
|
_allowed_symbols += virtual_batchnorm.__all__
|
||||||
|
remove_undocumented(__name__, _allowed_symbols)
|
@ -0,0 +1,28 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Utilities to clip weights."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.gan.python.features.python import clip_weights_impl
|
||||||
|
# pylint: disable=wildcard-import
|
||||||
|
from tensorflow.contrib.gan.python.features.python.clip_weights_impl import *
|
||||||
|
# pylint: enable=wildcard-import
|
||||||
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
|
|
||||||
|
__all__ = clip_weights_impl.__all__
|
||||||
|
remove_undocumented(__name__, __all__)
|
@ -0,0 +1,80 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Utilities to clip weights.
|
||||||
|
|
||||||
|
This is useful in the original formulation of the Wasserstein loss, which
|
||||||
|
requires that the discriminator be K-Lipschitz. See
|
||||||
|
https://arxiv.org/pdf/1701.07875 for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.opt.python.training import variable_clipping_optimizer
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'clip_variables',
|
||||||
|
'clip_discriminator_weights',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def clip_discriminator_weights(optimizer, model, weight_clip):
|
||||||
|
"""Modifies an optimizer so it clips weights to a certain value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer: An optimizer to perform variable weight clipping.
|
||||||
|
model: A GANModel namedtuple.
|
||||||
|
weight_clip: Positive python float to clip discriminator weights. Used to
|
||||||
|
enforce a K-lipschitz condition, which is useful for some GAN training
|
||||||
|
schemes (ex WGAN: https://arxiv.org/pdf/1701.07875).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An optimizer to perform weight clipping after updates.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `weight_clip` is less than 0.
|
||||||
|
"""
|
||||||
|
return clip_variables(optimizer, model.discriminator_variables, weight_clip)
|
||||||
|
|
||||||
|
|
||||||
|
def clip_variables(optimizer, variables, weight_clip):
|
||||||
|
"""Modifies an optimizer so it clips weights to a certain value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer: An optimizer to perform variable weight clipping.
|
||||||
|
variables: A list of TensorFlow variables.
|
||||||
|
weight_clip: Positive python float to clip discriminator weights. Used to
|
||||||
|
enforce a K-lipschitz condition, which is useful for some GAN training
|
||||||
|
schemes (ex WGAN: https://arxiv.org/pdf/1701.07875).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An optimizer to perform weight clipping after updates.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `weight_clip` is less than 0.
|
||||||
|
"""
|
||||||
|
if weight_clip < 0:
|
||||||
|
raise ValueError(
|
||||||
|
'`discriminator_weight_clip` must be positive. Instead, was %s',
|
||||||
|
weight_clip)
|
||||||
|
return variable_clipping_optimizer.VariableClippingOptimizer(
|
||||||
|
opt=optimizer,
|
||||||
|
# Do no reduction, so clipping happens per-value.
|
||||||
|
vars_to_clip_dims={var: [] for var in variables},
|
||||||
|
max_norm=weight_clip,
|
||||||
|
use_locking=True,
|
||||||
|
colocate_clip_ops_with_vars=True)
|
@ -0,0 +1,81 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tfgan.python.features.clip_weights."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
from tensorflow.contrib.gan.python.features.python import clip_weights_impl as clip_weights
|
||||||
|
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
from tensorflow.python.training import training
|
||||||
|
|
||||||
|
|
||||||
|
class ClipWeightsTest(test.TestCase):
|
||||||
|
"""Tests for `discriminator_weight_clip`."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.variables = [variables.Variable(2.0)]
|
||||||
|
self.tuple = collections.namedtuple(
|
||||||
|
'VarTuple', ['discriminator_variables'])(self.variables)
|
||||||
|
|
||||||
|
def _test_weight_clipping_helper(self, use_tuple):
|
||||||
|
loss = self.variables[0] * 2.0
|
||||||
|
opt = training.GradientDescentOptimizer(1.0)
|
||||||
|
if use_tuple:
|
||||||
|
opt_clip = clip_weights.weight_clip(opt, self.variables, 0.1)
|
||||||
|
else:
|
||||||
|
opt_clip = clip_weights.discriminator_weight_clip(opt, self.tuple, 0.1)
|
||||||
|
|
||||||
|
train_op1 = opt.minimize(loss, var_list=self.variables)
|
||||||
|
train_op2 = opt_clip.minimize(loss, var_list=self.variables)
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=True) as sess:
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
self.assertEqual(2.0, self.variables[0].eval())
|
||||||
|
sess.run(train_op1)
|
||||||
|
self.assertLess(0.1, self.variables[0].eval())
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=True) as sess:
|
||||||
|
sess.run(variables.global_variables_initializer())
|
||||||
|
self.assertEqual(2.0, self.variables[0].eval())
|
||||||
|
sess.run(train_op2)
|
||||||
|
self.assertNear(0.1, self.variables[0].eval(), 1e-7)
|
||||||
|
|
||||||
|
def test_weight_clipping_argsonly(self):
|
||||||
|
self._test_weight_clipping_helper(False)
|
||||||
|
|
||||||
|
def test_weight_clipping_ganmodel(self):
|
||||||
|
self._test_weight_clipping_helper(True)
|
||||||
|
|
||||||
|
def _test_incorrect_weight_clip_value_helper(self, use_tuple):
|
||||||
|
opt = training.GradientDescentOptimizer(1.0)
|
||||||
|
|
||||||
|
if use_tuple:
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'must be positive'):
|
||||||
|
clip_weights.clip_discriminator_weights(opt, self.tuple, weight_clip=-1)
|
||||||
|
else:
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'must be positive'):
|
||||||
|
clip_weights.clip_weights(opt, self.variables, weight_clip=-1)
|
||||||
|
|
||||||
|
def test_incorrect_weight_clip_value_argsonly(self):
|
||||||
|
self._test_incorrect_weight_clip_value_helper(False)
|
||||||
|
|
||||||
|
def test_incorrect_weight_clip_value_tuple(self):
|
||||||
|
self._test_incorrect_weight_clip_value_helper(True)
|
@ -0,0 +1,28 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Miscellanous utilities for TFGAN code and examples."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.gan.python.features.python import conditioning_utils_impl
|
||||||
|
# pylint: disable=wildcard-import
|
||||||
|
from tensorflow.contrib.gan.python.features.python.conditioning_utils_impl import *
|
||||||
|
# pylint: enable=wildcard-import
|
||||||
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
|
|
||||||
|
__all__ = conditioning_utils_impl.__all__
|
||||||
|
remove_undocumented(__name__, __all__)
|
@ -0,0 +1,112 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Miscellanous utilities for TFGAN code and examples.
|
||||||
|
|
||||||
|
Includes:
|
||||||
|
1) Conditioning the value of a Tensor, based on techniques from
|
||||||
|
https://arxiv.org/abs/1609.03499.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.layers.python.layers import layers
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import embedding_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'condition_tensor',
|
||||||
|
'condition_tensor_from_onehot',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_shape(tensor):
|
||||||
|
tensor_shape = array_ops.shape(tensor)
|
||||||
|
static_tensor_shape = tensor_util.constant_value(tensor_shape)
|
||||||
|
return (static_tensor_shape if static_tensor_shape is not None else
|
||||||
|
tensor_shape)
|
||||||
|
|
||||||
|
|
||||||
|
def condition_tensor(tensor, conditioning):
|
||||||
|
"""Condition the value of a tensor.
|
||||||
|
|
||||||
|
Conditioning scheme based on https://arxiv.org/abs/1609.03499.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: A minibatch tensor to be conditioned.
|
||||||
|
conditioning: A minibatch Tensor of to condition on. Must be 2D, with first
|
||||||
|
dimension the same as `tensor`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`tensor` conditioned on `conditioning`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the non-batch dimensions of `tensor` aren't fully defined.
|
||||||
|
ValueError: If `conditioning` isn't at least 2D.
|
||||||
|
ValueError: If the batch dimension for the input Tensors don't match.
|
||||||
|
"""
|
||||||
|
tensor.shape[1:].assert_is_fully_defined()
|
||||||
|
num_features = tensor.shape[1:].num_elements()
|
||||||
|
|
||||||
|
mapped_conditioning = layers.linear(
|
||||||
|
layers.flatten(conditioning), num_features)
|
||||||
|
if not mapped_conditioning.shape.is_compatible_with(tensor.shape):
|
||||||
|
mapped_conditioning = array_ops.reshape(
|
||||||
|
mapped_conditioning, _get_shape(tensor))
|
||||||
|
return tensor + mapped_conditioning
|
||||||
|
|
||||||
|
|
||||||
|
def _one_hot_to_embedding(one_hot, embedding_size):
|
||||||
|
"""Get a dense embedding vector from a one-hot encoding."""
|
||||||
|
num_tokens = one_hot.shape[1]
|
||||||
|
label_id = math_ops.argmax(one_hot, axis=1)
|
||||||
|
embedding = variable_scope.get_variable(
|
||||||
|
'embedding', [num_tokens, embedding_size])
|
||||||
|
return embedding_ops.embedding_lookup(
|
||||||
|
embedding, label_id, name='token_to_embedding')
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_onehot(one_hot_labels):
|
||||||
|
one_hot_labels.shape.assert_has_rank(2)
|
||||||
|
one_hot_labels.shape[1:].assert_is_fully_defined()
|
||||||
|
|
||||||
|
|
||||||
|
def condition_tensor_from_onehot(tensor, one_hot_labels, embedding_size=256):
|
||||||
|
"""Condition a tensor based on a one-hot tensor.
|
||||||
|
|
||||||
|
Conditioning scheme based on https://arxiv.org/abs/1609.03499.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor: Tensor to be conditioned.
|
||||||
|
one_hot_labels: A Tensor of one-hot labels. Shape is
|
||||||
|
[batch_size, num_classes].
|
||||||
|
embedding_size: The size of the class embedding.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`tensor` conditioned on `one_hot_labels`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: `one_hot_labels` isn't 2D, if non-batch dimensions aren't
|
||||||
|
fully defined, or if batch sizes don't match.
|
||||||
|
"""
|
||||||
|
_validate_onehot(one_hot_labels)
|
||||||
|
|
||||||
|
conditioning = _one_hot_to_embedding(one_hot_labels, embedding_size)
|
||||||
|
return condition_tensor(tensor, conditioning)
|
@ -0,0 +1,76 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tfgan.python.features.conditioning_utils."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.gan.python.features.python import conditioning_utils_impl as conditioning_utils
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class ConditioningUtilsTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_condition_tensor_multiple_shapes(self):
|
||||||
|
for tensor_shape in [(4, 1), (4, 2), (4, 2, 6), (None, 5, 3)]:
|
||||||
|
for conditioning_shape in [(4, 1), (4, 8), (4, 5, 3)]:
|
||||||
|
conditioning_utils.condition_tensor(
|
||||||
|
array_ops.placeholder(dtypes.float32, tensor_shape),
|
||||||
|
array_ops.placeholder(dtypes.float32, conditioning_shape))
|
||||||
|
|
||||||
|
def test_condition_tensor_asserts(self):
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Cannot reshape'):
|
||||||
|
conditioning_utils.condition_tensor(
|
||||||
|
array_ops.placeholder(dtypes.float32, (4, 1)),
|
||||||
|
array_ops.placeholder(dtypes.float32, (5, 1)))
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Shape .* is not fully defined'):
|
||||||
|
conditioning_utils.condition_tensor(
|
||||||
|
array_ops.placeholder(dtypes.float32, (5, None)),
|
||||||
|
array_ops.placeholder(dtypes.float32, (5, 1)))
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'must have a least 2 dimensions.'):
|
||||||
|
conditioning_utils.condition_tensor(
|
||||||
|
array_ops.placeholder(dtypes.float32, (5, 2)),
|
||||||
|
array_ops.placeholder(dtypes.float32, (5)))
|
||||||
|
|
||||||
|
def test_condition_tensor_from_onehot(self):
|
||||||
|
conditioning_utils.condition_tensor_from_onehot(
|
||||||
|
array_ops.placeholder(dtypes.float32, (5, 4, 1)),
|
||||||
|
array_ops.placeholder(dtypes.float32, (5, 10)))
|
||||||
|
|
||||||
|
def test_condition_tensor_from_onehot_asserts(self):
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Shape .* must have rank 2'):
|
||||||
|
conditioning_utils.condition_tensor_from_onehot(
|
||||||
|
array_ops.placeholder(dtypes.float32, (5, 1)),
|
||||||
|
array_ops.placeholder(dtypes.float32, (5)))
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Shape .* is not fully defined'):
|
||||||
|
conditioning_utils.condition_tensor_from_onehot(
|
||||||
|
array_ops.placeholder(dtypes.float32, (5, 1)),
|
||||||
|
array_ops.placeholder(dtypes.float32, (5, None)))
|
||||||
|
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Cannot reshape a tensor'):
|
||||||
|
conditioning_utils.condition_tensor_from_onehot(
|
||||||
|
array_ops.placeholder(dtypes.float32, (5, 1)),
|
||||||
|
array_ops.placeholder(dtypes.float32, (4, 6)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
@ -0,0 +1,27 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Virtual batch normalization."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.gan.python.features.python import virtual_batchnorm_impl
|
||||||
|
# pylint: disable=wildcard-import
|
||||||
|
from tensorflow.contrib.gan.python.features.python.virtual_batchnorm_impl import *
|
||||||
|
# pylint: enable=wildcard-import
|
||||||
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
|
|
||||||
|
__all__ = virtual_batchnorm_impl.__all__
|
||||||
|
remove_undocumented(__name__, __all__)
|
@ -0,0 +1,306 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Virtual batch normalization.
|
||||||
|
|
||||||
|
This technique was first introduced in `Improved Techniques for Training GANs`
|
||||||
|
(Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch
|
||||||
|
normalization on a minibatch, it fixes a reference subset of the data to use for
|
||||||
|
calculating normalization statistics.
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import init_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import nn
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'VBN',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _static_or_dynamic_batch_size(tensor, batch_axis):
|
||||||
|
"""Returns the static or dynamic batch size."""
|
||||||
|
batch_size = array_ops.shape(tensor)[batch_axis]
|
||||||
|
static_batch_size = tensor_util.constant_value(batch_size)
|
||||||
|
return static_batch_size or batch_size
|
||||||
|
|
||||||
|
|
||||||
|
def _statistics(x, axes):
|
||||||
|
"""Calculate the mean and mean square of `x`.
|
||||||
|
|
||||||
|
Modified from the implementation of `tf.nn.moments`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: A `Tensor`.
|
||||||
|
axes: Array of ints. Axes along which to compute mean and
|
||||||
|
variance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Two `Tensor` objects: `mean` and `square mean`.
|
||||||
|
"""
|
||||||
|
# The dynamic range of fp16 is too limited to support the collection of
|
||||||
|
# sufficient statistics. As a workaround we simply perform the operations
|
||||||
|
# on 32-bit floats before converting the mean and variance back to fp16
|
||||||
|
y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
|
||||||
|
|
||||||
|
# Compute true mean while keeping the dims for proper broadcasting.
|
||||||
|
shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keep_dims=True))
|
||||||
|
|
||||||
|
shifted_mean = math_ops.reduce_mean(y - shift, axes, keep_dims=True)
|
||||||
|
mean = shifted_mean + shift
|
||||||
|
mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keep_dims=True)
|
||||||
|
|
||||||
|
mean = array_ops.squeeze(mean, axes)
|
||||||
|
mean_squared = array_ops.squeeze(mean_squared, axes)
|
||||||
|
if x.dtype == dtypes.float16:
|
||||||
|
return (math_ops.cast(mean, dtypes.float16),
|
||||||
|
math_ops.cast(mean_squared, dtypes.float16))
|
||||||
|
else:
|
||||||
|
return (mean, mean_squared)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_init_input_and_get_axis(reference_batch, axis):
|
||||||
|
"""Validate input and return the used axis value."""
|
||||||
|
if reference_batch.shape.ndims is None:
|
||||||
|
raise ValueError('`reference_batch` has unknown dimensions.')
|
||||||
|
|
||||||
|
ndims = reference_batch.shape.ndims
|
||||||
|
if axis < 0:
|
||||||
|
used_axis = ndims + axis
|
||||||
|
else:
|
||||||
|
used_axis = axis
|
||||||
|
if used_axis < 0 or used_axis >= ndims:
|
||||||
|
raise ValueError('Value of `axis` argument ' + str(used_axis) +
|
||||||
|
' is out of range for input with rank ' + str(ndims))
|
||||||
|
return used_axis
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_call_input(tensor_list, batch_dim):
|
||||||
|
"""Verifies that tensor shapes are compatible, except for `batch_dim`."""
|
||||||
|
def _get_shape(tensor):
|
||||||
|
shape = tensor.shape.as_list()
|
||||||
|
del shape[batch_dim]
|
||||||
|
return shape
|
||||||
|
base_shape = tensor_shape.TensorShape(_get_shape(tensor_list[0]))
|
||||||
|
for tensor in tensor_list:
|
||||||
|
base_shape.assert_is_compatible_with(_get_shape(tensor))
|
||||||
|
|
||||||
|
|
||||||
|
class VBN(object):
|
||||||
|
"""A class to perform virtual batch normalization.
|
||||||
|
|
||||||
|
This technique was first introduced in `Improved Techniques for Training GANs`
|
||||||
|
(Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch
|
||||||
|
normalization on a minibatch, it fixes a reference subset of the data to use
|
||||||
|
for calculating normalization statistics.
|
||||||
|
|
||||||
|
To do this, we calculate the reference batch mean and mean square, and modify
|
||||||
|
those statistics for each example. We use mean square instead of variance,
|
||||||
|
since it is linear.
|
||||||
|
|
||||||
|
Note that if `center` or `scale` variables are created, they are shared
|
||||||
|
between all calls to this object.
|
||||||
|
|
||||||
|
The `__init__` API is intended to mimic `tf.layers.batch_normalization` as
|
||||||
|
closely as possible.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
reference_batch,
|
||||||
|
axis=-1,
|
||||||
|
epsilon=1e-3,
|
||||||
|
center=True,
|
||||||
|
scale=True,
|
||||||
|
beta_initializer=init_ops.zeros_initializer(),
|
||||||
|
gamma_initializer=init_ops.ones_initializer(),
|
||||||
|
beta_regularizer=None,
|
||||||
|
gamma_regularizer=None,
|
||||||
|
trainable=True,
|
||||||
|
name=None,
|
||||||
|
batch_axis=0):
|
||||||
|
"""Initialize virtual batch normalization object.
|
||||||
|
|
||||||
|
We precompute the 'mean' and 'mean squared' of the reference batch, so that
|
||||||
|
`__call__` is efficient. This means that the axis must be supplied when the
|
||||||
|
object is created, not when it is called.
|
||||||
|
|
||||||
|
We precompute 'square mean' instead of 'variance', because the square mean
|
||||||
|
can be easily adjusted on a per-example basis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
reference_batch: A minibatch tensors. This will form the reference data
|
||||||
|
from which the normalization statistics are calculated. See
|
||||||
|
https://arxiv.org/abs/1606.03498 for more details.
|
||||||
|
axis: Integer, the axis that should be normalized (typically the features
|
||||||
|
axis). For instance, after a `Convolution2D` layer with
|
||||||
|
`data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
|
||||||
|
epsilon: Small float added to variance to avoid dividing by zero.
|
||||||
|
center: If True, add offset of `beta` to normalized tensor. If False,
|
||||||
|
`beta` is ignored.
|
||||||
|
scale: If True, multiply by `gamma`. If False, `gamma` is
|
||||||
|
not used. When the next layer is linear (also e.g. `nn.relu`), this can
|
||||||
|
be disabled since the scaling can be done by the next layer.
|
||||||
|
beta_initializer: Initializer for the beta weight.
|
||||||
|
gamma_initializer: Initializer for the gamma weight.
|
||||||
|
beta_regularizer: Optional regularizer for the beta weight.
|
||||||
|
gamma_regularizer: Optional regularizer for the gamma weight.
|
||||||
|
trainable: Boolean, if `True` also add variables to the graph collection
|
||||||
|
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
|
||||||
|
name: String, the name of the ops.
|
||||||
|
batch_axis: The axis of the batch dimension. This dimension is treated
|
||||||
|
differently in `virtual batch normalization` vs `batch normalization`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `reference_batch` has unknown dimensions at graph
|
||||||
|
construction.
|
||||||
|
ValueError: If `batch_axis` is the same as `axis`.
|
||||||
|
"""
|
||||||
|
axis = _validate_init_input_and_get_axis(reference_batch, axis)
|
||||||
|
self._epsilon = epsilon
|
||||||
|
self._beta = 0
|
||||||
|
self._gamma = 1
|
||||||
|
self._batch_axis = _validate_init_input_and_get_axis(
|
||||||
|
reference_batch, batch_axis)
|
||||||
|
|
||||||
|
if axis == self._batch_axis:
|
||||||
|
raise ValueError('`axis` and `batch_axis` cannot be the same.')
|
||||||
|
|
||||||
|
with variable_scope.variable_scope(name, 'VBN',
|
||||||
|
values=[reference_batch]) as self._vs:
|
||||||
|
self._reference_batch = reference_batch
|
||||||
|
|
||||||
|
# Calculate important shapes:
|
||||||
|
# 1) Reduction axes for the reference batch
|
||||||
|
# 2) Broadcast shape, if necessary
|
||||||
|
# 3) Reduction axes for the virtual batchnormed batch
|
||||||
|
# 4) Shape for optional parameters
|
||||||
|
input_shape = self._reference_batch.shape
|
||||||
|
ndims = input_shape.ndims
|
||||||
|
reduction_axes = list(range(ndims))
|
||||||
|
del reduction_axes[axis]
|
||||||
|
|
||||||
|
self._broadcast_shape = [1] * len(input_shape)
|
||||||
|
self._broadcast_shape[axis] = input_shape[axis].value
|
||||||
|
|
||||||
|
self._example_reduction_axes = list(range(ndims))
|
||||||
|
del self._example_reduction_axes[max(axis, self._batch_axis)]
|
||||||
|
del self._example_reduction_axes[min(axis, self._batch_axis)]
|
||||||
|
|
||||||
|
params_shape = self._reference_batch.shape[axis]
|
||||||
|
|
||||||
|
# Determines whether broadcasting is needed. This is slightly different
|
||||||
|
# than in the `nn.batch_normalization` case, due to `batch_dim`.
|
||||||
|
self._needs_broadcasting = (
|
||||||
|
sorted(self._example_reduction_axes) != list(range(ndims))[:-2])
|
||||||
|
|
||||||
|
# Calculate the sufficient statistics for the reference batch in a way
|
||||||
|
# that can be easily modified by additional examples.
|
||||||
|
self._ref_mean, self._ref_mean_squares = _statistics(
|
||||||
|
self._reference_batch, reduction_axes)
|
||||||
|
self._ref_variance = (self._ref_mean_squares -
|
||||||
|
math_ops.square(self._ref_mean))
|
||||||
|
|
||||||
|
# Virtual batch normalization uses a weighted average between example
|
||||||
|
# statistics and the reference batch statistics.
|
||||||
|
ref_batch_size = _static_or_dynamic_batch_size(
|
||||||
|
self._reference_batch, self._batch_axis)
|
||||||
|
self._example_weight = 1. / (math_ops.to_float(ref_batch_size) + 1.)
|
||||||
|
self._ref_weight = 1. - self._example_weight
|
||||||
|
|
||||||
|
# Make the variables, if necessary.
|
||||||
|
if center:
|
||||||
|
self._beta = variable_scope.get_variable(
|
||||||
|
name='beta',
|
||||||
|
shape=(params_shape,),
|
||||||
|
initializer=beta_initializer,
|
||||||
|
regularizer=beta_regularizer,
|
||||||
|
trainable=trainable)
|
||||||
|
if scale:
|
||||||
|
self._gamma = variable_scope.get_variable(
|
||||||
|
name='gamma',
|
||||||
|
shape=(params_shape,),
|
||||||
|
initializer=gamma_initializer,
|
||||||
|
regularizer=gamma_regularizer,
|
||||||
|
trainable=trainable)
|
||||||
|
|
||||||
|
def _virtual_statistics(self, inputs, reduction_axes):
|
||||||
|
"""Compute the statistics needed for virtual batch normalization."""
|
||||||
|
cur_mean, cur_mean_sq = _statistics(inputs, reduction_axes)
|
||||||
|
vb_mean = (self._example_weight * cur_mean +
|
||||||
|
self._ref_weight * self._ref_mean)
|
||||||
|
vb_mean_sq = (self._example_weight * cur_mean_sq +
|
||||||
|
self._ref_weight * self._ref_mean_squares)
|
||||||
|
return (vb_mean, vb_mean_sq)
|
||||||
|
|
||||||
|
def _broadcast(self, v, broadcast_shape=None):
|
||||||
|
# The exact broadcast shape depends on the current batch, not the reference
|
||||||
|
# batch, unless we're calculating the batch normalization of the reference
|
||||||
|
# batch.
|
||||||
|
b_shape = broadcast_shape or self._broadcast_shape
|
||||||
|
if self._needs_broadcasting and v is not None:
|
||||||
|
return array_ops.reshape(v, b_shape)
|
||||||
|
return v
|
||||||
|
|
||||||
|
def reference_batch_normalization(self):
|
||||||
|
"""Return the reference batch, but batch normalized."""
|
||||||
|
with ops.name_scope(self._vs.name):
|
||||||
|
return nn.batch_normalization(self._reference_batch,
|
||||||
|
self._broadcast(self._ref_mean),
|
||||||
|
self._broadcast(self._ref_variance),
|
||||||
|
self._broadcast(self._beta),
|
||||||
|
self._broadcast(self._gamma),
|
||||||
|
self._epsilon)
|
||||||
|
|
||||||
|
def __call__(self, inputs):
|
||||||
|
"""Run virtual batch normalization on inputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: Tensor input.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A virtual batch normalized version of `inputs`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `inputs` shape isn't compatible with the reference batch.
|
||||||
|
"""
|
||||||
|
_validate_call_input([inputs, self._reference_batch], self._batch_axis)
|
||||||
|
|
||||||
|
with ops.name_scope(self._vs.name, values=[inputs, self._reference_batch]):
|
||||||
|
# Calculate the statistics on the current input on a per-example basis.
|
||||||
|
vb_mean, vb_mean_sq = self._virtual_statistics(
|
||||||
|
inputs, self._example_reduction_axes)
|
||||||
|
vb_variance = vb_mean_sq - math_ops.square(vb_mean)
|
||||||
|
|
||||||
|
# The exact broadcast shape of the input statistic Tensors depends on the
|
||||||
|
# current batch, not the reference batch. The parameter broadcast shape
|
||||||
|
# is independent of the shape of the input statistic Tensor dimensions.
|
||||||
|
b_shape = self._broadcast_shape[:] # deep copy
|
||||||
|
b_shape[self._batch_axis] = _static_or_dynamic_batch_size(
|
||||||
|
inputs, self._batch_axis)
|
||||||
|
return nn.batch_normalization(
|
||||||
|
inputs,
|
||||||
|
self._broadcast(vb_mean, b_shape),
|
||||||
|
self._broadcast(vb_variance, b_shape),
|
||||||
|
self._broadcast(self._beta, self._broadcast_shape),
|
||||||
|
self._broadcast(self._gamma, self._broadcast_shape),
|
||||||
|
self._epsilon)
|
@ -0,0 +1,267 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for tfgan.python.features.virtual_batchnorm."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.contrib.framework.python.ops import variables as contrib_variables_lib
|
||||||
|
from tensorflow.contrib.gan.python.features.python import virtual_batchnorm_impl as virtual_batchnorm
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import random_seed
|
||||||
|
from tensorflow.python.layers import normalization
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import nn
|
||||||
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
|
from tensorflow.python.ops import variables as variables_lib
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class VirtualBatchnormTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_syntax(self):
|
||||||
|
reference_batch = array_ops.zeros([5, 3, 16, 9, 15])
|
||||||
|
vbn = virtual_batchnorm.VBN(reference_batch, batch_axis=1)
|
||||||
|
vbn(array_ops.ones([5, 7, 16, 9, 15]))
|
||||||
|
|
||||||
|
def test_no_broadcast_needed(self):
|
||||||
|
"""When `axis` and `batch_axis` are at the end, no broadcast is needed."""
|
||||||
|
reference_batch = array_ops.zeros([5, 3, 16, 9, 15])
|
||||||
|
minibatch = array_ops.zeros([5, 3, 16, 3, 15])
|
||||||
|
vbn = virtual_batchnorm.VBN(reference_batch, axis=-1, batch_axis=-2)
|
||||||
|
vbn(minibatch)
|
||||||
|
|
||||||
|
def test_statistics(self):
|
||||||
|
"""Check that `_statistics` gives the same result as `nn.moments`."""
|
||||||
|
random_seed.set_random_seed(1234)
|
||||||
|
|
||||||
|
tensors = random_ops.random_normal([4, 5, 7, 3])
|
||||||
|
for axes in [(3), (0, 2), (1, 2, 3)]:
|
||||||
|
vb_mean, mean_sq = virtual_batchnorm._statistics(tensors, axes)
|
||||||
|
mom_mean, mom_var = nn.moments(tensors, axes)
|
||||||
|
vb_var = mean_sq - math_ops.square(vb_mean)
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=True) as sess:
|
||||||
|
vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run([
|
||||||
|
vb_mean, vb_var, mom_mean, mom_var])
|
||||||
|
|
||||||
|
self.assertAllClose(mom_mean_np, vb_mean_np)
|
||||||
|
self.assertAllClose(mom_var_np, vb_var_np)
|
||||||
|
|
||||||
|
def test_virtual_statistics(self):
|
||||||
|
"""Check that `_virtual_statistics` gives same result as `nn.moments`."""
|
||||||
|
random_seed.set_random_seed(1234)
|
||||||
|
|
||||||
|
batch_axis = 0
|
||||||
|
partial_batch = random_ops.random_normal([4, 5, 7, 3])
|
||||||
|
single_example = random_ops.random_normal([1, 5, 7, 3])
|
||||||
|
full_batch = array_ops.concat([partial_batch, single_example], axis=0)
|
||||||
|
|
||||||
|
for reduction_axis in range(1, 4):
|
||||||
|
# Get `nn.moments` on the full batch.
|
||||||
|
reduction_axes = list(range(4))
|
||||||
|
del reduction_axes[reduction_axis]
|
||||||
|
mom_mean, mom_variance = nn.moments(full_batch, reduction_axes)
|
||||||
|
|
||||||
|
# Get virtual batch statistics.
|
||||||
|
vb_reduction_axes = list(range(4))
|
||||||
|
del vb_reduction_axes[reduction_axis]
|
||||||
|
del vb_reduction_axes[batch_axis]
|
||||||
|
vbn = virtual_batchnorm.VBN(partial_batch, reduction_axis)
|
||||||
|
vb_mean, mean_sq = vbn._virtual_statistics(
|
||||||
|
single_example, vb_reduction_axes)
|
||||||
|
vb_variance = mean_sq - math_ops.square(vb_mean)
|
||||||
|
# Remove singleton batch dim for easy comparisons.
|
||||||
|
vb_mean = array_ops.squeeze(vb_mean, batch_axis)
|
||||||
|
vb_variance = array_ops.squeeze(vb_variance, batch_axis)
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=True) as sess:
|
||||||
|
vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run([
|
||||||
|
vb_mean, vb_variance, mom_mean, mom_variance])
|
||||||
|
|
||||||
|
self.assertAllClose(mom_mean_np, vb_mean_np)
|
||||||
|
self.assertAllClose(mom_var_np, vb_var_np)
|
||||||
|
|
||||||
|
def test_reference_batch_normalization(self):
|
||||||
|
"""Check that batch norm from VBN agrees with opensource implementation."""
|
||||||
|
random_seed.set_random_seed(1234)
|
||||||
|
|
||||||
|
batch = random_ops.random_normal([6, 5, 7, 3, 3])
|
||||||
|
|
||||||
|
for axis in range(5):
|
||||||
|
# Get `layers` batchnorm result.
|
||||||
|
bn_normalized = normalization.batch_normalization(
|
||||||
|
batch, axis, training=True)
|
||||||
|
|
||||||
|
# Get VBN's batch normalization on reference batch.
|
||||||
|
batch_axis = 0 if axis is not 0 else 1 # axis and batch_axis can't same
|
||||||
|
vbn = virtual_batchnorm.VBN(batch, axis, batch_axis=batch_axis)
|
||||||
|
vbn_normalized = vbn.reference_batch_normalization()
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=True) as sess:
|
||||||
|
variables_lib.global_variables_initializer().run()
|
||||||
|
|
||||||
|
bn_normalized_np, vbn_normalized_np = sess.run(
|
||||||
|
[bn_normalized, vbn_normalized])
|
||||||
|
self.assertAllClose(bn_normalized_np, vbn_normalized_np)
|
||||||
|
|
||||||
|
def test_same_as_batchnorm(self):
|
||||||
|
"""Check that batch norm on set X is the same as ref of X / y on `y`."""
|
||||||
|
random_seed.set_random_seed(1234)
|
||||||
|
|
||||||
|
num_examples = 4
|
||||||
|
examples = [random_ops.random_normal([5, 7, 3]) for _ in
|
||||||
|
range(num_examples)]
|
||||||
|
|
||||||
|
# Get the result of the opensource batch normalization.
|
||||||
|
batch_normalized = normalization.batch_normalization(
|
||||||
|
array_ops.stack(examples), training=True)
|
||||||
|
|
||||||
|
for i in range(num_examples):
|
||||||
|
examples_except_i = array_ops.stack(examples[:i] + examples[i+1:])
|
||||||
|
# Get the result of VBN's batch normalization.
|
||||||
|
vbn = virtual_batchnorm.VBN(examples_except_i)
|
||||||
|
vb_normed = array_ops.squeeze(
|
||||||
|
vbn(array_ops.expand_dims(examples[i], [0])), [0])
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=True) as sess:
|
||||||
|
variables_lib.global_variables_initializer().run()
|
||||||
|
bn_np, vb_np = sess.run([batch_normalized, vb_normed])
|
||||||
|
self.assertAllClose(bn_np[i, ...], vb_np)
|
||||||
|
|
||||||
|
def test_minibatch_independent(self):
|
||||||
|
"""Test that virtual batch normalized exampels are independent.
|
||||||
|
|
||||||
|
Unlike batch normalization, virtual batch normalization has the property
|
||||||
|
that the virtual batch normalized value of an example is independent of the
|
||||||
|
other examples in the minibatch. In this test, we verify this property.
|
||||||
|
"""
|
||||||
|
random_seed.set_random_seed(1234)
|
||||||
|
|
||||||
|
# These can be random, but must be the same for all session calls.
|
||||||
|
reference_batch = constant_op.constant(
|
||||||
|
np.random.normal(size=[4, 7, 3]), dtype=dtypes.float32)
|
||||||
|
fixed_example = constant_op.constant(np.random.normal(size=[7, 3]),
|
||||||
|
dtype=dtypes.float32)
|
||||||
|
|
||||||
|
# Get the VBN object and the virtual batch normalized value for
|
||||||
|
# `fixed_example`.
|
||||||
|
vbn = virtual_batchnorm.VBN(reference_batch)
|
||||||
|
vbn_fixed_example = array_ops.squeeze(
|
||||||
|
vbn(array_ops.expand_dims(fixed_example, 0)), 0)
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
variables_lib.global_variables_initializer().run()
|
||||||
|
vbn_fixed_example_np = vbn_fixed_example.eval()
|
||||||
|
|
||||||
|
# Check that the value is the same for different minibatches, and different
|
||||||
|
# sized minibatches.
|
||||||
|
for minibatch_size in range(1, 6):
|
||||||
|
examples = [random_ops.random_normal([7, 3]) for _ in
|
||||||
|
range(minibatch_size)]
|
||||||
|
|
||||||
|
minibatch = array_ops.stack([fixed_example] + examples)
|
||||||
|
vbn_minibatch = vbn(minibatch)
|
||||||
|
cur_vbn_fixed_example = vbn_minibatch[0, ...]
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
variables_lib.global_variables_initializer().run()
|
||||||
|
cur_vbn_fixed_example_np = cur_vbn_fixed_example.eval()
|
||||||
|
self.assertAllClose(vbn_fixed_example_np, cur_vbn_fixed_example_np)
|
||||||
|
|
||||||
|
def test_variable_reuse(self):
|
||||||
|
"""Test that variable scopes work and inference on a real-ish case."""
|
||||||
|
tensor1_ref = array_ops.zeros([6, 5, 7, 3, 3])
|
||||||
|
tensor1_examples = array_ops.zeros([4, 5, 7, 3, 3])
|
||||||
|
tensor2_ref = array_ops.zeros([4, 2, 3])
|
||||||
|
tensor2_examples = array_ops.zeros([2, 2, 3])
|
||||||
|
|
||||||
|
with variable_scope.variable_scope('dummy_scope', reuse=True):
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'does not exist, or was not created with '
|
||||||
|
'tf.get_variable()'):
|
||||||
|
virtual_batchnorm.VBN(tensor1_ref)
|
||||||
|
|
||||||
|
vbn1 = virtual_batchnorm.VBN(tensor1_ref, name='vbn1')
|
||||||
|
vbn2 = virtual_batchnorm.VBN(tensor2_ref, name='vbn2')
|
||||||
|
|
||||||
|
# Fetch reference and examples after virtual batch normalization. Also
|
||||||
|
# fetch in variable reuse case.
|
||||||
|
to_fetch = []
|
||||||
|
|
||||||
|
to_fetch.append(vbn1.reference_batch_normalization())
|
||||||
|
to_fetch.append(vbn2.reference_batch_normalization())
|
||||||
|
to_fetch.append(vbn1(tensor1_examples))
|
||||||
|
to_fetch.append(vbn2(tensor2_examples))
|
||||||
|
|
||||||
|
variable_scope.get_variable_scope().reuse_variables()
|
||||||
|
|
||||||
|
to_fetch.append(vbn1.reference_batch_normalization())
|
||||||
|
to_fetch.append(vbn2.reference_batch_normalization())
|
||||||
|
to_fetch.append(vbn1(tensor1_examples))
|
||||||
|
to_fetch.append(vbn2(tensor2_examples))
|
||||||
|
|
||||||
|
self.assertEqual(4, len(contrib_variables_lib.get_variables()))
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=True) as sess:
|
||||||
|
variables_lib.global_variables_initializer().run()
|
||||||
|
sess.run(to_fetch)
|
||||||
|
|
||||||
|
def test_invalid_input(self):
|
||||||
|
# Reference batch has unknown dimensions.
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, '`reference_batch` has unknown dimensions.'):
|
||||||
|
virtual_batchnorm.VBN(array_ops.placeholder(dtypes.float32), name='vbn1')
|
||||||
|
|
||||||
|
# Axis too negative.
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'Value of `axis` argument .* is out of range'):
|
||||||
|
virtual_batchnorm.VBN(array_ops.zeros([1, 2]), axis=-3, name='vbn2')
|
||||||
|
|
||||||
|
# Axis too large.
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'Value of `axis` argument .* is out of range'):
|
||||||
|
virtual_batchnorm.VBN(array_ops.zeros([1, 2]), axis=2, name='vbn3')
|
||||||
|
|
||||||
|
# Batch axis too negative.
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'Value of `axis` argument .* is out of range'):
|
||||||
|
virtual_batchnorm.VBN(array_ops.zeros([1, 2]), name='vbn4', batch_axis=-3)
|
||||||
|
|
||||||
|
# Batch axis too large.
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'Value of `axis` argument .* is out of range'):
|
||||||
|
virtual_batchnorm.VBN(array_ops.zeros([1, 2]), name='vbn5', batch_axis=2)
|
||||||
|
|
||||||
|
# Axis and batch axis are the same.
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, '`axis` and `batch_axis` cannot be the same.'):
|
||||||
|
virtual_batchnorm.VBN(array_ops.zeros(
|
||||||
|
[1, 2]), axis=1, name='vbn6', batch_axis=1)
|
||||||
|
|
||||||
|
# Reference Tensor and example Tensor have incompatible shapes.
|
||||||
|
tensor_ref = array_ops.zeros([5, 2, 3])
|
||||||
|
tensor_examples = array_ops.zeros([3, 2, 3])
|
||||||
|
vbn = virtual_batchnorm.VBN(tensor_ref, name='vbn7', batch_axis=1)
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
|
||||||
|
vbn(tensor_examples)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
32
tensorflow/contrib/gan/python/losses/__init__.py
Normal file
32
tensorflow/contrib/gan/python/losses/__init__.py
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# Copyright 2017 Google Inc. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""TFGAN grouped API. Please see README.md for details and usage."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# Collapse losses into a single namespace.
|
||||||
|
from tensorflow.contrib.gan.python.losses.python import losses_wargs as wargs
|
||||||
|
from tensorflow.contrib.gan.python.losses.python import tuple_losses
|
||||||
|
|
||||||
|
# pylint: disable=wildcard-import
|
||||||
|
from tensorflow.contrib.gan.python.losses.python.tuple_losses import *
|
||||||
|
# pylint: enable=wildcard-import
|
||||||
|
|
||||||
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
|
|
||||||
|
_allowed_symbols = ['wargs'] + tuple_losses.__all__
|
||||||
|
remove_undocumented(__name__, _allowed_symbols)
|
887
tensorflow/contrib/gan/python/losses/python/losses_impl.py
Normal file
887
tensorflow/contrib/gan/python/losses/python/losses_impl.py
Normal file
@ -0,0 +1,887 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Losses that are useful for training GANs.
|
||||||
|
|
||||||
|
The losses belong to two main groups, but there are others that do not:
|
||||||
|
1) xxxxx_generator_loss
|
||||||
|
2) xxxxx_discriminator_loss
|
||||||
|
|
||||||
|
Example:
|
||||||
|
1) wasserstein_generator_loss
|
||||||
|
2) wasserstein_discriminator_loss
|
||||||
|
|
||||||
|
Other example:
|
||||||
|
wasserstein_gradient_penalty
|
||||||
|
|
||||||
|
All losses must be able to accept 1D or 2D Tensors, so as to be compatible with
|
||||||
|
patchGAN style losses (https://arxiv.org/abs/1611.07004).
|
||||||
|
|
||||||
|
To make these losses usable in the TFGAN framework, please create a tuple
|
||||||
|
version of the losses with `losses_utils.py`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.contrib.framework.python.ops import variables as contrib_variables_lib
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import tensor_util
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import clip_ops
|
||||||
|
from tensorflow.python.ops import gradients_impl
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
|
from tensorflow.python.ops.distributions import distribution as ds
|
||||||
|
from tensorflow.python.ops.losses import losses
|
||||||
|
from tensorflow.python.ops.losses import util
|
||||||
|
from tensorflow.python.summary import summary
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'acgan_discriminator_loss',
|
||||||
|
'acgan_generator_loss',
|
||||||
|
'least_squares_discriminator_loss',
|
||||||
|
'least_squares_generator_loss',
|
||||||
|
'modified_discriminator_loss',
|
||||||
|
'modified_generator_loss',
|
||||||
|
'minimax_discriminator_loss',
|
||||||
|
'minimax_generator_loss',
|
||||||
|
'wasserstein_discriminator_loss',
|
||||||
|
'wasserstein_generator_loss',
|
||||||
|
'wasserstein_gradient_penalty',
|
||||||
|
'mutual_information_penalty',
|
||||||
|
'combine_adversarial_loss',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Wasserstein losses from `Wasserstein GAN` (https://arxiv.org/abs/1701.07875).
|
||||||
|
def wasserstein_generator_loss(
|
||||||
|
discriminator_gen_outputs,
|
||||||
|
weights=1.0,
|
||||||
|
scope=None,
|
||||||
|
loss_collection=ops.GraphKeys.LOSSES,
|
||||||
|
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
|
||||||
|
add_summaries=False):
|
||||||
|
"""Wasserstein generator loss for GANs.
|
||||||
|
|
||||||
|
See `Wasserstein GAN` (https://arxiv.org/abs/1701.07875) for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
discriminator_gen_outputs: Discriminator output on generated data. Expected
|
||||||
|
to be in the range of (-inf, inf).
|
||||||
|
weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
||||||
|
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
|
||||||
|
be either `1`, or the same as the corresponding `losses` dimension).
|
||||||
|
scope: The scope for the operations performed in computing the loss.
|
||||||
|
loss_collection: collection to which this loss will be added.
|
||||||
|
reduction: A `tf.losses.Reduction` to apply to loss.
|
||||||
|
add_summaries: Whether or not to add detailed summaries for the loss.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A loss Tensor. The shape depends on `reduction`.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(scope, 'generator_wasserstein_loss', (
|
||||||
|
discriminator_gen_outputs, weights)) as scope:
|
||||||
|
discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs)
|
||||||
|
|
||||||
|
loss = - discriminator_gen_outputs
|
||||||
|
loss = losses.compute_weighted_loss(
|
||||||
|
loss, weights, scope, loss_collection, reduction)
|
||||||
|
|
||||||
|
if add_summaries:
|
||||||
|
summary.scalar('generator_wass_loss', loss)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def wasserstein_discriminator_loss(
|
||||||
|
discriminator_real_outputs,
|
||||||
|
discriminator_gen_outputs,
|
||||||
|
real_weights=1.0,
|
||||||
|
generated_weights=1.0,
|
||||||
|
scope=None,
|
||||||
|
loss_collection=ops.GraphKeys.LOSSES,
|
||||||
|
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
|
||||||
|
add_summaries=False):
|
||||||
|
"""Wasserstein discriminator loss for GANs.
|
||||||
|
|
||||||
|
See `Wasserstein GAN` (https://arxiv.org/abs/1701.07875) for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
discriminator_real_outputs: Discriminator output on real data.
|
||||||
|
discriminator_gen_outputs: Discriminator output on generated data. Expected
|
||||||
|
to be in the range of (-inf, inf).
|
||||||
|
real_weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale
|
||||||
|
the real loss.
|
||||||
|
generated_weights: A scalar or a `Tensor` of size [batch_size, K] used to
|
||||||
|
rescale the generated loss.
|
||||||
|
scope: The scope for the operations performed in computing the loss.
|
||||||
|
loss_collection: collection to which this loss will be added.
|
||||||
|
reduction: A `tf.losses.Reduction` to apply to loss.
|
||||||
|
add_summaries: Whether or not to add summaries for the loss.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A loss Tensor. The shape depends on `reduction`.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(scope, 'discriminator_wasserstein_loss', (
|
||||||
|
discriminator_real_outputs, discriminator_gen_outputs, real_weights,
|
||||||
|
generated_weights)) as scope:
|
||||||
|
discriminator_real_outputs = math_ops.to_float(discriminator_real_outputs)
|
||||||
|
discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs)
|
||||||
|
discriminator_real_outputs.shape.assert_is_compatible_with(
|
||||||
|
discriminator_gen_outputs.shape)
|
||||||
|
|
||||||
|
loss_on_generated = losses.compute_weighted_loss(
|
||||||
|
discriminator_gen_outputs, generated_weights, scope,
|
||||||
|
loss_collection=None, reduction=reduction)
|
||||||
|
loss_on_real = losses.compute_weighted_loss(
|
||||||
|
discriminator_real_outputs, real_weights, scope, loss_collection=None,
|
||||||
|
reduction=reduction)
|
||||||
|
loss = loss_on_generated - loss_on_real
|
||||||
|
util.add_loss(loss, loss_collection)
|
||||||
|
|
||||||
|
if add_summaries:
|
||||||
|
summary.scalar('discriminator_gen_wass_loss', loss_on_generated)
|
||||||
|
summary.scalar('discriminator_real_wass_loss', loss_on_real)
|
||||||
|
summary.scalar('discriminator_wass_loss', loss)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
# ACGAN losses from `Conditional Image Synthesis With Auxiliary Classifier GANs`
|
||||||
|
# (https://arxiv.org/abs/1610.09585).
|
||||||
|
def acgan_discriminator_loss(
|
||||||
|
discriminator_gen_classification_logits,
|
||||||
|
discriminator_real_classification_logits,
|
||||||
|
one_hot_labels,
|
||||||
|
label_smoothing=0.0,
|
||||||
|
real_weights=1.0,
|
||||||
|
generated_weights=1.0,
|
||||||
|
scope=None,
|
||||||
|
loss_collection=ops.GraphKeys.LOSSES,
|
||||||
|
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
|
||||||
|
add_summaries=False):
|
||||||
|
"""ACGAN loss for the discriminator.
|
||||||
|
|
||||||
|
The ACGAN loss adds a classification loss to the conditional discriminator.
|
||||||
|
Therefore, the discriminator must output a tuple consisting of
|
||||||
|
(1) the real/fake prediction and
|
||||||
|
(2) the logits for the classification (usually the last conv layer,
|
||||||
|
flattened).
|
||||||
|
|
||||||
|
For more details:
|
||||||
|
ACGAN: https://arxiv.org/abs/1610.09585
|
||||||
|
|
||||||
|
Args:
|
||||||
|
discriminator_gen_classification_logits: Classification logits for generated
|
||||||
|
data.
|
||||||
|
discriminator_real_classification_logits: Classification logits for real
|
||||||
|
data.
|
||||||
|
one_hot_labels: A Tensor holding one-hot labels for the batch.
|
||||||
|
label_smoothing: A float in [0, 1]. If greater than 0, smooth the labels for
|
||||||
|
"discriminator on real data" as suggested in
|
||||||
|
https://arxiv.org/pdf/1701.00160
|
||||||
|
real_weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale
|
||||||
|
the real loss.
|
||||||
|
generated_weights: A scalar or a `Tensor` of size [batch_size, K] used to
|
||||||
|
rescale the generated loss.
|
||||||
|
scope: The scope for the operations performed in computing the loss.
|
||||||
|
loss_collection: collection to which this loss will be added.
|
||||||
|
reduction: A `tf.losses.Reduction` to apply to loss.
|
||||||
|
add_summaries: Whether or not to add summaries for the loss.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A loss Tensor. Shape depends on `reduction`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the discriminator does not output a tuple.
|
||||||
|
"""
|
||||||
|
loss_on_generated = losses.softmax_cross_entropy(
|
||||||
|
one_hot_labels, discriminator_gen_classification_logits,
|
||||||
|
weights=generated_weights, scope=scope, loss_collection=None,
|
||||||
|
reduction=reduction)
|
||||||
|
loss_on_real = losses.softmax_cross_entropy(
|
||||||
|
one_hot_labels, discriminator_real_classification_logits,
|
||||||
|
weights=real_weights, label_smoothing=label_smoothing, scope=scope,
|
||||||
|
loss_collection=None, reduction=reduction)
|
||||||
|
loss = loss_on_generated + loss_on_real
|
||||||
|
util.add_loss(loss, loss_collection)
|
||||||
|
|
||||||
|
if add_summaries:
|
||||||
|
summary.scalar('discriminator_gen_ac_loss', loss_on_generated)
|
||||||
|
summary.scalar('discriminator_real_ac_loss', loss_on_real)
|
||||||
|
summary.scalar('discriminator_ac_loss', loss)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def acgan_generator_loss(
|
||||||
|
discriminator_gen_classification_logits,
|
||||||
|
one_hot_labels,
|
||||||
|
weights=1.0,
|
||||||
|
scope=None,
|
||||||
|
loss_collection=ops.GraphKeys.LOSSES,
|
||||||
|
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
|
||||||
|
add_summaries=False):
|
||||||
|
"""ACGAN loss for the generator.
|
||||||
|
|
||||||
|
The ACGAN loss adds a classification loss to the conditional discriminator.
|
||||||
|
Therefore, the discriminator must output a tuple consisting of
|
||||||
|
(1) the real/fake prediction and
|
||||||
|
(2) the logits for the classification (usually the last conv layer,
|
||||||
|
flattened).
|
||||||
|
|
||||||
|
For more details:
|
||||||
|
ACGAN: https://arxiv.org/abs/1610.09585
|
||||||
|
|
||||||
|
Args:
|
||||||
|
discriminator_gen_classification_logits: Classification logits for generated
|
||||||
|
data.
|
||||||
|
one_hot_labels: A Tensor holding one-hot labels for the batch.
|
||||||
|
weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
||||||
|
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
|
||||||
|
be either `1`, or the same as the corresponding `losses` dimension).
|
||||||
|
scope: The scope for the operations performed in computing the loss.
|
||||||
|
loss_collection: collection to which this loss will be added.
|
||||||
|
reduction: A `tf.losses.Reduction` to apply to loss.
|
||||||
|
add_summaries: Whether or not to add summaries for the loss.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A loss Tensor. Shape depends on `reduction`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if arg module not either `generator` or `discriminator`
|
||||||
|
TypeError: if the discriminator does not output a tuple.
|
||||||
|
"""
|
||||||
|
loss = losses.softmax_cross_entropy(
|
||||||
|
one_hot_labels, discriminator_gen_classification_logits, weights=weights,
|
||||||
|
scope=scope, loss_collection=loss_collection, reduction=reduction)
|
||||||
|
|
||||||
|
if add_summaries:
|
||||||
|
summary.scalar('generator_ac_loss', loss)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
# Wasserstein Gradient Penalty losses from `Improved Training of Wasserstein
|
||||||
|
# GANs` (https://arxiv.org/abs/1704.00028).
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(joelshor): Figure out why this function can't be inside a name scope.
|
||||||
|
def wasserstein_gradient_penalty(
|
||||||
|
generated_data,
|
||||||
|
real_data,
|
||||||
|
generator_inputs,
|
||||||
|
discriminator_fn,
|
||||||
|
discriminator_scope,
|
||||||
|
epsilon=1e-10,
|
||||||
|
weights=1.0,
|
||||||
|
scope=None,
|
||||||
|
loss_collection=ops.GraphKeys.LOSSES,
|
||||||
|
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
|
||||||
|
add_summaries=False):
|
||||||
|
"""The gradient penalty for the Wasserstein discriminator loss.
|
||||||
|
|
||||||
|
See `Improved Training of Wasserstein GANs`
|
||||||
|
(https://arxiv.org/abs/1704.00028) for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
generated_data: Output of the generator.
|
||||||
|
real_data: Real data.
|
||||||
|
generator_inputs: Exact argument to pass to the generator, which is used
|
||||||
|
as optional conditioning to the discriminator.
|
||||||
|
discriminator_fn: A discriminator function that conforms to TFGAN API.
|
||||||
|
discriminator_scope: If not `None`, reuse discriminators from this scope.
|
||||||
|
epsilon: A small positive number added for numerical stability when
|
||||||
|
computing the gradient norm.
|
||||||
|
weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
||||||
|
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
|
||||||
|
be either `1`, or the same as the corresponding `losses` dimension).
|
||||||
|
scope: The scope for the operations performed in computing the loss.
|
||||||
|
loss_collection: collection to which this loss will be added.
|
||||||
|
reduction: A `tf.losses.Reduction` to apply to loss.
|
||||||
|
add_summaries: Whether or not to add summaries for the loss.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A loss Tensor. The shape depends on `reduction`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the rank of data Tensors is unknown.
|
||||||
|
"""
|
||||||
|
if generated_data.shape.ndims is None:
|
||||||
|
raise ValueError('`generated_data` can\'t have unknown rank.')
|
||||||
|
if real_data.shape.ndims is None:
|
||||||
|
raise ValueError('`real_data` can\'t have unknown rank.')
|
||||||
|
|
||||||
|
differences = generated_data - real_data
|
||||||
|
batch_size = differences.shape[0].value or array_ops.shape(differences)[0]
|
||||||
|
alpha_shape = [batch_size] + [1] * (differences.shape.ndims - 1)
|
||||||
|
alpha = random_ops.random_uniform(shape=alpha_shape)
|
||||||
|
interpolates = real_data + (alpha * differences)
|
||||||
|
|
||||||
|
# Reuse variables if a discriminator scope already exists.
|
||||||
|
reuse = False if discriminator_scope is None else True
|
||||||
|
with variable_scope.variable_scope(discriminator_scope, 'gpenalty_dscope',
|
||||||
|
reuse=reuse):
|
||||||
|
disc_interpolates = discriminator_fn(interpolates, generator_inputs)
|
||||||
|
|
||||||
|
if isinstance(disc_interpolates, tuple):
|
||||||
|
# ACGAN case: disc outputs more than one tensor
|
||||||
|
disc_interpolates = disc_interpolates[0]
|
||||||
|
|
||||||
|
gradients = gradients_impl.gradients(disc_interpolates, interpolates)[0]
|
||||||
|
gradient_squares = math_ops.reduce_sum(
|
||||||
|
math_ops.square(gradients), axis=list(range(1, gradients.shape.ndims)))
|
||||||
|
# Propagate shape information, if possible.
|
||||||
|
if isinstance(batch_size, int):
|
||||||
|
gradient_squares.set_shape([
|
||||||
|
batch_size] + gradient_squares.shape.as_list()[1:])
|
||||||
|
# For numerical stability, add epsilon to the sum before taking the square
|
||||||
|
# root. Note tf.norm does not add epsilon.
|
||||||
|
slopes = math_ops.sqrt(gradient_squares + epsilon)
|
||||||
|
penalties = math_ops.square(slopes - 1.0)
|
||||||
|
penalty = losses.compute_weighted_loss(
|
||||||
|
penalties, weights, scope=scope, loss_collection=loss_collection,
|
||||||
|
reduction=reduction)
|
||||||
|
|
||||||
|
if add_summaries:
|
||||||
|
summary.scalar('gradient_penalty_loss', penalty)
|
||||||
|
|
||||||
|
return penalty
|
||||||
|
|
||||||
|
|
||||||
|
# Original losses from `Generative Adversarial Nets`
|
||||||
|
# (https://arxiv.org/abs/1406.2661).
|
||||||
|
|
||||||
|
|
||||||
|
def minimax_discriminator_loss(
|
||||||
|
discriminator_real_outputs,
|
||||||
|
discriminator_gen_outputs,
|
||||||
|
label_smoothing=0.25,
|
||||||
|
real_weights=1.0,
|
||||||
|
generated_weights=1.0,
|
||||||
|
scope=None,
|
||||||
|
loss_collection=ops.GraphKeys.LOSSES,
|
||||||
|
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
|
||||||
|
add_summaries=False):
|
||||||
|
"""Original minimax discriminator loss for GANs, with label smoothing.
|
||||||
|
|
||||||
|
Note that the authors don't recommend using this loss. A more practically
|
||||||
|
useful loss is `modified_discriminator_loss`.
|
||||||
|
|
||||||
|
L = - real_weights * log(sigmoid(D(x)))
|
||||||
|
- generated_weights * log(1 - sigmoid(D(G(z))))
|
||||||
|
|
||||||
|
See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more
|
||||||
|
details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
discriminator_real_outputs: Discriminator output on real data.
|
||||||
|
discriminator_gen_outputs: Discriminator output on generated data. Expected
|
||||||
|
to be in the range of (-inf, inf).
|
||||||
|
label_smoothing: The amount of smoothing for positive labels. This technique
|
||||||
|
is taken from `Improved Techniques for Training GANs`
|
||||||
|
(https://arxiv.org/abs/1606.03498). `0.0` means no smoothing.
|
||||||
|
real_weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale
|
||||||
|
the real loss.
|
||||||
|
generated_weights: A scalar or a `Tensor` of size [batch_size, K] used to
|
||||||
|
rescale the generated loss.
|
||||||
|
scope: The scope for the operations performed in computing the loss.
|
||||||
|
loss_collection: collection to which this loss will be added.
|
||||||
|
reduction: A `tf.losses.Reduction` to apply to loss.
|
||||||
|
add_summaries: Whether or not to add summaries for the loss.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A loss Tensor. The shape depends on `reduction`.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(scope, 'discriminator_minimax_loss', (
|
||||||
|
discriminator_real_outputs, discriminator_gen_outputs, real_weights,
|
||||||
|
generated_weights, label_smoothing)) as scope:
|
||||||
|
|
||||||
|
# -log((1 - label_smoothing) - sigmoid(D(x)))
|
||||||
|
loss_on_real = losses.sigmoid_cross_entropy(
|
||||||
|
array_ops.ones_like(discriminator_real_outputs),
|
||||||
|
discriminator_real_outputs, real_weights, label_smoothing, scope,
|
||||||
|
loss_collection=None, reduction=reduction)
|
||||||
|
# -log(- sigmoid(D(G(x))))
|
||||||
|
loss_on_generated = losses.sigmoid_cross_entropy(
|
||||||
|
array_ops.zeros_like(discriminator_gen_outputs),
|
||||||
|
discriminator_gen_outputs, generated_weights, scope=scope,
|
||||||
|
loss_collection=None, reduction=reduction)
|
||||||
|
|
||||||
|
loss = loss_on_real + loss_on_generated
|
||||||
|
util.add_loss(loss, loss_collection)
|
||||||
|
|
||||||
|
if add_summaries:
|
||||||
|
summary.scalar('discriminator_gen_minimax_loss', loss_on_generated)
|
||||||
|
summary.scalar('discriminator_real_minimax_loss', loss_on_real)
|
||||||
|
summary.scalar('discriminator_minimax_loss', loss)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def minimax_generator_loss(
|
||||||
|
discriminator_gen_outputs,
|
||||||
|
label_smoothing=0.0,
|
||||||
|
weights=1.0,
|
||||||
|
scope=None,
|
||||||
|
loss_collection=ops.GraphKeys.LOSSES,
|
||||||
|
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
|
||||||
|
add_summaries=False):
|
||||||
|
"""Original minimax generator loss for GANs.
|
||||||
|
|
||||||
|
Note that the authors don't recommend using this loss. A more practically
|
||||||
|
useful loss is `modified_generator_loss`.
|
||||||
|
|
||||||
|
L = log(sigmoid(D(x))) + log(1 - sigmoid(D(G(z))))
|
||||||
|
|
||||||
|
See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more
|
||||||
|
details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
discriminator_gen_outputs: Discriminator output on generated data. Expected
|
||||||
|
to be in the range of (-inf, inf).
|
||||||
|
label_smoothing: The amount of smoothing for positive labels. This technique
|
||||||
|
is taken from `Improved Techniques for Training GANs`
|
||||||
|
(https://arxiv.org/abs/1606.03498). `0.0` means no smoothing.
|
||||||
|
weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale
|
||||||
|
the loss.
|
||||||
|
scope: The scope for the operations performed in computing the loss.
|
||||||
|
loss_collection: collection to which this loss will be added.
|
||||||
|
reduction: A `tf.losses.Reduction` to apply to loss.
|
||||||
|
add_summaries: Whether or not to add summaries for the loss.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A loss Tensor. The shape depends on `reduction`.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(scope, 'generator_minimax_loss') as scope:
|
||||||
|
loss = - minimax_discriminator_loss(
|
||||||
|
array_ops.ones_like(discriminator_gen_outputs),
|
||||||
|
discriminator_gen_outputs, label_smoothing, weights, weights, scope,
|
||||||
|
loss_collection, reduction, add_summaries=False)
|
||||||
|
|
||||||
|
if add_summaries:
|
||||||
|
summary.scalar('generator_minimax_loss', loss)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def modified_discriminator_loss(
|
||||||
|
discriminator_real_outputs,
|
||||||
|
discriminator_gen_outputs,
|
||||||
|
label_smoothing=0.25,
|
||||||
|
real_weights=1.0,
|
||||||
|
generated_weights=1.0,
|
||||||
|
scope=None,
|
||||||
|
loss_collection=ops.GraphKeys.LOSSES,
|
||||||
|
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
|
||||||
|
add_summaries=False):
|
||||||
|
"""Same as minimax discriminator loss.
|
||||||
|
|
||||||
|
See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more
|
||||||
|
details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
discriminator_real_outputs: Discriminator output on real data.
|
||||||
|
discriminator_gen_outputs: Discriminator output on generated data. Expected
|
||||||
|
to be in the range of (-inf, inf).
|
||||||
|
label_smoothing: The amount of smoothing for positive labels. This technique
|
||||||
|
is taken from `Improved Techniques for Training GANs`
|
||||||
|
(https://arxiv.org/abs/1606.03498). `0.0` means no smoothing.
|
||||||
|
real_weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale
|
||||||
|
the real loss.
|
||||||
|
generated_weights: A scalar or a `Tensor` of size [batch_size, K] used to
|
||||||
|
rescale the generated loss.
|
||||||
|
scope: The scope for the operations performed in computing the loss.
|
||||||
|
loss_collection: collection to which this loss will be added.
|
||||||
|
reduction: A `tf.losses.Reduction` to apply to loss.
|
||||||
|
add_summaries: Whether or not to add summaries for the loss.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A loss Tensor. The shape depends on `reduction`.
|
||||||
|
"""
|
||||||
|
return minimax_discriminator_loss(
|
||||||
|
discriminator_real_outputs,
|
||||||
|
discriminator_gen_outputs,
|
||||||
|
label_smoothing,
|
||||||
|
real_weights,
|
||||||
|
generated_weights,
|
||||||
|
scope or 'discriminator_modified_loss',
|
||||||
|
loss_collection,
|
||||||
|
reduction,
|
||||||
|
add_summaries)
|
||||||
|
|
||||||
|
|
||||||
|
def modified_generator_loss(
|
||||||
|
discriminator_gen_outputs,
|
||||||
|
label_smoothing=0.0,
|
||||||
|
weights=1.0,
|
||||||
|
scope='generator_modified_loss',
|
||||||
|
loss_collection=ops.GraphKeys.LOSSES,
|
||||||
|
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
|
||||||
|
add_summaries=False):
|
||||||
|
"""Modified generator loss for GANs.
|
||||||
|
|
||||||
|
L = -log(sigmoid(D(G(z))))
|
||||||
|
|
||||||
|
This is the trick used in the original paper to avoid vanishing gradients
|
||||||
|
early in training. See `Generative Adversarial Nets`
|
||||||
|
(https://arxiv.org/abs/1406.2661) for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
discriminator_gen_outputs: Discriminator output on generated data. Expected
|
||||||
|
to be in the range of (-inf, inf).
|
||||||
|
label_smoothing: The amount of smoothing for positive labels. This technique
|
||||||
|
is taken from `Improved Techniques for Training GANs`
|
||||||
|
(https://arxiv.org/abs/1606.03498). `0.0` means no smoothing.
|
||||||
|
weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
||||||
|
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
|
||||||
|
be either `1`, or the same as the corresponding `losses` dimension).
|
||||||
|
scope: The scope for the operations performed in computing the loss.
|
||||||
|
loss_collection: collection to which this loss will be added.
|
||||||
|
reduction: A `tf.losses.Reduction` to apply to loss.
|
||||||
|
add_summaries: Whether or not to add summaries for the loss.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A loss Tensor. The shape depends on `reduction`.
|
||||||
|
"""
|
||||||
|
loss = losses.sigmoid_cross_entropy(
|
||||||
|
array_ops.ones_like(discriminator_gen_outputs), discriminator_gen_outputs,
|
||||||
|
weights, label_smoothing, scope, loss_collection, reduction)
|
||||||
|
|
||||||
|
if add_summaries:
|
||||||
|
summary.scalar('generator_modified_loss', loss)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
# Least Squares loss from `Least Squares Generative Adversarial Networks`
|
||||||
|
# (https://arxiv.org/abs/1611.04076).
|
||||||
|
|
||||||
|
|
||||||
|
def least_squares_generator_loss(
|
||||||
|
discriminator_gen_outputs,
|
||||||
|
real_label=1,
|
||||||
|
weights=1.0,
|
||||||
|
scope=None,
|
||||||
|
loss_collection=ops.GraphKeys.LOSSES,
|
||||||
|
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
|
||||||
|
add_summaries=False):
|
||||||
|
"""Least squares generator loss.
|
||||||
|
|
||||||
|
This loss comes from `Least Squares Generative Adversarial Networks`
|
||||||
|
(https://arxiv.org/abs/1611.04076).
|
||||||
|
|
||||||
|
L = 1/2 * (D(G(z)) - `real_label`) ** 2
|
||||||
|
|
||||||
|
where D(y) are discriminator logits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
discriminator_gen_outputs: Discriminator output on generated data. Expected
|
||||||
|
to be in the range of (-inf, inf).
|
||||||
|
real_label: The value that the generator is trying to get the discriminator
|
||||||
|
to output on generated data.
|
||||||
|
weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
||||||
|
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
|
||||||
|
be either `1`, or the same as the corresponding `losses` dimension).
|
||||||
|
scope: The scope for the operations performed in computing the loss.
|
||||||
|
loss_collection: collection to which this loss will be added.
|
||||||
|
reduction: A `tf.losses.Reduction` to apply to loss.
|
||||||
|
add_summaries: Whether or not to add summaries for the loss.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A loss Tensor. The shape depends on `reduction`.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(scope, 'lsq_generator_loss',
|
||||||
|
(discriminator_gen_outputs, real_label)) as scope:
|
||||||
|
discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs)
|
||||||
|
loss = math_ops.squared_difference(
|
||||||
|
discriminator_gen_outputs, real_label) / 2.0
|
||||||
|
loss = losses.compute_weighted_loss(
|
||||||
|
loss, weights, scope, loss_collection, reduction)
|
||||||
|
|
||||||
|
if add_summaries:
|
||||||
|
summary.scalar('generator_lsq_loss', loss)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def least_squares_discriminator_loss(
|
||||||
|
discriminator_real_outputs,
|
||||||
|
discriminator_gen_outputs,
|
||||||
|
real_label=1,
|
||||||
|
fake_label=0,
|
||||||
|
real_weights=1.0,
|
||||||
|
generated_weights=1.0,
|
||||||
|
scope=None,
|
||||||
|
loss_collection=ops.GraphKeys.LOSSES,
|
||||||
|
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
|
||||||
|
add_summaries=False):
|
||||||
|
"""Least squares generator loss.
|
||||||
|
|
||||||
|
This loss comes from `Least Squares Generative Adversarial Networks`
|
||||||
|
(https://arxiv.org/abs/1611.04076).
|
||||||
|
|
||||||
|
L = 1/2 * (D(x) - `real`) ** 2 +
|
||||||
|
1/2 * (D(G(z)) - `fake_label`) ** 2
|
||||||
|
|
||||||
|
where D(y) are discriminator logits.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
discriminator_real_outputs: Discriminator output on real data.
|
||||||
|
discriminator_gen_outputs: Discriminator output on generated data. Expected
|
||||||
|
to be in the range of (-inf, inf).
|
||||||
|
real_label: The value that the discriminator tries to output for real data.
|
||||||
|
fake_label: The value that the discriminator tries to output for fake data.
|
||||||
|
real_weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale
|
||||||
|
the real loss.
|
||||||
|
generated_weights: A scalar or a `Tensor` of size [batch_size, K] used to
|
||||||
|
rescale the generated loss.
|
||||||
|
scope: The scope for the operations performed in computing the loss.
|
||||||
|
loss_collection: collection to which this loss will be added.
|
||||||
|
reduction: A `tf.losses.Reduction` to apply to loss.
|
||||||
|
add_summaries: Whether or not to add summaries for the loss.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A loss Tensor. The shape depends on `reduction`.
|
||||||
|
"""
|
||||||
|
with ops.name_scope(scope, 'lsq_discriminator_loss',
|
||||||
|
(discriminator_gen_outputs, real_label)) as scope:
|
||||||
|
discriminator_real_outputs = math_ops.to_float(discriminator_real_outputs)
|
||||||
|
discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs)
|
||||||
|
discriminator_real_outputs.shape.assert_is_compatible_with(
|
||||||
|
discriminator_gen_outputs.shape)
|
||||||
|
|
||||||
|
real_losses = math_ops.squared_difference(
|
||||||
|
discriminator_real_outputs, real_label) / 2.0
|
||||||
|
fake_losses = math_ops.squared_difference(
|
||||||
|
discriminator_gen_outputs, fake_label) / 2.0
|
||||||
|
|
||||||
|
loss_on_real = losses.compute_weighted_loss(
|
||||||
|
real_losses, real_weights, scope, loss_collection=None,
|
||||||
|
reduction=reduction)
|
||||||
|
loss_on_generated = losses.compute_weighted_loss(
|
||||||
|
fake_losses, generated_weights, scope, loss_collection=None,
|
||||||
|
reduction=reduction)
|
||||||
|
|
||||||
|
loss = loss_on_real + loss_on_generated
|
||||||
|
util.add_loss(loss, loss_collection)
|
||||||
|
|
||||||
|
if add_summaries:
|
||||||
|
summary.scalar('discriminator_gen_lsq_loss', loss_on_generated)
|
||||||
|
summary.scalar('discriminator_real_lsq_loss', loss_on_real)
|
||||||
|
summary.scalar('discriminator_lsq_loss', loss)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
# InfoGAN loss from `InfoGAN: Interpretable Representation Learning by
|
||||||
|
# `Information Maximizing Generative Adversarial Nets`
|
||||||
|
# https://arxiv.org/abs/1606.03657
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_distributions(distributions):
|
||||||
|
if not isinstance(distributions, (list, tuple)):
|
||||||
|
raise ValueError('`distributions` must be a list or tuple. Instead, '
|
||||||
|
'found %s.', type(distributions))
|
||||||
|
for x in distributions:
|
||||||
|
if not isinstance(x, ds.Distribution):
|
||||||
|
raise ValueError('`distributions` must be a list of `Distributions`. '
|
||||||
|
'Instead, found %s.', type(x))
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_information_penalty_inputs(
|
||||||
|
structured_generator_inputs, predicted_distributions):
|
||||||
|
"""Validate input to `mutual_information_penalty`."""
|
||||||
|
_validate_distributions(predicted_distributions)
|
||||||
|
if len(structured_generator_inputs) != len(predicted_distributions):
|
||||||
|
raise ValueError('`structured_generator_inputs` length %i must be the same '
|
||||||
|
'as `predicted_distributions` length %i.' % (
|
||||||
|
len(structured_generator_inputs),
|
||||||
|
len(predicted_distributions)))
|
||||||
|
|
||||||
|
|
||||||
|
def mutual_information_penalty(
|
||||||
|
structured_generator_inputs,
|
||||||
|
predicted_distributions,
|
||||||
|
weights=1.0,
|
||||||
|
scope='generator_modified_loss',
|
||||||
|
loss_collection=ops.GraphKeys.LOSSES,
|
||||||
|
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
|
||||||
|
add_summaries=False):
|
||||||
|
"""Returns a penalty on the mutual information in an InfoGAN model.
|
||||||
|
|
||||||
|
This loss comes from an InfoGAN paper https://arxiv.org/abs/1606.03657.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
structured_generator_inputs: A list of Tensors representing the random noise
|
||||||
|
that must have high mutual information with the generator output. List
|
||||||
|
length should match `predicted_distributions`.
|
||||||
|
predicted_distributions: A list of tf.Distributions. Predicted by the
|
||||||
|
recognizer, and used to evaluate the likelihood of the structured noise.
|
||||||
|
List length should match `structured_generator_inputs`.
|
||||||
|
weights: Optional `Tensor` whose rank is either 0, or the same rank as
|
||||||
|
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
|
||||||
|
be either `1`, or the same as the corresponding `losses` dimension).
|
||||||
|
scope: The scope for the operations performed in computing the loss.
|
||||||
|
loss_collection: collection to which this loss will be added.
|
||||||
|
reduction: A `tf.losses.Reduction` to apply to loss.
|
||||||
|
add_summaries: Whether or not to add summaries for the loss.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A scalar Tensor representing the mutual information loss.
|
||||||
|
"""
|
||||||
|
_validate_information_penalty_inputs(
|
||||||
|
structured_generator_inputs, predicted_distributions)
|
||||||
|
|
||||||
|
# Calculate the negative log-likelihood of the reconstructed noise.
|
||||||
|
log_probs = [math_ops.reduce_mean(dist.log_prob(noise)) for dist, noise in
|
||||||
|
zip(predicted_distributions, structured_generator_inputs)]
|
||||||
|
loss = -1 * losses.compute_weighted_loss(
|
||||||
|
log_probs, weights, scope, loss_collection=loss_collection,
|
||||||
|
reduction=reduction)
|
||||||
|
|
||||||
|
if add_summaries:
|
||||||
|
summary.scalar('mutual_information_penalty', loss)
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
def _numerically_stable_global_norm(tensor_list):
|
||||||
|
"""Compute the global norm of a list of Tensors, with improved stability.
|
||||||
|
|
||||||
|
The global norm computation sometimes overflows due to the intermediate L2
|
||||||
|
step. To avoid this, we divide by a cheap-to-compute max over the
|
||||||
|
matrix elements.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor_list: A list of tensors, or `None`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A scalar tensor with the global norm.
|
||||||
|
"""
|
||||||
|
if np.all([x is None for x in tensor_list]):
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
list_max = math_ops.reduce_max([math_ops.reduce_max(math_ops.abs(x)) for x in
|
||||||
|
tensor_list if x is not None])
|
||||||
|
return list_max * clip_ops.global_norm([x / list_max for x in tensor_list
|
||||||
|
if x is not None])
|
||||||
|
|
||||||
|
|
||||||
|
def _used_weight(weights_list):
|
||||||
|
for weight in weights_list:
|
||||||
|
if weight is not None:
|
||||||
|
return tensor_util.constant_value(ops.convert_to_tensor(weight))
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_args(losses_list, weight_factor, gradient_ratio):
|
||||||
|
for loss in losses_list:
|
||||||
|
loss.shape.assert_is_compatible_with([])
|
||||||
|
if weight_factor is None and gradient_ratio is None:
|
||||||
|
raise ValueError(
|
||||||
|
'`weight_factor` and `gradient_ratio` cannot both be `None.`')
|
||||||
|
if weight_factor is not None and gradient_ratio is not None:
|
||||||
|
raise ValueError(
|
||||||
|
'`weight_factor` and `gradient_ratio` cannot both be specified.')
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(joelshor): Add ability to pass in gradients, to avoid recomputing.
|
||||||
|
def combine_adversarial_loss(main_loss,
|
||||||
|
adversarial_loss,
|
||||||
|
weight_factor=None,
|
||||||
|
gradient_ratio=None,
|
||||||
|
gradient_ratio_epsilon=1e-6,
|
||||||
|
variables=None,
|
||||||
|
scalar_summaries=True,
|
||||||
|
gradient_summaries=True,
|
||||||
|
scope=None):
|
||||||
|
"""Utility to combine main and adversarial losses.
|
||||||
|
|
||||||
|
This utility combines the main and adversarial losses in one of two ways.
|
||||||
|
1) Fixed coefficient on adversarial loss. Use `weight_factor` in this case.
|
||||||
|
2) Fixed ratio of gradients. Use `gradient_ratio` in this case. This is often
|
||||||
|
used to make sure both losses affect weights roughly equally, as in
|
||||||
|
https://arxiv.org/pdf/1705.05823.
|
||||||
|
|
||||||
|
One can optionally also visualize the scalar and gradient behavior of the
|
||||||
|
losses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
main_loss: A floating scalar Tensor indicating the main loss.
|
||||||
|
adversarial_loss: A floating scalar Tensor indication the adversarial loss.
|
||||||
|
weight_factor: If not `None`, the coefficient by which to multiply the
|
||||||
|
adversarial loss. Exactly one of this and `gradient_ratio` must be
|
||||||
|
non-None.
|
||||||
|
gradient_ratio: If not `None`, the ratio of the magnitude of the gradients.
|
||||||
|
Specifically,
|
||||||
|
gradient_ratio = grad_mag(main_loss) / grad_mag(adversarial_loss)
|
||||||
|
Exactly one of this and `weight_factor` must be non-None.
|
||||||
|
gradient_ratio_epsilon: An epsilon to add to the adversarial loss
|
||||||
|
coefficient denominator, to avoid division-by-zero.
|
||||||
|
variables: List of variables to calculate gradients with respect to. If not
|
||||||
|
present, defaults to all trainable variables.
|
||||||
|
scalar_summaries: Create scalar summaries of losses.
|
||||||
|
gradient_summaries: Create gradient summaries of losses.
|
||||||
|
scope: Optional name scope.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A floating scalar Tensor indicating the desired combined loss.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Malformed input.
|
||||||
|
"""
|
||||||
|
_validate_args([main_loss, adversarial_loss], weight_factor, gradient_ratio)
|
||||||
|
if variables is None:
|
||||||
|
variables = contrib_variables_lib.get_trainable_variables()
|
||||||
|
|
||||||
|
with ops.name_scope(scope, 'adversarial_loss',
|
||||||
|
values=[main_loss, adversarial_loss]):
|
||||||
|
# Compute gradients if we will need them.
|
||||||
|
if gradient_summaries or gradient_ratio is not None:
|
||||||
|
main_loss_grad_mag = _numerically_stable_global_norm(
|
||||||
|
gradients_impl.gradients(main_loss, variables))
|
||||||
|
adv_loss_grad_mag = _numerically_stable_global_norm(
|
||||||
|
gradients_impl.gradients(adversarial_loss, variables))
|
||||||
|
|
||||||
|
# Add summaries, if applicable.
|
||||||
|
if scalar_summaries:
|
||||||
|
summary.scalar('main_loss', main_loss)
|
||||||
|
summary.scalar('adversarial_loss', adversarial_loss)
|
||||||
|
if gradient_summaries:
|
||||||
|
summary.scalar('main_loss_gradients', main_loss_grad_mag)
|
||||||
|
summary.scalar('adversarial_loss_gradients', adv_loss_grad_mag)
|
||||||
|
|
||||||
|
# Combine losses in the appropriate way.
|
||||||
|
# If `weight_factor` is always `0`, avoid computing the adversarial loss
|
||||||
|
# tensor entirely.
|
||||||
|
if _used_weight((weight_factor, gradient_ratio)) == 0:
|
||||||
|
final_loss = main_loss
|
||||||
|
elif weight_factor is not None:
|
||||||
|
final_loss = (main_loss +
|
||||||
|
array_ops.stop_gradient(weight_factor) * adversarial_loss)
|
||||||
|
elif gradient_ratio is not None:
|
||||||
|
grad_mag_ratio = main_loss_grad_mag / (
|
||||||
|
adv_loss_grad_mag + gradient_ratio_epsilon)
|
||||||
|
adv_coeff = grad_mag_ratio / gradient_ratio
|
||||||
|
summary.scalar('adversarial_coefficient', adv_coeff)
|
||||||
|
final_loss = (main_loss +
|
||||||
|
array_ops.stop_gradient(adv_coeff) * adversarial_loss)
|
||||||
|
|
||||||
|
return final_loss
|
606
tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
Normal file
606
tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
Normal file
@ -0,0 +1,606 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for TFGAN losses."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.gan.python.losses.python import losses_impl as tfgan_losses
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import random_seed
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import clip_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops import variable_scope
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.ops.distributions import categorical
|
||||||
|
from tensorflow.python.ops.distributions import normal
|
||||||
|
from tensorflow.python.ops.losses import losses as tf_losses
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(joelshor): Use `parameterized` tests when opensourced.
|
||||||
|
class _LossesTest(object):
|
||||||
|
|
||||||
|
def init_constants(self):
|
||||||
|
self._discriminator_real_outputs_np = [-5.0, 1.4, 12.5, 2.7]
|
||||||
|
self._discriminator_gen_outputs_np = [10.0, 4.4, -5.5, 3.6]
|
||||||
|
self._weights = 2.3
|
||||||
|
self._discriminator_real_outputs = constant_op.constant(
|
||||||
|
self._discriminator_real_outputs_np, dtype=dtypes.float32)
|
||||||
|
self._discriminator_gen_outputs = constant_op.constant(
|
||||||
|
self._discriminator_gen_outputs_np, dtype=dtypes.float32)
|
||||||
|
|
||||||
|
def test_generator_all_correct(self):
|
||||||
|
loss = self._g_loss_fn(self._discriminator_gen_outputs)
|
||||||
|
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
|
||||||
|
self.assertEqual(self._generator_loss_name, loss.op.name)
|
||||||
|
with self.test_session():
|
||||||
|
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
|
||||||
|
|
||||||
|
def test_discriminator_all_correct(self):
|
||||||
|
loss = self._d_loss_fn(
|
||||||
|
self._discriminator_real_outputs, self._discriminator_gen_outputs)
|
||||||
|
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
|
||||||
|
self.assertEqual(self._discriminator_loss_name, loss.op.name)
|
||||||
|
with self.test_session():
|
||||||
|
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
|
||||||
|
|
||||||
|
def test_generator_loss_collection(self):
|
||||||
|
self.assertEqual(0, len(ops.get_collection('collection')))
|
||||||
|
self._g_loss_fn(
|
||||||
|
self._discriminator_gen_outputs, loss_collection='collection')
|
||||||
|
self.assertEqual(1, len(ops.get_collection('collection')))
|
||||||
|
|
||||||
|
def test_discriminator_loss_collection(self):
|
||||||
|
self.assertEqual(0, len(ops.get_collection('collection')))
|
||||||
|
self._d_loss_fn(
|
||||||
|
self._discriminator_real_outputs, self._discriminator_gen_outputs,
|
||||||
|
loss_collection='collection')
|
||||||
|
self.assertEqual(1, len(ops.get_collection('collection')))
|
||||||
|
|
||||||
|
def test_generator_no_reduction(self):
|
||||||
|
loss = self._g_loss_fn(
|
||||||
|
self._discriminator_gen_outputs, reduction=tf_losses.Reduction.NONE)
|
||||||
|
self.assertAllEqual([4], loss.shape)
|
||||||
|
|
||||||
|
def test_discriminator_no_reduction(self):
|
||||||
|
loss = self._d_loss_fn(
|
||||||
|
self._discriminator_real_outputs, self._discriminator_gen_outputs,
|
||||||
|
reduction=tf_losses.Reduction.NONE)
|
||||||
|
self.assertAllEqual([4], loss.shape)
|
||||||
|
|
||||||
|
def test_generator_patch(self):
|
||||||
|
loss = self._g_loss_fn(
|
||||||
|
array_ops.reshape(self._discriminator_gen_outputs, [2, 2]))
|
||||||
|
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
|
||||||
|
with self.test_session():
|
||||||
|
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
|
||||||
|
|
||||||
|
def test_discriminator_patch(self):
|
||||||
|
loss = self._d_loss_fn(
|
||||||
|
array_ops.reshape(self._discriminator_real_outputs, [2, 2]),
|
||||||
|
array_ops.reshape(self._discriminator_gen_outputs, [2, 2]))
|
||||||
|
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
|
||||||
|
with self.test_session():
|
||||||
|
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
|
||||||
|
|
||||||
|
def test_generator_loss_with_placeholder_for_logits(self):
|
||||||
|
logits = array_ops.placeholder(dtypes.float32, shape=(None, 4))
|
||||||
|
weights = array_ops.ones_like(logits, dtype=dtypes.float32)
|
||||||
|
|
||||||
|
loss = self._g_loss_fn(logits, weights=weights)
|
||||||
|
self.assertEqual(logits.dtype, loss.dtype)
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
loss = sess.run(loss,
|
||||||
|
feed_dict={
|
||||||
|
logits: [[10.0, 4.4, -5.5, 3.6]],
|
||||||
|
})
|
||||||
|
self.assertAlmostEqual(self._expected_g_loss, loss, 5)
|
||||||
|
|
||||||
|
def test_discriminator_loss_with_placeholder_for_logits(self):
|
||||||
|
logits = array_ops.placeholder(dtypes.float32, shape=(None, 4))
|
||||||
|
logits2 = array_ops.placeholder(dtypes.float32, shape=(None, 4))
|
||||||
|
real_weights = array_ops.ones_like(logits, dtype=dtypes.float32)
|
||||||
|
generated_weights = array_ops.ones_like(logits, dtype=dtypes.float32)
|
||||||
|
|
||||||
|
loss = self._d_loss_fn(
|
||||||
|
logits, logits2, real_weights=real_weights,
|
||||||
|
generated_weights=generated_weights)
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
loss = sess.run(loss,
|
||||||
|
feed_dict={
|
||||||
|
logits: [self._discriminator_real_outputs_np],
|
||||||
|
logits2: [self._discriminator_gen_outputs_np],
|
||||||
|
})
|
||||||
|
self.assertAlmostEqual(self._expected_d_loss, loss, 5)
|
||||||
|
|
||||||
|
def test_generator_with_python_scalar_weight(self):
|
||||||
|
loss = self._g_loss_fn(
|
||||||
|
self._discriminator_gen_outputs, weights=self._weights)
|
||||||
|
with self.test_session():
|
||||||
|
self.assertAlmostEqual(self._expected_g_loss * self._weights,
|
||||||
|
loss.eval(), 4)
|
||||||
|
|
||||||
|
def test_discriminator_with_python_scalar_weight(self):
|
||||||
|
loss = self._d_loss_fn(
|
||||||
|
self._discriminator_real_outputs, self._discriminator_gen_outputs,
|
||||||
|
real_weights=self._weights, generated_weights=self._weights)
|
||||||
|
with self.test_session():
|
||||||
|
self.assertAlmostEqual(self._expected_d_loss * self._weights,
|
||||||
|
loss.eval(), 4)
|
||||||
|
|
||||||
|
def test_generator_with_scalar_tensor_weight(self):
|
||||||
|
loss = self._g_loss_fn(self._discriminator_gen_outputs,
|
||||||
|
weights=constant_op.constant(self._weights))
|
||||||
|
with self.test_session():
|
||||||
|
self.assertAlmostEqual(self._expected_g_loss * self._weights,
|
||||||
|
loss.eval(), 4)
|
||||||
|
|
||||||
|
def test_discriminator_with_scalar_tensor_weight(self):
|
||||||
|
weights = constant_op.constant(self._weights)
|
||||||
|
loss = self._d_loss_fn(
|
||||||
|
self._discriminator_real_outputs, self._discriminator_gen_outputs,
|
||||||
|
real_weights=weights, generated_weights=weights)
|
||||||
|
with self.test_session():
|
||||||
|
self.assertAlmostEqual(self._expected_d_loss * self._weights,
|
||||||
|
loss.eval(), 4)
|
||||||
|
|
||||||
|
def test_generator_add_summaries(self):
|
||||||
|
self.assertEqual(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
|
||||||
|
self._g_loss_fn(self._discriminator_gen_outputs, add_summaries=True)
|
||||||
|
self.assertLess(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
|
||||||
|
|
||||||
|
def test_discriminator_add_summaries(self):
|
||||||
|
self.assertEqual(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
|
||||||
|
self._d_loss_fn(
|
||||||
|
self._discriminator_real_outputs, self._discriminator_gen_outputs,
|
||||||
|
add_summaries=True)
|
||||||
|
self.assertLess(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
|
||||||
|
|
||||||
|
|
||||||
|
class LeastSquaresLossTest(test.TestCase, _LossesTest):
|
||||||
|
"""Tests for least_squares_xxx_loss."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(LeastSquaresLossTest, self).setUp()
|
||||||
|
self.init_constants()
|
||||||
|
self._expected_g_loss = 17.69625
|
||||||
|
self._expected_d_loss = 41.73375
|
||||||
|
self._generator_loss_name = 'lsq_generator_loss/value'
|
||||||
|
self._discriminator_loss_name = 'lsq_discriminator_loss/add'
|
||||||
|
self._g_loss_fn = tfgan_losses.least_squares_generator_loss
|
||||||
|
self._d_loss_fn = tfgan_losses.least_squares_discriminator_loss
|
||||||
|
|
||||||
|
|
||||||
|
class ModifiedLossTest(test.TestCase, _LossesTest):
|
||||||
|
"""Tests for modified_xxx_loss."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(ModifiedLossTest, self).setUp()
|
||||||
|
self.init_constants()
|
||||||
|
self._expected_g_loss = 1.38582
|
||||||
|
self._expected_d_loss = 6.19637
|
||||||
|
self._generator_loss_name = 'generator_modified_loss/value'
|
||||||
|
self._discriminator_loss_name = 'discriminator_modified_loss/add_1'
|
||||||
|
self._g_loss_fn = tfgan_losses.modified_generator_loss
|
||||||
|
self._d_loss_fn = tfgan_losses.modified_discriminator_loss
|
||||||
|
|
||||||
|
|
||||||
|
class MinimaxLossTest(test.TestCase, _LossesTest):
|
||||||
|
"""Tests for minimax_xxx_loss."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(MinimaxLossTest, self).setUp()
|
||||||
|
self.init_constants()
|
||||||
|
self._expected_g_loss = -4.82408
|
||||||
|
self._expected_d_loss = 6.19637
|
||||||
|
self._generator_loss_name = 'generator_minimax_loss/Neg'
|
||||||
|
self._discriminator_loss_name = 'discriminator_minimax_loss/add_1'
|
||||||
|
self._g_loss_fn = tfgan_losses.minimax_generator_loss
|
||||||
|
self._d_loss_fn = tfgan_losses.minimax_discriminator_loss
|
||||||
|
|
||||||
|
|
||||||
|
class WassersteinLossTest(test.TestCase, _LossesTest):
|
||||||
|
"""Tests for wasserstein_xxx_loss."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(WassersteinLossTest, self).setUp()
|
||||||
|
self.init_constants()
|
||||||
|
self._expected_g_loss = -3.12500
|
||||||
|
self._expected_d_loss = 0.22500
|
||||||
|
self._generator_loss_name = 'generator_wasserstein_loss/value'
|
||||||
|
self._discriminator_loss_name = 'discriminator_wasserstein_loss/sub'
|
||||||
|
self._g_loss_fn = tfgan_losses.wasserstein_generator_loss
|
||||||
|
self._d_loss_fn = tfgan_losses.wasserstein_discriminator_loss
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(joelshor): Use `parameterized` tests when opensourced.
|
||||||
|
# TODO(joelshor): Refactor this test to use the same code as the other losses.
|
||||||
|
class ACGANLossTest(test.TestCase):
|
||||||
|
"""Tests for wasserstein_xxx_loss."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(ACGANLossTest, self).setUp()
|
||||||
|
self._g_loss_fn = tfgan_losses.acgan_generator_loss
|
||||||
|
self._d_loss_fn = tfgan_losses.acgan_discriminator_loss
|
||||||
|
self._discriminator_gen_classification_logits_np = [[10.0, 4.4, -5.5, 3.6],
|
||||||
|
[-4.0, 4.4, 5.2, 4.6],
|
||||||
|
[1.1, 2.4, -3.5, 5.6],
|
||||||
|
[1.1, 2.4, -3.5, 5.6]]
|
||||||
|
self._discriminator_real_classification_logits_np = [[-2.0, 0.4, 12.5, 2.7],
|
||||||
|
[-1.2, 1.9, 12.3, 2.6],
|
||||||
|
[-2.4, -1.7, 2.5, 2.7],
|
||||||
|
[1.1, 2.4, -3.5, 5.6]]
|
||||||
|
self._one_hot_labels_np = [[0, 1, 0, 0],
|
||||||
|
[0, 0, 1, 0],
|
||||||
|
[1, 0, 0, 0],
|
||||||
|
[1, 0, 0, 0]]
|
||||||
|
self._weights = 2.3
|
||||||
|
|
||||||
|
self._discriminator_gen_classification_logits = constant_op.constant(
|
||||||
|
self._discriminator_gen_classification_logits_np, dtype=dtypes.float32)
|
||||||
|
self._discriminator_real_classification_logits = constant_op.constant(
|
||||||
|
self._discriminator_real_classification_logits_np, dtype=dtypes.float32)
|
||||||
|
self._one_hot_labels = constant_op.constant(
|
||||||
|
self._one_hot_labels_np, dtype=dtypes.float32)
|
||||||
|
self._generator_kwargs = {
|
||||||
|
'discriminator_gen_classification_logits':
|
||||||
|
self._discriminator_gen_classification_logits,
|
||||||
|
'one_hot_labels': self._one_hot_labels,
|
||||||
|
}
|
||||||
|
self._discriminator_kwargs = {
|
||||||
|
'discriminator_gen_classification_logits':
|
||||||
|
self._discriminator_gen_classification_logits,
|
||||||
|
'discriminator_real_classification_logits':
|
||||||
|
self._discriminator_real_classification_logits,
|
||||||
|
'one_hot_labels': self._one_hot_labels,
|
||||||
|
}
|
||||||
|
self._generator_loss_name = 'softmax_cross_entropy_loss/value'
|
||||||
|
self._discriminator_loss_name = 'add'
|
||||||
|
self._expected_g_loss = 3.84974
|
||||||
|
self._expected_d_loss = 9.43950
|
||||||
|
|
||||||
|
def test_generator_all_correct(self):
|
||||||
|
loss = self._g_loss_fn(**self._generator_kwargs)
|
||||||
|
self.assertEqual(
|
||||||
|
self._discriminator_gen_classification_logits.dtype, loss.dtype)
|
||||||
|
self.assertEqual(self._generator_loss_name, loss.op.name)
|
||||||
|
with self.test_session():
|
||||||
|
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
|
||||||
|
|
||||||
|
def test_discriminator_all_correct(self):
|
||||||
|
loss = self._d_loss_fn(**self._discriminator_kwargs)
|
||||||
|
self.assertEqual(
|
||||||
|
self._discriminator_gen_classification_logits.dtype, loss.dtype)
|
||||||
|
self.assertEqual(self._discriminator_loss_name, loss.op.name)
|
||||||
|
with self.test_session():
|
||||||
|
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
|
||||||
|
|
||||||
|
def test_generator_loss_collection(self):
|
||||||
|
self.assertEqual(0, len(ops.get_collection('collection')))
|
||||||
|
self._g_loss_fn(loss_collection='collection', **self._generator_kwargs)
|
||||||
|
self.assertEqual(1, len(ops.get_collection('collection')))
|
||||||
|
|
||||||
|
def test_discriminator_loss_collection(self):
|
||||||
|
self.assertEqual(0, len(ops.get_collection('collection')))
|
||||||
|
self._d_loss_fn(loss_collection='collection', **self._discriminator_kwargs)
|
||||||
|
self.assertEqual(1, len(ops.get_collection('collection')))
|
||||||
|
|
||||||
|
def test_generator_no_reduction(self):
|
||||||
|
loss = self._g_loss_fn(
|
||||||
|
reduction=tf_losses.Reduction.NONE, **self._generator_kwargs)
|
||||||
|
self.assertAllEqual([4], loss.shape)
|
||||||
|
|
||||||
|
def test_discriminator_no_reduction(self):
|
||||||
|
loss = self._d_loss_fn(
|
||||||
|
reduction=tf_losses.Reduction.NONE, **self._discriminator_kwargs)
|
||||||
|
self.assertAllEqual([4], loss.shape)
|
||||||
|
|
||||||
|
def test_generator_patch(self):
|
||||||
|
patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in
|
||||||
|
self._generator_kwargs.items()}
|
||||||
|
loss = self._g_loss_fn(**patch_args)
|
||||||
|
with self.test_session():
|
||||||
|
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
|
||||||
|
|
||||||
|
def test_discriminator_patch(self):
|
||||||
|
patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in
|
||||||
|
self._discriminator_kwargs.items()}
|
||||||
|
loss = self._d_loss_fn(**patch_args)
|
||||||
|
with self.test_session():
|
||||||
|
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
|
||||||
|
|
||||||
|
def test_generator_loss_with_placeholder_for_logits(self):
|
||||||
|
gen_logits = array_ops.placeholder(dtypes.float32, shape=(None, 4))
|
||||||
|
one_hot_labels = array_ops.placeholder(dtypes.int32, shape=(None, 4))
|
||||||
|
|
||||||
|
loss = self._g_loss_fn(gen_logits, one_hot_labels)
|
||||||
|
with self.test_session() as sess:
|
||||||
|
loss = sess.run(
|
||||||
|
loss, feed_dict={
|
||||||
|
gen_logits: self._discriminator_gen_classification_logits_np,
|
||||||
|
one_hot_labels: self._one_hot_labels_np,
|
||||||
|
})
|
||||||
|
self.assertAlmostEqual(self._expected_g_loss, loss, 5)
|
||||||
|
|
||||||
|
def test_discriminator_loss_with_placeholder_for_logits_and_weights(self):
|
||||||
|
gen_logits = array_ops.placeholder(dtypes.float32, shape=(None, 4))
|
||||||
|
real_logits = array_ops.placeholder(dtypes.float32, shape=(None, 4))
|
||||||
|
one_hot_labels = array_ops.placeholder(dtypes.int32, shape=(None, 4))
|
||||||
|
|
||||||
|
loss = self._d_loss_fn(gen_logits, real_logits, one_hot_labels)
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
loss = sess.run(
|
||||||
|
loss, feed_dict={
|
||||||
|
gen_logits: self._discriminator_gen_classification_logits_np,
|
||||||
|
real_logits: self._discriminator_real_classification_logits_np,
|
||||||
|
one_hot_labels: self._one_hot_labels_np,
|
||||||
|
})
|
||||||
|
self.assertAlmostEqual(self._expected_d_loss, loss, 5)
|
||||||
|
|
||||||
|
def test_generator_with_python_scalar_weight(self):
|
||||||
|
loss = self._g_loss_fn(weights=self._weights, **self._generator_kwargs)
|
||||||
|
with self.test_session():
|
||||||
|
self.assertAlmostEqual(self._expected_g_loss * self._weights,
|
||||||
|
loss.eval(), 4)
|
||||||
|
|
||||||
|
def test_discriminator_with_python_scalar_weight(self):
|
||||||
|
loss = self._d_loss_fn(
|
||||||
|
real_weights=self._weights, generated_weights=self._weights,
|
||||||
|
**self._discriminator_kwargs)
|
||||||
|
with self.test_session():
|
||||||
|
self.assertAlmostEqual(self._expected_d_loss * self._weights,
|
||||||
|
loss.eval(), 4)
|
||||||
|
|
||||||
|
def test_generator_with_scalar_tensor_weight(self):
|
||||||
|
loss = self._g_loss_fn(
|
||||||
|
weights=constant_op.constant(self._weights), **self._generator_kwargs)
|
||||||
|
with self.test_session():
|
||||||
|
self.assertAlmostEqual(self._expected_g_loss * self._weights,
|
||||||
|
loss.eval(), 4)
|
||||||
|
|
||||||
|
def test_discriminator_with_scalar_tensor_weight(self):
|
||||||
|
weights = constant_op.constant(self._weights)
|
||||||
|
loss = self._d_loss_fn(real_weights=weights, generated_weights=weights,
|
||||||
|
**self._discriminator_kwargs)
|
||||||
|
with self.test_session():
|
||||||
|
self.assertAlmostEqual(self._expected_d_loss * self._weights,
|
||||||
|
loss.eval(), 4)
|
||||||
|
|
||||||
|
def test_generator_add_summaries(self):
|
||||||
|
self.assertEqual(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
|
||||||
|
self._g_loss_fn(add_summaries=True, **self._generator_kwargs)
|
||||||
|
self.assertLess(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
|
||||||
|
|
||||||
|
def test_discriminator_add_summaries(self):
|
||||||
|
self.assertEqual(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
|
||||||
|
self._d_loss_fn(add_summaries=True, **self._discriminator_kwargs)
|
||||||
|
self.assertLess(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
|
||||||
|
|
||||||
|
|
||||||
|
class _PenaltyTest(object):
|
||||||
|
|
||||||
|
def test_all_correct(self):
|
||||||
|
loss = self._penalty_fn(**self._kwargs)
|
||||||
|
self.assertEqual(self._expected_dtype, loss.dtype)
|
||||||
|
self.assertEqual(self._expected_op_name, loss.op.name)
|
||||||
|
with self.test_session():
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
self.assertAlmostEqual(self._expected_loss, loss.eval(), 6)
|
||||||
|
|
||||||
|
def test_loss_collection(self):
|
||||||
|
self.assertEqual(0, len(ops.get_collection('collection')))
|
||||||
|
self._penalty_fn(loss_collection='collection', **self._kwargs)
|
||||||
|
self.assertEqual(1, len(ops.get_collection('collection')))
|
||||||
|
|
||||||
|
def test_no_reduction(self):
|
||||||
|
loss = self._penalty_fn(reduction=tf_losses.Reduction.NONE, **self._kwargs)
|
||||||
|
self.assertAllEqual([self._batch_size], loss.shape)
|
||||||
|
|
||||||
|
def test_python_scalar_weight(self):
|
||||||
|
loss = self._penalty_fn(weights=2.3, **self._kwargs)
|
||||||
|
with self.test_session():
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3)
|
||||||
|
|
||||||
|
def test_scalar_tensor_weight(self):
|
||||||
|
loss = self._penalty_fn(weights=constant_op.constant(2.3), **self._kwargs)
|
||||||
|
with self.test_session():
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3)
|
||||||
|
|
||||||
|
|
||||||
|
class GradientPenaltyTest(test.TestCase, _PenaltyTest):
|
||||||
|
"""Tests for wasserstein_gradient_penalty."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(GradientPenaltyTest, self).setUp()
|
||||||
|
self._penalty_fn = tfgan_losses.wasserstein_gradient_penalty
|
||||||
|
self._generated_data_np = [[3.1, 2.3, -12.3, 32.1]]
|
||||||
|
self._real_data_np = [[-12.3, 23.2, 16.3, -43.2]]
|
||||||
|
self._expected_dtype = dtypes.float32
|
||||||
|
|
||||||
|
with variable_scope.variable_scope('fake_scope') as self._scope:
|
||||||
|
self._discriminator_fn(0.0, 0.0)
|
||||||
|
|
||||||
|
self._kwargs = {
|
||||||
|
'generated_data': constant_op.constant(
|
||||||
|
self._generated_data_np, dtype=self._expected_dtype),
|
||||||
|
'real_data': constant_op.constant(
|
||||||
|
self._real_data_np, dtype=self._expected_dtype),
|
||||||
|
'generator_inputs': None,
|
||||||
|
'discriminator_fn': self._discriminator_fn,
|
||||||
|
'discriminator_scope': self._scope,
|
||||||
|
}
|
||||||
|
self._expected_loss = 9.00000
|
||||||
|
self._expected_op_name = 'weighted_loss/value'
|
||||||
|
self._batch_size = 1
|
||||||
|
|
||||||
|
def _discriminator_fn(self, inputs, _):
|
||||||
|
return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs
|
||||||
|
|
||||||
|
def test_loss_with_placeholder(self):
|
||||||
|
generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None))
|
||||||
|
real_data = array_ops.placeholder(dtypes.float32, shape=(None, None))
|
||||||
|
|
||||||
|
loss = tfgan_losses.wasserstein_gradient_penalty(
|
||||||
|
generated_data,
|
||||||
|
real_data,
|
||||||
|
self._kwargs['generator_inputs'],
|
||||||
|
self._kwargs['discriminator_fn'],
|
||||||
|
self._kwargs['discriminator_scope'])
|
||||||
|
self.assertEqual(generated_data.dtype, loss.dtype)
|
||||||
|
|
||||||
|
with self.test_session() as sess:
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
loss = sess.run(loss,
|
||||||
|
feed_dict={
|
||||||
|
generated_data: self._generated_data_np,
|
||||||
|
real_data: self._real_data_np,
|
||||||
|
})
|
||||||
|
self.assertAlmostEqual(self._expected_loss, loss, 5)
|
||||||
|
|
||||||
|
def test_reuses_scope(self):
|
||||||
|
"""Test that gradient penalty reuses discriminator scope."""
|
||||||
|
num_vars = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
|
||||||
|
tfgan_losses.wasserstein_gradient_penalty(**self._kwargs)
|
||||||
|
self.assertEqual(
|
||||||
|
num_vars, len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
|
||||||
|
|
||||||
|
|
||||||
|
class MutualInformationPenaltyTest(test.TestCase, _PenaltyTest):
|
||||||
|
"""Tests for mutual_information_penalty."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(MutualInformationPenaltyTest, self).setUp()
|
||||||
|
self._penalty_fn = tfgan_losses.mutual_information_penalty
|
||||||
|
self._structured_generator_inputs = [1.0, 2.0]
|
||||||
|
self._predicted_distributions = [categorical.Categorical(logits=[1.0, 2.0]),
|
||||||
|
normal.Normal([0.0], [1.0])]
|
||||||
|
self._expected_dtype = dtypes.float32
|
||||||
|
|
||||||
|
self._kwargs = {
|
||||||
|
'structured_generator_inputs': self._structured_generator_inputs,
|
||||||
|
'predicted_distributions': self._predicted_distributions,
|
||||||
|
}
|
||||||
|
self._expected_loss = 1.61610
|
||||||
|
self._expected_op_name = 'mul'
|
||||||
|
self._batch_size = 2
|
||||||
|
|
||||||
|
|
||||||
|
class CombineAdversarialLossTest(test.TestCase):
|
||||||
|
"""Tests for combine_adversarial_loss."""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
super(CombineAdversarialLossTest, self).setUp()
|
||||||
|
self._generated_data_np = [[3.1, 2.3, -12.3, 32.1]]
|
||||||
|
self._real_data_np = [[-12.3, 23.2, 16.3, -43.2]]
|
||||||
|
self._generated_data = constant_op.constant(
|
||||||
|
self._generated_data_np, dtype=dtypes.float32)
|
||||||
|
self._real_data = constant_op.constant(
|
||||||
|
self._real_data_np, dtype=dtypes.float32)
|
||||||
|
self._generated_inputs = None
|
||||||
|
self._expected_loss = 9.00000
|
||||||
|
|
||||||
|
def _test_correct_helper(self, use_weight_factor):
|
||||||
|
variable_list = [variables.Variable(1.0)]
|
||||||
|
main_loss = variable_list[0] * 2
|
||||||
|
adversarial_loss = variable_list[0] * 3
|
||||||
|
gradient_ratio_epsilon = 1e-6
|
||||||
|
if use_weight_factor:
|
||||||
|
weight_factor = constant_op.constant(2.0)
|
||||||
|
gradient_ratio = None
|
||||||
|
adv_coeff = 2.0
|
||||||
|
expected_loss = 1.0 * 2 + adv_coeff * 1.0 * 3
|
||||||
|
else:
|
||||||
|
weight_factor = None
|
||||||
|
gradient_ratio = constant_op.constant(0.5)
|
||||||
|
adv_coeff = 2.0 / (3 * 0.5 + gradient_ratio_epsilon)
|
||||||
|
expected_loss = 1.0 * 2 + adv_coeff * 1.0 * 3
|
||||||
|
combined_loss = tfgan_losses.combine_adversarial_loss(
|
||||||
|
main_loss,
|
||||||
|
adversarial_loss,
|
||||||
|
weight_factor=weight_factor,
|
||||||
|
gradient_ratio=gradient_ratio,
|
||||||
|
gradient_ratio_epsilon=gradient_ratio_epsilon,
|
||||||
|
variables=variable_list)
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
variables.global_variables_initializer().run()
|
||||||
|
self.assertNear(expected_loss, combined_loss.eval(), 1e-5)
|
||||||
|
|
||||||
|
def test_correct_useweightfactor(self):
|
||||||
|
self._test_correct_helper(True)
|
||||||
|
|
||||||
|
def test_correct_nouseweightfactor(self):
|
||||||
|
self._test_correct_helper(False)
|
||||||
|
|
||||||
|
def _test_no_weight_skips_adversarial_loss_helper(self, use_weight_factor):
|
||||||
|
"""Test the 0 adversarial weight or grad ratio skips adversarial loss."""
|
||||||
|
main_loss = constant_op.constant(1.0)
|
||||||
|
adversarial_loss = constant_op.constant(1.0)
|
||||||
|
|
||||||
|
weight_factor = 0.0 if use_weight_factor else None
|
||||||
|
gradient_ratio = None if use_weight_factor else 0.0
|
||||||
|
|
||||||
|
combined_loss = tfgan_losses.combine_adversarial_loss(
|
||||||
|
main_loss,
|
||||||
|
adversarial_loss,
|
||||||
|
weight_factor=weight_factor,
|
||||||
|
gradient_ratio=gradient_ratio,
|
||||||
|
gradient_summaries=False)
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
self.assertEqual(1.0, combined_loss.eval())
|
||||||
|
|
||||||
|
def test_no_weight_skips_adversarial_loss_useweightfactor(self):
|
||||||
|
self._test_no_weight_skips_adversarial_loss_helper(True)
|
||||||
|
|
||||||
|
def test_no_weight_skips_adversarial_loss_nouseweightfactor(self):
|
||||||
|
self._test_no_weight_skips_adversarial_loss_helper(False)
|
||||||
|
|
||||||
|
def test_stable_global_norm_avoids_overflow(self):
|
||||||
|
tensors = [array_ops.ones([4]), array_ops.ones([4, 4]) * 1e19, None]
|
||||||
|
gnorm_is_inf = math_ops.is_inf(clip_ops.global_norm(tensors))
|
||||||
|
stable_gnorm_is_inf = math_ops.is_inf(
|
||||||
|
tfgan_losses._numerically_stable_global_norm(tensors))
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=True):
|
||||||
|
self.assertTrue(gnorm_is_inf.eval())
|
||||||
|
self.assertFalse(stable_gnorm_is_inf.eval())
|
||||||
|
|
||||||
|
def test_stable_global_norm_unchanged(self):
|
||||||
|
"""Test that preconditioning doesn't change global norm value."""
|
||||||
|
random_seed.set_random_seed(1234)
|
||||||
|
tensors = [random_ops.random_uniform([3]*i, -10.0, 10.0) for i in range(6)]
|
||||||
|
gnorm = clip_ops.global_norm(tensors)
|
||||||
|
precond_gnorm = tfgan_losses._numerically_stable_global_norm(tensors)
|
||||||
|
|
||||||
|
with self.test_session(use_gpu=True) as sess:
|
||||||
|
for _ in range(10): # spot check closeness on more than one sample.
|
||||||
|
gnorm_np, precond_gnorm_np = sess.run([gnorm, precond_gnorm])
|
||||||
|
self.assertNear(gnorm_np, precond_gnorm_np, 1e-5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
27
tensorflow/contrib/gan/python/losses/python/losses_wargs.py
Normal file
27
tensorflow/contrib/gan/python/losses/python/losses_wargs.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# Copyright 2017 Google Inc. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""TFGAN grouped API. Please see README.md for details and usage."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# pylint: disable=wildcard-import
|
||||||
|
from tensorflow.contrib.gan.python.losses.python import losses_impl
|
||||||
|
from tensorflow.contrib.gan.python.losses.python.losses_impl import *
|
||||||
|
# pylint: enable=wildcard-import
|
||||||
|
|
||||||
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
|
|
||||||
|
remove_undocumented(__name__, losses_impl.__all__)
|
27
tensorflow/contrib/gan/python/losses/python/tuple_losses.py
Normal file
27
tensorflow/contrib/gan/python/losses/python/tuple_losses.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""TFGAN utilities for loss functions that accept GANModel namedtuples."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
# pylint: disable=wildcard-import
|
||||||
|
from tensorflow.contrib.gan.python.losses.python import tuple_losses_impl
|
||||||
|
from tensorflow.contrib.gan.python.losses.python.tuple_losses_impl import *
|
||||||
|
# pylint: enable=wildcard-import
|
||||||
|
from tensorflow.python.util.all_util import remove_undocumented
|
||||||
|
|
||||||
|
__all__ = tuple_losses_impl.__all__
|
||||||
|
remove_undocumented(__name__, __all__)
|
203
tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py
Normal file
203
tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""TFGAN utilities for loss functions that accept GANModel namedtuples.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
# `tfgan.losses.args` losses take individual arguments.
|
||||||
|
w_loss = tfgan.losses.args.wasserstein_discriminator_loss(
|
||||||
|
discriminator_real_outputs,
|
||||||
|
discriminator_gen_outputs)
|
||||||
|
|
||||||
|
# `tfgan.losses` losses take GANModel namedtuples.
|
||||||
|
w_loss2 = tfgan.losses.wasserstein_discriminator_loss(gan_model)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from tensorflow.contrib.gan.python.losses.python import losses_impl
|
||||||
|
from tensorflow.python.util import tf_inspect
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'acgan_discriminator_loss',
|
||||||
|
'acgan_generator_loss',
|
||||||
|
'least_squares_discriminator_loss',
|
||||||
|
'least_squares_generator_loss',
|
||||||
|
'modified_discriminator_loss',
|
||||||
|
'modified_generator_loss',
|
||||||
|
'minimax_discriminator_loss',
|
||||||
|
'minimax_generator_loss',
|
||||||
|
'wasserstein_discriminator_loss',
|
||||||
|
'wasserstein_generator_loss',
|
||||||
|
'wasserstein_gradient_penalty',
|
||||||
|
'mutual_information_penalty',
|
||||||
|
'combine_adversarial_loss',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _args_to_gan_model(loss_fn):
|
||||||
|
"""Converts a loss taking individual args to one taking a GANModel namedtuple.
|
||||||
|
|
||||||
|
The new function has the same name as the original one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loss_fn: A python function taking a `GANModel` object and returning a loss
|
||||||
|
Tensor calculated from that object. The shape of the loss depends on
|
||||||
|
`reduction`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new function that takes a GANModel namedtuples and returns the same loss.
|
||||||
|
"""
|
||||||
|
# Match arguments in `loss_fn` to elements of `namedtuple`.
|
||||||
|
# TODO(joelshor): Properly handle `varargs` and `keywords`.
|
||||||
|
argspec = tf_inspect.getargspec(loss_fn)
|
||||||
|
defaults = argspec.defaults or []
|
||||||
|
|
||||||
|
required_args = set(argspec.args[:-len(defaults)])
|
||||||
|
args_with_defaults = argspec.args[-len(defaults):]
|
||||||
|
default_args_dict = dict(zip(args_with_defaults, defaults))
|
||||||
|
|
||||||
|
def new_loss_fn(gan_model, **kwargs): # pylint:disable=missing-docstring
|
||||||
|
gan_model_dict = gan_model._asdict()
|
||||||
|
|
||||||
|
# Make sure non-tuple required args are supplied.
|
||||||
|
args_from_tuple = set(argspec.args).intersection(set(gan_model._fields))
|
||||||
|
required_args_not_from_tuple = required_args - args_from_tuple
|
||||||
|
for arg in required_args_not_from_tuple:
|
||||||
|
if arg not in kwargs:
|
||||||
|
raise ValueError('`%s` must be supplied to %s loss function.' % (
|
||||||
|
arg, loss_fn.__name__))
|
||||||
|
|
||||||
|
# Make sure tuple args aren't also supplied as keyword args.
|
||||||
|
ambiguous_args = set(gan_model._fields).intersection(set(kwargs.keys()))
|
||||||
|
if ambiguous_args:
|
||||||
|
raise ValueError(
|
||||||
|
'The following args are present in both the tuple and keyword args '
|
||||||
|
'for %s: %s' % (loss_fn.__name__, ambiguous_args))
|
||||||
|
|
||||||
|
# Add required args to arg dictionary.
|
||||||
|
required_args_from_tuple = required_args.intersection(args_from_tuple)
|
||||||
|
for arg in required_args_from_tuple:
|
||||||
|
assert arg not in kwargs
|
||||||
|
kwargs[arg] = gan_model_dict[arg]
|
||||||
|
|
||||||
|
# Add arguments that have defaults.
|
||||||
|
for arg in default_args_dict:
|
||||||
|
val_from_tuple = gan_model_dict[arg] if arg in gan_model_dict else None
|
||||||
|
val_from_kwargs = kwargs[arg] if arg in kwargs else None
|
||||||
|
assert not (val_from_tuple is not None and val_from_kwargs is not None)
|
||||||
|
kwargs[arg] = (val_from_tuple if val_from_tuple is not None else
|
||||||
|
val_from_kwargs if val_from_kwargs is not None else
|
||||||
|
default_args_dict[arg])
|
||||||
|
|
||||||
|
return loss_fn(**kwargs)
|
||||||
|
|
||||||
|
new_docstring = """The gan_model version of %s.""" % loss_fn.__name__
|
||||||
|
new_loss_fn.__docstring__ = new_docstring
|
||||||
|
new_loss_fn.__name__ = loss_fn.__name__
|
||||||
|
new_loss_fn.__module__ = loss_fn.__module__
|
||||||
|
return new_loss_fn
|
||||||
|
|
||||||
|
|
||||||
|
# Wasserstein losses from `Wasserstein GAN` (https://arxiv.org/abs/1701.07875).
|
||||||
|
wasserstein_generator_loss = _args_to_gan_model(
|
||||||
|
losses_impl.wasserstein_generator_loss)
|
||||||
|
wasserstein_discriminator_loss = _args_to_gan_model(
|
||||||
|
losses_impl.wasserstein_discriminator_loss)
|
||||||
|
wasserstein_gradient_penalty = _args_to_gan_model(
|
||||||
|
losses_impl.wasserstein_gradient_penalty)
|
||||||
|
|
||||||
|
# ACGAN losses from `Conditional Image Synthesis With Auxiliary Classifier GANs`
|
||||||
|
# (https://arxiv.org/abs/1610.09585).
|
||||||
|
acgan_discriminator_loss = _args_to_gan_model(
|
||||||
|
losses_impl.acgan_discriminator_loss)
|
||||||
|
acgan_generator_loss = _args_to_gan_model(
|
||||||
|
losses_impl.acgan_generator_loss)
|
||||||
|
|
||||||
|
|
||||||
|
# Original losses from `Generative Adversarial Nets`
|
||||||
|
# (https://arxiv.org/abs/1406.2661).
|
||||||
|
minimax_discriminator_loss = _args_to_gan_model(
|
||||||
|
losses_impl.minimax_discriminator_loss)
|
||||||
|
minimax_generator_loss = _args_to_gan_model(
|
||||||
|
losses_impl.minimax_generator_loss)
|
||||||
|
modified_discriminator_loss = _args_to_gan_model(
|
||||||
|
losses_impl.modified_discriminator_loss)
|
||||||
|
modified_generator_loss = _args_to_gan_model(
|
||||||
|
losses_impl.modified_generator_loss)
|
||||||
|
|
||||||
|
|
||||||
|
# Least Squares loss from `Least Squares Generative Adversarial Networks`
|
||||||
|
# (https://arxiv.org/abs/1611.04076).
|
||||||
|
least_squares_generator_loss = _args_to_gan_model(
|
||||||
|
losses_impl.least_squares_generator_loss)
|
||||||
|
least_squares_discriminator_loss = _args_to_gan_model(
|
||||||
|
losses_impl.least_squares_discriminator_loss)
|
||||||
|
|
||||||
|
|
||||||
|
# InfoGAN loss from `InfoGAN: Interpretable Representation Learning by
|
||||||
|
# `Information Maximizing Generative Adversarial Nets`
|
||||||
|
# https://arxiv.org/abs/1606.03657
|
||||||
|
mutual_information_penalty = _args_to_gan_model(
|
||||||
|
losses_impl.mutual_information_penalty)
|
||||||
|
|
||||||
|
|
||||||
|
def combine_adversarial_loss(gan_loss,
|
||||||
|
gan_model,
|
||||||
|
non_adversarial_loss,
|
||||||
|
weight_factor=None,
|
||||||
|
gradient_ratio=None,
|
||||||
|
gradient_ratio_epsilon=1e-6,
|
||||||
|
scalar_summaries=True,
|
||||||
|
gradient_summaries=True):
|
||||||
|
"""Combine adversarial loss and main loss.
|
||||||
|
|
||||||
|
Uses `combine_adversarial_loss` to combine the losses, and returns
|
||||||
|
a modified GANLoss namedtuple.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gan_loss: A GANLoss namedtuple. Assume the GANLoss.generator_loss is the
|
||||||
|
adversarial loss.
|
||||||
|
gan_model: A GANModel namedtuple. Used to access the generator's variables.
|
||||||
|
non_adversarial_loss: Same as `main_loss` from
|
||||||
|
`combine_adversarial_loss`.
|
||||||
|
weight_factor: Same as `weight_factor` from
|
||||||
|
`combine_adversarial_loss`.
|
||||||
|
gradient_ratio: Same as `gradient_ratio` from
|
||||||
|
`combine_adversarial_loss`.
|
||||||
|
gradient_ratio_epsilon: Same as `gradient_ratio_epsilon` from
|
||||||
|
`combine_adversarial_loss`.
|
||||||
|
scalar_summaries: Same as `scalar_summaries` from
|
||||||
|
`combine_adversarial_loss`.
|
||||||
|
gradient_summaries: Same as `gradient_summaries` from
|
||||||
|
`combine_adversarial_loss`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A modified GANLoss namedtuple, with `non_adversarial_loss` included
|
||||||
|
appropriately.
|
||||||
|
"""
|
||||||
|
combined_loss = losses_impl.combine_adversarial_loss(
|
||||||
|
non_adversarial_loss,
|
||||||
|
gan_loss.generator_loss,
|
||||||
|
weight_factor,
|
||||||
|
gradient_ratio,
|
||||||
|
gradient_ratio_epsilon,
|
||||||
|
gan_model.generator_variables,
|
||||||
|
scalar_summaries,
|
||||||
|
gradient_summaries)
|
||||||
|
return gan_loss._replace(generator_loss=combined_loss)
|
134
tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
Normal file
134
tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for contrib.gan.python.losses."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.contrib.gan.python.losses.python import tuple_losses_impl as tfgan_losses
|
||||||
|
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class ArgsToGanModelTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_args_to_gan_model(self):
|
||||||
|
"""Test `_args_to_gan_model`."""
|
||||||
|
tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg3'])
|
||||||
|
|
||||||
|
def args_loss(arg1, arg2, arg3=3, arg4=4):
|
||||||
|
return arg1 + arg2 + arg3 + arg4
|
||||||
|
|
||||||
|
gan_model_loss = tfgan_losses._args_to_gan_model(args_loss)
|
||||||
|
|
||||||
|
# Value is correct.
|
||||||
|
self.assertEqual(1 + 2 + 5 + 6,
|
||||||
|
gan_model_loss(tuple_type(1, 2), arg2=5, arg4=6))
|
||||||
|
|
||||||
|
# Uses tuple argument with defaults.
|
||||||
|
self.assertEqual(1 + 5 + 3 + 7,
|
||||||
|
gan_model_loss(tuple_type(1, None), arg2=5, arg4=7))
|
||||||
|
|
||||||
|
# Uses non-tuple argument with defaults.
|
||||||
|
self.assertEqual(1 + 5 + 2 + 4,
|
||||||
|
gan_model_loss(tuple_type(1, 2), arg2=5))
|
||||||
|
|
||||||
|
# Requires non-tuple, non-default arguments.
|
||||||
|
with self.assertRaisesRegexp(ValueError, '`arg2` must be supplied'):
|
||||||
|
gan_model_loss(tuple_type(1, 2))
|
||||||
|
|
||||||
|
# Can't pass tuple argument outside tuple.
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'present in both the tuple and keyword args'):
|
||||||
|
gan_model_loss(tuple_type(1, 2), arg2=1, arg3=5)
|
||||||
|
|
||||||
|
def test_args_to_gan_model_name(self):
|
||||||
|
"""Test that `_args_to_gan_model` produces correctly named functions."""
|
||||||
|
def loss_fn(x):
|
||||||
|
return x
|
||||||
|
new_loss_fn = tfgan_losses._args_to_gan_model(loss_fn)
|
||||||
|
self.assertEqual('loss_fn', new_loss_fn.__name__)
|
||||||
|
self.assertTrue('The gan_model version of' in new_loss_fn.__docstring__)
|
||||||
|
|
||||||
|
def test_tuple_respects_optional_args(self):
|
||||||
|
"""Test that optional args can be changed with tuple losses."""
|
||||||
|
tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg2'])
|
||||||
|
def args_loss(arg1, arg2, arg3=3):
|
||||||
|
return arg1 + 2 * arg2 + 3 * arg3
|
||||||
|
|
||||||
|
loss_fn = tfgan_losses._args_to_gan_model(args_loss)
|
||||||
|
loss = loss_fn(tuple_type(arg1=-1, arg2=2), arg3=4)
|
||||||
|
|
||||||
|
# If `arg3` were not set properly, this value would be different.
|
||||||
|
self.assertEqual(-1 + 2 * 2 + 3 * 4, loss)
|
||||||
|
|
||||||
|
|
||||||
|
class ConsistentLossesTest(test.TestCase):
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _tuple_from_dict(args_dict):
|
||||||
|
return collections.namedtuple('Tuple', args_dict.keys())(**args_dict)
|
||||||
|
|
||||||
|
|
||||||
|
def add_loss_consistency_test(test_class, loss_name_str, loss_args):
|
||||||
|
tuple_loss = getattr(tfgan_losses, loss_name_str)
|
||||||
|
arg_loss = getattr(tfgan_losses.losses_impl, loss_name_str)
|
||||||
|
|
||||||
|
def consistency_test(self):
|
||||||
|
self.assertEqual(arg_loss.__name__, tuple_loss.__name__)
|
||||||
|
with self.test_session():
|
||||||
|
self.assertEqual(arg_loss(**loss_args).eval(),
|
||||||
|
tuple_loss(_tuple_from_dict(loss_args)).eval())
|
||||||
|
|
||||||
|
test_name = 'test_loss_consistency_%s' % loss_name_str
|
||||||
|
setattr(test_class, test_name, consistency_test)
|
||||||
|
|
||||||
|
|
||||||
|
# A list of consistency tests which need to be manually written.
|
||||||
|
manual_tests = [
|
||||||
|
'acgan_discriminator_loss',
|
||||||
|
'acgan_generator_loss',
|
||||||
|
'combine_adversarial_loss',
|
||||||
|
'mutual_information_penalty',
|
||||||
|
'wasserstein_gradient_penalty',
|
||||||
|
]
|
||||||
|
|
||||||
|
discriminator_keyword_args = {
|
||||||
|
'discriminator_real_outputs': np.array([[3.4, 2.3, -2.3],
|
||||||
|
[6.3, -2.1, 0.2]]),
|
||||||
|
'discriminator_gen_outputs': np.array([[6.2, -1.5, 2.3],
|
||||||
|
[-2.9, -5.1, 0.1]]),
|
||||||
|
}
|
||||||
|
generator_keyword_args = {
|
||||||
|
'discriminator_gen_outputs': np.array([[6.2, -1.5, 2.3],
|
||||||
|
[-2.9, -5.1, 0.1]]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
for loss_name in tfgan_losses.__all__:
|
||||||
|
if loss_name in manual_tests: continue
|
||||||
|
keyword_args = (generator_keyword_args if 'generator' in loss_name else
|
||||||
|
discriminator_keyword_args)
|
||||||
|
add_loss_consistency_test(ConsistentLossesTest, loss_name, keyword_args)
|
||||||
|
|
||||||
|
test.main()
|
@ -127,7 +127,6 @@ py_test(
|
|||||||
name = "sdca_estimator_test",
|
name = "sdca_estimator_test",
|
||||||
srcs = ["python/sdca_estimator_test.py"],
|
srcs = ["python/sdca_estimator_test.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
tags = ["notsan"],
|
|
||||||
deps = [
|
deps = [
|
||||||
":sdca_estimator_py",
|
":sdca_estimator_py",
|
||||||
"//tensorflow/contrib/layers:layers_py",
|
"//tensorflow/contrib/layers:layers_py",
|
||||||
|
@ -61,6 +61,7 @@ tf_kernel_library(
|
|||||||
srcs = ["kernels/hyperplane_lsh_probes.cc"],
|
srcs = ["kernels/hyperplane_lsh_probes.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":hyperplane_lsh_probes",
|
":hyperplane_lsh_probes",
|
||||||
|
":nearest_neighbor_ops_op_lib",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
|
@ -20,6 +20,7 @@
|
|||||||
@@ARModel
|
@@ARModel
|
||||||
|
|
||||||
@@CSVReader
|
@@CSVReader
|
||||||
|
@@NumpyReader
|
||||||
@@RandomWindowInputFn
|
@@RandomWindowInputFn
|
||||||
@@WholeDatasetInputFn
|
@@WholeDatasetInputFn
|
||||||
@@predict_continuation_input_fn
|
@@predict_continuation_input_fn
|
||||||
|
@ -1788,6 +1788,7 @@ tf_cuda_library(
|
|||||||
"common_runtime/process_util.cc",
|
"common_runtime/process_util.cc",
|
||||||
"common_runtime/renamed_device.cc",
|
"common_runtime/renamed_device.cc",
|
||||||
"common_runtime/rendezvous_mgr.cc",
|
"common_runtime/rendezvous_mgr.cc",
|
||||||
|
"common_runtime/rendezvous_util.cc",
|
||||||
"common_runtime/resource_variable_read_optimizer.cc",
|
"common_runtime/resource_variable_read_optimizer.cc",
|
||||||
"common_runtime/session.cc",
|
"common_runtime/session.cc",
|
||||||
"common_runtime/session_factory.cc",
|
"common_runtime/session_factory.cc",
|
||||||
@ -1831,6 +1832,7 @@ tf_cuda_library(
|
|||||||
"common_runtime/profile_handler.h",
|
"common_runtime/profile_handler.h",
|
||||||
"common_runtime/renamed_device.h",
|
"common_runtime/renamed_device.h",
|
||||||
"common_runtime/rendezvous_mgr.h",
|
"common_runtime/rendezvous_mgr.h",
|
||||||
|
"common_runtime/rendezvous_util.h",
|
||||||
"common_runtime/session_factory.h",
|
"common_runtime/session_factory.h",
|
||||||
"common_runtime/graph_execution_state.h",
|
"common_runtime/graph_execution_state.h",
|
||||||
"common_runtime/placer.h",
|
"common_runtime/placer.h",
|
||||||
@ -2675,29 +2677,29 @@ tf_cc_test(
|
|||||||
srcs = ["common_runtime/process_function_library_runtime_test.cc"],
|
srcs = ["common_runtime/process_function_library_runtime_test.cc"],
|
||||||
linkstatic = tf_kernel_tests_linkstatic(),
|
linkstatic = tf_kernel_tests_linkstatic(),
|
||||||
deps = [
|
deps = [
|
||||||
":core",
|
|
||||||
":core_cpu",
|
":core_cpu",
|
||||||
":core_cpu_internal",
|
":core_cpu_internal",
|
||||||
":direct_session_internal",
|
|
||||||
":framework",
|
":framework",
|
||||||
":framework_internal",
|
|
||||||
":lib",
|
|
||||||
":lib_internal",
|
|
||||||
":ops",
|
|
||||||
":protos_all_cc",
|
|
||||||
":test",
|
":test",
|
||||||
":test_main",
|
":test_main",
|
||||||
":testlib",
|
":testlib",
|
||||||
"//tensorflow/cc:cc_ops",
|
|
||||||
"//tensorflow/cc:cc_ops_internal",
|
|
||||||
"//tensorflow/cc:function_ops",
|
"//tensorflow/cc:function_ops",
|
||||||
"//tensorflow/cc:functional_ops",
|
|
||||||
"//tensorflow/core/kernels:cast_op",
|
"//tensorflow/core/kernels:cast_op",
|
||||||
"//tensorflow/core/kernels:cwise_op",
|
"//tensorflow/core/kernels:cwise_op",
|
||||||
"//tensorflow/core/kernels:function_ops",
|
"//tensorflow/core/kernels:function_ops",
|
||||||
"//tensorflow/core/kernels:matmul_op",
|
],
|
||||||
"//tensorflow/core/kernels:shape_ops",
|
)
|
||||||
"//third_party/eigen3",
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "common_runtime_rendezvous_util_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["common_runtime/rendezvous_util_test.cc"],
|
||||||
|
linkstatic = tf_kernel_tests_linkstatic(),
|
||||||
|
deps = [
|
||||||
|
":core_cpu_internal",
|
||||||
|
":lib",
|
||||||
|
":test",
|
||||||
|
":test_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -213,6 +213,9 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
|
|||||||
FunctionBody** g_body);
|
FunctionBody** g_body);
|
||||||
bool IsLocalTarget(const AttrSlice& attrs);
|
bool IsLocalTarget(const AttrSlice& attrs);
|
||||||
AttrValueMap FixAttrs(const AttrSlice& attrs);
|
AttrValueMap FixAttrs(const AttrSlice& attrs);
|
||||||
|
void RunRemote(const Options& opts, Handle handle,
|
||||||
|
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
|
||||||
|
Executor::Args* exec_args, Item* item, DoneCallback done);
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
|
TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
|
||||||
};
|
};
|
||||||
@ -557,52 +560,130 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
|
||||||
|
gtl::ArraySlice<Tensor> args,
|
||||||
|
std::vector<Tensor>* rets,
|
||||||
|
Executor::Args* exec_args,
|
||||||
|
Item* item, DoneCallback done) {
|
||||||
|
FunctionCallFrame* frame = exec_args->call_frame;
|
||||||
|
string target_device = parent_->GetDeviceName(handle);
|
||||||
|
string source_device = opts.source_device;
|
||||||
|
Rendezvous* rendezvous = opts.rendezvous;
|
||||||
|
// TODO(rohanj): Handle alloc_attrs in Rendezvous::Args.
|
||||||
|
Rendezvous::Args rendez_args;
|
||||||
|
Status s =
|
||||||
|
parent_->GetDeviceContext(target_device, &rendez_args.device_context);
|
||||||
|
if (!s.ok()) {
|
||||||
|
delete frame;
|
||||||
|
delete exec_args;
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The ProcFLR sends the arguments to the function from the source_device to
|
||||||
|
// the target_device. So here we receive those arguments. Similarly, when the
|
||||||
|
// computation is done and stored in *rets, we send the return values back
|
||||||
|
// to the source_device (caller) so that the ProcFLR can receive them later.
|
||||||
|
std::vector<Tensor>* remote_args = new std::vector<Tensor>;
|
||||||
|
ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
|
||||||
|
source_device, target_device, "arg_", args.size(), rendez_args,
|
||||||
|
rendezvous, remote_args,
|
||||||
|
[frame, remote_args, item, source_device, target_device, rendezvous,
|
||||||
|
rendez_args, rets, done, exec_args](const Status& status) {
|
||||||
|
Status s = status;
|
||||||
|
s = frame->SetArgs(*remote_args);
|
||||||
|
if (!s.ok()) {
|
||||||
|
delete frame;
|
||||||
|
delete remote_args;
|
||||||
|
delete exec_args;
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
item->exec->RunAsync(
|
||||||
|
*exec_args,
|
||||||
|
[item, frame, rets, done, source_device, target_device, rendezvous,
|
||||||
|
rendez_args, remote_args, exec_args](const Status& status) {
|
||||||
|
item->Unref();
|
||||||
|
Status s = status;
|
||||||
|
if (s.ok()) {
|
||||||
|
s = frame->ConsumeRetvals(rets);
|
||||||
|
}
|
||||||
|
delete frame;
|
||||||
|
if (!s.ok()) {
|
||||||
|
delete remote_args;
|
||||||
|
delete exec_args;
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
s = ProcessFunctionLibraryRuntime::SendTensors(
|
||||||
|
target_device, source_device, "ret_", *rets, rendez_args,
|
||||||
|
rendezvous);
|
||||||
|
delete remote_args;
|
||||||
|
delete exec_args;
|
||||||
|
done(s);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
|
void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
|
||||||
gtl::ArraySlice<Tensor> args,
|
gtl::ArraySlice<Tensor> args,
|
||||||
std::vector<Tensor>* rets,
|
std::vector<Tensor>* rets,
|
||||||
DoneCallback done) {
|
DoneCallback done) {
|
||||||
if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
|
if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
|
||||||
return done(errors::Cancelled(""));
|
done(errors::Cancelled(""));
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
|
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
|
||||||
return parent_->Run(opts, handle, args, rets, done);
|
parent_->Run(opts, handle, args, rets, done);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
const FunctionBody* fbody = GetFunctionBody(handle);
|
const FunctionBody* fbody = GetFunctionBody(handle);
|
||||||
FunctionCallFrame* frame =
|
FunctionCallFrame* frame =
|
||||||
new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
|
new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
|
||||||
Status s = frame->SetArgs(args);
|
|
||||||
if (!s.ok()) {
|
|
||||||
delete frame;
|
|
||||||
return done(s);
|
|
||||||
}
|
|
||||||
Item* item = nullptr;
|
Item* item = nullptr;
|
||||||
s = GetOrCreateItem(handle, &item);
|
Status s = GetOrCreateItem(handle, &item);
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
delete frame;
|
delete frame;
|
||||||
return done(s);
|
done(s);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
DCHECK(opts.runner != nullptr);
|
DCHECK(opts.runner != nullptr);
|
||||||
|
|
||||||
Executor::Args exec_args;
|
Executor::Args* exec_args = new Executor::Args;
|
||||||
// Inherit the step_id from the caller.
|
// Inherit the step_id from the caller.
|
||||||
exec_args.step_id = opts.step_id;
|
exec_args->step_id = opts.step_id;
|
||||||
exec_args.rendezvous = opts.rendezvous;
|
exec_args->rendezvous = opts.rendezvous;
|
||||||
exec_args.stats_collector = opts.stats_collector;
|
exec_args->stats_collector = opts.stats_collector;
|
||||||
exec_args.call_frame = frame;
|
exec_args->call_frame = frame;
|
||||||
exec_args.cancellation_manager = opts.cancellation_manager;
|
exec_args->cancellation_manager = opts.cancellation_manager;
|
||||||
exec_args.step_container = opts.step_container;
|
exec_args->step_container = opts.step_container;
|
||||||
exec_args.runner = *opts.runner;
|
exec_args->runner = *opts.runner;
|
||||||
|
|
||||||
|
if (opts.remote_execution) {
|
||||||
|
RunRemote(opts, handle, args, rets, exec_args, item, done);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
s = frame->SetArgs(args);
|
||||||
|
if (!s.ok()) {
|
||||||
|
delete frame;
|
||||||
|
delete exec_args;
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
item->exec->RunAsync(
|
item->exec->RunAsync(
|
||||||
// Executor args
|
// Executor args
|
||||||
exec_args,
|
*exec_args,
|
||||||
// Done callback.
|
// Done callback.
|
||||||
[item, frame, rets, done](const Status& status) {
|
[item, frame, rets, done, exec_args](const Status& status) {
|
||||||
item->Unref();
|
item->Unref();
|
||||||
Status s = status;
|
Status s = status;
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
s = frame->GetRetvals(rets);
|
s = frame->ConsumeRetvals(rets);
|
||||||
}
|
}
|
||||||
delete frame;
|
delete frame;
|
||||||
|
delete exec_args;
|
||||||
done(s);
|
done(s);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/common_runtime/executor.h"
|
#include "tensorflow/core/common_runtime/executor.h"
|
||||||
#include "tensorflow/core/common_runtime/function_testlib.h"
|
#include "tensorflow/core/common_runtime/function_testlib.h"
|
||||||
|
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/framework/function_testlib.h"
|
#include "tensorflow/core/framework/function_testlib.h"
|
||||||
#include "tensorflow/core/framework/op.h"
|
#include "tensorflow/core/framework/op.h"
|
||||||
@ -155,6 +156,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
|
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
|
||||||
|
FunctionLibraryRuntime::Options opts,
|
||||||
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
|
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
|
||||||
std::atomic<int32> call_count(0);
|
std::atomic<int32> call_count(0);
|
||||||
std::function<void(std::function<void()>)> runner =
|
std::function<void(std::function<void()>)> runner =
|
||||||
@ -164,7 +166,6 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
};
|
};
|
||||||
|
|
||||||
Notification done;
|
Notification done;
|
||||||
FunctionLibraryRuntime::Options opts;
|
|
||||||
opts.runner = &runner;
|
opts.runner = &runner;
|
||||||
std::vector<Tensor> out;
|
std::vector<Tensor> out;
|
||||||
Status status;
|
Status status;
|
||||||
@ -205,7 +206,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
return Run(flr, handle, args, std::move(rets));
|
FunctionLibraryRuntime::Options opts;
|
||||||
|
return Run(flr, handle, opts, args, std::move(rets));
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr,
|
std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr,
|
||||||
@ -963,15 +965,21 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
|
|||||||
{{"_target", "/job:localhost/replica:0/task:0/cpu:1"}}, &handle));
|
{{"_target", "/job:localhost/replica:0/task:0/cpu:1"}}, &handle));
|
||||||
|
|
||||||
Tensor y;
|
Tensor y;
|
||||||
|
FunctionLibraryRuntime::Options opts;
|
||||||
|
opts.rendezvous = new IntraProcessRendezvous(device_mgr_.get());
|
||||||
|
opts.source_device = "/device:CPU:1";
|
||||||
// Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
|
// Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
|
||||||
TF_CHECK_OK(Run(flr1_, handle, {}, {&y}));
|
TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}));
|
||||||
test::ExpectTensorEqual<string>(
|
test::ExpectTensorEqual<string>(
|
||||||
y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"},
|
y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"},
|
||||||
TensorShape({})));
|
TensorShape({})));
|
||||||
TF_CHECK_OK(Run(flr2_, handle, {}, {&y}));
|
opts.remote_execution = true;
|
||||||
|
opts.source_device = "/job:localhost/replica:0/task:0/cpu:2";
|
||||||
|
TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y}));
|
||||||
test::ExpectTensorEqual<string>(
|
test::ExpectTensorEqual<string>(
|
||||||
y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"},
|
y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"},
|
||||||
TensorShape({})));
|
TensorShape({})));
|
||||||
|
opts.rendezvous->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
|||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "tensorflow/core/common_runtime/function.h"
|
#include "tensorflow/core/common_runtime/function.h"
|
||||||
|
#include "tensorflow/core/common_runtime/rendezvous_util.h"
|
||||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -57,6 +58,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */
|
||||||
string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
|
string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
|
||||||
const AttrSlice& attrs) {
|
const AttrSlice& attrs) {
|
||||||
const AttrValue* value;
|
const AttrValue* value;
|
||||||
@ -66,6 +68,63 @@ string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
|
|||||||
return value->s();
|
return value->s();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* static */
|
||||||
|
Status ProcessFunctionLibraryRuntime::SendTensors(
|
||||||
|
const string& source_device, const string& target_device,
|
||||||
|
const string& key_prefix, gtl::ArraySlice<Tensor> tensors_to_send,
|
||||||
|
const Rendezvous::Args& args, Rendezvous* rendezvous) {
|
||||||
|
std::vector<string> keys;
|
||||||
|
for (int i = 0; i < tensors_to_send.size(); ++i) {
|
||||||
|
string name = strings::StrCat(key_prefix, i);
|
||||||
|
string key = Rendezvous::CreateKey(source_device, i, target_device, name,
|
||||||
|
FrameAndIter(0, 0));
|
||||||
|
keys.push_back(key);
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
SendTensorsToRendezvous(rendezvous, args, keys, tensors_to_send));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* static */
|
||||||
|
void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
|
||||||
|
const string& source_device, const string& target_device,
|
||||||
|
const string& key_prefix, int64 num_tensors, const Rendezvous::Args& args,
|
||||||
|
Rendezvous* rendezvous, std::vector<Tensor>* received_tensors,
|
||||||
|
const StatusCallback& done) {
|
||||||
|
std::vector<string> keys;
|
||||||
|
for (int64 i = 0; i < num_tensors; ++i) {
|
||||||
|
string name = strings::StrCat(key_prefix, i);
|
||||||
|
string key = Rendezvous::CreateKey(source_device, i, target_device, name,
|
||||||
|
FrameAndIter(0, 0));
|
||||||
|
keys.push_back(key);
|
||||||
|
}
|
||||||
|
RecvOutputsFromRendezvousAsync(
|
||||||
|
rendezvous, args, keys, received_tensors,
|
||||||
|
[done](const Status& status) { done(status); });
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ProcessFunctionLibraryRuntime::GetDeviceContext(
|
||||||
|
const string& device_name, DeviceContext** device_context) {
|
||||||
|
*device_context = nullptr;
|
||||||
|
FunctionLibraryRuntime* flr = GetFLR(device_name);
|
||||||
|
if (flr == nullptr) {
|
||||||
|
return errors::InvalidArgument("Device name: ", device_name, " not found.");
|
||||||
|
}
|
||||||
|
Device* device = flr->device();
|
||||||
|
string device_type = device->parsed_name().type;
|
||||||
|
if (device_type == "CPU") return Status::OK();
|
||||||
|
if (device_type == "GPU") {
|
||||||
|
auto* dev_info = flr->device()->tensorflow_gpu_device_info();
|
||||||
|
if (dev_info) {
|
||||||
|
*device_context = dev_info->default_context;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errors::Internal("Device type: ", device_type,
|
||||||
|
" is currently unsupported for remote ",
|
||||||
|
"function executions");
|
||||||
|
}
|
||||||
|
|
||||||
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
|
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
|
||||||
const string& device_name) {
|
const string& device_name) {
|
||||||
if (flr_map_.find(device_name) == flr_map_.end()) {
|
if (flr_map_.find(device_name) == flr_map_.end()) {
|
||||||
@ -105,6 +164,7 @@ FunctionLibraryRuntime::LocalHandle
|
|||||||
ProcessFunctionLibraryRuntime::GetHandleOnDevice(
|
ProcessFunctionLibraryRuntime::GetHandleOnDevice(
|
||||||
const string& device_name, FunctionLibraryRuntime::Handle handle) {
|
const string& device_name, FunctionLibraryRuntime::Handle handle) {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
|
CHECK_LE(handle, function_data_.size());
|
||||||
std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
|
std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
|
||||||
function_data_[handle];
|
function_data_[handle];
|
||||||
if (p.first != device_name) {
|
if (p.first != device_name) {
|
||||||
@ -113,6 +173,15 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice(
|
|||||||
return p.second;
|
return p.second;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string ProcessFunctionLibraryRuntime::GetDeviceName(
|
||||||
|
FunctionLibraryRuntime::Handle handle) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
CHECK_LE(handle, function_data_.size());
|
||||||
|
std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
|
||||||
|
function_data_[handle];
|
||||||
|
return p.first;
|
||||||
|
}
|
||||||
|
|
||||||
Status ProcessFunctionLibraryRuntime::Instantiate(
|
Status ProcessFunctionLibraryRuntime::Instantiate(
|
||||||
const string& function_name, AttrSlice attrs,
|
const string& function_name, AttrSlice attrs,
|
||||||
FunctionLibraryRuntime::Handle* handle) {
|
FunctionLibraryRuntime::Handle* handle) {
|
||||||
@ -129,15 +198,58 @@ void ProcessFunctionLibraryRuntime::Run(
|
|||||||
const FunctionLibraryRuntime::Options& opts,
|
const FunctionLibraryRuntime::Options& opts,
|
||||||
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
|
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
|
||||||
std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
|
std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
|
||||||
|
if (!opts.remote_execution) {
|
||||||
|
done(errors::InvalidArgument(
|
||||||
|
"ProcessFunctionLibraryRuntime::Run should only be called when there ",
|
||||||
|
"is a remote execution."));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
FunctionLibraryRuntime* flr = nullptr;
|
FunctionLibraryRuntime* flr = nullptr;
|
||||||
|
string target_device;
|
||||||
{
|
{
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
|
CHECK_LE(handle, function_data_.size());
|
||||||
std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
|
std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
|
||||||
function_data_[handle];
|
function_data_[handle];
|
||||||
|
target_device = p.first;
|
||||||
flr = GetFLR(p.first);
|
flr = GetFLR(p.first);
|
||||||
}
|
}
|
||||||
if (flr != nullptr) {
|
if (flr != nullptr) {
|
||||||
return flr->Run(opts, handle, args, rets, std::move(done));
|
auto rendezvous = opts.rendezvous;
|
||||||
|
string source_device = opts.source_device;
|
||||||
|
Rendezvous::Args rendez_args;
|
||||||
|
Status s = GetDeviceContext(source_device, &rendez_args.device_context);
|
||||||
|
if (!s.ok()) {
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// Send the args over to the target device.
|
||||||
|
s = SendTensors(source_device, target_device, "arg_", args, rendez_args,
|
||||||
|
rendezvous);
|
||||||
|
if (!s.ok()) {
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
|
||||||
|
flr->Run(opts, handle, args, remote_rets,
|
||||||
|
[source_device, target_device, rendezvous, remote_rets, rets, done,
|
||||||
|
rendez_args](const Status& status) {
|
||||||
|
if (!status.ok()) {
|
||||||
|
delete remote_rets;
|
||||||
|
done(status);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
int64 num_returns = remote_rets->size();
|
||||||
|
delete remote_rets;
|
||||||
|
// Now receive the return values from the target.
|
||||||
|
ReceiveTensorsAsync(target_device, source_device, "ret_",
|
||||||
|
num_returns, rendez_args, rendezvous, rets,
|
||||||
|
done);
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
done(errors::Internal("Could not find device"));
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,6 +45,31 @@ class ProcessFunctionLibraryRuntime {
|
|||||||
// attribute, returns "". Canonicalizes the device name.
|
// attribute, returns "". Canonicalizes the device name.
|
||||||
static string ObtainFunctionTarget(const AttrSlice& attrs);
|
static string ObtainFunctionTarget(const AttrSlice& attrs);
|
||||||
|
|
||||||
|
// Sends `tensors_to_send` from `source_device` to `target_device` using
|
||||||
|
// `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the
|
||||||
|
// Rendezvous. Method takes references on each of the `tensors_to_send`.
|
||||||
|
// Method doesn't block.
|
||||||
|
static Status SendTensors(const string& source_device,
|
||||||
|
const string& target_device,
|
||||||
|
const string& key_prefix,
|
||||||
|
gtl::ArraySlice<Tensor> tensors_to_send,
|
||||||
|
const Rendezvous::Args& args,
|
||||||
|
Rendezvous* rendezvous);
|
||||||
|
|
||||||
|
typedef std::function<void(const Status&)> StatusCallback;
|
||||||
|
|
||||||
|
// Receives `received_tensors` from `target_device` (originally sent from
|
||||||
|
// `source_device`) using `rendezvous`. Uses `key_prefix` to construct the
|
||||||
|
// keys to be retrieved. Method doesn't block and calls `done` when
|
||||||
|
// `num_tensors` are fetched.
|
||||||
|
static void ReceiveTensorsAsync(const string& source_device,
|
||||||
|
const string& target_device,
|
||||||
|
const string& key_prefix, int64 num_tensors,
|
||||||
|
const Rendezvous::Args& args,
|
||||||
|
Rendezvous* rendezvous,
|
||||||
|
std::vector<Tensor>* received_tensors,
|
||||||
|
const StatusCallback& done);
|
||||||
|
|
||||||
static const char kDefaultFLRDevice[];
|
static const char kDefaultFLRDevice[];
|
||||||
// Returns the FunctionLibraryRuntime for the corresponding device_name.
|
// Returns the FunctionLibraryRuntime for the corresponding device_name.
|
||||||
FunctionLibraryRuntime* GetFLR(const string& device_name);
|
FunctionLibraryRuntime* GetFLR(const string& device_name);
|
||||||
@ -85,6 +110,17 @@ class ProcessFunctionLibraryRuntime {
|
|||||||
FunctionLibraryRuntime::DoneCallback done);
|
FunctionLibraryRuntime::DoneCallback done);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// For a given device_name, returns a DeviceContext for copying
|
||||||
|
// tensors to/from the device.
|
||||||
|
Status GetDeviceContext(const string& device_name,
|
||||||
|
DeviceContext** device_context);
|
||||||
|
|
||||||
|
// Looks up the information for the given `handle` and returns the name
|
||||||
|
// of the device where the function is registered.
|
||||||
|
string GetDeviceName(FunctionLibraryRuntime::Handle handle);
|
||||||
|
|
||||||
|
friend class FunctionLibraryRuntimeImpl;
|
||||||
|
|
||||||
mutable mutex mu_;
|
mutable mutex mu_;
|
||||||
|
|
||||||
// Holds all the function invocations here.
|
// Holds all the function invocations here.
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
#include "tensorflow/core/common_runtime/function_testlib.h"
|
#include "tensorflow/core/common_runtime/function_testlib.h"
|
||||||
|
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
|
||||||
#include "tensorflow/core/framework/function_testlib.h"
|
#include "tensorflow/core/framework/function_testlib.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
@ -43,10 +44,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
proc_flr_.reset(new ProcessFunctionLibraryRuntime(
|
proc_flr_.reset(new ProcessFunctionLibraryRuntime(
|
||||||
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
|
||||||
opts));
|
opts));
|
||||||
|
rendezvous_ = new IntraProcessRendezvous(device_mgr_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
Status Run(const string& name, test::function::Attrs attrs,
|
Status Run(const string& name, FunctionLibraryRuntime::Options opts,
|
||||||
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
|
test::function::Attrs attrs, const std::vector<Tensor>& args,
|
||||||
|
std::vector<Tensor*> rets) {
|
||||||
FunctionLibraryRuntime::Handle handle;
|
FunctionLibraryRuntime::Handle handle;
|
||||||
Status status = proc_flr_->Instantiate(name, attrs, &handle);
|
Status status = proc_flr_->Instantiate(name, attrs, &handle);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
@ -61,7 +64,6 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
};
|
};
|
||||||
|
|
||||||
Notification done;
|
Notification done;
|
||||||
FunctionLibraryRuntime::Options opts;
|
|
||||||
opts.runner = &runner;
|
opts.runner = &runner;
|
||||||
std::vector<Tensor> out;
|
std::vector<Tensor> out;
|
||||||
proc_flr_->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
|
proc_flr_->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
|
||||||
@ -86,6 +88,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
|
|||||||
std::unique_ptr<DeviceMgr> device_mgr_;
|
std::unique_ptr<DeviceMgr> device_mgr_;
|
||||||
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
|
||||||
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
|
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
|
||||||
|
IntraProcessRendezvous* rendezvous_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
|
TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
|
||||||
@ -99,6 +102,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
|
|||||||
EXPECT_EQ(flr->device(), devices_[1]);
|
EXPECT_EQ(flr->device(), devices_[1]);
|
||||||
flr = proc_flr_->GetFLR("abc");
|
flr = proc_flr_->GetFLR("abc");
|
||||||
EXPECT_EQ(flr, nullptr);
|
EXPECT_EQ(flr, nullptr);
|
||||||
|
rendezvous_->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) {
|
TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) {
|
||||||
@ -118,69 +122,94 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) {
|
|||||||
|
|
||||||
TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) {
|
TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) {
|
||||||
Init({test::function::XTimesTwo()});
|
Init({test::function::XTimesTwo()});
|
||||||
|
FunctionLibraryRuntime::Options opts;
|
||||||
|
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||||
|
opts.rendezvous = rendezvous_;
|
||||||
|
opts.remote_execution = true;
|
||||||
auto x = test::AsTensor<float>({1, 2, 3, 4});
|
auto x = test::AsTensor<float>({1, 2, 3, 4});
|
||||||
Tensor y;
|
Tensor y;
|
||||||
TF_CHECK_OK(
|
TF_CHECK_OK(
|
||||||
Run("XTimesTwo",
|
Run("XTimesTwo", opts,
|
||||||
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
|
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
|
||||||
{&y}));
|
{&y}));
|
||||||
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
|
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
|
||||||
|
rendezvous_->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) {
|
TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) {
|
||||||
Init({test::function::FindDevice()});
|
Init({test::function::FindDevice()});
|
||||||
|
FunctionLibraryRuntime::Options opts;
|
||||||
|
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||||
|
opts.rendezvous = rendezvous_;
|
||||||
|
opts.remote_execution = true;
|
||||||
Tensor y;
|
Tensor y;
|
||||||
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:0"}},
|
TF_CHECK_OK(Run("FindDevice", opts,
|
||||||
{}, {&y}));
|
{{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y}));
|
||||||
test::ExpectTensorEqual<string>(
|
test::ExpectTensorEqual<string>(
|
||||||
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:0"},
|
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:0"},
|
||||||
TensorShape({})));
|
TensorShape({})));
|
||||||
|
rendezvous_->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) {
|
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) {
|
||||||
Init({test::function::XTimesTwo(), test::function::XTimesFour()});
|
Init({test::function::XTimesTwo(), test::function::XTimesFour()});
|
||||||
auto x = test::AsTensor<float>({1, 2, 3, 4});
|
auto x = test::AsTensor<float>({1, 2, 3, 4});
|
||||||
|
FunctionLibraryRuntime::Options opts;
|
||||||
|
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||||
|
opts.rendezvous = rendezvous_;
|
||||||
|
opts.remote_execution = true;
|
||||||
Tensor y;
|
Tensor y;
|
||||||
TF_CHECK_OK(
|
TF_CHECK_OK(
|
||||||
Run("XTimesTwo",
|
Run("XTimesTwo", opts,
|
||||||
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
|
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
|
||||||
{&y}));
|
{&y}));
|
||||||
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
|
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
|
||||||
TF_CHECK_OK(
|
TF_CHECK_OK(
|
||||||
Run("XTimesFour",
|
Run("XTimesFour", opts,
|
||||||
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
|
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
|
||||||
{&y}));
|
{&y}));
|
||||||
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
|
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
|
||||||
|
rendezvous_->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) {
|
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) {
|
||||||
Init({test::function::FindDevice()});
|
Init({test::function::FindDevice()});
|
||||||
|
FunctionLibraryRuntime::Options opts;
|
||||||
|
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||||
|
opts.rendezvous = rendezvous_;
|
||||||
|
opts.remote_execution = true;
|
||||||
Tensor y;
|
Tensor y;
|
||||||
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}},
|
TF_CHECK_OK(Run("FindDevice", opts,
|
||||||
{}, {&y}));
|
{{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
|
||||||
test::ExpectTensorEqual<string>(
|
test::ExpectTensorEqual<string>(
|
||||||
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
|
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
|
||||||
TensorShape({})));
|
TensorShape({})));
|
||||||
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}},
|
TF_CHECK_OK(Run("FindDevice", opts,
|
||||||
{}, {&y}));
|
{{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
|
||||||
test::ExpectTensorEqual<string>(
|
test::ExpectTensorEqual<string>(
|
||||||
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
|
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
|
||||||
TensorShape({})));
|
TensorShape({})));
|
||||||
|
rendezvous_->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
|
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
|
||||||
Init({test::function::FindDevice()});
|
Init({test::function::FindDevice()});
|
||||||
|
FunctionLibraryRuntime::Options opts;
|
||||||
|
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
|
||||||
|
opts.rendezvous = rendezvous_;
|
||||||
|
opts.remote_execution = true;
|
||||||
Tensor y;
|
Tensor y;
|
||||||
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:0"}},
|
TF_CHECK_OK(Run("FindDevice", opts,
|
||||||
{}, {&y}));
|
{{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y}));
|
||||||
test::ExpectTensorEqual<string>(
|
test::ExpectTensorEqual<string>(
|
||||||
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:0"},
|
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:0"},
|
||||||
TensorShape({})));
|
TensorShape({})));
|
||||||
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}},
|
TF_CHECK_OK(Run("FindDevice", opts,
|
||||||
{}, {&y}));
|
{{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
|
||||||
test::ExpectTensorEqual<string>(
|
test::ExpectTensorEqual<string>(
|
||||||
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
|
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
|
||||||
TensorShape({})));
|
TensorShape({})));
|
||||||
|
rendezvous_->Unref();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
119
tensorflow/core/common_runtime/rendezvous_util.cc
Normal file
119
tensorflow/core/common_runtime/rendezvous_util.cc
Normal file
@ -0,0 +1,119 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/common_runtime/rendezvous_util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
Status SendTensorsToRendezvous(Rendezvous* rendezvous,
|
||||||
|
const Rendezvous::Args& args,
|
||||||
|
const std::vector<string>& keys,
|
||||||
|
gtl::ArraySlice<Tensor> tensors_to_send) {
|
||||||
|
if (keys.size() != tensors_to_send.size()) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"keys and tensors_to_send are not the same size. keys.size() = ",
|
||||||
|
keys.size(), "; tensors_to_send.size() = ", tensors_to_send.size());
|
||||||
|
}
|
||||||
|
Rendezvous::ParsedKey parsed;
|
||||||
|
for (int i = 0; i < keys.size(); ++i) {
|
||||||
|
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(keys[i], &parsed));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
rendezvous->Send(parsed, args, tensors_to_send[i], false));
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
|
||||||
|
const Rendezvous::Args& args,
|
||||||
|
const std::vector<string>& keys,
|
||||||
|
std::vector<Tensor>* received_tensors,
|
||||||
|
const StatusCallback& done) {
|
||||||
|
if (keys.empty()) {
|
||||||
|
done(Status::OK());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
received_tensors->reserve(keys.size());
|
||||||
|
std::vector<std::tuple<string, Tensor*, Rendezvous::ParsedKey>> arguments;
|
||||||
|
for (int i = 0; i < keys.size(); ++i) {
|
||||||
|
Rendezvous::ParsedKey parsed;
|
||||||
|
Status s = Rendezvous::ParseKey(keys[i], &parsed);
|
||||||
|
received_tensors->push_back(Tensor());
|
||||||
|
if (!s.ok()) {
|
||||||
|
done(s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
arguments.push_back(
|
||||||
|
std::make_tuple(keys[i], &((*received_tensors)[i]), parsed));
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
mutex mu;
|
||||||
|
int64 done_counter;
|
||||||
|
Status shared_status = Status::OK();
|
||||||
|
} CallState;
|
||||||
|
CallState* call_state = new CallState;
|
||||||
|
call_state->done_counter = keys.size();
|
||||||
|
for (auto& p : arguments) {
|
||||||
|
const string& key = std::get<0>(p);
|
||||||
|
Tensor* val = std::get<1>(p);
|
||||||
|
Rendezvous::ParsedKey parsed = std::get<2>(p);
|
||||||
|
rendezvous->RecvAsync(
|
||||||
|
parsed, args,
|
||||||
|
[val, done, key, call_state](const Status& s,
|
||||||
|
const Rendezvous::Args& send_args,
|
||||||
|
const Rendezvous::Args& recv_args,
|
||||||
|
const Tensor& v, const bool is_dead) {
|
||||||
|
Status status = s;
|
||||||
|
if (status.ok()) {
|
||||||
|
*val = v;
|
||||||
|
if (is_dead) {
|
||||||
|
status = errors::InvalidArgument("The tensor returned for ", key,
|
||||||
|
" was not valid.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
call_state->mu.lock();
|
||||||
|
call_state->shared_status.Update(status);
|
||||||
|
call_state->done_counter--;
|
||||||
|
// If we are the last async call to return, call the done callback.
|
||||||
|
if (call_state->done_counter == 0) {
|
||||||
|
const Status& final_status = call_state->shared_status;
|
||||||
|
call_state->mu.unlock();
|
||||||
|
done(final_status);
|
||||||
|
delete call_state;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
call_state->mu.unlock();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
|
||||||
|
const Rendezvous::Args& args) {
|
||||||
|
// Receives values requested by the caller.
|
||||||
|
Rendezvous::ParsedKey parsed;
|
||||||
|
for (auto& p : *out) {
|
||||||
|
const string& key = p.first;
|
||||||
|
Tensor* val = &p.second;
|
||||||
|
bool is_dead = false;
|
||||||
|
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
|
||||||
|
TF_RETURN_IF_ERROR(rendezvous->Recv(parsed, args, val, &is_dead));
|
||||||
|
if (is_dead) {
|
||||||
|
return errors::InvalidArgument("The tensor returned for ", key,
|
||||||
|
" was not valid.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
44
tensorflow/core/common_runtime/rendezvous_util.h
Normal file
44
tensorflow/core/common_runtime/rendezvous_util.h
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/rendezvous.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef std::map<string, Tensor> NamedTensors;
|
||||||
|
typedef std::function<void(const Status&)> StatusCallback;
|
||||||
|
|
||||||
|
// Uses `rendezvous` to send tensors in `in`.
|
||||||
|
Status SendTensorsToRendezvous(Rendezvous* rendezvous,
|
||||||
|
const Rendezvous::Args& args,
|
||||||
|
const std::vector<string>& keys,
|
||||||
|
gtl::ArraySlice<Tensor> tensors_to_send);
|
||||||
|
|
||||||
|
void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
|
||||||
|
const Rendezvous::Args& args,
|
||||||
|
const std::vector<string>& keys,
|
||||||
|
std::vector<Tensor>* received_tensors,
|
||||||
|
const StatusCallback& done);
|
||||||
|
|
||||||
|
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
|
||||||
|
const Rendezvous::Args& args);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
|
94
tensorflow/core/common_runtime/rendezvous_util_test.cc
Normal file
94
tensorflow/core/common_runtime/rendezvous_util_test.cc
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/common_runtime/rendezvous_util.h"
|
||||||
|
|
||||||
|
#include "tensorflow/core/lib/core/notification.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class RendezvousUtilTest : public ::testing::Test {
|
||||||
|
public:
|
||||||
|
RendezvousUtilTest() { rendez_ = NewLocalRendezvous(); }
|
||||||
|
|
||||||
|
~RendezvousUtilTest() override { rendez_->Unref(); }
|
||||||
|
|
||||||
|
Rendezvous* rendez_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// string -> Tensor<string>
|
||||||
|
Tensor V(const string& content) {
|
||||||
|
Tensor tensor(DT_STRING, TensorShape({}));
|
||||||
|
tensor.scalar<string>()() = content;
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tensor<string> -> string
|
||||||
|
string V(const Tensor& tensor) {
|
||||||
|
CHECK_EQ(tensor.dtype(), DT_STRING);
|
||||||
|
CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
|
||||||
|
return tensor.scalar<string>()();
|
||||||
|
}
|
||||||
|
|
||||||
|
string MakeStringKey(const string& name) {
|
||||||
|
return Rendezvous::CreateKey(
|
||||||
|
"/job:localhost/replica:0/task:0/device:CPU:0", 0,
|
||||||
|
"/job:localhost/replica:0/task:0/device:GPU:0", name, FrameAndIter(0, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(RendezvousUtilTest, SendBeforeRecv) {
|
||||||
|
// Fire off sends before receive the tensors.
|
||||||
|
Rendezvous::Args args;
|
||||||
|
TF_ASSERT_OK(SendTensorsToRendezvous(
|
||||||
|
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
|
||||||
|
{V("hello1"), V("hello2")}));
|
||||||
|
|
||||||
|
Notification n;
|
||||||
|
std::vector<Tensor> received_keys;
|
||||||
|
RecvOutputsFromRendezvousAsync(
|
||||||
|
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
|
||||||
|
&received_keys, [&n](const Status& status) { n.Notify(); });
|
||||||
|
n.WaitForNotification();
|
||||||
|
|
||||||
|
EXPECT_EQ(2, received_keys.size());
|
||||||
|
EXPECT_EQ("hello1", V(received_keys[0]));
|
||||||
|
EXPECT_EQ("hello2", V(received_keys[1]));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(RendezvousUtilTest, RecvBeforeSend) {
|
||||||
|
// Fire off recvs, wait for a notification in the callback.
|
||||||
|
Rendezvous::Args args;
|
||||||
|
|
||||||
|
Notification n;
|
||||||
|
std::vector<Tensor> received_keys;
|
||||||
|
RecvOutputsFromRendezvousAsync(
|
||||||
|
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
|
||||||
|
&received_keys, [&n](const Status& status) { n.Notify(); });
|
||||||
|
|
||||||
|
TF_ASSERT_OK(SendTensorsToRendezvous(
|
||||||
|
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
|
||||||
|
{V("hello1"), V("hello2")}));
|
||||||
|
|
||||||
|
n.WaitForNotification();
|
||||||
|
|
||||||
|
EXPECT_EQ(2, received_keys.size());
|
||||||
|
EXPECT_EQ("hello1", V(received_keys[0]));
|
||||||
|
EXPECT_EQ("hello2", V(received_keys[1]));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -26,12 +26,13 @@ limitations under the License.
|
|||||||
#include "grpc++/create_channel.h"
|
#include "grpc++/create_channel.h"
|
||||||
#else
|
#else
|
||||||
// winsock2.h is used in grpc, so Ws2_32.lib is needed
|
// winsock2.h is used in grpc, so Ws2_32.lib is needed
|
||||||
#pragma comment(lib,"Ws2_32.lib")
|
#pragma comment(lib, "Ws2_32.lib")
|
||||||
#endif // #ifndef PLATFORM_WINDOWS
|
#endif // #ifndef PLATFORM_WINDOWS
|
||||||
|
|
||||||
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
|
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
|
||||||
#include "tensorflow/core/framework/graph.pb.h"
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
#include "tensorflow/core/framework/summary.pb.h"
|
#include "tensorflow/core/framework/summary.pb.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||||
#include "tensorflow/core/lib/core/bits.h"
|
#include "tensorflow/core/lib/core/bits.h"
|
||||||
#include "tensorflow/core/lib/hash/hash.h"
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
#include "tensorflow/core/lib/io/path.h"
|
#include "tensorflow/core/lib/io/path.h"
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/common_runtime/memory_types.h"
|
#include "tensorflow/core/common_runtime/memory_types.h"
|
||||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||||
#include "tensorflow/core/common_runtime/process_util.h"
|
#include "tensorflow/core/common_runtime/process_util.h"
|
||||||
|
#include "tensorflow/core/common_runtime/rendezvous_util.h"
|
||||||
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
#include "tensorflow/core/common_runtime/step_stats_collector.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
|
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
|
||||||
#include "tensorflow/core/framework/cancellation.h"
|
#include "tensorflow/core/framework/cancellation.h"
|
||||||
@ -321,116 +322,25 @@ Status GraphMgr::DeregisterAll() {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GraphMgr::SendInputsToRendezvous(Rendezvous* rendezvous,
|
|
||||||
const NamedTensors& in) {
|
|
||||||
Rendezvous::ParsedKey parsed;
|
|
||||||
for (const auto& p : in) {
|
|
||||||
const string& key = p.first;
|
|
||||||
const Tensor& val = p.second;
|
|
||||||
|
|
||||||
Status s = Rendezvous::ParseKey(key, &parsed);
|
|
||||||
if (s.ok()) {
|
|
||||||
s = rendezvous->Send(parsed, Rendezvous::Args(), val, false);
|
|
||||||
}
|
|
||||||
if (!s.ok()) {
|
|
||||||
return s;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GraphMgr::RecvOutputsFromRendezvous(Rendezvous* rendezvous,
|
|
||||||
NamedTensors* out) {
|
|
||||||
// Receives values requested by the caller.
|
|
||||||
Rendezvous::ParsedKey parsed;
|
|
||||||
for (auto& p : *out) {
|
|
||||||
const string& key = p.first;
|
|
||||||
Tensor* val = &p.second;
|
|
||||||
bool is_dead = false;
|
|
||||||
Status s = Rendezvous::ParseKey(key, &parsed);
|
|
||||||
if (s.ok()) {
|
|
||||||
s = rendezvous->Recv(parsed, Rendezvous::Args(), val, &is_dead);
|
|
||||||
}
|
|
||||||
if (is_dead) {
|
|
||||||
s = errors::InvalidArgument("The tensor returned for ", key,
|
|
||||||
" was not valid.");
|
|
||||||
}
|
|
||||||
if (!s.ok()) return s;
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
void GraphMgr::RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
|
|
||||||
NamedTensors* out,
|
|
||||||
const StatusCallback& done) {
|
|
||||||
if (out->empty()) {
|
|
||||||
done(Status::OK());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
// We compute the args before calling RecvAsync because we need to ensure that
|
|
||||||
// out isn't being iterated over after done is called, since done deletes out.
|
|
||||||
std::vector<std::tuple<string, Tensor*, Rendezvous::ParsedKey>> args;
|
|
||||||
for (auto& p : *out) {
|
|
||||||
Rendezvous::ParsedKey parsed;
|
|
||||||
Status s = Rendezvous::ParseKey(p.first, &parsed);
|
|
||||||
if (!s.ok()) {
|
|
||||||
done(s);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
args.push_back(std::make_tuple(p.first, &p.second, parsed));
|
|
||||||
}
|
|
||||||
|
|
||||||
typedef struct {
|
|
||||||
mutex mu;
|
|
||||||
int done_counter;
|
|
||||||
Status shared_status = Status::OK();
|
|
||||||
} CallState;
|
|
||||||
CallState* call_state = new CallState;
|
|
||||||
call_state->done_counter = out->size();
|
|
||||||
for (auto& p : args) {
|
|
||||||
const string& key = std::get<0>(p);
|
|
||||||
Tensor* val = std::get<1>(p);
|
|
||||||
Rendezvous::ParsedKey parsed = std::get<2>(p);
|
|
||||||
rendezvous->RecvAsync(
|
|
||||||
parsed, Rendezvous::Args(),
|
|
||||||
[val, done, key, call_state](const Status& s,
|
|
||||||
const Rendezvous::Args& send_args,
|
|
||||||
const Rendezvous::Args& recv_args,
|
|
||||||
const Tensor& v, const bool is_dead) {
|
|
||||||
Status status = s;
|
|
||||||
if (status.ok()) {
|
|
||||||
*val = v;
|
|
||||||
if (is_dead) {
|
|
||||||
status = errors::InvalidArgument("The tensor returned for ", key,
|
|
||||||
" was not valid.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
call_state->mu.lock();
|
|
||||||
call_state->shared_status.Update(status);
|
|
||||||
call_state->done_counter--;
|
|
||||||
// If we are the last async call to return, call the done callback.
|
|
||||||
if (call_state->done_counter == 0) {
|
|
||||||
const Status& final_status = call_state->shared_status;
|
|
||||||
call_state->mu.unlock();
|
|
||||||
done(final_status);
|
|
||||||
delete call_state;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
call_state->mu.unlock();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
|
Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
|
||||||
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
||||||
Status s = SendInputsToRendezvous(rendezvous, in);
|
std::vector<string> keys;
|
||||||
|
std::vector<Tensor> tensors_to_send;
|
||||||
|
keys.reserve(in.size());
|
||||||
|
tensors_to_send.reserve(in.size());
|
||||||
|
for (const auto& p : in) {
|
||||||
|
keys.push_back(p.first);
|
||||||
|
tensors_to_send.push_back(p.second);
|
||||||
|
}
|
||||||
|
Status s = SendTensorsToRendezvous(rendezvous, Rendezvous::Args(), keys,
|
||||||
|
tensors_to_send);
|
||||||
rendezvous->Unref();
|
rendezvous->Unref();
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
|
Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
|
||||||
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
||||||
Status s = RecvOutputsFromRendezvous(rendezvous, out);
|
Status s = RecvOutputsFromRendezvous(rendezvous, out, Rendezvous::Args());
|
||||||
rendezvous->Unref();
|
rendezvous->Unref();
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
@ -438,11 +348,24 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
|
|||||||
void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
|
void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
|
||||||
StatusCallback done) {
|
StatusCallback done) {
|
||||||
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
|
||||||
RecvOutputsFromRendezvousAsync(rendezvous, out,
|
std::vector<string> keys;
|
||||||
[done, rendezvous](const Status s) {
|
std::vector<Tensor>* received_keys = new std::vector<Tensor>;
|
||||||
rendezvous->Unref();
|
keys.reserve(out->size());
|
||||||
done(s);
|
received_keys->reserve(out->size());
|
||||||
});
|
for (const auto& p : *out) {
|
||||||
|
keys.push_back(p.first);
|
||||||
|
received_keys->push_back(p.second);
|
||||||
|
}
|
||||||
|
RecvOutputsFromRendezvousAsync(
|
||||||
|
rendezvous, Rendezvous::Args(), keys, received_keys,
|
||||||
|
[done, rendezvous, received_keys, out, keys](const Status s) {
|
||||||
|
rendezvous->Unref();
|
||||||
|
for (int i = 0; i < keys.size(); ++i) {
|
||||||
|
(*out)[keys[i]] = (*received_keys)[i];
|
||||||
|
}
|
||||||
|
delete received_keys;
|
||||||
|
done(s);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
|
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
|
||||||
@ -484,7 +407,16 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
|
|||||||
|
|
||||||
// Sends values specified by the caller.
|
// Sends values specified by the caller.
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
s = SendInputsToRendezvous(rendezvous, in);
|
std::vector<string> keys;
|
||||||
|
std::vector<Tensor> tensors_to_send;
|
||||||
|
keys.reserve(in.size());
|
||||||
|
tensors_to_send.reserve(in.size());
|
||||||
|
for (auto& p : in) {
|
||||||
|
keys.push_back(p.first);
|
||||||
|
tensors_to_send.push_back(p.second);
|
||||||
|
}
|
||||||
|
s = SendTensorsToRendezvous(rendezvous, Rendezvous::Args(), keys,
|
||||||
|
tensors_to_send);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
|
@ -169,11 +169,6 @@ class GraphMgr {
|
|||||||
void BuildCostModel(Item* item, StepStatsCollector* collector,
|
void BuildCostModel(Item* item, StepStatsCollector* collector,
|
||||||
CostGraphDef* cost_graph);
|
CostGraphDef* cost_graph);
|
||||||
|
|
||||||
Status SendInputsToRendezvous(Rendezvous* rendezvous, const NamedTensors& in);
|
|
||||||
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out);
|
|
||||||
void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, NamedTensors* out,
|
|
||||||
const StatusCallback& done);
|
|
||||||
|
|
||||||
Status InitItem(const string& session, const GraphDef& gdef,
|
Status InitItem(const string& session, const GraphDef& gdef,
|
||||||
const GraphOptions& graph_options,
|
const GraphOptions& graph_options,
|
||||||
const DebugOptions& debug_options, Item* item);
|
const DebugOptions& debug_options, Item* item);
|
||||||
|
@ -465,7 +465,6 @@ tf_cuda_cc_test(
|
|||||||
linkstatic = tf_kernel_tests_linkstatic(),
|
linkstatic = tf_kernel_tests_linkstatic(),
|
||||||
tags = tf_cuda_tests_tags() + [
|
tags = tf_cuda_tests_tags() + [
|
||||||
"no_oss", # b/62956105: port conflicts.
|
"no_oss", # b/62956105: port conflicts.
|
||||||
"noguitar", # b/64805119
|
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":grpc_channel",
|
":grpc_channel",
|
||||||
|
@ -426,6 +426,10 @@ class FunctionLibraryRuntime {
|
|||||||
StepStatsCollector* stats_collector = nullptr;
|
StepStatsCollector* stats_collector = nullptr;
|
||||||
|
|
||||||
std::function<void(std::function<void()>)>* runner = nullptr;
|
std::function<void(std::function<void()>)>* runner = nullptr;
|
||||||
|
|
||||||
|
// Parameters for remote function execution.
|
||||||
|
bool remote_execution = false;
|
||||||
|
string source_device = ""; // Fully specified device name.
|
||||||
};
|
};
|
||||||
typedef std::function<void(const Status&)> DoneCallback;
|
typedef std::function<void(const Status&)> DoneCallback;
|
||||||
virtual void Run(const Options& opts, Handle handle,
|
virtual void Run(const Options& opts, Handle handle,
|
||||||
|
@ -110,6 +110,37 @@ bool ConsumeAttrNumber(StringPiece* sp, int64* out) {
|
|||||||
} \
|
} \
|
||||||
} while (false)
|
} while (false)
|
||||||
|
|
||||||
|
bool ConsumeCompoundAttrType(StringPiece* sp, StringPiece* out) {
|
||||||
|
auto capture_begin = sp->begin();
|
||||||
|
if (sp->Consume("numbertype") || sp->Consume("numerictype") ||
|
||||||
|
sp->Consume("quantizedtype") || sp->Consume("realnumbertype") ||
|
||||||
|
sp->Consume("realnumberictype")) {
|
||||||
|
*out = StringPiece(capture_begin, sp->begin() - capture_begin);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ProcessCompoundType(const StringPiece type_string, AttrValue* allowed) {
|
||||||
|
if (type_string == "numbertype" || type_string == "numerictype") {
|
||||||
|
for (DataType dt : NumberTypes()) {
|
||||||
|
allowed->mutable_list()->add_type(dt);
|
||||||
|
}
|
||||||
|
} else if (type_string == "quantizedtype") {
|
||||||
|
for (DataType dt : QuantizedTypes()) {
|
||||||
|
allowed->mutable_list()->add_type(dt);
|
||||||
|
}
|
||||||
|
} else if (type_string == "realnumbertype" ||
|
||||||
|
type_string == "realnumerictype") {
|
||||||
|
for (DataType dt : RealNumberTypes()) {
|
||||||
|
allowed->mutable_list()->add_type(dt);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
void FinalizeAttr(StringPiece spec, OpDef* op_def,
|
void FinalizeAttr(StringPiece spec, OpDef* op_def,
|
||||||
std::vector<string>* errors) {
|
std::vector<string>* errors) {
|
||||||
OpDef::AttrDef* attr = op_def->add_attr();
|
OpDef::AttrDef* attr = op_def->add_attr();
|
||||||
@ -123,6 +154,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
|
|||||||
// Read "<type>" or "list(<type>)".
|
// Read "<type>" or "list(<type>)".
|
||||||
bool is_list = ConsumeListPrefix(&spec);
|
bool is_list = ConsumeListPrefix(&spec);
|
||||||
string type;
|
string type;
|
||||||
|
StringPiece type_string; // Used if type == "type"
|
||||||
if (spec.Consume("string")) {
|
if (spec.Consume("string")) {
|
||||||
type = "string";
|
type = "string";
|
||||||
} else if (spec.Consume("int")) {
|
} else if (spec.Consume("int")) {
|
||||||
@ -139,29 +171,15 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
|
|||||||
type = "tensor";
|
type = "tensor";
|
||||||
} else if (spec.Consume("func")) {
|
} else if (spec.Consume("func")) {
|
||||||
type = "func";
|
type = "func";
|
||||||
} else if (spec.Consume("numbertype") || spec.Consume("numerictype")) {
|
} else if (ConsumeCompoundAttrType(&spec, &type_string)) {
|
||||||
type = "type";
|
type = "type";
|
||||||
AttrValue* allowed = attr->mutable_allowed_values();
|
AttrValue* allowed = attr->mutable_allowed_values();
|
||||||
for (DataType dt : NumberTypes()) {
|
VERIFY(ProcessCompoundType(type_string, allowed),
|
||||||
allowed->mutable_list()->add_type(dt);
|
"Expected to see a compound type, saw: ", type_string);
|
||||||
}
|
|
||||||
} else if (spec.Consume("quantizedtype")) {
|
|
||||||
type = "type";
|
|
||||||
AttrValue* allowed = attr->mutable_allowed_values();
|
|
||||||
for (DataType dt : QuantizedTypes()) {
|
|
||||||
allowed->mutable_list()->add_type(dt);
|
|
||||||
}
|
|
||||||
} else if (spec.Consume("realnumbertype") ||
|
|
||||||
spec.Consume("realnumerictype")) {
|
|
||||||
type = "type";
|
|
||||||
AttrValue* allowed = attr->mutable_allowed_values();
|
|
||||||
for (DataType dt : RealNumberTypes()) {
|
|
||||||
allowed->mutable_list()->add_type(dt);
|
|
||||||
}
|
|
||||||
} else if (spec.Consume("{")) {
|
} else if (spec.Consume("{")) {
|
||||||
// e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }"
|
// e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }"
|
||||||
str_util::RemoveLeadingWhitespace(&spec);
|
|
||||||
AttrValue* allowed = attr->mutable_allowed_values();
|
AttrValue* allowed = attr->mutable_allowed_values();
|
||||||
|
str_util::RemoveLeadingWhitespace(&spec);
|
||||||
if (spec.starts_with("\"") || spec.starts_with("'")) {
|
if (spec.starts_with("\"") || spec.starts_with("'")) {
|
||||||
type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }"
|
type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }"
|
||||||
while (true) {
|
while (true) {
|
||||||
@ -172,8 +190,8 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
|
|||||||
string unescaped;
|
string unescaped;
|
||||||
string error;
|
string error;
|
||||||
VERIFY(str_util::CUnescape(escaped_string, &unescaped, &error),
|
VERIFY(str_util::CUnescape(escaped_string, &unescaped, &error),
|
||||||
"Trouble unescaping \"", escaped_string, "\", got error: ",
|
"Trouble unescaping \"", escaped_string,
|
||||||
error);
|
"\", got error: ", error);
|
||||||
allowed->mutable_list()->add_s(unescaped);
|
allowed->mutable_list()->add_s(unescaped);
|
||||||
if (spec.Consume(",")) {
|
if (spec.Consume(",")) {
|
||||||
str_util::RemoveLeadingWhitespace(&spec);
|
str_util::RemoveLeadingWhitespace(&spec);
|
||||||
@ -184,16 +202,19 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else { // "{ int32, float, bool }"
|
} else { // "{ bool, numbertype, string }"
|
||||||
type = "type";
|
type = "type";
|
||||||
while (true) {
|
while (true) {
|
||||||
StringPiece type_string;
|
|
||||||
VERIFY(ConsumeAttrType(&spec, &type_string),
|
VERIFY(ConsumeAttrType(&spec, &type_string),
|
||||||
"Trouble parsing type string at '", spec, "'");
|
"Trouble parsing type string at '", spec, "'");
|
||||||
DataType dt;
|
if (ProcessCompoundType(type_string, allowed)) {
|
||||||
VERIFY(DataTypeFromString(type_string, &dt),
|
// Processed a compound type.
|
||||||
"Unrecognized type string '", type_string, "'");
|
} else {
|
||||||
allowed->mutable_list()->add_type(dt);
|
DataType dt;
|
||||||
|
VERIFY(DataTypeFromString(type_string, &dt),
|
||||||
|
"Unrecognized type string '", type_string, "'");
|
||||||
|
allowed->mutable_list()->add_type(dt);
|
||||||
|
}
|
||||||
if (spec.Consume(",")) {
|
if (spec.Consume(",")) {
|
||||||
str_util::RemoveLeadingWhitespace(&spec);
|
str_util::RemoveLeadingWhitespace(&spec);
|
||||||
if (spec.Consume("}")) break; // Allow ending with ", }".
|
if (spec.Consume("}")) break; // Allow ending with ", }".
|
||||||
@ -204,7 +225,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else { // if spec.Consume("{")
|
||||||
VERIFY(false, "Trouble parsing type string at '", spec, "'");
|
VERIFY(false, "Trouble parsing type string at '", spec, "'");
|
||||||
}
|
}
|
||||||
str_util::RemoveLeadingWhitespace(&spec);
|
str_util::RemoveLeadingWhitespace(&spec);
|
||||||
|
@ -57,8 +57,10 @@ class OpDefBuilder {
|
|||||||
// (by convention only using capital letters for attrs that can be inferred)
|
// (by convention only using capital letters for attrs that can be inferred)
|
||||||
// <type> can be:
|
// <type> can be:
|
||||||
// "string", "int", "float", "bool", "type", "shape", or "tensor"
|
// "string", "int", "float", "bool", "type", "shape", or "tensor"
|
||||||
// "numbertype", "realnumbertype", "quantizedtype", "{int32,int64}"
|
// "numbertype", "realnumbertype", "quantizedtype"
|
||||||
// (meaning "type" with a restriction on valid values)
|
// (meaning "type" with a restriction on valid values)
|
||||||
|
// "{int32,int64}" or {realnumbertype,quantizedtype,string}"
|
||||||
|
// (meaning "type" with a restriction containing unions of value types)
|
||||||
// "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}"
|
// "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}"
|
||||||
// (meaning "string" with a restriction on valid values)
|
// (meaning "string" with a restriction on valid values)
|
||||||
// "list(string)", ..., "list(tensor)", "list(numbertype)", ...
|
// "list(string)", ..., "list(tensor)", "list(numbertype)", ...
|
||||||
|
@ -125,13 +125,27 @@ TEST_F(OpDefBuilderTest, AttrWithRestrictions) {
|
|||||||
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
|
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
|
||||||
"DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, "
|
"DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, "
|
||||||
"DT_QINT32] } } }");
|
"DT_QINT32] } } }");
|
||||||
|
ExpectSuccess(
|
||||||
|
b().Attr("a:{numbertype, variant}"),
|
||||||
|
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
|
||||||
|
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
|
||||||
|
"DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, "
|
||||||
|
"DT_QINT32, DT_VARIANT] } } }");
|
||||||
ExpectSuccess(b().Attr("a:realnumbertype"),
|
ExpectSuccess(b().Attr("a:realnumbertype"),
|
||||||
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
|
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
|
||||||
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, "
|
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, "
|
||||||
"DT_INT16, DT_UINT16, DT_INT8] } } }");
|
"DT_INT16, DT_UINT16, DT_INT8] } } }");
|
||||||
|
ExpectSuccess(b().Attr("a:{realnumbertype, variant , string, }"),
|
||||||
|
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
|
||||||
|
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, "
|
||||||
|
"DT_INT16, DT_UINT16, DT_INT8, DT_VARIANT, DT_STRING] } } }");
|
||||||
ExpectSuccess(b().Attr("a:quantizedtype"),
|
ExpectSuccess(b().Attr("a:quantizedtype"),
|
||||||
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
|
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
|
||||||
"[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16]} } }");
|
"[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16]} } }");
|
||||||
|
ExpectSuccess(b().Attr("a:{quantizedtype ,string}"),
|
||||||
|
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
|
||||||
|
"[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16, "
|
||||||
|
"DT_STRING]} } }");
|
||||||
ExpectSuccess(b().Attr("a:{string,int32}"),
|
ExpectSuccess(b().Attr("a:{string,int32}"),
|
||||||
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
|
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
|
||||||
"[DT_STRING, DT_INT32] } } }");
|
"[DT_STRING, DT_INT32] } } }");
|
||||||
@ -202,6 +216,11 @@ TEST_F(OpDefBuilderTest, AttrListOfRestricted) {
|
|||||||
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
|
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
|
||||||
"[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
|
"[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
|
||||||
"DT_UINT16, DT_INT8, DT_HALF] } } }");
|
"DT_UINT16, DT_INT8, DT_HALF] } } }");
|
||||||
|
ExpectSuccess(
|
||||||
|
b().Attr("a:list({realnumbertype, variant})"),
|
||||||
|
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
|
||||||
|
"[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
|
||||||
|
"DT_UINT16, DT_INT8, DT_HALF, DT_VARIANT] } } }");
|
||||||
ExpectSuccess(
|
ExpectSuccess(
|
||||||
b().Attr("a:list(quantizedtype)"),
|
b().Attr("a:list(quantizedtype)"),
|
||||||
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
|
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
|
||||||
|
@ -24,6 +24,12 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
|
std::unordered_set<string>* UnaryVariantOpRegistry::PersistentStringStorage() {
|
||||||
|
static std::unordered_set<string>* string_storage =
|
||||||
|
new std::unordered_set<string>();
|
||||||
|
return string_storage;
|
||||||
|
}
|
||||||
|
|
||||||
// static
|
// static
|
||||||
UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() {
|
UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() {
|
||||||
static UnaryVariantOpRegistry* global_unary_variant_op_registry =
|
static UnaryVariantOpRegistry* global_unary_variant_op_registry =
|
||||||
@ -32,7 +38,7 @@ UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
UnaryVariantOpRegistry::VariantShapeFn* UnaryVariantOpRegistry::GetShapeFn(
|
UnaryVariantOpRegistry::VariantShapeFn* UnaryVariantOpRegistry::GetShapeFn(
|
||||||
const string& type_name) {
|
StringPiece type_name) {
|
||||||
auto found = shape_fns.find(type_name);
|
auto found = shape_fns.find(type_name);
|
||||||
if (found == shape_fns.end()) return nullptr;
|
if (found == shape_fns.end()) return nullptr;
|
||||||
return &found->second;
|
return &found->second;
|
||||||
@ -45,7 +51,8 @@ void UnaryVariantOpRegistry::RegisterShapeFn(const string& type_name,
|
|||||||
CHECK_EQ(existing, nullptr)
|
CHECK_EQ(existing, nullptr)
|
||||||
<< "Unary VariantShapeFn for type_name: " << type_name
|
<< "Unary VariantShapeFn for type_name: " << type_name
|
||||||
<< " already registered";
|
<< " already registered";
|
||||||
shape_fns.insert(std::pair<string, VariantShapeFn>(type_name, shape_fn));
|
shape_fns.insert(std::pair<StringPiece, VariantShapeFn>(
|
||||||
|
GetPersistentStringPiece(type_name), shape_fn));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) {
|
Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) {
|
||||||
@ -65,8 +72,29 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) {
|
|||||||
return (*shape_fn)(v, shape);
|
return (*shape_fn)(v, shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add some basic registrations for use by others, e.g., for testing.
|
||||||
|
namespace {
|
||||||
|
template <typename T>
|
||||||
|
Status ScalarShape(const T&, TensorShape* shape) {
|
||||||
|
*shape = TensorShape({});
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
#define REGISTER_VARIANT_SHAPE_TYPE(T) \
|
||||||
|
REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, TF_STR(T), ScalarShape<T>);
|
||||||
|
|
||||||
|
// No encode/shape registered for std::complex<> and Eigen::half
|
||||||
|
// objects yet.
|
||||||
|
REGISTER_VARIANT_SHAPE_TYPE(int);
|
||||||
|
REGISTER_VARIANT_SHAPE_TYPE(float);
|
||||||
|
REGISTER_VARIANT_SHAPE_TYPE(bool);
|
||||||
|
REGISTER_VARIANT_SHAPE_TYPE(double);
|
||||||
|
|
||||||
|
#undef REGISTER_VARIANT_SHAPE_TYPE
|
||||||
|
|
||||||
UnaryVariantOpRegistry::VariantDecodeFn* UnaryVariantOpRegistry::GetDecodeFn(
|
UnaryVariantOpRegistry::VariantDecodeFn* UnaryVariantOpRegistry::GetDecodeFn(
|
||||||
const string& type_name) {
|
StringPiece type_name) {
|
||||||
auto found = decode_fns.find(type_name);
|
auto found = decode_fns.find(type_name);
|
||||||
if (found == decode_fns.end()) return nullptr;
|
if (found == decode_fns.end()) return nullptr;
|
||||||
return &found->second;
|
return &found->second;
|
||||||
@ -79,7 +107,8 @@ void UnaryVariantOpRegistry::RegisterDecodeFn(
|
|||||||
CHECK_EQ(existing, nullptr)
|
CHECK_EQ(existing, nullptr)
|
||||||
<< "Unary VariantDecodeFn for type_name: " << type_name
|
<< "Unary VariantDecodeFn for type_name: " << type_name
|
||||||
<< " already registered";
|
<< " already registered";
|
||||||
decode_fns.insert(std::pair<string, VariantDecodeFn>(type_name, decode_fn));
|
decode_fns.insert(std::pair<StringPiece, VariantDecodeFn>(
|
||||||
|
GetPersistentStringPiece(type_name), decode_fn));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool DecodeUnaryVariant(Variant* variant) {
|
bool DecodeUnaryVariant(Variant* variant) {
|
||||||
@ -103,13 +132,6 @@ bool DecodeUnaryVariant(Variant* variant) {
|
|||||||
|
|
||||||
// Add some basic registrations for use by others, e.g., for testing.
|
// Add some basic registrations for use by others, e.g., for testing.
|
||||||
|
|
||||||
namespace {
|
|
||||||
string MaybeRemoveTFPrefix(const StringPiece& str) {
|
|
||||||
return str.starts_with("::tensorflow::") ? str.substr(14).ToString()
|
|
||||||
: str.ToString();
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
#define REGISTER_VARIANT_DECODE_TYPE(T) \
|
#define REGISTER_VARIANT_DECODE_TYPE(T) \
|
||||||
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, TF_STR(T));
|
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, TF_STR(T));
|
||||||
|
|
||||||
@ -122,30 +144,31 @@ REGISTER_VARIANT_DECODE_TYPE(double);
|
|||||||
|
|
||||||
#undef REGISTER_VARIANT_DECODE_TYPE
|
#undef REGISTER_VARIANT_DECODE_TYPE
|
||||||
|
|
||||||
// Special casing ZerosLikeFn per device.
|
// Special casing UnaryOpFn per op and per device.
|
||||||
UnaryVariantOpRegistry::VariantZerosLikeFn*
|
UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn(
|
||||||
UnaryVariantOpRegistry::GetZerosLikeFn(const string& device,
|
VariantUnaryOp op, StringPiece device, StringPiece type_name) {
|
||||||
const string& type_name) {
|
auto found = unary_op_fns.find(std::make_tuple(op, device, type_name));
|
||||||
auto found = zeros_like_fns.find(std::make_pair(device, type_name));
|
if (found == unary_op_fns.end()) return nullptr;
|
||||||
if (found == zeros_like_fns.end()) return nullptr;
|
|
||||||
return &found->second;
|
return &found->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
void UnaryVariantOpRegistry::RegisterZerosLikeFn(
|
void UnaryVariantOpRegistry::RegisterUnaryOpFn(
|
||||||
const string& device, const string& type_name,
|
VariantUnaryOp op, const string& device, const string& type_name,
|
||||||
const VariantZerosLikeFn& zeros_like_fn) {
|
const VariantUnaryOpFn& unary_op_fn) {
|
||||||
CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantZerosLike";
|
CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantUnaryOp";
|
||||||
VariantZerosLikeFn* existing = GetZerosLikeFn(device, type_name);
|
VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_name);
|
||||||
CHECK_EQ(existing, nullptr)
|
CHECK_EQ(existing, nullptr)
|
||||||
<< "Unary VariantZerosLikeFn for type_name: " << type_name
|
<< "Unary VariantUnaryOpFn for type_name: " << type_name
|
||||||
<< " already registered for device type: " << device;
|
<< " already registered for device type: " << device;
|
||||||
zeros_like_fns.insert(
|
unary_op_fns.insert(
|
||||||
std::pair<std::pair<string, string>, VariantZerosLikeFn>(
|
std::pair<std::tuple<VariantUnaryOp, StringPiece, StringPiece>,
|
||||||
std::make_pair(device, type_name), zeros_like_fn));
|
VariantUnaryOpFn>(
|
||||||
|
std::make_tuple(op, GetPersistentStringPiece(device),
|
||||||
|
GetPersistentStringPiece(type_name)),
|
||||||
|
unary_op_fn));
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t,
|
Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t,
|
||||||
T* t_out) {
|
T* t_out) {
|
||||||
@ -154,9 +177,10 @@ Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t,
|
|||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
#define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \
|
#define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \
|
||||||
REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION( \
|
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \
|
||||||
DEVICE_CPU, T, TF_STR(T), ZerosLikeVariantPrimitiveType<T>);
|
DEVICE_CPU, T, TF_STR(T), \
|
||||||
|
ZerosLikeVariantPrimitiveType<T>);
|
||||||
|
|
||||||
// No zeros_like registered for std::complex<> or Eigen::half objects yet.
|
// No zeros_like registered for std::complex<> or Eigen::half objects yet.
|
||||||
REGISTER_VARIANT_ZEROS_LIKE_TYPE(int);
|
REGISTER_VARIANT_ZEROS_LIKE_TYPE(int);
|
||||||
@ -166,4 +190,51 @@ REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool);
|
|||||||
|
|
||||||
#undef REGISTER_VARIANT_ZEROS_LIKE_TYPE
|
#undef REGISTER_VARIANT_ZEROS_LIKE_TYPE
|
||||||
|
|
||||||
|
// Special casing BinaryOpFn per op and per device.
|
||||||
|
UnaryVariantOpRegistry::VariantBinaryOpFn*
|
||||||
|
UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
|
||||||
|
StringPiece type_name) {
|
||||||
|
auto found = binary_op_fns.find(std::make_tuple(op, device, type_name));
|
||||||
|
if (found == binary_op_fns.end()) return nullptr;
|
||||||
|
return &found->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
void UnaryVariantOpRegistry::RegisterBinaryOpFn(
|
||||||
|
VariantBinaryOp op, const string& device, const string& type_name,
|
||||||
|
const VariantBinaryOpFn& add_fn) {
|
||||||
|
CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantBinaryOp";
|
||||||
|
VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_name);
|
||||||
|
CHECK_EQ(existing, nullptr)
|
||||||
|
<< "Unary VariantBinaryOpFn for type_name: " << type_name
|
||||||
|
<< " already registered for device type: " << device;
|
||||||
|
binary_op_fns.insert(
|
||||||
|
std::pair<std::tuple<VariantBinaryOp, StringPiece, StringPiece>,
|
||||||
|
VariantBinaryOpFn>(
|
||||||
|
std::make_tuple(op, GetPersistentStringPiece(device),
|
||||||
|
GetPersistentStringPiece(type_name)),
|
||||||
|
add_fn));
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
template <typename T>
|
||||||
|
Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b,
|
||||||
|
T* out) {
|
||||||
|
*out = a + b;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
#define REGISTER_VARIANT_ADD_TYPE(T) \
|
||||||
|
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \
|
||||||
|
T, TF_STR(T), \
|
||||||
|
AddVariantPrimitiveType<T>);
|
||||||
|
|
||||||
|
// No add registered for std::complex<> or Eigen::half objects yet.
|
||||||
|
REGISTER_VARIANT_ADD_TYPE(int);
|
||||||
|
REGISTER_VARIANT_ADD_TYPE(float);
|
||||||
|
REGISTER_VARIANT_ADD_TYPE(double);
|
||||||
|
REGISTER_VARIANT_ADD_TYPE(bool);
|
||||||
|
|
||||||
|
#undef REGISTER_VARIANT_ADD_TYPE
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -17,11 +17,13 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_
|
#define TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <unordered_set>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/framework/variant.h"
|
#include "tensorflow/core/framework/variant.h"
|
||||||
#include "tensorflow/core/framework/variant_encode_decode.h"
|
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||||
|
#include "tensorflow/core/lib/hash/hash.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
@ -30,49 +32,110 @@ class OpKernelContext;
|
|||||||
// for different variant types. To be used by ShapeOp, RankOp, and
|
// for different variant types. To be used by ShapeOp, RankOp, and
|
||||||
// SizeOp, decoding, etc.
|
// SizeOp, decoding, etc.
|
||||||
|
|
||||||
|
enum VariantUnaryOp {
|
||||||
|
INVALID_VARIANT_UNARY_OP = 0,
|
||||||
|
ZEROS_LIKE_VARIANT_UNARY_OP = 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
enum VariantBinaryOp {
|
||||||
|
INVALID_VARIANT_BINARY_OP = 0,
|
||||||
|
ADD_VARIANT_BINARY_OP = 1,
|
||||||
|
};
|
||||||
|
|
||||||
class UnaryVariantOpRegistry {
|
class UnaryVariantOpRegistry {
|
||||||
public:
|
public:
|
||||||
typedef std::function<Status(const Variant& v, TensorShape*)> VariantShapeFn;
|
typedef std::function<Status(const Variant& v, TensorShape*)> VariantShapeFn;
|
||||||
typedef std::function<bool(Variant*)> VariantDecodeFn;
|
typedef std::function<bool(Variant*)> VariantDecodeFn;
|
||||||
typedef std::function<Status(OpKernelContext*, const Variant&, Variant*)>
|
typedef std::function<Status(OpKernelContext*, const Variant&, Variant*)>
|
||||||
VariantZerosLikeFn;
|
VariantUnaryOpFn;
|
||||||
|
typedef std::function<Status(OpKernelContext*, const Variant&, const Variant&,
|
||||||
|
Variant*)>
|
||||||
|
VariantBinaryOpFn;
|
||||||
|
|
||||||
// Add a shape lookup function to the registry.
|
// Add a shape lookup function to the registry.
|
||||||
void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn);
|
void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn);
|
||||||
|
|
||||||
// Returns nullptr if no shape function was found for the given TypeName.
|
// Returns nullptr if no shape function was found for the given TypeName.
|
||||||
VariantShapeFn* GetShapeFn(const string& type_name);
|
VariantShapeFn* GetShapeFn(StringPiece type_name);
|
||||||
|
|
||||||
// Add a decode function to the registry.
|
// Add a decode function to the registry.
|
||||||
void RegisterDecodeFn(const string& type_name,
|
void RegisterDecodeFn(const string& type_name,
|
||||||
const VariantDecodeFn& decode_fn);
|
const VariantDecodeFn& decode_fn);
|
||||||
|
|
||||||
// Returns nullptr if no decode function was found for the given TypeName.
|
// Returns nullptr if no decode function was found for the given TypeName.
|
||||||
VariantDecodeFn* GetDecodeFn(const string& type_name);
|
VariantDecodeFn* GetDecodeFn(StringPiece type_name);
|
||||||
|
|
||||||
// Add a zeros-like function to the registry.
|
// Add a unary op function to the registry.
|
||||||
void RegisterZerosLikeFn(const string& device, const string& type_name,
|
void RegisterUnaryOpFn(VariantUnaryOp op, const string& device,
|
||||||
const VariantZerosLikeFn& zeros_like_fn);
|
const string& type_name,
|
||||||
|
const VariantUnaryOpFn& unary_op_fn);
|
||||||
|
|
||||||
// Returns nullptr if no zeros-like function was found for the given
|
// Returns nullptr if no unary op function was found for the given
|
||||||
// device and TypeName.
|
// op, device, and TypeName.
|
||||||
VariantZerosLikeFn* GetZerosLikeFn(const string& device,
|
VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device,
|
||||||
const string& type_name);
|
StringPiece type_name);
|
||||||
|
|
||||||
|
// Add a binary op function to the registry.
|
||||||
|
void RegisterBinaryOpFn(VariantBinaryOp op, const string& device,
|
||||||
|
const string& type_name,
|
||||||
|
const VariantBinaryOpFn& add_fn);
|
||||||
|
|
||||||
|
// Returns nullptr if no binary op function was found for the given
|
||||||
|
// op, device and TypeName.
|
||||||
|
VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
|
||||||
|
StringPiece type_name);
|
||||||
|
|
||||||
|
// Get a pointer to a global UnaryVariantOpRegistry object
|
||||||
static UnaryVariantOpRegistry* Global();
|
static UnaryVariantOpRegistry* Global();
|
||||||
|
|
||||||
|
// Get a pointer to a global persistent string storage object.
|
||||||
|
// ISO/IEC C++ working draft N4296 clarifies that insertion into an
|
||||||
|
// std::unordered_set does not invalidate memory locations of
|
||||||
|
// *values* inside the set (though it may invalidate existing
|
||||||
|
// iterators). In other words, one may safely point a StringPiece to
|
||||||
|
// a value in the set without that StringPiece being invalidated by
|
||||||
|
// future insertions.
|
||||||
|
static std::unordered_set<string>* PersistentStringStorage();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unordered_map<string, VariantShapeFn> shape_fns;
|
std::unordered_map<StringPiece, VariantShapeFn, StringPiece::Hasher>
|
||||||
std::unordered_map<string, VariantDecodeFn> decode_fns;
|
shape_fns;
|
||||||
// Map std::pair<device, type_name> to function.
|
std::unordered_map<StringPiece, VariantDecodeFn, StringPiece::Hasher>
|
||||||
struct PairHash {
|
decode_fns;
|
||||||
template <typename T, typename U>
|
|
||||||
std::size_t operator()(const std::pair<T, U>& x) const {
|
// Map std::tuple<Op, device, type_name> to function.
|
||||||
return std::hash<T>()(x.first) ^ std::hash<U>()(x.second);
|
struct TupleHash {
|
||||||
|
template <typename Op>
|
||||||
|
std::size_t operator()(
|
||||||
|
const std::tuple<Op, StringPiece, StringPiece>& x) const {
|
||||||
|
// The hash of an enum is just its value as a std::size_t.
|
||||||
|
std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
|
||||||
|
StringPiece::Hasher sp_hasher;
|
||||||
|
ret = Hash64Combine(ret, sp_hasher(std::get<1>(x)));
|
||||||
|
ret = Hash64Combine(ret, sp_hasher(std::get<2>(x)));
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
std::unordered_map<std::pair<string, string>, VariantZerosLikeFn, PairHash>
|
std::unordered_map<std::tuple<VariantUnaryOp, StringPiece, StringPiece>,
|
||||||
zeros_like_fns;
|
VariantUnaryOpFn, TupleHash>
|
||||||
|
unary_op_fns;
|
||||||
|
std::unordered_map<std::tuple<VariantBinaryOp, StringPiece, StringPiece>,
|
||||||
|
VariantBinaryOpFn, TupleHash>
|
||||||
|
binary_op_fns;
|
||||||
|
|
||||||
|
// Find or insert a string into a persistent string storage
|
||||||
|
// container; return the StringPiece pointing to the permanent
|
||||||
|
// string location.
|
||||||
|
static StringPiece GetPersistentStringPiece(const string& str) {
|
||||||
|
const auto string_storage = PersistentStringStorage();
|
||||||
|
auto found = string_storage->find(str);
|
||||||
|
if (found == string_storage->end()) {
|
||||||
|
auto inserted = string_storage->insert(str);
|
||||||
|
return StringPiece(*inserted.first);
|
||||||
|
} else {
|
||||||
|
return StringPiece(*found);
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Gets a TensorShape from a Tensor containing a scalar Variant.
|
// Gets a TensorShape from a Tensor containing a scalar Variant.
|
||||||
@ -94,26 +157,57 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape);
|
|||||||
//
|
//
|
||||||
bool DecodeUnaryVariant(Variant* variant);
|
bool DecodeUnaryVariant(Variant* variant);
|
||||||
|
|
||||||
// Sets *z_out = zeros_like(v). The variant v must have a registered
|
// Sets *v_out = unary_op(v). The variant v must have a registered
|
||||||
// ZerosLike function for the given Device. Returns an Internal error
|
// UnaryOp function for the given Device. Returns an Internal error
|
||||||
// if v does not have a registered zeros_like function for this device, or if
|
// if v does not have a registered unary_op function for this device, or if
|
||||||
// ZerosLike fails.
|
// UnaryOp fails.
|
||||||
//
|
//
|
||||||
// REQUIRES:
|
// REQUIRES:
|
||||||
// v_out is not null.
|
// v_out is not null.
|
||||||
//
|
//
|
||||||
template <typename Device>
|
template <typename Device>
|
||||||
Status CreateZerosLikeVariant(OpKernelContext* ctx, const Variant& v,
|
Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
|
||||||
Variant* v_out) {
|
Variant* v_out) {
|
||||||
const string& device = DeviceName<Device>::value;
|
const string& device = DeviceName<Device>::value;
|
||||||
UnaryVariantOpRegistry::VariantZerosLikeFn* zeros_like_fn =
|
UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
|
||||||
UnaryVariantOpRegistry::Global()->GetZerosLikeFn(device, v.TypeName());
|
UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeName());
|
||||||
if (zeros_like_fn == nullptr) {
|
if (unary_op_fn == nullptr) {
|
||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
"No unary variant zeros_like function found for Variant type_name: ",
|
"No unary variant unary_op function found for unary variant op enum: ",
|
||||||
v.TypeName(), " for device type: ", device);
|
op, " Variant type_name: ", v.TypeName(), " for device type: ", device);
|
||||||
}
|
}
|
||||||
return (*zeros_like_fn)(ctx, v, v_out);
|
return (*unary_op_fn)(ctx, v, v_out);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sets *out = binary_op(a, b). The variants a and b must be the same type
|
||||||
|
// and have a registered binary_op function for the given Device. Returns an
|
||||||
|
// Internal error if a and b are not the same type_name or if
|
||||||
|
// if a does not have a registered op function for this device, or if
|
||||||
|
// BinaryOp fails.
|
||||||
|
//
|
||||||
|
// REQUIRES:
|
||||||
|
// out is not null.
|
||||||
|
//
|
||||||
|
template <typename Device>
|
||||||
|
Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
|
||||||
|
const Variant& a, const Variant& b, Variant* out) {
|
||||||
|
if (a.TypeName() != b.TypeName()) {
|
||||||
|
return errors::Internal(
|
||||||
|
"BianryOpVariants: Variants a and b have different "
|
||||||
|
"type names: '",
|
||||||
|
a.TypeName(), "' vs. '", b.TypeName(), "'");
|
||||||
|
}
|
||||||
|
const string& device = DeviceName<Device>::value;
|
||||||
|
UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
|
||||||
|
UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeName());
|
||||||
|
if (binary_op_fn == nullptr) {
|
||||||
|
return errors::Internal(
|
||||||
|
"No unary variant binary_op function found for binary variant op "
|
||||||
|
"enum: ",
|
||||||
|
op, " Variant type_name: '", a.TypeName(),
|
||||||
|
"' for device type: ", device);
|
||||||
|
}
|
||||||
|
return (*binary_op_fn)(ctx, a, b, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace variant_op_registry_fn_registration {
|
namespace variant_op_registry_fn_registration {
|
||||||
@ -165,30 +259,65 @@ class UnaryVariantDecodeRegistration {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class UnaryVariantZerosLikeRegistration {
|
class UnaryVariantUnaryOpRegistration {
|
||||||
typedef std::function<Status(OpKernelContext* ctx, const T& t, T* t_out)>
|
typedef std::function<Status(OpKernelContext* ctx, const T& t, T* t_out)>
|
||||||
LocalVariantZerosLikeFn;
|
LocalVariantUnaryOpFn;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
UnaryVariantZerosLikeRegistration(
|
UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device,
|
||||||
const string& device, const string& type_name,
|
const string& type_name,
|
||||||
const LocalVariantZerosLikeFn& zeros_like_fn) {
|
const LocalVariantUnaryOpFn& unary_op_fn) {
|
||||||
auto wrapped_fn = [type_name, zeros_like_fn](OpKernelContext* ctx,
|
auto wrapped_fn = [type_name, unary_op_fn](OpKernelContext* ctx,
|
||||||
const Variant& v,
|
const Variant& v,
|
||||||
Variant* v_out) -> Status {
|
Variant* v_out) -> Status {
|
||||||
CHECK_NOTNULL(v_out);
|
CHECK_NOTNULL(v_out);
|
||||||
*v_out = T();
|
*v_out = T();
|
||||||
if (v.get<T>() == nullptr) {
|
if (v.get<T>() == nullptr) {
|
||||||
return errors::Internal(
|
return errors::Internal(
|
||||||
"VariantZerosLikeFn: Could not access object, type_name: ",
|
"VariantUnaryOpFn: Could not access object, type_name: ",
|
||||||
type_name);
|
type_name);
|
||||||
}
|
}
|
||||||
const T& t = *v.get<T>();
|
const T& t = *v.get<T>();
|
||||||
T* t_out = v_out->get<T>();
|
T* t_out = v_out->get<T>();
|
||||||
return zeros_like_fn(ctx, t, t_out);
|
return unary_op_fn(ctx, t, t_out);
|
||||||
};
|
};
|
||||||
UnaryVariantOpRegistry::Global()->RegisterZerosLikeFn(device, type_name,
|
UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(op, device, type_name,
|
||||||
wrapped_fn);
|
wrapped_fn);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class UnaryVariantBinaryOpRegistration {
|
||||||
|
typedef std::function<Status(OpKernelContext* ctx, const T& a, const T& b,
|
||||||
|
T* out)>
|
||||||
|
LocalVariantBinaryOpFn;
|
||||||
|
|
||||||
|
public:
|
||||||
|
UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device,
|
||||||
|
const string& type_name,
|
||||||
|
const LocalVariantBinaryOpFn& binary_op_fn) {
|
||||||
|
auto wrapped_fn = [type_name, binary_op_fn](
|
||||||
|
OpKernelContext* ctx, const Variant& a,
|
||||||
|
const Variant& b, Variant* out) -> Status {
|
||||||
|
CHECK_NOTNULL(out);
|
||||||
|
*out = T();
|
||||||
|
if (a.get<T>() == nullptr) {
|
||||||
|
return errors::Internal(
|
||||||
|
"VariantBinaryOpFn: Could not access object 'a', type_name: ",
|
||||||
|
type_name);
|
||||||
|
}
|
||||||
|
if (b.get<T>() == nullptr) {
|
||||||
|
return errors::Internal(
|
||||||
|
"VariantBinaryOpFn: Could not access object 'b', type_name: ",
|
||||||
|
type_name);
|
||||||
|
}
|
||||||
|
const T& t_a = *a.get<T>();
|
||||||
|
const T& t_b = *b.get<T>();
|
||||||
|
T* t_out = out->get<T>();
|
||||||
|
return binary_op_fn(ctx, t_a, t_b, t_out);
|
||||||
|
};
|
||||||
|
UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(op, device, type_name,
|
||||||
|
wrapped_fn);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -223,25 +352,47 @@ class UnaryVariantZerosLikeRegistration {
|
|||||||
T> \
|
T> \
|
||||||
register_unary_variant_op_decoder_fn_##ctr(type_name)
|
register_unary_variant_op_decoder_fn_##ctr(type_name)
|
||||||
|
|
||||||
// Register a unary zeros_like variant function with the signature:
|
// Register a unary unary_op variant function with the signature:
|
||||||
// Status ZerosLikeFn(OpKernelContext* ctx, const T& t, T* t_out);
|
// Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out);
|
||||||
// to Variants having TypeName type_name, for device string device.
|
// to Variants having TypeName type_name, for device string device,
|
||||||
#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION(device, T, type_name, \
|
// for UnaryVariantOp enum op.
|
||||||
zeros_like_function) \
|
#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, type_name, \
|
||||||
REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ_HELPER( \
|
unary_op_function) \
|
||||||
__COUNTER__, device, T, type_name, zeros_like_function)
|
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
|
||||||
|
__COUNTER__, op, device, T, type_name, unary_op_function)
|
||||||
|
|
||||||
#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ_HELPER( \
|
#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
|
||||||
ctr, device, T, type_name, zeros_like_function) \
|
ctr, op, device, T, type_name, unary_op_function) \
|
||||||
REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ(ctr, device, T, type_name, \
|
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, type_name, \
|
||||||
zeros_like_function)
|
unary_op_function)
|
||||||
|
|
||||||
#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ( \
|
#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \
|
||||||
ctr, device, T, type_name, zeros_like_function) \
|
ctr, op, device, T, type_name, unary_op_function) \
|
||||||
static variant_op_registry_fn_registration:: \
|
static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \
|
||||||
UnaryVariantZerosLikeRegistration<T> \
|
T> \
|
||||||
register_unary_variant_op_decoder_fn_##ctr(device, type_name, \
|
register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \
|
||||||
zeros_like_function)
|
unary_op_function)
|
||||||
|
|
||||||
|
// Register a binary_op variant function with the signature:
|
||||||
|
// Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out);
|
||||||
|
// to Variants having TypeName type_name, for device string device,
|
||||||
|
// for BinaryVariantOp enum OP.
|
||||||
|
#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, type_name, \
|
||||||
|
binary_op_function) \
|
||||||
|
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
|
||||||
|
__COUNTER__, op, device, T, type_name, binary_op_function)
|
||||||
|
|
||||||
|
#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
|
||||||
|
ctr, op, device, T, type_name, binary_op_function) \
|
||||||
|
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
|
||||||
|
ctr, op, device, T, type_name, binary_op_function)
|
||||||
|
|
||||||
|
#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
|
||||||
|
ctr, op, device, T, type_name, binary_op_function) \
|
||||||
|
static variant_op_registry_fn_registration:: \
|
||||||
|
UnaryVariantBinaryOpRegistration<T> \
|
||||||
|
register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \
|
||||||
|
binary_op_function)
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ struct VariantValue {
|
|||||||
if (v.early_exit) {
|
if (v.early_exit) {
|
||||||
return errors::InvalidArgument("early exit zeros_like!");
|
return errors::InvalidArgument("early exit zeros_like!");
|
||||||
}
|
}
|
||||||
v_out->zeros_like_set = 1; // CPU
|
v_out->value = 1; // CPU
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
static Status GPUZerosLikeFn(OpKernelContext* ctx, const VariantValue& v,
|
static Status GPUZerosLikeFn(OpKernelContext* ctx, const VariantValue& v,
|
||||||
@ -58,11 +58,27 @@ struct VariantValue {
|
|||||||
if (v.early_exit) {
|
if (v.early_exit) {
|
||||||
return errors::InvalidArgument("early exit zeros_like!");
|
return errors::InvalidArgument("early exit zeros_like!");
|
||||||
}
|
}
|
||||||
v_out->zeros_like_set = 2; // GPU
|
v_out->value = 2; // GPU
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
static Status CPUAddFn(OpKernelContext* ctx, const VariantValue& a,
|
||||||
|
const VariantValue& b, VariantValue* out) {
|
||||||
|
if (a.early_exit) {
|
||||||
|
return errors::InvalidArgument("early exit add!");
|
||||||
|
}
|
||||||
|
out->value = a.value + b.value; // CPU
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
static Status GPUAddFn(OpKernelContext* ctx, const VariantValue& a,
|
||||||
|
const VariantValue& b, VariantValue* out) {
|
||||||
|
if (a.early_exit) {
|
||||||
|
return errors::InvalidArgument("early exit add!");
|
||||||
|
}
|
||||||
|
out->value = -(a.value + b.value); // GPU
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
bool early_exit;
|
bool early_exit;
|
||||||
int zeros_like_set;
|
int value;
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue",
|
REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue",
|
||||||
@ -70,13 +86,23 @@ REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue",
|
|||||||
|
|
||||||
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue");
|
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue");
|
||||||
|
|
||||||
REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION(DEVICE_CPU, VariantValue,
|
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
|
||||||
"TEST VariantValue",
|
DEVICE_CPU, VariantValue,
|
||||||
VariantValue::CPUZerosLikeFn);
|
"TEST VariantValue",
|
||||||
|
VariantValue::CPUZerosLikeFn);
|
||||||
|
|
||||||
REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION(DEVICE_GPU, VariantValue,
|
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
|
||||||
"TEST VariantValue",
|
DEVICE_GPU, VariantValue,
|
||||||
VariantValue::GPUZerosLikeFn);
|
"TEST VariantValue",
|
||||||
|
VariantValue::GPUZerosLikeFn);
|
||||||
|
|
||||||
|
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
|
||||||
|
VariantValue, "TEST VariantValue",
|
||||||
|
VariantValue::CPUAddFn);
|
||||||
|
|
||||||
|
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
|
||||||
|
VariantValue, "TEST VariantValue",
|
||||||
|
VariantValue::GPUAddFn);
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
@ -104,8 +130,9 @@ TEST(VariantOpShapeRegistryTest, TestBasic) {
|
|||||||
TEST(VariantOpShapeRegistryTest, TestDuplicate) {
|
TEST(VariantOpShapeRegistryTest, TestDuplicate) {
|
||||||
UnaryVariantOpRegistry registry;
|
UnaryVariantOpRegistry registry;
|
||||||
UnaryVariantOpRegistry::VariantShapeFn f;
|
UnaryVariantOpRegistry::VariantShapeFn f;
|
||||||
registry.RegisterShapeFn("fjfjfj", f);
|
string kTypeName = "fjfjfj";
|
||||||
EXPECT_DEATH(registry.RegisterShapeFn("fjfjfj", f),
|
registry.RegisterShapeFn(kTypeName, f);
|
||||||
|
EXPECT_DEATH(registry.RegisterShapeFn(kTypeName, f),
|
||||||
"fjfjfj already registered");
|
"fjfjfj already registered");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -133,71 +160,146 @@ TEST(VariantOpDecodeRegistryTest, TestBasic) {
|
|||||||
TEST(VariantOpDecodeRegistryTest, TestDuplicate) {
|
TEST(VariantOpDecodeRegistryTest, TestDuplicate) {
|
||||||
UnaryVariantOpRegistry registry;
|
UnaryVariantOpRegistry registry;
|
||||||
UnaryVariantOpRegistry::VariantDecodeFn f;
|
UnaryVariantOpRegistry::VariantDecodeFn f;
|
||||||
registry.RegisterDecodeFn("fjfjfj", f);
|
string kTypeName = "fjfjfj";
|
||||||
EXPECT_DEATH(registry.RegisterDecodeFn("fjfjfj", f),
|
registry.RegisterDecodeFn(kTypeName, f);
|
||||||
|
EXPECT_DEATH(registry.RegisterDecodeFn(kTypeName, f),
|
||||||
"fjfjfj already registered");
|
"fjfjfj already registered");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
|
TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
|
||||||
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetZerosLikeFn(
|
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
|
||||||
DEVICE_CPU, "YOU SHALL NOT PASS"),
|
ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"),
|
||||||
nullptr);
|
nullptr);
|
||||||
|
|
||||||
VariantValue vv_early_exit{true /* early_exit */, 0 /* zeros_like_set */};
|
VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
|
||||||
Variant v = vv_early_exit;
|
Variant v = vv_early_exit;
|
||||||
Variant v_out = VariantValue();
|
Variant v_out = VariantValue();
|
||||||
|
|
||||||
OpKernelContext* null_context_pointer = nullptr;
|
OpKernelContext* null_context_pointer = nullptr;
|
||||||
Status s0 =
|
Status s0 = UnaryOpVariant<CPUDevice>(null_context_pointer,
|
||||||
CreateZerosLikeVariant<CPUDevice>(null_context_pointer, v, &v_out);
|
ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out);
|
||||||
EXPECT_FALSE(s0.ok());
|
EXPECT_FALSE(s0.ok());
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
StringPiece(s0.error_message()).contains("early exit zeros_like"));
|
StringPiece(s0.error_message()).contains("early exit zeros_like"));
|
||||||
|
|
||||||
VariantValue vv_ok{false /* early_exit */, 0 /* zeros_like_set */};
|
VariantValue vv_ok{false /* early_exit */, 0 /* value */};
|
||||||
v = vv_ok;
|
v = vv_ok;
|
||||||
TF_EXPECT_OK(
|
TF_EXPECT_OK(UnaryOpVariant<CPUDevice>(
|
||||||
CreateZerosLikeVariant<CPUDevice>(null_context_pointer, v, &v_out));
|
null_context_pointer, ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out));
|
||||||
VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
|
VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
|
||||||
EXPECT_EQ(vv_out->zeros_like_set, 1); // CPU
|
EXPECT_EQ(vv_out->value, 1); // CPU
|
||||||
}
|
}
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
TEST(VariantOpZerosLikeRegistryTest, TestBasicGPU) {
|
TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
|
||||||
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetZerosLikeFn(
|
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
|
||||||
DEVICE_GPU, "YOU SHALL NOT PASS"),
|
ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"),
|
||||||
nullptr);
|
nullptr);
|
||||||
|
|
||||||
VariantValue vv_early_exit{true /* early_exit */, 0 /* zeros_like_set */};
|
VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
|
||||||
Variant v = vv_early_exit;
|
Variant v = vv_early_exit;
|
||||||
Variant v_out = VariantValue();
|
Variant v_out = VariantValue();
|
||||||
|
|
||||||
OpKernelContext* null_context_pointer = nullptr;
|
OpKernelContext* null_context_pointer = nullptr;
|
||||||
Status s0 =
|
Status s0 = UnaryOpVariant<GPUDevice>(null_context_pointer,
|
||||||
CreateZerosLikeVariant<GPUDevice>(null_context_pointer, v, &v_out);
|
ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out);
|
||||||
EXPECT_FALSE(s0.ok());
|
EXPECT_FALSE(s0.ok());
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
StringPiece(s0.error_message()).contains("early exit zeros_like"));
|
StringPiece(s0.error_message()).contains("early exit zeros_like"));
|
||||||
|
|
||||||
VariantValue vv_ok{false /* early_exit */, 0 /* zeros_like_set */};
|
VariantValue vv_ok{false /* early_exit */, 0 /* value */};
|
||||||
v = vv_ok;
|
v = vv_ok;
|
||||||
TF_EXPECT_OK(
|
TF_EXPECT_OK(UnaryOpVariant<GPUDevice>(
|
||||||
CreateZerosLikeVariant<GPUDevice>(null_context_pointer, v, &v_out));
|
null_context_pointer, ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out));
|
||||||
VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
|
VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
|
||||||
EXPECT_EQ(vv_out->zeros_like_set, 2); // GPU
|
EXPECT_EQ(vv_out->value, 2); // GPU
|
||||||
}
|
}
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
TEST(VariantOpZerosLikeRegistryTest, TestDuplicate) {
|
TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) {
|
||||||
UnaryVariantOpRegistry registry;
|
UnaryVariantOpRegistry registry;
|
||||||
UnaryVariantOpRegistry::VariantZerosLikeFn f;
|
UnaryVariantOpRegistry::VariantUnaryOpFn f;
|
||||||
|
string kTypeName = "fjfjfj";
|
||||||
|
|
||||||
registry.RegisterZerosLikeFn(DEVICE_CPU, "fjfjfj", f);
|
registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, kTypeName,
|
||||||
EXPECT_DEATH(registry.RegisterZerosLikeFn(DEVICE_CPU, "fjfjfj", f),
|
f);
|
||||||
|
EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
|
||||||
|
DEVICE_CPU, kTypeName, f),
|
||||||
"fjfjfj already registered");
|
"fjfjfj already registered");
|
||||||
|
|
||||||
registry.RegisterZerosLikeFn(DEVICE_GPU, "fjfjfj", f);
|
registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, kTypeName,
|
||||||
EXPECT_DEATH(registry.RegisterZerosLikeFn(DEVICE_GPU, "fjfjfj", f),
|
f);
|
||||||
|
EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
|
||||||
|
DEVICE_GPU, kTypeName, f),
|
||||||
|
"fjfjfj already registered");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(VariantOpAddRegistryTest, TestBasicCPU) {
|
||||||
|
return;
|
||||||
|
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
|
||||||
|
ADD_VARIANT_BINARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"),
|
||||||
|
nullptr);
|
||||||
|
|
||||||
|
VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
|
||||||
|
VariantValue vv_other{true /* early_exit */, 4 /* value */};
|
||||||
|
Variant v_a = vv_early_exit;
|
||||||
|
Variant v_b = vv_other;
|
||||||
|
Variant v_out = VariantValue();
|
||||||
|
|
||||||
|
OpKernelContext* null_context_pointer = nullptr;
|
||||||
|
Status s0 = BinaryOpVariants<CPUDevice>(
|
||||||
|
null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out);
|
||||||
|
EXPECT_FALSE(s0.ok());
|
||||||
|
EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit add"));
|
||||||
|
|
||||||
|
VariantValue vv_ok{false /* early_exit */, 3 /* value */};
|
||||||
|
v_a = vv_ok;
|
||||||
|
TF_EXPECT_OK(BinaryOpVariants<CPUDevice>(
|
||||||
|
null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out));
|
||||||
|
VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
|
||||||
|
EXPECT_EQ(vv_out->value, 7); // CPU
|
||||||
|
}
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
TEST(VariantOpAddRegistryTest, TestBasicGPU) {
|
||||||
|
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
|
||||||
|
ADD_VARIANT_BINARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"),
|
||||||
|
nullptr);
|
||||||
|
|
||||||
|
VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
|
||||||
|
VariantValue vv_other{true /* early_exit */, 4 /* value */};
|
||||||
|
Variant v_a = vv_early_exit;
|
||||||
|
Variant v_b = vv_other;
|
||||||
|
Variant v_out = VariantValue();
|
||||||
|
|
||||||
|
OpKernelContext* null_context_pointer = nullptr;
|
||||||
|
Status s0 = BinaryOpVariants<GPUDevice>(
|
||||||
|
null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out);
|
||||||
|
EXPECT_FALSE(s0.ok());
|
||||||
|
EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit add"));
|
||||||
|
|
||||||
|
VariantValue vv_ok{false /* early_exit */, 3 /* value */};
|
||||||
|
v_a = vv_ok;
|
||||||
|
TF_EXPECT_OK(BinaryOpVariants<GPUDevice>(
|
||||||
|
null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out));
|
||||||
|
VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
|
||||||
|
EXPECT_EQ(vv_out->value, -7); // GPU
|
||||||
|
}
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
TEST(VariantOpAddRegistryTest, TestDuplicate) {
|
||||||
|
UnaryVariantOpRegistry registry;
|
||||||
|
UnaryVariantOpRegistry::VariantBinaryOpFn f;
|
||||||
|
string kTypeName = "fjfjfj";
|
||||||
|
|
||||||
|
registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeName, f);
|
||||||
|
EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
|
||||||
|
kTypeName, f),
|
||||||
|
"fjfjfj already registered");
|
||||||
|
|
||||||
|
registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeName, f);
|
||||||
|
EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
|
||||||
|
kTypeName, f),
|
||||||
"fjfjfj already registered");
|
"fjfjfj already registered");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,8 +34,8 @@ TensorId ParseTensorName(StringPiece name) {
|
|||||||
// whole name string forms the first part of the tensor name.
|
// whole name string forms the first part of the tensor name.
|
||||||
const char* base = name.data();
|
const char* base = name.data();
|
||||||
const char* p = base + name.size() - 1;
|
const char* p = base + name.size() - 1;
|
||||||
int index = 0;
|
unsigned int index = 0;
|
||||||
int mul = 1;
|
unsigned int mul = 1;
|
||||||
while (p > base && (*p >= '0' && *p <= '9')) {
|
while (p > base && (*p >= '0' && *p <= '9')) {
|
||||||
index += ((*p - '0') * mul);
|
index += ((*p - '0') * mul);
|
||||||
mul *= 10;
|
mul *= 10;
|
||||||
|
@ -24,6 +24,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/numeric_op.h"
|
#include "tensorflow/core/framework/numeric_op.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/variant.h"
|
||||||
|
#include "tensorflow/core/framework/variant_encode_decode.h"
|
||||||
|
#include "tensorflow/core/framework/variant_op_registry.h"
|
||||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
|
||||||
@ -33,7 +36,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
|
|||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
typedef Eigen::SyclDevice SYCLDevice;
|
typedef Eigen::SyclDevice SYCLDevice;
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class AddNOp : public OpKernel {
|
class AddNOp : public OpKernel {
|
||||||
@ -150,6 +153,65 @@ class AddNOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename Device>
|
||||||
|
class AddNOp<Device, Variant> : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit AddNOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
if (!ctx->ValidateInputsAreSameShape(this)) return;
|
||||||
|
|
||||||
|
const Tensor& input0 = ctx->input(0);
|
||||||
|
const int num = ctx->num_inputs();
|
||||||
|
|
||||||
|
if (num == 1) {
|
||||||
|
ctx->set_output(0, input0);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < num; ++i) {
|
||||||
|
// Step 1: ensure unary variants.
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, ctx->input(i).dims() == 0,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"AddN of non-scalar Tensor with dtype=DT_VARIANT is not "
|
||||||
|
"supported; inputs[",
|
||||||
|
i, " has shape: ", ctx->input(i).shape().DebugString(), "."));
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorShape common_shape;
|
||||||
|
OP_REQUIRES_OK(ctx, GetUnaryVariantShape(ctx->input(0), &common_shape));
|
||||||
|
// Step 2: access all variants and ensure shapes match.
|
||||||
|
for (int i = 1; i < num; ++i) {
|
||||||
|
TensorShape check_shape;
|
||||||
|
OP_REQUIRES_OK(ctx, GetUnaryVariantShape(ctx->input(i), &check_shape));
|
||||||
|
OP_REQUIRES(ctx, common_shape == check_shape,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"AddN of Variants of differing shapes; inputs[0] shape: ",
|
||||||
|
common_shape.DebugString(), ", inputs[", i,
|
||||||
|
"] shape: ", check_shape.DebugString()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: attempt to add using
|
||||||
|
// BinaryOpVariants(ADD_VARIANT_BINARY_OP, ...)
|
||||||
|
// For the output create a default-constructed variant object.
|
||||||
|
// TODO(ebrevdo): Perform summation in a tree-structure.
|
||||||
|
Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
|
||||||
|
Variant* v_out = &(out.scalar<Variant>()());
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, BinaryOpVariants<Device>(
|
||||||
|
ctx, ADD_VARIANT_BINARY_OP, ctx->input(0).scalar<Variant>()(),
|
||||||
|
ctx->input(1).scalar<Variant>()(), v_out));
|
||||||
|
for (int i = 2; i < num; ++i) {
|
||||||
|
const Variant tmp = std::move(*v_out);
|
||||||
|
const Variant& inp = ctx->input(i).scalar<Variant>()();
|
||||||
|
OP_REQUIRES_OK(ctx, BinaryOpVariants<Device>(ctx, ADD_VARIANT_BINARY_OP,
|
||||||
|
inp, tmp, v_out));
|
||||||
|
}
|
||||||
|
ctx->set_output(0, out);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
#define REGISTER_ADDN(type, dev) \
|
#define REGISTER_ADDN(type, dev) \
|
||||||
REGISTER_KERNEL_BUILDER( \
|
REGISTER_KERNEL_BUILDER( \
|
||||||
Name("AddN").Device(DEVICE_##dev).TypeConstraint<type>("T"), \
|
Name("AddN").Device(DEVICE_##dev).TypeConstraint<type>("T"), \
|
||||||
@ -158,6 +220,8 @@ class AddNOp : public OpKernel {
|
|||||||
#define REGISTER_ADDN_CPU(type) REGISTER_ADDN(type, CPU)
|
#define REGISTER_ADDN_CPU(type) REGISTER_ADDN(type, CPU)
|
||||||
|
|
||||||
TF_CALL_NUMBER_TYPES(REGISTER_ADDN_CPU);
|
TF_CALL_NUMBER_TYPES(REGISTER_ADDN_CPU);
|
||||||
|
REGISTER_ADDN_CPU(Variant);
|
||||||
|
|
||||||
#undef REGISTER_ADDN_CPU
|
#undef REGISTER_ADDN_CPU
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
@ -176,6 +240,16 @@ REGISTER_KERNEL_BUILDER(Name("AddN")
|
|||||||
.HostMemory("inputs")
|
.HostMemory("inputs")
|
||||||
.HostMemory("sum"),
|
.HostMemory("sum"),
|
||||||
AddNOp<CPUDevice, int32>);
|
AddNOp<CPUDevice, int32>);
|
||||||
|
|
||||||
|
// TODO(ebrevdo): Once rendezvous has been properly set up for
|
||||||
|
// Variants, we'll no longer need a HostMemory attribute for this case.
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("AddN")
|
||||||
|
.Device(DEVICE_GPU)
|
||||||
|
.TypeConstraint<Variant>("T")
|
||||||
|
.HostMemory("inputs")
|
||||||
|
.HostMemory("sum"),
|
||||||
|
AddNOp<GPUDevice, Variant>);
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
#ifdef TENSORFLOW_USE_SYCL
|
#ifdef TENSORFLOW_USE_SYCL
|
||||||
@ -191,7 +265,7 @@ REGISTER_KERNEL_BUILDER(Name("AddN")
|
|||||||
.HostMemory("inputs")
|
.HostMemory("inputs")
|
||||||
.HostMemory("sum"),
|
.HostMemory("sum"),
|
||||||
AddNOp<CPUDevice, int32>);
|
AddNOp<CPUDevice, int32>);
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
|
|
||||||
#undef REGISTER_ADDN
|
#undef REGISTER_ADDN
|
||||||
|
|
||||||
|
@ -279,13 +279,15 @@ class ZerosLikeOp : public OpKernel {
|
|||||||
const Tensor& input = ctx->input(0);
|
const Tensor& input = ctx->input(0);
|
||||||
const Device& d = ctx->eigen_device<Device>();
|
const Device& d = ctx->eigen_device<Device>();
|
||||||
if (std::is_same<T, Variant>::value) {
|
if (std::is_same<T, Variant>::value) {
|
||||||
OP_REQUIRES(ctx, input.dims() == 0,
|
OP_REQUIRES(
|
||||||
errors::InvalidArgument(
|
ctx, input.dims() == 0,
|
||||||
"ZerosLike of non-unary Variant not supported."));
|
errors::InvalidArgument("ZerosLike non-scalar Tensor with "
|
||||||
|
"dtype=DT_VARIANT is not supported."));
|
||||||
const Variant& v = input.scalar<Variant>()();
|
const Variant& v = input.scalar<Variant>()();
|
||||||
Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
|
Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
|
||||||
Variant* out_v = &(out.scalar<Variant>()());
|
Variant* out_v = &(out.scalar<Variant>()());
|
||||||
OP_REQUIRES_OK(ctx, CreateZerosLikeVariant<Device>(ctx, v, out_v));
|
OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(
|
||||||
|
ctx, ZEROS_LIKE_VARIANT_UNARY_OP, v, out_v));
|
||||||
ctx->set_output(0, out);
|
ctx->set_output(0, out);
|
||||||
} else {
|
} else {
|
||||||
Tensor* out = nullptr;
|
Tensor* out = nullptr;
|
||||||
|
@ -292,7 +292,8 @@ class RemoteCallOp : public AsyncOpKernel {
|
|||||||
OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
|
OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
|
||||||
AttrValueMap attr_values = func_->attr();
|
AttrValueMap attr_values = func_->attr();
|
||||||
AttrValue v;
|
AttrValue v;
|
||||||
v.set_s(target->scalar<string>()());
|
const string& target_device = target->scalar<string>()();
|
||||||
|
v.set_s(target_device);
|
||||||
AddAttr("_target", v, &attr_values);
|
AddAttr("_target", v, &attr_values);
|
||||||
|
|
||||||
FunctionLibraryRuntime* lib = ctx->function_library();
|
FunctionLibraryRuntime* lib = ctx->function_library();
|
||||||
@ -310,6 +311,11 @@ class RemoteCallOp : public AsyncOpKernel {
|
|||||||
FunctionLibraryRuntime::Options opts;
|
FunctionLibraryRuntime::Options opts;
|
||||||
opts.step_id = ctx->step_id();
|
opts.step_id = ctx->step_id();
|
||||||
opts.runner = ctx->runner();
|
opts.runner = ctx->runner();
|
||||||
|
opts.source_device = lib->device()->name();
|
||||||
|
if (opts.source_device != target_device) {
|
||||||
|
opts.remote_execution = true;
|
||||||
|
}
|
||||||
|
opts.rendezvous = ctx->rendezvous();
|
||||||
std::vector<Tensor> args;
|
std::vector<Tensor> args;
|
||||||
args.reserve(arguments.size());
|
args.reserve(arguments.size());
|
||||||
for (const Tensor& argument : arguments) {
|
for (const Tensor& argument : arguments) {
|
||||||
@ -334,10 +340,13 @@ class RemoteCallOp : public AsyncOpKernel {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_CPU), RemoteCallOp);
|
REGISTER_KERNEL_BUILDER(
|
||||||
REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_GPU), RemoteCallOp);
|
Name("RemoteCall").Device(DEVICE_CPU).HostMemory("target"), RemoteCallOp);
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("RemoteCall").Device(DEVICE_GPU).HostMemory("target"), RemoteCallOp);
|
||||||
#if TENSORFLOW_USE_SYCL
|
#if TENSORFLOW_USE_SYCL
|
||||||
REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_SYCL), RemoteCallOp);
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("RemoteCall").Device(DEVICE_SYCL).HostMemory("target"), RemoteCallOp);
|
||||||
|
|
||||||
#endif // TENSORFLOW_USE_SYCL
|
#endif // TENSORFLOW_USE_SYCL
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -920,6 +920,13 @@ class MaxPoolingGradWithArgmaxOp : public OpKernel {
|
|||||||
public:
|
public:
|
||||||
explicit MaxPoolingGradWithArgmaxOp(OpKernelConstruction* context)
|
explicit MaxPoolingGradWithArgmaxOp(OpKernelConstruction* context)
|
||||||
: OpKernel(context) {
|
: OpKernel(context) {
|
||||||
|
string data_format_str;
|
||||||
|
auto status = context->GetAttr("data_format", &data_format_str);
|
||||||
|
if (status.ok()) {
|
||||||
|
OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
|
||||||
|
errors::InvalidArgument("Invalid data format"));
|
||||||
|
}
|
||||||
|
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
|
OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
|
||||||
OP_REQUIRES(context, ksize_.size() == 4,
|
OP_REQUIRES(context, ksize_.size() == 4,
|
||||||
errors::InvalidArgument("Sliding window ksize field must "
|
errors::InvalidArgument("Sliding window ksize field must "
|
||||||
@ -959,6 +966,7 @@ class MaxPoolingGradWithArgmaxOp : public OpKernel {
|
|||||||
std::vector<int32> ksize_;
|
std::vector<int32> ksize_;
|
||||||
std::vector<int32> stride_;
|
std::vector<int32> stride_;
|
||||||
Padding padding_;
|
Padding padding_;
|
||||||
|
TensorFormat data_format_;
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
@ -1051,17 +1059,36 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
|
|||||||
TensorShape out_shape =
|
TensorShape out_shape =
|
||||||
ShapeFromFormat(data_format_, params.tensor_in_batch, params.out_height,
|
ShapeFromFormat(data_format_, params.tensor_in_batch, params.out_height,
|
||||||
params.out_width, params.depth);
|
params.out_width, params.depth);
|
||||||
if (use_dnn_ && data_format_ == FORMAT_NCHW) {
|
|
||||||
|
// Assuming qint8 <--> NCHW_VECT_C (int8x4) here.
|
||||||
|
constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
|
||||||
|
OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"qint8 should be used with data_format NCHW_VECT_C."));
|
||||||
|
|
||||||
|
// These is_int8x4 checks avoid linker errors for missing qint8 kernels.
|
||||||
|
if (!is_int8x4 && use_dnn_ && data_format_ == FORMAT_NCHW) {
|
||||||
DnnPoolingOp<T>::Compute(
|
DnnPoolingOp<T>::Compute(
|
||||||
context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize_,
|
context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize_,
|
||||||
stride_, padding_, data_format_, tensor_in, out_shape);
|
stride_, padding_, data_format_, tensor_in, out_shape);
|
||||||
} else {
|
} else {
|
||||||
CHECK(data_format_ == FORMAT_NHWC)
|
|
||||||
<< "Non-Cudnn MaxPool only supports NHWC format";
|
|
||||||
Tensor* output = nullptr;
|
Tensor* output = nullptr;
|
||||||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
|
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
|
||||||
LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in,
|
if (is_int8x4) {
|
||||||
output);
|
LaunchMaxPoolingNoMask_NCHW_VECT_C<Device>::launch(context, params,
|
||||||
|
tensor_in, output);
|
||||||
|
} else if (data_format_ == FORMAT_NHWC) {
|
||||||
|
LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in,
|
||||||
|
output);
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << "MaxPool currently only supports the following (layout, "
|
||||||
|
"type) combinations: (NHWC, non-qint8), "
|
||||||
|
"(NCHW, non-qint8) or (NCHW_VECT_C, qint8). The "
|
||||||
|
"requested combination ("
|
||||||
|
<< ToString(data_format_) << ", "
|
||||||
|
<< DataTypeString(DataTypeToEnum<T>::v())
|
||||||
|
<< ") is not supported.";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1346,6 +1373,26 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_MAX_POOL_KERNELS);
|
|||||||
.TypeConstraint<int64>("Targmax"), \
|
.TypeConstraint<int64>("Targmax"), \
|
||||||
MaxPoolingGradGradWithArgmaxOp<GPUDevice, T>);
|
MaxPoolingGradGradWithArgmaxOp<GPUDevice, T>);
|
||||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_ONLY_POOL_KERNELS);
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_ONLY_POOL_KERNELS);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(
|
||||||
|
Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<qint8>("T"),
|
||||||
|
MaxPoolingNoMaskOp<GPUDevice, qint8>);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("MaxPoolV2")
|
||||||
|
.Device(DEVICE_GPU)
|
||||||
|
.HostMemory("ksize")
|
||||||
|
.HostMemory("strides")
|
||||||
|
.TypeConstraint<qint8>("T"),
|
||||||
|
MaxPoolingV2Op<GPUDevice, qint8>);
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("MaxPoolV2")
|
||||||
|
.Device(DEVICE_GPU)
|
||||||
|
.HostMemory("ksize")
|
||||||
|
.HostMemory("strides")
|
||||||
|
.TypeConstraint<qint8>("T")
|
||||||
|
.Label("eigen_tensor"),
|
||||||
|
MaxPoolingV2Op<GPUDevice, qint8>);
|
||||||
|
|
||||||
#undef REGISTER_GPU_ONLY_POOL_KERNELS
|
#undef REGISTER_GPU_ONLY_POOL_KERNELS
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
@ -17,7 +17,9 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_KERNELS_MAXPOOLING_OP_H_
|
#define TENSORFLOW_KERNELS_MAXPOOLING_OP_H_
|
||||||
// Functor definition for MaxPoolingOp, must be compilable by nvcc.
|
// Functor definition for MaxPoolingOp, must be compilable by nvcc.
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/numeric_types.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/type_traits.h"
|
||||||
#include "tensorflow/core/kernels/eigen_pooling.h"
|
#include "tensorflow/core/kernels/eigen_pooling.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
@ -37,6 +39,14 @@ struct SpatialMaxPooling {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename Device>
|
||||||
|
struct SpatialMaxPooling<Device, qint8> {
|
||||||
|
void operator()(const Device& d, typename TTypes<qint8, 4>::Tensor output,
|
||||||
|
typename TTypes<qint8, 4>::ConstTensor input, int window_rows,
|
||||||
|
int window_cols, int row_stride, int col_stride,
|
||||||
|
const Eigen::PaddingType& padding) {}
|
||||||
|
};
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/framework/type_traits.h"
|
||||||
#include "tensorflow/core/kernels/maxpooling_op.h"
|
#include "tensorflow/core/kernels/maxpooling_op.h"
|
||||||
#include "tensorflow/core/kernels/maxpooling_op_gpu.h"
|
#include "tensorflow/core/kernels/maxpooling_op_gpu.h"
|
||||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||||
@ -89,6 +90,42 @@ __global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The parameters for MaxPoolForwardNoMaskKernel_NCHW_VECT_C are the same as for
|
||||||
|
// MaxPoolForwardNCHW above, except that mask is not supported, and each
|
||||||
|
// element of the input and output contains 4 adjacent channel values for
|
||||||
|
// the same X, y coordinate.
|
||||||
|
// (so channels = outer_channels, output_size = real output size / 4).
|
||||||
|
__global__ void MaxPoolForwardNoMaskKernel_NCHW_VECT_C(
|
||||||
|
const int nthreads, const int32* bottom_data, const int height,
|
||||||
|
const int width, const int channels, const int pooled_height,
|
||||||
|
const int pooled_width, const int kernel_h, const int kernel_w,
|
||||||
|
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
|
||||||
|
int32* top_data) {
|
||||||
|
// TODO(pauldonnelly): Implement a better optimized version of this kernel.
|
||||||
|
const int32 kMinINT8X4 = 0x80808080;
|
||||||
|
CUDA_1D_KERNEL_LOOP(index, nthreads) {
|
||||||
|
int pw = index % pooled_width;
|
||||||
|
int ph = (index / pooled_width) % pooled_height;
|
||||||
|
int c = (index / pooled_width / pooled_height) % channels;
|
||||||
|
int n = index / pooled_width / pooled_height / channels;
|
||||||
|
int hstart = ph * stride_h - pad_t;
|
||||||
|
int wstart = pw * stride_w - pad_l;
|
||||||
|
int hend = min(hstart + kernel_h, height);
|
||||||
|
int wend = min(wstart + kernel_w, width);
|
||||||
|
hstart = max(hstart, 0);
|
||||||
|
wstart = max(wstart, 0);
|
||||||
|
int32 maxval = kMinINT8X4;
|
||||||
|
const int32* bottom_data_n = bottom_data + n * channels * height * width;
|
||||||
|
for (int h = hstart; h < hend; ++h) {
|
||||||
|
for (int w = wstart; w < wend; ++w) {
|
||||||
|
int idx = (c * height + h) * width + w;
|
||||||
|
maxval = __vmaxs4(maxval, bottom_data_n[idx]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
top_data[index] = maxval;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename dtype>
|
template <typename dtype>
|
||||||
__global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data,
|
__global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data,
|
||||||
const int height, const int width,
|
const int height, const int width,
|
||||||
@ -328,6 +365,25 @@ __global__ void MaxPoolGradBackward(const int nthreads, const dtype* top_diff,
|
|||||||
|
|
||||||
namespace functor {
|
namespace functor {
|
||||||
|
|
||||||
|
// Note: channels is the outer channels (dim 1) which has already been
|
||||||
|
// divided by 4.
|
||||||
|
bool MaxPoolForwardNoMask_NCHW_VECT_C::operator()(
|
||||||
|
const int32* bottom_data, const int batch, const int height,
|
||||||
|
const int width, int channels, const int pooled_height,
|
||||||
|
const int pooled_width, const int kernel_h, const int kernel_w,
|
||||||
|
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
|
||||||
|
int32* top_data, const Eigen::GpuDevice& d) {
|
||||||
|
const int kThreadsPerBlock = 1024;
|
||||||
|
const int output_size = batch * channels * pooled_height * pooled_width;
|
||||||
|
MaxPoolForwardNoMaskKernel_NCHW_VECT_C<<<
|
||||||
|
(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock,
|
||||||
|
0, d.stream()>>>(output_size, bottom_data, height, width, channels,
|
||||||
|
pooled_height, pooled_width, kernel_h, kernel_w,
|
||||||
|
stride_h, stride_w, pad_t, pad_l, top_data);
|
||||||
|
d.synchronize();
|
||||||
|
return d.ok();
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool MaxPoolForwardWithOptionalArgmax<T>::operator()(
|
bool MaxPoolForwardWithOptionalArgmax<T>::operator()(
|
||||||
const T* bottom_data, const int batch, const int height, const int width,
|
const T* bottom_data, const int batch, const int height, const int width,
|
||||||
|
@ -42,6 +42,15 @@ struct MaxPoolForwardWithOptionalArgmax {
|
|||||||
const Eigen::GpuDevice& d);
|
const Eigen::GpuDevice& d);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct MaxPoolForwardNoMask_NCHW_VECT_C {
|
||||||
|
bool operator()(const int32* bottom_data, const int batch, const int height,
|
||||||
|
const int width, int channels, const int pooled_height,
|
||||||
|
const int pooled_width, const int kernel_h,
|
||||||
|
const int kernel_w, const int stride_h, const int stride_w,
|
||||||
|
const int pad_t, const int pad_l, int32* top_data,
|
||||||
|
const Eigen::GpuDevice& d);
|
||||||
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct MaxPoolBackwardWithArgmax {
|
struct MaxPoolBackwardWithArgmax {
|
||||||
bool operator()(const int output_size, const int input_size,
|
bool operator()(const int output_size, const int input_size,
|
||||||
|
@ -22,7 +22,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
#include "tensorflow/core/kernels/conv_2d.h"
|
#include "tensorflow/core/kernels/conv_2d.h"
|
||||||
#include "tensorflow/core/kernels/maxpooling_op_gpu.h"
|
|
||||||
#include "tensorflow/core/kernels/pooling_ops_common_gpu.h"
|
#include "tensorflow/core/kernels/pooling_ops_common_gpu.h"
|
||||||
#include "tensorflow/core/platform/stream_executor.h"
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
@ -34,12 +33,18 @@ PoolParameters::PoolParameters(OpKernelContext* context,
|
|||||||
const std::vector<int32>& stride,
|
const std::vector<int32>& stride,
|
||||||
Padding padding, TensorFormat data_format,
|
Padding padding, TensorFormat data_format,
|
||||||
const TensorShape& tensor_in_shape) {
|
const TensorShape& tensor_in_shape) {
|
||||||
// For maxpooling, tensor_in should have 4 dimensions.
|
// For maxpooling, tensor_in should have 2 spatial dimensions.
|
||||||
OP_REQUIRES(context, tensor_in_shape.dims() == 4,
|
// Note: the total number of dimensions could be 4 for NHWC, NCHW,
|
||||||
errors::InvalidArgument("tensor_in must be 4-dimensional"));
|
// or 5 for NCHW_VECT_C.
|
||||||
|
OP_REQUIRES(context,
|
||||||
|
GetTensorSpatialDims(tensor_in_shape.dims(), data_format) == 2,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"tensor_in_shape must have 2 spatial dimensions. ",
|
||||||
|
tensor_in_shape.dims(), " ", data_format));
|
||||||
|
|
||||||
this->data_format = data_format;
|
this->data_format = data_format;
|
||||||
depth = GetTensorDim(tensor_in_shape, data_format, 'C');
|
depth = GetTensorDim(tensor_in_shape, data_format, 'C') *
|
||||||
|
(data_format == FORMAT_NCHW_VECT_C ? 4 : 1);
|
||||||
tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, 'W');
|
tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, 'W');
|
||||||
tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, 'H');
|
tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, 'H');
|
||||||
tensor_in_batch = GetTensorDim(tensor_in_shape, data_format, 'N');
|
tensor_in_batch = GetTensorDim(tensor_in_shape, data_format, 'N');
|
||||||
|
@ -29,6 +29,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
#include "tensorflow/core/util/work_sharder.h"
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
#include "tensorflow/core/kernels/maxpooling_op_gpu.h"
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
@ -256,6 +260,30 @@ class MaxPoolingOp : public OpKernel {
|
|||||||
TensorFormat data_format_;
|
TensorFormat data_format_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename Device>
|
||||||
|
struct LaunchMaxPoolingNoMask_NCHW_VECT_C;
|
||||||
|
|
||||||
|
#ifdef GOOGLE_CUDA
|
||||||
|
template <>
|
||||||
|
struct LaunchMaxPoolingNoMask_NCHW_VECT_C<Eigen::GpuDevice> {
|
||||||
|
static void launch(OpKernelContext* context, const PoolParameters& params,
|
||||||
|
const Tensor& input, Tensor* output) {
|
||||||
|
bool status = functor::MaxPoolForwardNoMask_NCHW_VECT_C()(
|
||||||
|
reinterpret_cast<const int32*>(input.flat<qint8>().data()),
|
||||||
|
params.tensor_in_batch, params.tensor_in_rows, params.tensor_in_cols,
|
||||||
|
params.depth, params.out_height, params.out_width, params.window_rows,
|
||||||
|
params.window_cols, params.row_stride, params.col_stride,
|
||||||
|
params.pad_rows, params.pad_cols,
|
||||||
|
reinterpret_cast<int32*>(output->flat<qint8>().data()),
|
||||||
|
context->eigen_gpu_device());
|
||||||
|
if (!status) {
|
||||||
|
context->SetStatus(errors::Internal(
|
||||||
|
"Failed launching LaunchMaxPoolingNoMask_NCHW_VECT_C"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class MaxPoolingV2Op : public OpKernel {
|
class MaxPoolingV2Op : public OpKernel {
|
||||||
public:
|
public:
|
||||||
@ -266,8 +294,11 @@ class MaxPoolingV2Op : public OpKernel {
|
|||||||
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
|
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
|
||||||
errors::InvalidArgument("Invalid data format"));
|
errors::InvalidArgument("Invalid data format"));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, data_format_ == FORMAT_NHWC,
|
context,
|
||||||
errors::InvalidArgument("Default MaxPoolingOp only supports NHWC."));
|
data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW_VECT_C,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"MaxPoolingV2Op only supports NHWC or NCHW_VECT_C. Got: ",
|
||||||
|
data_format));
|
||||||
} else {
|
} else {
|
||||||
data_format_ = FORMAT_NHWC;
|
data_format_ = FORMAT_NHWC;
|
||||||
}
|
}
|
||||||
@ -315,8 +346,8 @@ class MaxPoolingV2Op : public OpKernel {
|
|||||||
errors::Unimplemented(
|
errors::Unimplemented(
|
||||||
"Pooling is not yet supported on the batch dimension."));
|
"Pooling is not yet supported on the batch dimension."));
|
||||||
|
|
||||||
PoolParameters params{context, ksize, stride,
|
PoolParameters params{context, ksize, stride,
|
||||||
padding_, FORMAT_NHWC, tensor_in.shape()};
|
padding_, data_format_, tensor_in.shape()};
|
||||||
if (!context->status().ok()) {
|
if (!context->status().ok()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -368,13 +399,21 @@ class MaxPoolingV2Op : public OpKernel {
|
|||||||
// Spatial MaxPooling implementation.
|
// Spatial MaxPooling implementation.
|
||||||
//
|
//
|
||||||
// TODO(vrv): Remove this once we no longer need it.
|
// TODO(vrv): Remove this once we no longer need it.
|
||||||
|
#ifdef GOOGLE_CUDA
|
||||||
if (std::is_same<Device, GPUDevice>::value) {
|
if (std::is_same<Device, GPUDevice>::value) {
|
||||||
Eigen::PaddingType pt = BrainPadding2EigenPadding(padding);
|
Eigen::PaddingType pt = BrainPadding2EigenPadding(padding);
|
||||||
functor::SpatialMaxPooling<Device, T>()(
|
if (std::is_same<T, qint8>::value) {
|
||||||
context->eigen_device<Device>(), output->tensor<T, 4>(),
|
LaunchMaxPoolingNoMask_NCHW_VECT_C<GPUDevice>::launch(
|
||||||
tensor_in.tensor<T, 4>(), params.window_rows, params.window_cols,
|
context, params, tensor_in, output);
|
||||||
params.row_stride, params.col_stride, pt);
|
} else {
|
||||||
} else {
|
functor::SpatialMaxPooling<Device, T>()(
|
||||||
|
context->eigen_device<Device>(), output->tensor<T, 4>(),
|
||||||
|
tensor_in.tensor<T, 4>(), params.window_rows, params.window_cols,
|
||||||
|
params.row_stride, params.col_stride, pt);
|
||||||
|
}
|
||||||
|
} else
|
||||||
|
#endif
|
||||||
|
{
|
||||||
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||||
ConstEigenMatrixMap;
|
ConstEigenMatrixMap;
|
||||||
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
|
||||||
|
@ -82,8 +82,6 @@ Status SqliteQueryConnection::GetNext(std::vector<Tensor>* out_tensors,
|
|||||||
int rc = sqlite3_step(stmt_);
|
int rc = sqlite3_step(stmt_);
|
||||||
if (rc == SQLITE_ROW) {
|
if (rc == SQLITE_ROW) {
|
||||||
for (int i = 0; i < column_count_; i++) {
|
for (int i = 0; i < column_count_; i++) {
|
||||||
// TODO(b/64276939) Support other tensorflow types. Interpret columns as
|
|
||||||
// the types that the client specifies.
|
|
||||||
DataType dt = output_types_[i];
|
DataType dt = output_types_[i];
|
||||||
Tensor tensor(cpu_allocator(), dt, {});
|
Tensor tensor(cpu_allocator(), dt, {});
|
||||||
FillTensorWithResultSetEntry(dt, i, &tensor);
|
FillTensorWithResultSetEntry(dt, i, &tensor);
|
||||||
@ -125,11 +123,46 @@ void SqliteQueryConnection::FillTensorWithResultSetEntry(
|
|||||||
tensor->scalar<string>()() = value;
|
tensor->scalar<string>()() = value;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case DT_INT8: {
|
||||||
|
int8 value = sqlite3_column_int(stmt_, column_index);
|
||||||
|
tensor->scalar<int8>()() = value;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DT_INT16: {
|
||||||
|
int16 value = sqlite3_column_int(stmt_, column_index);
|
||||||
|
tensor->scalar<int16>()() = value;
|
||||||
|
break;
|
||||||
|
}
|
||||||
case DT_INT32: {
|
case DT_INT32: {
|
||||||
int32 value = sqlite3_column_int(stmt_, column_index);
|
int32 value = sqlite3_column_int(stmt_, column_index);
|
||||||
tensor->scalar<int32>()() = value;
|
tensor->scalar<int32>()() = value;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case DT_INT64: {
|
||||||
|
int64 value = sqlite3_column_int64(stmt_, column_index);
|
||||||
|
tensor->scalar<int64>()() = value;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DT_UINT8: {
|
||||||
|
uint8 value = sqlite3_column_int(stmt_, column_index);
|
||||||
|
tensor->scalar<uint8>()() = value;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DT_UINT16: {
|
||||||
|
uint16 value = sqlite3_column_int(stmt_, column_index);
|
||||||
|
tensor->scalar<uint16>()() = value;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DT_BOOL: {
|
||||||
|
int value = sqlite3_column_int(stmt_, column_index);
|
||||||
|
tensor->scalar<bool>()() = value ? true : false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case DT_DOUBLE: {
|
||||||
|
double value = sqlite3_column_double(stmt_, column_index);
|
||||||
|
tensor->scalar<double>()() = value;
|
||||||
|
break;
|
||||||
|
}
|
||||||
// Error preemptively thrown by SqlDatasetOp::MakeDataset in this case.
|
// Error preemptively thrown by SqlDatasetOp::MakeDataset in this case.
|
||||||
default: {
|
default: {
|
||||||
LOG(FATAL)
|
LOG(FATAL)
|
||||||
|
@ -34,13 +34,15 @@ class SqlDatasetOp : public DatasetOpKernel {
|
|||||||
explicit SqlDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
|
explicit SqlDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
|
||||||
// TODO(b/64276939) Remove this check when we add support for other
|
|
||||||
// tensorflow types.
|
|
||||||
for (const DataType& dt : output_types_) {
|
for (const DataType& dt : output_types_) {
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(ctx,
|
||||||
ctx, dt == DT_STRING || dt == DT_INT32,
|
dt == DT_STRING || dt == DT_INT8 || dt == DT_INT16 ||
|
||||||
errors::InvalidArgument(
|
dt == DT_INT32 || dt == DT_INT64 || dt == DT_UINT8 ||
|
||||||
"Each element of `output_types_` must be DT_STRING or DT_INT32"));
|
dt == DT_UINT16 || dt == DT_BOOL || dt == DT_DOUBLE,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Each element of `output_types_` must be one of: "
|
||||||
|
"DT_STRING, DT_INT8, DT_INT16, DT_INT32, DT_INT64, "
|
||||||
|
"DT_UINT8, DT_UINT16, DT_BOOL, DT_DOUBLE "));
|
||||||
}
|
}
|
||||||
for (const PartialTensorShape& pts : output_shapes_) {
|
for (const PartialTensorShape& pts : output_shapes_) {
|
||||||
OP_REQUIRES(ctx, pts.dims() == 0,
|
OP_REQUIRES(ctx, pts.dims() == 0,
|
||||||
|
@ -303,6 +303,49 @@ op {
|
|||||||
is_aggregate: true
|
is_aggregate: true
|
||||||
is_commutative: true
|
is_commutative: true
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "AddN"
|
||||||
|
input_arg {
|
||||||
|
name: "inputs"
|
||||||
|
type_attr: "T"
|
||||||
|
number_attr: "N"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "sum"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "N"
|
||||||
|
type: "int"
|
||||||
|
has_minimum: true
|
||||||
|
minimum: 1
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
type: DT_INT64
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_UINT8
|
||||||
|
type: DT_UINT16
|
||||||
|
type: DT_INT16
|
||||||
|
type: DT_INT8
|
||||||
|
type: DT_COMPLEX64
|
||||||
|
type: DT_COMPLEX128
|
||||||
|
type: DT_QINT8
|
||||||
|
type: DT_QUINT8
|
||||||
|
type: DT_QINT32
|
||||||
|
type: DT_HALF
|
||||||
|
type: DT_VARIANT
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
is_aggregate: true
|
||||||
|
is_commutative: true
|
||||||
|
}
|
||||||
op {
|
op {
|
||||||
name: "AddSparseToTensorsMap"
|
name: "AddSparseToTensorsMap"
|
||||||
input_arg {
|
input_arg {
|
||||||
@ -13473,6 +13516,74 @@ op {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "MaxPool"
|
||||||
|
input_arg {
|
||||||
|
name: "input"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
default_value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
type: DT_UINT8
|
||||||
|
type: DT_INT16
|
||||||
|
type: DT_INT8
|
||||||
|
type: DT_UINT16
|
||||||
|
type: DT_HALF
|
||||||
|
type: DT_QINT8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "ksize"
|
||||||
|
type: "list(int)"
|
||||||
|
has_minimum: true
|
||||||
|
minimum: 4
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "strides"
|
||||||
|
type: "list(int)"
|
||||||
|
has_minimum: true
|
||||||
|
minimum: 4
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "padding"
|
||||||
|
type: "string"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
s: "SAME"
|
||||||
|
s: "VALID"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "data_format"
|
||||||
|
type: "string"
|
||||||
|
default_value {
|
||||||
|
s: "NHWC"
|
||||||
|
}
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
s: "NHWC"
|
||||||
|
s: "NCHW"
|
||||||
|
s: "NCHW_VECT_C"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
op {
|
op {
|
||||||
name: "MaxPool3D"
|
name: "MaxPool3D"
|
||||||
input_arg {
|
input_arg {
|
||||||
@ -14435,6 +14546,70 @@ op {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "MaxPoolV2"
|
||||||
|
input_arg {
|
||||||
|
name: "input"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "ksize"
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "strides"
|
||||||
|
type: DT_INT32
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
default_value {
|
||||||
|
type: DT_FLOAT
|
||||||
|
}
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
type: DT_UINT8
|
||||||
|
type: DT_INT16
|
||||||
|
type: DT_INT8
|
||||||
|
type: DT_UINT16
|
||||||
|
type: DT_HALF
|
||||||
|
type: DT_QINT8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "padding"
|
||||||
|
type: "string"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
s: "SAME"
|
||||||
|
s: "VALID"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "data_format"
|
||||||
|
type: "string"
|
||||||
|
default_value {
|
||||||
|
s: "NHWC"
|
||||||
|
}
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
s: "NHWC"
|
||||||
|
s: "NCHW"
|
||||||
|
s: "NCHW_VECT_C"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
op {
|
op {
|
||||||
name: "MaxPoolWithArgmax"
|
name: "MaxPoolWithArgmax"
|
||||||
input_arg {
|
input_arg {
|
||||||
|
@ -28,7 +28,7 @@ REGISTER_OP("AddN")
|
|||||||
.Input("inputs: N * T")
|
.Input("inputs: N * T")
|
||||||
.Output("sum: T")
|
.Output("sum: T")
|
||||||
.Attr("N: int >= 1")
|
.Attr("N: int >= 1")
|
||||||
.Attr("T: numbertype")
|
.Attr("T: {numbertype, variant}")
|
||||||
.SetIsCommutative()
|
.SetIsCommutative()
|
||||||
.SetIsAggregate()
|
.SetIsAggregate()
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
@ -1344,11 +1344,13 @@ output: The gradients for LRN.
|
|||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
REGISTER_OP("MaxPool")
|
REGISTER_OP("MaxPool")
|
||||||
.Attr("T: realnumbertype = DT_FLOAT")
|
.Attr(
|
||||||
|
"T: {float, double, int32, int64, uint8, int16, int8, uint16, "
|
||||||
|
"half, qint8} = DT_FLOAT")
|
||||||
.Attr("ksize: list(int) >= 4")
|
.Attr("ksize: list(int) >= 4")
|
||||||
.Attr("strides: list(int) >= 4")
|
.Attr("strides: list(int) >= 4")
|
||||||
.Attr(GetPaddingAttrString())
|
.Attr(GetPaddingAttrString())
|
||||||
.Attr(GetConvnetDataFormatAttrString())
|
.Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
|
||||||
.Input("input: T")
|
.Input("input: T")
|
||||||
.Output("output: T")
|
.Output("output: T")
|
||||||
.SetShapeFn(shape_inference::MaxPoolShape)
|
.SetShapeFn(shape_inference::MaxPoolShape)
|
||||||
@ -1369,9 +1371,11 @@ output: The max pooled output tensor.
|
|||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
REGISTER_OP("MaxPoolV2")
|
REGISTER_OP("MaxPoolV2")
|
||||||
.Attr("T: realnumbertype = DT_FLOAT")
|
.Attr(
|
||||||
|
"T: {float, double, int32, int64, uint8, int16, int8, uint16, "
|
||||||
|
"half, qint8} = DT_FLOAT")
|
||||||
.Attr(GetPaddingAttrString())
|
.Attr(GetPaddingAttrString())
|
||||||
.Attr(GetConvnetDataFormatAttrString())
|
.Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
|
||||||
.Input("input: T")
|
.Input("input: T")
|
||||||
.Input("ksize: int32")
|
.Input("ksize: int32")
|
||||||
.Input("strides: int32")
|
.Input("strides: int32")
|
||||||
|
@ -334,6 +334,7 @@ op {
|
|||||||
type: DT_QUINT8
|
type: DT_QUINT8
|
||||||
type: DT_QINT32
|
type: DT_QINT32
|
||||||
type: DT_HALF
|
type: DT_HALF
|
||||||
|
type: DT_VARIANT
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -12628,6 +12629,7 @@ op {
|
|||||||
type: DT_INT8
|
type: DT_INT8
|
||||||
type: DT_UINT16
|
type: DT_UINT16
|
||||||
type: DT_HALF
|
type: DT_HALF
|
||||||
|
type: DT_QINT8
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -12667,6 +12669,7 @@ op {
|
|||||||
list {
|
list {
|
||||||
s: "NHWC"
|
s: "NHWC"
|
||||||
s: "NCHW"
|
s: "NCHW"
|
||||||
|
s: "NCHW_VECT_C"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -13401,6 +13404,7 @@ op {
|
|||||||
type: DT_INT8
|
type: DT_INT8
|
||||||
type: DT_UINT16
|
type: DT_UINT16
|
||||||
type: DT_HALF
|
type: DT_HALF
|
||||||
|
type: DT_QINT8
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -13426,6 +13430,7 @@ op {
|
|||||||
list {
|
list {
|
||||||
s: "NHWC"
|
s: "NHWC"
|
||||||
s: "NCHW"
|
s: "NCHW"
|
||||||
|
s: "NCHW_VECT_C"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -216,7 +216,7 @@ seq2seq_attention_model.py:363:build_graph:self._add_train_o..., cpu: 1.28sec, a
|
|||||||
|
|
||||||
```shell
|
```shell
|
||||||
# The following example generates a timeline.
|
# The following example generates a timeline.
|
||||||
tfprof> graph -step 0 -max_depth 100000 -output timeline:outfile=<filename>
|
tfprof> graph -step -1 -max_depth 100000 -output timeline:outfile=<filename>
|
||||||
|
|
||||||
generating trace file.
|
generating trace file.
|
||||||
|
|
||||||
|
@ -14,7 +14,12 @@
|
|||||||
|
|
||||||
### Command Line Inputs
|
### Command Line Inputs
|
||||||
|
|
||||||
tfprof command line tool uses the following inputs:
|
tfprof command line tool uses the following input:
|
||||||
|
|
||||||
|
<b>--profile_path:</b> A ProfileProto binary proto file.
|
||||||
|
See QuickStart on generating the file.
|
||||||
|
|
||||||
|
<b>THE OLD WAY BELOW IS DEPRECATED:</b>
|
||||||
|
|
||||||
<b>--graph_path:</b> GraphDef proto file (required). Used to build in-memory
|
<b>--graph_path:</b> GraphDef proto file (required). Used to build in-memory
|
||||||
data structure of the model. For example, graph.pbtxt written by tf.Supervisor
|
data structure of the model. For example, graph.pbtxt written by tf.Supervisor
|
||||||
|
@ -84,7 +84,6 @@ string RunProfile(const string& command, const string& options,
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
bool NewProfiler(const string* graph, const string* op_log) {
|
bool NewProfiler(const string* graph, const string* op_log) {
|
||||||
CHECK(!tf_stat) << "Currently only 1 living tfprof profiler is allowed";
|
|
||||||
CHECK(graph) << "graph mustn't be null";
|
CHECK(graph) << "graph mustn't be null";
|
||||||
std::unique_ptr<GraphDef> graph_ptr(new GraphDef());
|
std::unique_ptr<GraphDef> graph_ptr(new GraphDef());
|
||||||
if (!graph_ptr->ParseFromString(*graph)) {
|
if (!graph_ptr->ParseFromString(*graph)) {
|
||||||
|
@ -175,22 +175,22 @@ class ExecStep {
|
|||||||
std::map<int32, std::pair<int64, uint64>> output_memory_;
|
std::map<int32, std::pair<int64, uint64>> output_memory_;
|
||||||
};
|
};
|
||||||
|
|
||||||
#define GRAPH_NODE_BYTES(type) \
|
#define GRAPH_NODE_BYTES(type) \
|
||||||
do { \
|
do { \
|
||||||
if (execs_.empty()) { \
|
if (execs_.empty()) { \
|
||||||
return 0; \
|
return 0; \
|
||||||
} \
|
} \
|
||||||
if (step >= 0) { \
|
if (step >= 0) { \
|
||||||
auto exec = execs_.find(step); \
|
auto exec = execs_.find(step); \
|
||||||
CHECK(exec != execs_.end()) << "unknown step " << step; \
|
if (exec == execs_.end()) return 0; \
|
||||||
return exec->second.type##_bytes(); \
|
return exec->second.type##_bytes(); \
|
||||||
} \
|
} \
|
||||||
\
|
\
|
||||||
int64 bytes = 0; \
|
int64 bytes = 0; \
|
||||||
for (const auto& exec : execs_) { \
|
for (const auto& exec : execs_) { \
|
||||||
bytes += exec.second.type##_bytes(); \
|
bytes += exec.second.type##_bytes(); \
|
||||||
} \
|
} \
|
||||||
return bytes / execs_.size(); \
|
return bytes / execs_.size(); \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
class TFGraphNode {
|
class TFGraphNode {
|
||||||
@ -372,7 +372,9 @@ class TFGraphNode {
|
|||||||
}
|
}
|
||||||
if (step >= 0) {
|
if (step >= 0) {
|
||||||
auto exec = execs_.find(step);
|
auto exec = execs_.find(step);
|
||||||
CHECK(exec != execs_.end());
|
if (exec == execs_.end()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
return exec->second.run_count();
|
return exec->second.run_count();
|
||||||
}
|
}
|
||||||
int64 total_run_count = 0;
|
int64 total_run_count = 0;
|
||||||
@ -390,7 +392,9 @@ class TFGraphNode {
|
|||||||
}
|
}
|
||||||
if (step >= 0) {
|
if (step >= 0) {
|
||||||
auto exec = execs_.find(step);
|
auto exec = execs_.find(step);
|
||||||
CHECK(exec != execs_.end());
|
if (exec == execs_.end()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
return exec->second.exec_micros();
|
return exec->second.exec_micros();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -410,7 +414,9 @@ class TFGraphNode {
|
|||||||
}
|
}
|
||||||
if (step >= 0) {
|
if (step >= 0) {
|
||||||
auto exec = execs_.find(step);
|
auto exec = execs_.find(step);
|
||||||
CHECK(exec != execs_.end());
|
if (exec == execs_.end()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
return exec->second.accelerator_exec_micros();
|
return exec->second.accelerator_exec_micros();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -430,7 +436,9 @@ class TFGraphNode {
|
|||||||
}
|
}
|
||||||
if (step >= 0) {
|
if (step >= 0) {
|
||||||
auto exec = execs_.find(step);
|
auto exec = execs_.find(step);
|
||||||
CHECK(exec != execs_.end());
|
if (exec == execs_.end()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
return exec->second.cpu_exec_micros();
|
return exec->second.cpu_exec_micros();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -448,20 +456,26 @@ class TFGraphNode {
|
|||||||
|
|
||||||
int64 all_start_micros(int64 step) const {
|
int64 all_start_micros(int64 step) const {
|
||||||
auto exec = execs_.find(step);
|
auto exec = execs_.find(step);
|
||||||
CHECK(exec != execs_.end()) << "unknown step " << step;
|
if (exec == execs_.end()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
return exec->second.all_start_micros();
|
return exec->second.all_start_micros();
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 latest_end_micros(int64 step) const {
|
int64 latest_end_micros(int64 step) const {
|
||||||
auto exec = execs_.find(step);
|
auto exec = execs_.find(step);
|
||||||
CHECK(exec != execs_.end()) << "unknown step " << step;
|
if (exec == execs_.end()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
return exec->second.latest_end_micros();
|
return exec->second.latest_end_micros();
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::map<string, std::vector<std::pair<int64, int64>>>& op_execs(
|
const std::map<string, std::vector<std::pair<int64, int64>>>& op_execs(
|
||||||
int64 step) const {
|
int64 step) const {
|
||||||
auto exec = execs_.find(step);
|
auto exec = execs_.find(step);
|
||||||
CHECK(exec != execs_.end()) << "unknown step " << step;
|
if (exec == execs_.end()) {
|
||||||
|
return empty_op_execs_;
|
||||||
|
}
|
||||||
return exec->second.op_execs();
|
return exec->second.op_execs();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -469,33 +483,45 @@ class TFGraphNode {
|
|||||||
|
|
||||||
int64 accelerator_temp_bytes(int64 step) const {
|
int64 accelerator_temp_bytes(int64 step) const {
|
||||||
auto exec = execs_.find(step);
|
auto exec = execs_.find(step);
|
||||||
CHECK(exec != execs_.end()) << "unknown step " << step;
|
if (exec == execs_.end()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
return exec->second.accelerator_temp_bytes();
|
return exec->second.accelerator_temp_bytes();
|
||||||
}
|
}
|
||||||
int64 host_temp_bytes(int64 step) const {
|
int64 host_temp_bytes(int64 step) const {
|
||||||
auto exec = execs_.find(step);
|
auto exec = execs_.find(step);
|
||||||
CHECK(exec != execs_.end()) << "unknown step " << step;
|
if (exec == execs_.end()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
return exec->second.host_temp_bytes();
|
return exec->second.host_temp_bytes();
|
||||||
}
|
}
|
||||||
int64 accelerator_persistent_bytes(int64 step) const {
|
int64 accelerator_persistent_bytes(int64 step) const {
|
||||||
auto exec = execs_.find(step);
|
auto exec = execs_.find(step);
|
||||||
CHECK(exec != execs_.end()) << "unknown step " << step;
|
if (exec == execs_.end()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
return exec->second.accelerator_persistent_bytes();
|
return exec->second.accelerator_persistent_bytes();
|
||||||
}
|
}
|
||||||
int64 host_persistent_bytes(int64 step) const {
|
int64 host_persistent_bytes(int64 step) const {
|
||||||
auto exec = execs_.find(step);
|
auto exec = execs_.find(step);
|
||||||
CHECK(exec != execs_.end()) << "unknown step " << step;
|
if (exec == execs_.end()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
return exec->second.host_persistent_bytes();
|
return exec->second.host_persistent_bytes();
|
||||||
}
|
}
|
||||||
const std::map<int32, std::pair<int64, uint64>>& output_memory(
|
const std::map<int32, std::pair<int64, uint64>>& output_memory(
|
||||||
int64 step) const {
|
int64 step) const {
|
||||||
auto exec = execs_.find(step);
|
auto exec = execs_.find(step);
|
||||||
CHECK(exec != execs_.end()) << "unknown step " << step;
|
if (exec == execs_.end()) {
|
||||||
|
return empty_output_memory_;
|
||||||
|
}
|
||||||
return exec->second.output_memory();
|
return exec->second.output_memory();
|
||||||
}
|
}
|
||||||
int64 allocator_bytes_in_use(int64 step) const {
|
int64 allocator_bytes_in_use(int64 step) const {
|
||||||
auto exec = execs_.find(step);
|
auto exec = execs_.find(step);
|
||||||
CHECK(exec != execs_.end()) << "unknown step " << step;
|
if (exec == execs_.end()) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
return exec->second.allocator_bytes_in_use();
|
return exec->second.allocator_bytes_in_use();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -566,6 +592,9 @@ class TFGraphNode {
|
|||||||
std::set<string> op_types_;
|
std::set<string> op_types_;
|
||||||
|
|
||||||
std::map<int64, ExecStep> execs_;
|
std::map<int64, ExecStep> execs_;
|
||||||
|
|
||||||
|
std::map<int32, std::pair<int64, uint64>> empty_output_memory_;
|
||||||
|
std::map<string, std::vector<std::pair<int64, int64>>> empty_op_execs_;
|
||||||
};
|
};
|
||||||
|
|
||||||
class TFMultiGraphNode {
|
class TFMultiGraphNode {
|
||||||
|
@ -88,6 +88,9 @@ TFStats::TFStats(const string& filename,
|
|||||||
node_pb.second.name(), std::move(node)));
|
node_pb.second.name(), std::move(node)));
|
||||||
}
|
}
|
||||||
has_code_traces_ = profile.has_trace();
|
has_code_traces_ = profile.has_trace();
|
||||||
|
for (int64 s : profile.steps()) {
|
||||||
|
steps_.insert(s);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TFStats::BuildView(const string& cmd) {
|
void TFStats::BuildView(const string& cmd) {
|
||||||
@ -136,6 +139,14 @@ const GraphNodeProto& TFStats::ShowGraphNode(const string& cmd,
|
|||||||
if (cmd == kCmds[0]) {
|
if (cmd == kCmds[0]) {
|
||||||
return scope_view_->Show(opts);
|
return scope_view_->Show(opts);
|
||||||
} else if (cmd == kCmds[1]) {
|
} else if (cmd == kCmds[1]) {
|
||||||
|
if (opts.step < 0 && opts.output_type == kOutput[0]) {
|
||||||
|
for (int64 step : steps_) {
|
||||||
|
Options nopts = opts;
|
||||||
|
nopts.step = step;
|
||||||
|
graph_view_->Show(nopts);
|
||||||
|
}
|
||||||
|
return empty_graph_node_;
|
||||||
|
}
|
||||||
return graph_view_->Show(opts);
|
return graph_view_->Show(opts);
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "Unknown command: %s\n", cmd.c_str());
|
fprintf(stderr, "Unknown command: %s\n", cmd.c_str());
|
||||||
@ -148,7 +159,11 @@ const MultiGraphNodeProto& TFStats::ShowMultiGraphNode(
|
|||||||
if (!Validate(opts)) {
|
if (!Validate(opts)) {
|
||||||
return empty_multi_graph_node_;
|
return empty_multi_graph_node_;
|
||||||
}
|
}
|
||||||
if (cmd == kCmds[2] && has_code_traces()) {
|
if (cmd == kCmds[2]) {
|
||||||
|
if (!has_code_traces()) {
|
||||||
|
fprintf(stderr, "No code trace information\n");
|
||||||
|
return empty_multi_graph_node_;
|
||||||
|
}
|
||||||
return code_view_->Show(opts);
|
return code_view_->Show(opts);
|
||||||
} else if (cmd == kCmds[3]) {
|
} else if (cmd == kCmds[3]) {
|
||||||
return op_view_->Show(opts);
|
return op_view_->Show(opts);
|
||||||
@ -212,7 +227,9 @@ void TFStats::AddOpLogProto(std::unique_ptr<OpLogProto> op_log) {
|
|||||||
}
|
}
|
||||||
if (entry.has_code_def()) {
|
if (entry.has_code_def()) {
|
||||||
has_code_traces_ = true;
|
has_code_traces_ = true;
|
||||||
node->second->AddCode(entry.code_def());
|
if (node->second->code().traces_size() == 0) {
|
||||||
|
node->second->AddCode(entry.code_def());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -258,9 +275,11 @@ void TFStats::WriteProfile(const string& filename) {
|
|||||||
}
|
}
|
||||||
(*profile.mutable_nodes())[it->second->id()].MergeFrom(
|
(*profile.mutable_nodes())[it->second->id()].MergeFrom(
|
||||||
it->second->ToProto(nodes_map_));
|
it->second->ToProto(nodes_map_));
|
||||||
if (it->second->code().traces_size() > 0) {
|
}
|
||||||
profile.set_has_trace(true);
|
|
||||||
}
|
profile.set_has_trace(has_code_traces_);
|
||||||
|
for (int64 s : steps_) {
|
||||||
|
profile.add_steps(s);
|
||||||
}
|
}
|
||||||
Status s =
|
Status s =
|
||||||
WriteStringToFile(Env::Default(), filename, profile.SerializeAsString());
|
WriteStringToFile(Env::Default(), filename, profile.SerializeAsString());
|
||||||
@ -271,7 +290,12 @@ void TFStats::WriteProfile(const string& filename) {
|
|||||||
|
|
||||||
bool TFStats::Validate(const Options& opts) const {
|
bool TFStats::Validate(const Options& opts) const {
|
||||||
if (opts.step >= 0 && steps_.find(opts.step) == steps_.end()) {
|
if (opts.step >= 0 && steps_.find(opts.step) == steps_.end()) {
|
||||||
fprintf(stderr, "Options -step=%lld not found\n", opts.step);
|
fprintf(stderr,
|
||||||
|
"Options -step=%lld not found.\nAvailable steps: ", opts.step);
|
||||||
|
for (int64 s : steps_) {
|
||||||
|
fprintf(stderr, "%lld ", s);
|
||||||
|
}
|
||||||
|
fprintf(stderr, "\n");
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/strings/str_util.h"
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||||
#include "tensorflow/core/profiler/internal/tfprof_utils.h"
|
#include "tensorflow/core/profiler/internal/tfprof_utils.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -303,11 +304,12 @@ void Timeline::GenerateCodeTimeline(const CodeNode* node) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void Timeline::OutputTimeline() {
|
void Timeline::OutputTimeline() {
|
||||||
|
string outfile = strings::Printf("%s_%lld", outfile_.c_str(), step());
|
||||||
Status s =
|
Status s =
|
||||||
WriteStringToFile(Env::Default(), outfile_, chrome_formatter_.Format());
|
WriteStringToFile(Env::Default(), outfile, chrome_formatter_.Format());
|
||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
fprintf(stderr, "Failed to write timeline file: %s\nError: %s\n",
|
fprintf(stderr, "Failed to write timeline file: %s\nError: %s\n",
|
||||||
outfile_.c_str(), s.ToString().c_str());
|
outfile.c_str(), s.ToString().c_str());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
fprintf(stdout, "\n******************************************************\n");
|
fprintf(stdout, "\n******************************************************\n");
|
||||||
@ -315,7 +317,7 @@ void Timeline::OutputTimeline() {
|
|||||||
"Timeline file is written to %s.\n"
|
"Timeline file is written to %s.\n"
|
||||||
"Open a Chrome browser, enter URL chrome://tracing and "
|
"Open a Chrome browser, enter URL chrome://tracing and "
|
||||||
"load the timeline file.",
|
"load the timeline file.",
|
||||||
outfile_.c_str());
|
outfile.c_str());
|
||||||
fprintf(stdout, "\n******************************************************\n");
|
fprintf(stdout, "\n******************************************************\n");
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
|
@ -70,7 +70,7 @@ TEST_F(TFProfTimelineTest, GraphView) {
|
|||||||
tf_stats_->ShowGraphNode("graph", opts);
|
tf_stats_->ShowGraphNode("graph", opts);
|
||||||
|
|
||||||
string dump_str;
|
string dump_str;
|
||||||
TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file, &dump_str));
|
TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file + "_0", &dump_str));
|
||||||
EXPECT_EQ(1754536562981488144ull, Hash64(dump_str));
|
EXPECT_EQ(1754536562981488144ull, Hash64(dump_str));
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,7 +84,7 @@ TEST_F(TFProfTimelineTest, ScopeView) {
|
|||||||
tf_stats_->ShowGraphNode("scope", opts);
|
tf_stats_->ShowGraphNode("scope", opts);
|
||||||
|
|
||||||
string dump_str;
|
string dump_str;
|
||||||
TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file, &dump_str));
|
TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file + "_0", &dump_str));
|
||||||
EXPECT_EQ(17545174915963890413ull, Hash64(dump_str));
|
EXPECT_EQ(17545174915963890413ull, Hash64(dump_str));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,6 +42,8 @@ message ProfileProto {
|
|||||||
map<int64, ProfileNode> nodes = 1;
|
map<int64, ProfileNode> nodes = 1;
|
||||||
// Whether or not has code traces.
|
// Whether or not has code traces.
|
||||||
bool has_trace = 2;
|
bool has_trace = 2;
|
||||||
|
// Traced steps.
|
||||||
|
repeated int64 steps = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message ProfileNode {
|
message ProfileNode {
|
||||||
|
@ -632,6 +632,22 @@ define an attr with constraints, you can use the following `<attr-type-expr>`s:
|
|||||||
tf.number_type(t=tf.bool) # Invalid
|
tf.number_type(t=tf.bool) # Invalid
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Lists can be combined with other lists and single types. The following
|
||||||
|
op allows attr `t` to be any of the numberic types, or the bool type:
|
||||||
|
|
||||||
|
```c++
|
||||||
|
REGISTER_OP("NumberOrBooleanType")
|
||||||
|
.Attr("t: {numbertype, bool}");
|
||||||
|
```
|
||||||
|
|
||||||
|
For this op:
|
||||||
|
|
||||||
|
```python
|
||||||
|
tf.number_or_boolean_type(t=tf.int32) # Valid
|
||||||
|
tf.number_or_boolean_type(t=tf.bool) # Valid
|
||||||
|
tf.number_or_boolean_type(t=tf.string) # Invalid
|
||||||
|
```
|
||||||
|
|
||||||
* `int >= <n>`: The value must be an int whose value is greater than or equal to
|
* `int >= <n>`: The value must be an int whose value is greater than or equal to
|
||||||
`<n>`, where `<n>` is a natural number.
|
`<n>`, where `<n>` is a natural number.
|
||||||
|
|
||||||
|
@ -1,5 +1,13 @@
|
|||||||
# Installing TensorFlow
|
# Installing TensorFlow
|
||||||
|
|
||||||
|
We've built and tested TensorFlow on the following 64-bit laptop/desktop
|
||||||
|
operating systems:
|
||||||
|
* MacOS X 10.11 (El Capitan) or later.
|
||||||
|
* Ubuntu 14.04 or later
|
||||||
|
* Windows 7 or later.
|
||||||
|
Although you might be able to install TensorFlow on other laptop or desktop
|
||||||
|
systems, we only support (and only fix issues in) the preceding configurations.
|
||||||
|
|
||||||
The following guides explain how to install a version of TensorFlow
|
The following guides explain how to install a version of TensorFlow
|
||||||
that enables you to write applications in Python:
|
that enables you to write applications in Python:
|
||||||
|
|
||||||
|
@ -1,43 +1,182 @@
|
|||||||
# Performance Guide
|
# Performance Guide
|
||||||
|
|
||||||
This guide contains a collection of best practices for optimizing your
|
This guide contains a collection of best practices for optimizing TensorFlow
|
||||||
TensorFlow code. The best practices apply to both new and experienced
|
code. The guide is divided into a few sections:
|
||||||
Tensorflow users. As a complement to the best practices in this document, the
|
|
||||||
@{$performance_models$High-Performance Models} document links to example code
|
|
||||||
and details for creating models that scale on a variety of hardware.
|
|
||||||
|
|
||||||
## Best Practices
|
* [General best practices](#general_best_practices) covers topics that are
|
||||||
While optimizing implementations of different types of models can be different,
|
common across a variety of model types and hardware.
|
||||||
the topics below cover best practices to get the most performance from
|
* [Optimizing for GPU](#optimizing_for_gpu) details tips specifically relevant
|
||||||
TensorFlow. Although these suggestions focus on image-based models, we will
|
to GPUs.
|
||||||
regularly add tips for all kinds of models. The following list highlights key
|
* [Optimizing for CPU](#optimizing_for_cpu) details CPU specific information.
|
||||||
best practices:
|
|
||||||
|
|
||||||
* Build and install from source
|
## General best practices
|
||||||
* Utilize queues for reading data
|
|
||||||
* Preprocessing on the CPU
|
|
||||||
* Use `NCHW` image data format
|
|
||||||
* Place shared parameters on the GPU
|
|
||||||
* Use fused batch norm
|
|
||||||
|
|
||||||
The following sections detail the preceding suggestions.
|
The sections below cover best practices that are relevant to a variety of
|
||||||
|
hardware and models. The best practices section is broken down into the
|
||||||
|
following sections:
|
||||||
|
|
||||||
### Build and install from source
|
* [Input pipeline optimizations](#input-pipeline-optimization)
|
||||||
|
* [Data formats](#data-formats)
|
||||||
|
* [Common fused Ops](#common-fused-ops)
|
||||||
|
* [Building and installing from source](#building-and-installing-from-source)
|
||||||
|
|
||||||
To install the most optimized version of TensorFlow, build and install
|
### Input pipeline optimization
|
||||||
TensorFlow from source by following [Installing TensorFlow from Source](../install/install_sources).
|
|
||||||
Building from source with compiler optimizations for the target hardware and
|
|
||||||
ensuring the latest CUDA platform and cuDNN libraries are installed results in
|
|
||||||
the highest performing installs.
|
|
||||||
|
|
||||||
For the most stable experience, build from the [latest release](https://github.com/tensorflow/tensorflow/releases)
|
Typical models retrieve data from disk and preprocess it before sending the data
|
||||||
branch. To get the latest performance changes and accept some stability risk,
|
through the network. For example, models that process JPEG images will follow
|
||||||
build from [master](https://github.com/tensorflow/tensorflow).
|
this flow: load image from disk, decode JPEG into a tensor, crop and pad,
|
||||||
|
possibly flip and distort, and then batch. This flow is referred to as the input
|
||||||
|
pipeline. As GPUs and other hardware accelerators get faster, preprocessing of
|
||||||
|
data can be a bottleneck.
|
||||||
|
|
||||||
If there is a need to build TensorFlow on a platform that has different hardware
|
Determining if the input pipeline is the bottleneck can be complicated. One of
|
||||||
than the target, then cross-compile with the highest optimizations for the target
|
the most straightforward methods is to reduce the model to a single operation
|
||||||
platform. The following command is an example of telling `bazel` to compile for
|
(trivial model) after the input pipeline and measure the examples per second. If
|
||||||
a specific platform:
|
the difference in examples per second for the full model and the trivial model
|
||||||
|
is minimal then the input pipeline is likely a bottleneck. Below are some other
|
||||||
|
approaches to identifying issues:
|
||||||
|
|
||||||
|
* Check if a GPU is underutilized by running `watch -n 2 nvidia-smi`. If GPU
|
||||||
|
utilization is not approaching 80-100%, then the input pipeline may be the
|
||||||
|
bottleneck.
|
||||||
|
* Generate a timeline and look for large blocks of white space (waiting). An
|
||||||
|
example of generating a timeline exists as part of the @{$jit$XLA JIT}
|
||||||
|
tutorial.
|
||||||
|
* Check CPU usage. It is possible to have an optimized input pipeline and lack
|
||||||
|
the CPU cycles to process the pipeline.
|
||||||
|
* Estimate the throughput needed and verify the disk used is capable of that
|
||||||
|
level of throughput. Some cloud solutions have network attached disks that
|
||||||
|
start as low as 50 MB/sec, which is slower than spinning disks (150 MB/sec),
|
||||||
|
SATA SSDs (500 MB/sec), and PCIe SSDs (2,000+ MB/sec).
|
||||||
|
|
||||||
|
#### Preprocessing on the CPU
|
||||||
|
|
||||||
|
Placing input pipeline operations on the CPU can significantly improve
|
||||||
|
performance. Utilizing the CPU for the input pipeline frees the GPU to focus on
|
||||||
|
training. To ensure preprocessing is on the CPU, wrap the preprocessing
|
||||||
|
operations as shown below:
|
||||||
|
|
||||||
|
```python
|
||||||
|
with tf.device('/cpu:0'):
|
||||||
|
# function to get and process images or data.
|
||||||
|
distorted_inputs = load_and_distort_images()
|
||||||
|
```
|
||||||
|
|
||||||
|
If using `tf.estimator.Estimator` the input function is automatically placed on
|
||||||
|
the CPU.
|
||||||
|
|
||||||
|
#### Using the Dataset API
|
||||||
|
|
||||||
|
The @{$datasets$Dataset API} is replacing `queue_runner` as the recommended API
|
||||||
|
for building input pipelines. The API was added to contrib as part of TensorFlow
|
||||||
|
1.2 and will move to core in the near future. This
|
||||||
|
[ResNet example](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10_estimator/cifar10_main.py)
|
||||||
|
([arXiv:1512.03385](https://arxiv.org/abs/1512.03385))
|
||||||
|
training CIFAR-10 illustrates the use of the Dataset API along with
|
||||||
|
`tf.estimator.Estimator`. The Dataset API utilizes C++ multi-threading and has a
|
||||||
|
much lower overhead than the Python-based `queue_runner` that is limited by
|
||||||
|
Python's multi-threading performance.
|
||||||
|
|
||||||
|
While feeding data using a `feed_dict` offers a high level of flexibility, in
|
||||||
|
most instances using `feed_dict` does not scale optimally. However, in instances
|
||||||
|
where only a single GPU is being used the difference can be negligible. Using
|
||||||
|
the Dataset API is still strongly recommended. Try to avoid the following:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# feed_dict often results in suboptimal performance when using large inputs.
|
||||||
|
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Use large files
|
||||||
|
|
||||||
|
Reading large numbers of small files significantly impacts I/O performance.
|
||||||
|
One approach to get maximum I/O throughput is to preprocess input data into
|
||||||
|
larger (~100MB) `TFRecord` files. For smaller data sets (200MB-1GB), the best
|
||||||
|
approach is often to load the entire data set into memory. The document
|
||||||
|
[Downloading and converting to TFRecord format](https://github.com/tensorflow/models/tree/master/slim#Data)
|
||||||
|
includes information and scripts for creating `TFRecords` and this
|
||||||
|
[script](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10_estimator/generate_cifar10_tfrecords.py)
|
||||||
|
converts the CIFAR-10 data set into `TFRecords`.
|
||||||
|
|
||||||
|
### Data formats
|
||||||
|
|
||||||
|
Data formats refers to the structure of the Tensor passed to a given Op. The
|
||||||
|
discussion below is specifically about 4D Tensors representing images. In
|
||||||
|
TensorFlow the parts of the 4D tensor are often referred to by the following
|
||||||
|
letters:
|
||||||
|
|
||||||
|
* N refers to the number of images in a batch.
|
||||||
|
* H refers to the number of pixels in the vertical (height) dimension.
|
||||||
|
* W refers to the number of pixels in the horizontal (width) dimension.
|
||||||
|
* C refers to the channels. For example, 1 for black and white or grayscale
|
||||||
|
and 3 for RGB.
|
||||||
|
|
||||||
|
Within TensorFlow there are two naming conventions representing the two most
|
||||||
|
common data formats:
|
||||||
|
|
||||||
|
* `NCHW` or `channels_first`
|
||||||
|
* `NHWC` or `channels_last`
|
||||||
|
|
||||||
|
`NHWC` is the TensorFlow default and `NCHW` is the optimal format to use when
|
||||||
|
training on NVIDIA GPUs using [cuDNN](https://developer.nvidia.com/cudnn).
|
||||||
|
|
||||||
|
The best practice is to build models that work with both data formats. This
|
||||||
|
simplifies training on GPUs and then running inference on CPUs. If TensorFlow is
|
||||||
|
compiled with the [Intel MKL](#tensorflow_with_intel_mkl-dnn) optimizations,
|
||||||
|
many operations, especially those related to CNN based models, will be optimized
|
||||||
|
and support `NCHW`. If not using the MKL, some operations are not supported on
|
||||||
|
CPU when using `NCHW`.
|
||||||
|
|
||||||
|
The brief history of these two formats is that TensorFlow started by using
|
||||||
|
`NHWC` because it was a little faster on CPUs. In the long term, we are working
|
||||||
|
on tools to auto rewrite graphs to make switching between the formats
|
||||||
|
transparent and take advantages of micro optimizations where a GPU Op may be
|
||||||
|
faster using `NHWC` than the normally most efficient `NCHW`.
|
||||||
|
|
||||||
|
### Common fused Ops
|
||||||
|
|
||||||
|
Fused Ops combine multiple operations into a single kernel for improved
|
||||||
|
performance. There are many fused Ops within TensorFlow and @{$xla$XLA} will
|
||||||
|
create fused Ops when possible to automatically improve performance. Collected
|
||||||
|
below are select fused Ops that can greatly improve performance and may be
|
||||||
|
overlooked.
|
||||||
|
|
||||||
|
#### Fused batch norm
|
||||||
|
|
||||||
|
Fused batch norm combines the multiple operations needed to do batch
|
||||||
|
normalization into a single kernel. Batch norm is an expensive process that for
|
||||||
|
some models makes up a large percentage of the operation time. Using fused batch
|
||||||
|
norm can result in a 12%-30% speedup.
|
||||||
|
|
||||||
|
There are two commonly used batch norms and both support fusing. The core
|
||||||
|
@{tf.layers.batch_normalization} added fused starting in TensorFlow 1.3.
|
||||||
|
|
||||||
|
```python
|
||||||
|
bn = tf.layers.batch_normalization(
|
||||||
|
input_layer, fused=True, data_format='NCHW')
|
||||||
|
```
|
||||||
|
|
||||||
|
The contrib @{tf.contrib.layers.batch_norm} method has had fused as an option
|
||||||
|
since before TensorFlow 1.0.
|
||||||
|
|
||||||
|
```python
|
||||||
|
bn = tf.contrib.layers.batch_norm(input_layer, fused=True, data_format='NCHW')
|
||||||
|
```
|
||||||
|
|
||||||
|
### Building and installing from source
|
||||||
|
|
||||||
|
The default TensorFlow binaries target the broadest range of hardware to make
|
||||||
|
TensorFlow accessible to everyone. If using CPUs for training or inference, it
|
||||||
|
is recommended to compile TensorFlow with all of the optimizations available for
|
||||||
|
the CPU in use. Speedups for training and inference on CPU are documented below
|
||||||
|
in [Comparing compiler optimizations](#comparing-compiler-optimizations).
|
||||||
|
|
||||||
|
To install the most optimized version of TensorFlow,
|
||||||
|
@{$install_sources$build and install} from source. If there is a need to build
|
||||||
|
TensorFlow on a platform that has different hardware than the target, then
|
||||||
|
cross-compile with the highest optimizations for the target platform. The
|
||||||
|
following command is an example of using `bazel` to compile for a specific
|
||||||
|
platform:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# This command optimizes for Intel’s Broadwell processor
|
# This command optimizes for Intel’s Broadwell processor
|
||||||
@ -47,106 +186,467 @@ bazel build -c opt --copt=-march="broadwell" --config=cuda //tensorflow/tools/pi
|
|||||||
|
|
||||||
#### Environment, build, and install tips
|
#### Environment, build, and install tips
|
||||||
|
|
||||||
* Compile with the highest level of compute the [GPU
|
* `./configure` asks which compute capability to include in the build. This
|
||||||
supports](http://developer.nvidia.com/cuda-gpus), e.g. P100: 6.0, Titan X
|
does not impact overall performance but does impact initial startup. After
|
||||||
(pascal): 6.2, Titan X (maxwell): 5.2, and K80: 3.7.
|
running TensorFlow once, the compiled kernels are cached by CUDA. If using
|
||||||
* Install the latest CUDA platform and cuDNN libraries.
|
a docker container, the data is not cached and the penalty is paid each time
|
||||||
* Make sure to use a version of gcc that supports all of the optimizations of
|
TensorFlow starts. The best practice is to include the
|
||||||
the target CPU. The recommended minimum gcc version is 4.8.3. On OS X upgrade
|
[compute capabilities](http://developer.nvidia.com/cuda-gpus)
|
||||||
to the latest Xcode version and use the version of clang that comes with Xcode.
|
of the GPUs that will be used, e.g. P100: 6.0, Titan X (Pascal): 6.1, Titan
|
||||||
* TensorFlow checks on startup whether it has been compiled with the
|
X (Maxwell): 5.2, and K80: 3.7.
|
||||||
optimizations available on the CPU. If the optimizations are not included,
|
* Use a version of gcc that supports all of the optimizations of the target
|
||||||
TensorFlow will emit warnings, e.g. AVX, AVX2, and FMA instructions not
|
CPU. The recommended minimum gcc version is 4.8.3. On OS X, upgrade to the
|
||||||
included.
|
latest Xcode version and use the version of clang that comes with Xcode.
|
||||||
|
* Install the latest stable CUDA platform and cuDNN libraries supported by
|
||||||
|
TensorFlow.
|
||||||
|
|
||||||
### Utilize queues for reading data
|
## Optimizing for GPU
|
||||||
|
|
||||||
One common cause of poor performance is underutilizing GPUs, or essentially
|
This section contains GPU-specific tips that are not covered in the
|
||||||
"starving" them of data by not setting up an efficient pipeline. Make sure to
|
[General best practices](#general-best-practices). Obtaining optimal performance
|
||||||
set up an input pipeline to utilize queues and stream data effectively. Review
|
on multi-GPUs is a challenge. A common approach is to use data parallelism.
|
||||||
the @{$reading_data#reading_from_files$Reading Data guide} for implementation
|
Scaling through the use of data parallelism involves making multiple copies of
|
||||||
details. One way to identify a "starved" GPU is to generate and review
|
the model, which are referred to as "towers", and then placing one tower on each
|
||||||
timelines. A detailed tutorial for timelines does not exist, but a quick example
|
of the GPUs. Each tower operates on a different mini-batch of data and then
|
||||||
of generating a timeline exists as part of the @{$jit$XLA JIT} tutorial. Another
|
updates variables, also known as parameters, that need to be shared between
|
||||||
simple way to check if a GPU is underutilized is to run `watch nvidia-smi`, and
|
each of the towers. How each tower gets the updated variables and how the
|
||||||
if GPU utilization is not approaching 100% then the GPU is not getting data fast
|
gradients are applied has an impact on the performance, scaling, and convergence
|
||||||
enough.
|
of the model. The rest of this section provides an overview of variable
|
||||||
|
placement and the towering of a model on multiple GPUs.
|
||||||
|
@{$performance_models$High-Performance Models} gets into more details regarding
|
||||||
|
more complex methods that can be used to share and update variables between
|
||||||
|
towers.
|
||||||
|
|
||||||
Unless for a special circumstance or for example code, do not feed data
|
The best approach to handling variable updates depends on the model, hardware,
|
||||||
into the session from Python variables, e.g. `dictionary`.
|
and even how the hardware has been configured. An example of this, is that two
|
||||||
|
systems can be built with NVIDIA Tesla P100s but one may be using PCIe and the
|
||||||
|
other [NVLink](http://www.nvidia.com/object/nvlink.html). In that scenario, the
|
||||||
|
optimal solution for each system may be different. For real world examples, read
|
||||||
|
the @{$benchmarks$benchmark} page which details the settings that were optimal
|
||||||
|
for a variety of platforms. Below is a summary of what was learned from
|
||||||
|
benchmarking various platforms and configurations:
|
||||||
|
|
||||||
|
* **Tesla K80**: If the GPUs are on the same PCI Express root complex and are
|
||||||
|
able to use [NVIDIA GPUDirect](https://developer.nvidia.com/gpudirect) Peer
|
||||||
|
to Peer, then placing the variables equally across the GPUs used for
|
||||||
|
training is the best approach. If the GPUs cannot use GPUDirect, then
|
||||||
|
placing the variables on the CPU is the best option.
|
||||||
|
|
||||||
|
* **Titan X (Maxwell and Pascal), M40, P100, and similar**: For models like
|
||||||
|
ResNet and InceptionV3, placing variables on the CPU is the optimal setting,
|
||||||
|
but for models with a lot of variables like AlexNet and VGG, using GPUs with
|
||||||
|
`NCCL` is better.
|
||||||
|
|
||||||
|
A common approach to managing where variables are placed, is to create a method
|
||||||
|
to determine where each Op is to be placed and use that method in place of a
|
||||||
|
specific device name when calling `with tf.device():`. Consider a scenario where
|
||||||
|
a model is being trained on 2 GPUs and the variables are to be placed on the
|
||||||
|
CPU. There would be a loop for creating and placing the "towers" on each of the
|
||||||
|
2 GPUs. A custom device placement method would be created that watches for Ops
|
||||||
|
of type `Variable`, `VariableV2`, and `VarHandleOp` and indicates that they are
|
||||||
|
to be placed on the CPU. All other Ops would be placed on the target GPU.
|
||||||
|
The building of the graph would proceed as follows:
|
||||||
|
|
||||||
|
* On the first loop a "tower" of the model would be created for `gpu:0`.
|
||||||
|
During the placement of the Ops, the custom device placement method would
|
||||||
|
indicate that variables are to be placed on `cpu:0` and all other Ops on
|
||||||
|
`gpu:0`.
|
||||||
|
|
||||||
|
* On the second loop, `reuse` is set to `True` to indicate that variables are
|
||||||
|
to be reused and then the "tower" is created on `gpu:1`. During the
|
||||||
|
placement of the Ops associated with the "tower", the variables that were
|
||||||
|
placed on `cpu:0` are reused and all other Ops are created and placed on
|
||||||
|
`gpu:1`.
|
||||||
|
|
||||||
|
The final result is all of the variables are placed on the CPU with each GPU
|
||||||
|
having a copy of all of the computational Ops associated with the model.
|
||||||
|
|
||||||
|
The code snippet below illustrates two different approaches for variable
|
||||||
|
placement: one is placing variables on the CPU; the other is placing variables
|
||||||
|
equally across the GPUs.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Using feed_dict often results in suboptimal performance when using large inputs.
|
|
||||||
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
|
class GpuParamServerDeviceSetter(object):
|
||||||
|
"""Used with tf.device() to place variables on the least loaded GPU.
|
||||||
|
|
||||||
|
A common use for this class is to pass a list of GPU devices, e.g. ['gpu:0',
|
||||||
|
'gpu:1','gpu:2'], as ps_devices. When each variable is placed, it will be
|
||||||
|
placed on the least loaded gpu. All other Ops, which will be the computation
|
||||||
|
Ops, will be placed on the worker_device.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, worker_device, ps_devices):
|
||||||
|
"""Initializer for GpuParamServerDeviceSetter.
|
||||||
|
Args:
|
||||||
|
worker_device: the device to use for computation Ops.
|
||||||
|
ps_devices: a list of devices to use for Variable Ops. Each variable is
|
||||||
|
assigned to the least loaded device.
|
||||||
|
"""
|
||||||
|
self.ps_devices = ps_devices
|
||||||
|
self.worker_device = worker_device
|
||||||
|
self.ps_sizes = [0] * len(self.ps_devices)
|
||||||
|
|
||||||
|
def __call__(self, op):
|
||||||
|
if op.device:
|
||||||
|
return op.device
|
||||||
|
if op.type not in ['Variable', 'VariableV2', 'VarHandleOp']:
|
||||||
|
return self.worker_device
|
||||||
|
|
||||||
|
# Gets the least loaded ps_device
|
||||||
|
device_index, _ = min(enumerate(self.ps_sizes), key=operator.itemgetter(1))
|
||||||
|
device_name = self.ps_devices[device_index]
|
||||||
|
var_size = op.outputs[0].get_shape().num_elements()
|
||||||
|
self.ps_sizes[device_index] += var_size
|
||||||
|
|
||||||
|
return device_name
|
||||||
|
|
||||||
|
def _create_device_setter(is_cpu_ps, worker, num_gpus):
|
||||||
|
"""Create device setter object."""
|
||||||
|
if is_cpu_ps:
|
||||||
|
# tf.train.replica_device_setter supports placing variables on the CPU, all
|
||||||
|
# on one GPU, or on ps_servers defined in a cluster_spec.
|
||||||
|
return tf.train.replica_device_setter(
|
||||||
|
worker_device=worker, ps_device='/cpu:0', ps_tasks=1)
|
||||||
|
else:
|
||||||
|
gpus = ['/gpu:%d' % i for i in range(num_gpus)]
|
||||||
|
return ParamServerDeviceSetter(worker, gpus)
|
||||||
|
|
||||||
|
# The method below is a modified snippet from the full example.
|
||||||
|
def _resnet_model_fn():
|
||||||
|
# When set to False, variables are placed on the least loaded GPU. If set
|
||||||
|
# to True, the variables will be placed on the CPU.
|
||||||
|
is_cpu_ps = False
|
||||||
|
|
||||||
|
# Loops over the number of GPUs and creates a copy ("tower") of the model on
|
||||||
|
# each GPU.
|
||||||
|
for i in range(num_gpus):
|
||||||
|
worker = '/gpu:%d' % i
|
||||||
|
# Creates a device setter used to determine where Ops are to be placed.
|
||||||
|
device_setter = _create_device_setter(is_cpu_ps, worker, FLAGS.num_gpus)
|
||||||
|
# Creates variables on the first loop. On subsequent loops reuse is set
|
||||||
|
# to True, which results in the "towers" sharing variables.
|
||||||
|
with tf.variable_scope('resnet', reuse=bool(i != 0)):
|
||||||
|
with tf.name_scope('tower_%d' % i) as name_scope:
|
||||||
|
# tf.device calls the device_setter for each Op that is created.
|
||||||
|
# device_setter returns the device the Op is to be placed on.
|
||||||
|
with tf.device(device_setter):
|
||||||
|
# Creates the "tower".
|
||||||
|
_tower_fn(is_training, weight_decay, tower_features[i],
|
||||||
|
tower_labels[i], tower_losses, tower_gradvars,
|
||||||
|
tower_preds, False)
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Preprocessing on the CPU
|
In the near future the above code will be for illustration purposes only as
|
||||||
|
there will be easy to use high level methods to support a wide range of popular
|
||||||
|
approaches. This
|
||||||
|
[example](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10_estimator)
|
||||||
|
will continue to get updated as the API expands and evolves to address multi-GPU
|
||||||
|
scenarios.
|
||||||
|
|
||||||
Placing preprocessing operations on the CPU can significantly improve
|
## Optimizing for CPU
|
||||||
performance. When preprocessing occurs on the GPU the flow of data is
|
|
||||||
CPU -> GPU (preprocessing) -> CPU -> GPU (training). The data is bounced back
|
|
||||||
and forth between the CPU and GPU. When preprocessing is placed on the CPU,
|
|
||||||
the data flow is CPU (preprocessing) -> GPU (training). Another benefit is
|
|
||||||
preprocessing on the CPU frees GPU time to focus on training.
|
|
||||||
|
|
||||||
Placing preprocessing on the CPU can result in a 6X+ increase in samples/sec
|
CPUs, which includes Intel® Xeon Phi™, achieve optimal performance when
|
||||||
processed, which could lead to training in 1/6th of the time. To ensure
|
TensorFlow is @{$install_sources$built from source} with all of the instructions
|
||||||
preprocessing is on the CPU, wrap the preprocessing operations as shown below:
|
supported by the target CPU.
|
||||||
|
|
||||||
|
Beyond using the latest instruction sets, Intel® has added support for the
|
||||||
|
Intel® Math Kernel Library for Deep Neural Networks (Intel® MKL-DNN) to
|
||||||
|
TensorFlow. While the name is not completely accurate, these optimizations are
|
||||||
|
often simply referred to as 'MKL' or 'TensorFlow with MKL'. [TensorFlow
|
||||||
|
with Intel® MKL-DNN](#tensorflow_with_intel_mkl_dnn) contains details on the
|
||||||
|
MKL optimizations.
|
||||||
|
|
||||||
|
The two configurations listed below are used to optimize CPU performance by
|
||||||
|
adjusting the thread pools.
|
||||||
|
|
||||||
|
* `intra_op_parallelism_threads`: Nodes that can use multiple threads to
|
||||||
|
parallelize their execution will schedule the individual pieces into this
|
||||||
|
pool.
|
||||||
|
* `inter_op_parallelism_threads`: All ready nodes are scheduled in this pool.
|
||||||
|
|
||||||
|
These configurations are set via the `tf.ConfigProto` and passed to `tf.Session`
|
||||||
|
in the `config` attribute as shown in the snippet below. For both configuration
|
||||||
|
options, if they are unset or set to 0, will default to the number of logical
|
||||||
|
CPU cores. Testing has shown that the default is effective for systems ranging
|
||||||
|
from one CPU with 4 cores to multiple CPUs with 70+ combined logical cores.
|
||||||
|
A common alternative optimization is to set the number of threads in both pools
|
||||||
|
equal to the number of physical cores rather than logical cores.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
with tf.device('/cpu:0'):
|
|
||||||
# function to get and process images or data.
|
config = tf.ConfigProto()
|
||||||
distorted_inputs = load_and_distort_images()
|
config.intra_op_parallelism_threads = 44
|
||||||
|
config.inter_op_parallelism_threads = 44
|
||||||
|
tf.session(config=config)
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Use large files
|
The [Comparing compiler optimizations](#comparing-compiler-optimizations)
|
||||||
|
section contains the results of tests that used different compiler
|
||||||
|
optimizations.
|
||||||
|
|
||||||
Under some circumstances, both the CPU and GPU can be starved for data by the
|
### TensorFlow with Intel® MKL DNN
|
||||||
I/O system. If you are using many small files to form your input data set, you
|
|
||||||
may be limited by the speed of your filesystem. If your training loop runs
|
|
||||||
faster when using SSDs vs HDDs for storing your input data, you could be
|
|
||||||
I/O bottlenecked.
|
|
||||||
|
|
||||||
If this is the case, you should pre-process your input data, creating a few
|
Intel® has added optimizations to TensorFlow for Intel® Xeon® and Intel® Xeon
|
||||||
large TFRecord files.
|
Phi™ though the use of Intel® Math Kernel Library for Deep Neural Networks
|
||||||
|
(Intel® MKL-DNN) optimized primitives. The optimizations also provide speedups
|
||||||
|
for the consumer line of processors, e.g. i5 and i7 Intel processors. The Intel
|
||||||
|
published paper
|
||||||
|
[TensorFlow* Optimizations on Modern Intel® Architecture](https://software.intel.com/en-us/articles/tensorflow-optimizations-on-modern-intel-architecture)
|
||||||
|
contains additional details on the implementation.
|
||||||
|
|
||||||
### Use NCHW image data format
|
> Note: MKL was added as of TensorFlow 1.2 and currently only works on Linux. It
|
||||||
|
> also does not work when also using `--config=cuda`.
|
||||||
|
|
||||||
Image data format refers to the representation of batches of images. TensorFlow
|
In addition to providing significant performance improvements for training CNN
|
||||||
supports `NHWC` (TensorFlow default) and `NCHW` (cuDNN default). N refers to the
|
based models, compiling with the MKL creates a binary that is optimized for AVX
|
||||||
number of images in a batch, H refers to the number of pixels in the vertical
|
and AVX2. The result is a single binary that is optimized and compatible with
|
||||||
dimension, W refers to the number of pixels in the horizontal dimension, and C
|
most modern (post-2011) processors.
|
||||||
refers to the channels (e.g. 1 for black and white, 3 for RGB, etc.) Although
|
|
||||||
cuDNN can operate on both formats, it is faster to operate in its default
|
|
||||||
format.
|
|
||||||
|
|
||||||
The best practice is to build models that work with both `NCHW` and `NHWC` as it
|
TensorFlow can be compiled with the MKL optimizations using the following
|
||||||
is common to train using `NCHW` on GPU, and then do inference with `NHWC` on CPU.
|
commands that depending on the version of the TensorFlow source used.
|
||||||
|
|
||||||
There are edge cases where `NCHW` can be slower on GPU than `NHWC`. One
|
For TensorFlow source versions after 1.3.0:
|
||||||
[case](https://github.com/tensorflow/tensorflow/issues/7551#issuecomment-280421351)
|
|
||||||
is using non-fused batch norm on WRN-16-4 without dropout. In that case using
|
|
||||||
fused batch norm, which is also recommended, is the optimal solution.
|
|
||||||
|
|
||||||
The very brief history of these two formats is that TensorFlow started by using
|
```bash
|
||||||
`NHWC` because it was a little faster on CPUs. Then the TensorFlow team
|
./configure
|
||||||
discovered that `NCHW` performs better when using the NVIDIA cuDNN library. The
|
# Pick the desired options
|
||||||
current recommendation is that users support both formats in their models. In
|
bazel build --config=mkl -c opt //tensorflow/tools/pip_package:build_pip_package
|
||||||
the long term, we plan to rewrite graphs to make switching between the formats
|
|
||||||
transparent.
|
|
||||||
|
|
||||||
### Use fused batch norm
|
```
|
||||||
|
|
||||||
When using batch norm
|
For TensorFlow versions 1.2.0 through 1.3.0:
|
||||||
@{tf.contrib.layers.batch_norm} set the attribute `fused=True`:
|
|
||||||
|
```bash
|
||||||
|
./configure
|
||||||
|
Do you wish to build TensorFlow with MKL support? [y/N] Y
|
||||||
|
Do you wish to download MKL LIB from the web? [Y/n] Y
|
||||||
|
# Select the defaults for the rest of the options.
|
||||||
|
|
||||||
|
bazel build --config=mkl --copt="-DEIGEN_USE_VML" -c opt //tensorflow/tools/pip_package:build_pip_package
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Tuning MKL for the best performance
|
||||||
|
|
||||||
|
This section details the different configurations and environment variables that
|
||||||
|
can be used to tune the MKL to get optimal performance. Before tweaking various
|
||||||
|
environment variables make sure the model is using the `NCHW` (`channels_first`)
|
||||||
|
[data format](#data-formats). The MKL is optimized for `NCHW` and Intel is
|
||||||
|
working to get near performance parity when using `NHWC`.
|
||||||
|
|
||||||
|
MKL uses the following environment variables to tune performance:
|
||||||
|
|
||||||
|
* KMP_BLOCKTIME - Sets the time, in milliseconds, that a thread should wait,
|
||||||
|
after completing the execution of a parallel region, before sleeping.
|
||||||
|
* KMP_AFFINITY - Enables the run-time library to bind threads to physical
|
||||||
|
processing units.
|
||||||
|
* KMP_SETTINGS - Enables (true) or disables (false) the printing of OpenMP*
|
||||||
|
run-time library environment variables during program execution.
|
||||||
|
* OMP_NUM_THREADS - Specifies the number of threads to use.
|
||||||
|
|
||||||
|
More details on the KMP variables are on
|
||||||
|
[Intel's](https://software.intel.com/en-us/node/522775) site and the OMP
|
||||||
|
variables on
|
||||||
|
[gnu.org](https://gcc.gnu.org/onlinedocs/libgomp/Environment-Variables.html)
|
||||||
|
|
||||||
|
While there can be substantial gains from adjusting the environment variables,
|
||||||
|
which is discussed below, the simplified advice is to set the
|
||||||
|
`inter_op_parallelism_threads` equal to the number of physical CPUs and to set
|
||||||
|
the following environment variables:
|
||||||
|
|
||||||
|
* KMP_BLOCKTIME=0
|
||||||
|
* KMP_AFFINITY=granularity=fine,verbose,compact,1,0
|
||||||
|
|
||||||
|
Example setting MKL variables with command-line arguments:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
KMP_BLOCKTIME=0 KMP_AFFINITY=granularity=fine,verbose,compact,1,0 \
|
||||||
|
KMP_SETTINGS=1 python your_python_script.py
|
||||||
|
```
|
||||||
|
|
||||||
|
Example setting MKL variables with python `os.environ`:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
bn = tf.contrib.layers.batch_norm(
|
os.environ["KMP_BLOCKTIME"] = str(FLAGS.kmp_blocktime)
|
||||||
input_layer, fused=True, data_format='NCHW'
|
os.environ["KMP_SETTINGS"] = str(FLAGS.kmp_settings)
|
||||||
scope=scope, **kwargs)
|
os.environ["KMP_AFFINITY"]= FLAGS.kmp_affinity
|
||||||
|
if FLAGS.num_intra_threads > 0:
|
||||||
|
os.environ["OMP_NUM_THREADS"]= str(FLAGS.num_intra_threads)
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
The non-fused batch norm does computations using several individual Ops. Fused
|
There are models and hardware platforms that benefit from different settings.
|
||||||
batch norm combines the individual operations into a single kernel, which runs
|
Each variable that impacts performance is discussed below.
|
||||||
faster.
|
|
||||||
|
|
||||||
|
* **KMP_BLOCKTIME**: The MKL default is 200ms, which was not optimal in our
|
||||||
|
testing. 0 (0ms) was a good default for CNN based models that were tested.
|
||||||
|
The best performance for AlexNex was achieved at 30ms and both GoogleNet and
|
||||||
|
VGG11 performed best set at 1ms.
|
||||||
|
|
||||||
|
* **KMP_AFFINITY**: The recommended setting is
|
||||||
|
`granularity=fine,verbose,compact,1,0`.
|
||||||
|
|
||||||
|
* **OMP_NUM_THREADS**: This defaults to the number of physical cores.
|
||||||
|
Adjusting this parameter beyond matching the number of cores can have an
|
||||||
|
impact when using Intel® Xeon Phi™ (Knights Landing) for some models. See
|
||||||
|
[TensorFlow* Optimizations on Modern Intel® Architecture](https://software.intel.com/en-us/articles/tensorflow-optimizations-on-modern-intel-architecture)
|
||||||
|
for optimal settings.
|
||||||
|
|
||||||
|
* **intra_op_parallelism_threads**: Setting this equal to the number of
|
||||||
|
physical cores is recommended. Setting the value to 0, which is the default
|
||||||
|
and will result in the value being set to the number of logical cores, is an
|
||||||
|
option to try for some architectures. This value and `OMP_NUM_THREADS`
|
||||||
|
should be equal.
|
||||||
|
|
||||||
|
* **inter_op_parallelism_threads**: Setting this equal to the number of
|
||||||
|
sockets is recommended. Setting the value to 0, which is the default,
|
||||||
|
results in the value being set to the number of logical cores.
|
||||||
|
|
||||||
|
### Comparing compiler optimizations
|
||||||
|
|
||||||
|
Collected below are performance results running training and inference on
|
||||||
|
different types of CPUs on different platforms with various compiler
|
||||||
|
optimizations. The models used were ResNet-50
|
||||||
|
([arXiv:1512.03385](https://arxiv.org/abs/1512.03385)) and
|
||||||
|
InceptionV3 ([arXiv:1512.00567](https://arxiv.org/abs/1512.00567)).
|
||||||
|
|
||||||
|
For each test, when the MKL optimization was used the environment variable
|
||||||
|
KMP_BLOCKTIME was set to 0 (0ms) and KMP_AFFINITY to
|
||||||
|
`granularity=fine,verbose,compact,1,0`.
|
||||||
|
|
||||||
|
#### Inference InceptionV3
|
||||||
|
|
||||||
|
**Environment**
|
||||||
|
|
||||||
|
* Instance Type: AWS EC2 m4.xlarge
|
||||||
|
* CPU: Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz (Broadwell)
|
||||||
|
* Dataset: ImageNet
|
||||||
|
* TensorFlow Version: 1.2.0 RC2
|
||||||
|
* Test Script: [tf_cnn_benchmarks.py](https://github.com/tensorflow/benchmarks/blob/mkl_experiment/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py)
|
||||||
|
|
||||||
|
**Batch Size: 1**
|
||||||
|
|
||||||
|
Command executed for the MKL test:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python tf_cnn_benchmarks.py --forward_only=True --device=cpu --mkl=True \
|
||||||
|
--kmp_blocktime=0 --nodistortions --model=inception3 --data_format=NCHW \
|
||||||
|
--batch_size=1 --num_inter_threads=1 --num_intra_threads=4 \
|
||||||
|
--data_dir=<path to ImageNet TFRecords>
|
||||||
|
```
|
||||||
|
|
||||||
|
| Optimization | Data Format | Images/Sec | Intra threads | Inter Threads |
|
||||||
|
: : : (step time) : : :
|
||||||
|
| ------------ | ----------- | ------------ | ------------- | ------------- |
|
||||||
|
| AVX2 | NHWC | 6.8 (147ms) | 4 | 0 |
|
||||||
|
| MKL | NCHW | 6.6 (151ms) | 4 | 1 |
|
||||||
|
| MKL | NHWC | 5.95 (168ms) | 4 | 1 |
|
||||||
|
| AVX | NHWC | 4.7 (211ms) | 4 | 0 |
|
||||||
|
| SSE3 | NHWC | 2.7 (370ms) | 4 | 0 |
|
||||||
|
|
||||||
|
**Batch Size: 32**
|
||||||
|
|
||||||
|
Command executed for the MKL test:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python tf_cnn_benchmarks.py --forward_only=True --device=cpu --mkl=True \
|
||||||
|
--kmp_blocktime=0 --nodistortions --model=inception3 --data_format=NCHW \
|
||||||
|
--batch_size=32 --num_inter_threads=1 --num_intra_threads=4 \
|
||||||
|
--data_dir=<path to ImageNet TFRecords>
|
||||||
|
```
|
||||||
|
|
||||||
|
| Optimization | Data Format | Images/Sec | Intra threads | Inter Threads |
|
||||||
|
: : : (step time) : : :
|
||||||
|
| ------------ | ----------- | ------------- | ------------- | ------------- |
|
||||||
|
| MKL | NCHW | 10.24 | 4 | 1 |
|
||||||
|
: : : (3125ms) : : :
|
||||||
|
| MKL | NHWC | 8.9 (3595ms) | 4 | 1 |
|
||||||
|
| AVX2 | NHWC | 7.3 (4383ms) | 4 | 0 |
|
||||||
|
| AVX | NHWC | 5.1 (6275ms) | 4 | 0 |
|
||||||
|
| SSE3 | NHWC | 2.8 (11428ms) | 4 | 0 |
|
||||||
|
|
||||||
|
#### Inference ResNet-50
|
||||||
|
|
||||||
|
**Environment**
|
||||||
|
|
||||||
|
* Instance Type: AWS EC2 m4.xlarge
|
||||||
|
* CPU: Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz (Broadwell)
|
||||||
|
* Dataset: ImageNet
|
||||||
|
* TensorFlow Version: 1.2.0 RC2
|
||||||
|
* Test Script: [tf_cnn_benchmarks.py](https://github.com/tensorflow/benchmarks/blob/mkl_experiment/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py)
|
||||||
|
|
||||||
|
**Batch Size: 1**
|
||||||
|
|
||||||
|
Command executed for the MKL test:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python tf_cnn_benchmarks.py --forward_only=True --device=cpu --mkl=True \
|
||||||
|
--kmp_blocktime=0 --nodistortions --model=resnet50 --data_format=NCHW \
|
||||||
|
--batch_size=1 --num_inter_threads=1 --num_intra_threads=4 \
|
||||||
|
--data_dir=<path to ImageNet TFRecords>
|
||||||
|
```
|
||||||
|
|
||||||
|
| Optimization | Data Format | Images/Sec | Intra threads | Inter Threads |
|
||||||
|
: : : (step time) : : :
|
||||||
|
| ------------ | ----------- | ------------ | ------------- | ------------- |
|
||||||
|
| AVX2 | NHWC | 6.8 (147ms) | 4 | 0 |
|
||||||
|
| MKL | NCHW | 6.6 (151ms) | 4 | 1 |
|
||||||
|
| MKL | NHWC | 5.95 (168ms) | 4 | 1 |
|
||||||
|
| AVX | NHWC | 4.7 (211ms) | 4 | 0 |
|
||||||
|
| SSE3 | NHWC | 2.7 (370ms) | 4 | 0 |
|
||||||
|
|
||||||
|
**Batch Size: 32**
|
||||||
|
|
||||||
|
Command executed for the MKL test:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python tf_cnn_benchmarks.py --forward_only=True --device=cpu --mkl=True \
|
||||||
|
--kmp_blocktime=0 --nodistortions --model=resnet50 --data_format=NCHW \
|
||||||
|
--batch_size=32 --num_inter_threads=1 --num_intra_threads=4 \
|
||||||
|
--data_dir=<path to ImageNet TFRecords>
|
||||||
|
```
|
||||||
|
|
||||||
|
| Optimization | Data Format | Images/Sec | Intra threads | Inter Threads |
|
||||||
|
: : : (step time) : : :
|
||||||
|
| ------------ | ----------- | ------------- | ------------- | ------------- |
|
||||||
|
| MKL | NCHW | 10.24 | 4 | 1 |
|
||||||
|
: : : (3125ms) : : :
|
||||||
|
| MKL | NHWC | 8.9 (3595ms) | 4 | 1 |
|
||||||
|
| AVX2 | NHWC | 7.3 (4383ms) | 4 | 0 |
|
||||||
|
| AVX | NHWC | 5.1 (6275ms) | 4 | 0 |
|
||||||
|
| SSE3 | NHWC | 2.8 (11428ms) | 4 | 0 |
|
||||||
|
|
||||||
|
#### Training InceptionV3
|
||||||
|
|
||||||
|
**Environment**
|
||||||
|
|
||||||
|
* Instance Type: Dedicated AWS EC2 r4.16xlarge (Broadwell)
|
||||||
|
* CPU: Intel Xeon E5-2686 v4 (Broadwell) Processors
|
||||||
|
* Dataset: ImageNet
|
||||||
|
* TensorFlow Version: 1.2.0 RC2
|
||||||
|
* Test Script: [tf_cnn_benchmarks.py](https://github.com/tensorflow/benchmarks/blob/mkl_experiment/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py)
|
||||||
|
|
||||||
|
Command executed for MKL test:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python tf_cnn_benchmarks.py --device=cpu --mkl=True --kmp_blocktime=0 \
|
||||||
|
--nodistortions --model=resnet50 --data_format=NCHW --batch_size=32 \
|
||||||
|
--num_inter_threads=2 --num_intra_threads=36 \
|
||||||
|
--data_dir=<path to ImageNet TFRecords>
|
||||||
|
```
|
||||||
|
|
||||||
|
Optimization | Data Format | Images/Sec | Intra threads | Inter Threads
|
||||||
|
------------ | ----------- | ---------- | ------------- | -------------
|
||||||
|
MKL | NCHW | 20.8 | 36 | 2
|
||||||
|
AVX2 | NHWC | 6.2 | 36 | 0
|
||||||
|
AVX | NHWC | 5.7 | 36 | 0
|
||||||
|
SSE3 | NHWC | 4.3 | 36 | 0
|
||||||
|
|
||||||
|
ResNet and [AlexNet](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf)
|
||||||
|
were also run on this configuration but in an ad hoc manner. There were not
|
||||||
|
enough runs executed to publish a coherent table of results. The incomplete
|
||||||
|
results strongly indicated the final result would be similar to the table above
|
||||||
|
with MKL providing significant 3x+ gains over AVX2.
|
||||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user