Merge pull request #12956 from yifeif/branch_168186374

Branch 168186374
This commit is contained in:
Yifei Feng 2017-09-11 10:46:32 -07:00 committed by GitHub
commit 40eef4473b
132 changed files with 7138 additions and 847 deletions

View File

@ -79,6 +79,8 @@ tf_cc_test(
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],

View File

@ -479,20 +479,16 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
if (kernel == nullptr) {
const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
kernel = new tensorflow::KernelAndDevice(ctx->rendezvous);
if (!op->is_function()) {
status->status =
tensorflow::KernelAndDevice::InitOp(device, ndef, kernel);
} else {
// Knowledge of the implementation of InitFn (and in-turn
// FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
// will be accessed, so grab on to the lock.
// See WARNING comment below - would be nice to rework to avoid this
// subtlety.
tensorflow::mutex_lock l(ctx->functions_mu);
status->status = tensorflow::KernelAndDevice::InitFn(
ndef, ctx->func_lib(device), kernel);
}
// Knowledge of the implementation of Init (and in-turn
// FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
// will be accessed, so grab on to the lock.
// See WARNING comment below - would be nice to rework to avoid this
// subtlety.
tensorflow::tf_shared_lock l(ctx->functions_mu);
status->status =
tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel);
if (!status->status.ok()) {
delete kernel;
return;
}
tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);

View File

@ -238,9 +238,8 @@ Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef,
}
// static
Status KernelAndDevice::InitFn(const NodeDef& ndef,
FunctionLibraryRuntime* flib,
KernelAndDevice* out) {
Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
KernelAndDevice* out) {
OpKernel* k = nullptr;
Status s = flib->CreateKernel(ndef, &k);
out->device_ = flib->device();

View File

@ -150,28 +150,19 @@ class KernelAndDevice {
public:
// Populates 'out' with a kernel appropriate for 'ndef'.
//
// Assumes that 'ndef' refers to a primitive op (as opposed to a function).
static Status InitOp(Device* device, const NodeDef& ndef,
KernelAndDevice* out);
// Like InitOp but for functions defined in flib (i.e., ndef.op() refers to a
// TensorFlow function in the FunctionLibraryRuntime).
//
// The provided FunctionLibraryRuntime MUST outlive all calls to
// Run() on the returned KernelAndDevice.
//
// TODO(ashankar): There shouldn't be a need for a separate InitOp and InitFn.
// The implementation of InitFn should work for both because
// FunctionLibraryRuntime::CreateKernel will create a primitive op kernel if
// appropriate. However, for now we keep them separate because I haven't
// figured out thread-safety concerns around FunctionLibraryRuntime (in
// particular, how the underlying FunctionLibraryDefinition might be mutated
// by another thread as new functions are registered with it).
// Conservatively, thread-safe usage of the FunctionLibraryRuntime is pushed
// on to the caller (see locking in c_api.cc) for now. But I really should
// dig into this so that both InitOp and InitFn can be collapsed to
// FunctionLibraryRuntime::CreateKernel.
static Status InitFn(const NodeDef& ndef, FunctionLibraryRuntime* flib,
// TODO(ashankar): Figure out thread-safety concerns around
// FunctionLibraryRuntime (in particular, how the underlying
// FunctionLibraryDefinition might be mutated by another thread as new
// functions are registered with it). Conservatively, thread-safe usage of
// the FunctionLibraryRuntime is pushed on to the caller (see locking in
// c_api.cc).
static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
KernelAndDevice* out);
// TODO(ashankar): Remove this
static Status InitOp(Device* device, const NodeDef& ndef,
KernelAndDevice* out);
KernelAndDevice(tensorflow::Rendezvous* rendez)
@ -184,10 +175,10 @@ class KernelAndDevice {
private:
std::unique_ptr<OpKernel> kernel_;
tensorflow::Device* device_;
tensorflow::FunctionLibraryRuntime* flib_;
tensorflow::checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
tensorflow::Rendezvous* rendez_;
Device* device_;
FunctionLibraryRuntime* flib_;
checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_;
Rendezvous* rendez_;
};
} // namespace tensorflow

View File

@ -23,15 +23,36 @@ limitations under the License.
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace {
Device* CPUDevice() {
return DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
}
class TestEnv {
public:
TestEnv() : flib_def_(OpRegistry::Global(), {}) {
Device* device =
DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0");
device_mgr_.reset(new DeviceMgr({device}));
flib_runtime_ = NewFunctionLibraryRuntime(device_mgr_.get(), Env::Default(),
device, TF_GRAPH_DEF_VERSION,
&flib_def_, {}, nullptr);
}
FunctionLibraryRuntime* function_library_runtime() const {
return flib_runtime_.get();
}
private:
FunctionLibraryDefinition flib_def_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryRuntime> flib_runtime_;
};
TEST(AttrTypeMap, Lookup) {
const AttrTypeMap* m = nullptr;
@ -69,9 +90,10 @@ TEST(KernelAndDevice, Run) {
.Set("transpose_b", false)
.NumInputs(inputs.size())
.BuildNodeDef());
std::unique_ptr<Device> device(CPUDevice());
TestEnv env;
KernelAndDevice kernel(nullptr);
Status s = KernelAndDevice::InitOp(device.get(), ndef, &kernel);
Status s =
KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel);
ASSERT_TRUE(s.ok()) << s;
std::vector<Tensor> outputs;
s = kernel.Run(&inputs, &outputs);
@ -132,11 +154,12 @@ void BM_KernelAndDeviceInit(int iters) {
.Set("transpose_b", false)
.NumInputs(2)
.BuildNodeDef());
std::unique_ptr<Device> device(CPUDevice());
TestEnv env;
KernelAndDevice k(nullptr);
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &k));
TF_CHECK_OK(
KernelAndDevice::Init(ndef, env.function_library_runtime(), &k));
}
}
BENCHMARK(BM_KernelAndDeviceInit);
@ -154,9 +177,10 @@ void BM_KernelAndDeviceRun(int iters) {
.Set("transpose_b", false)
.NumInputs(inputs.size())
.BuildNodeDef());
std::unique_ptr<Device> device(CPUDevice());
TestEnv env;
KernelAndDevice kernel(nullptr);
TF_CHECK_OK(KernelAndDevice::InitOp(device.get(), ndef, &kernel));
TF_CHECK_OK(
KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel));
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TF_CHECK_OK(kernel.Run(&inputs, &outputs));

View File

@ -286,6 +286,7 @@ cc_library(
srcs = ["call_inliner.cc"],
hdrs = ["call_inliner.h"],
deps = [
":call_graph",
":hlo_pass",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib",

View File

@ -17,33 +17,11 @@ limitations under the License.
#include <deque>
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
StatusOr<bool> CallInliner::Run(HloModule* module) {
std::deque<HloInstruction*> work_queue;
// Seed the work queue with call instructions from the main computation.
TF_RETURN_IF_ERROR(
module->entry_computation()->Accept([&](HloInstruction* hlo) {
if (hlo->opcode() == HloOpcode::kCall) {
work_queue.push_back(hlo);
}
return Status::OK();
}));
VLOG(1) << "Work queue seeded with " << work_queue.size() << " entries.";
bool mutated = false;
while (!work_queue.empty()) {
mutated = true;
HloInstruction* call = work_queue.front();
work_queue.pop_front();
TF_RETURN_IF_ERROR(ReplaceWithInlinedBody(call, &work_queue));
}
return mutated;
}
namespace {
// Traverses the callee computation, inlining cloned nodes into the caller
// computation and connecting them to producers/consumers appropriately.
@ -52,14 +30,18 @@ StatusOr<bool> CallInliner::Run(HloModule* module) {
// computation have been added to the work_queue.
class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
public:
SubcomputationInsertionVisitor(HloInstruction* call,
std::deque<HloInstruction*>* work_queue)
: call_(call), outer_(call->parent()), work_queue_(work_queue) {}
// call is the call operation -- it will be replaced with the body of the
// called computation.
explicit SubcomputationInsertionVisitor(HloInstruction* call)
: call_(call), outer_(call->parent()) {
CHECK_EQ(HloOpcode::kCall, call_->opcode());
}
// Resolves the operands to the HLO instruction in the inlined (caller) graph,
// and clones the HLO instruction into that graph with the new operands.
// If the instruction is a call, it is added to the work queue.
Status DefaultAction(HloInstruction* hlo) override {
TF_RET_CHECK(hlo->opcode() != HloOpcode::kCall);
std::vector<HloInstruction*> new_operands;
for (HloInstruction* operand : hlo->operands()) {
TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, Resolve(operand));
@ -79,12 +61,6 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
new_control_predecessor->AddControlDependencyTo(new_hlo_pointer));
}
if (new_hlo_pointer->opcode() == HloOpcode::kCall) {
VLOG(1) << "Adding new call HLO to work queue.";
// Call instructions we observe in the subcomputation are added to the
// inliner work queue.
work_queue_->push_back(new_hlo_pointer);
}
return Status::OK();
}
@ -141,16 +117,30 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
std::deque<HloInstruction*>* work_queue_;
};
Status CallInliner::ReplaceWithInlinedBody(
HloInstruction* call, std::deque<HloInstruction*>* work_queue) {
TF_RET_CHECK(call->opcode() == HloOpcode::kCall);
TF_RET_CHECK(call->called_computations().size() == 1);
HloComputation* called = call->called_computations()[0];
VLOG(1) << "Replacing call " << call->ToString() << " with inlined body of "
<< called->name();
} // namespace
SubcomputationInsertionVisitor visitor(call, work_queue);
return called->Accept(&visitor);
StatusOr<bool> CallInliner::Run(HloModule* module) {
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
// Because call graph nodes are visited in post-order (callees before callers)
// we'll always inline kCalls into their callers in the appropriate order.
bool did_mutate = false;
TF_RETURN_IF_ERROR(
call_graph->VisitNodes([&](const CallGraphNode& node) -> Status {
for (const CallSite& callsite : node.caller_callsites()) {
VLOG(1) << "Visiting callsite: " << callsite.ToString();
if (callsite.instruction()->opcode() == HloOpcode::kCall) {
did_mutate = true;
const auto& callees = callsite.called_computations();
TF_RET_CHECK(callees.size() == 1);
HloComputation* callee = callees[0];
// We visit the callee, cloning its body into its caller.
SubcomputationInsertionVisitor visitor(callsite.instruction());
TF_RETURN_IF_ERROR(callee->Accept(&visitor));
}
}
return Status::OK();
}));
return did_mutate;
}
} // namespace xla

View File

@ -31,16 +31,6 @@ class CallInliner : public HloPassInterface {
tensorflow::StringPiece name() const override { return "CallInliner"; }
StatusOr<bool> Run(HloModule* module) override;
private:
// Replaces the given call operation -- which must be an operation inside the
// entry computation with opcode kCall -- with the called computation's body,
// such that the called computation is inline in the entry computation.
//
// On successful inlining, the inlined computation may have itself contained
// calls; if so, they are added to the work_queue.
Status ReplaceWithInlinedBody(HloInstruction* call,
std::deque<HloInstruction*>* work_queue);
};
} // namespace xla

View File

@ -44,6 +44,8 @@ namespace {
using CallInlinerTest = HloTestBase;
TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
// "inner" computation just has a control dependency from the "zero" value to
// the "one" value.
HloComputation::Builder inner(TestName() + ".inner");
HloInstruction* zero = inner.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(24.0f)));
@ -54,6 +56,7 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
HloComputation* inner_computation =
module->AddEmbeddedComputation(inner.Build());
// "outer" computation just calls the "inner" computation.
HloComputation::Builder outer(TestName() + ".outer");
Shape r0f32 = ShapeUtil::MakeShape(F32, {});
outer.AddInstruction(
@ -73,5 +76,44 @@ TEST_F(CallInlinerTest, ControlDependenciesAreCarriedToCaller) {
EXPECT_EQ(prior->literal().GetFirstElement<float>(), 24);
}
// Tests for referential transparency (a function that calls a function that
// returns false should be identical to just returning false).
TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) {
const Shape pred = ShapeUtil::MakeShape(PRED, {});
auto module = CreateNewModule();
// Create a lambda that calls a function that returns the false predicate.
// Note we also use this lambda twice by reference, just to make the test a
// little trickier.
HloComputation::Builder just_false(TestName() + ".false");
just_false.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
HloComputation* false_computation =
module->AddEmbeddedComputation(just_false.Build());
HloComputation::Builder call_false_builder(TestName() + ".call_false");
call_false_builder.AddInstruction(
HloInstruction::CreateCall(pred, {}, false_computation));
HloComputation* call_false =
module->AddEmbeddedComputation(call_false_builder.Build());
HloComputation::Builder outer(TestName() + ".outer");
HloInstruction* init_value = outer.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
outer.AddInstruction(
HloInstruction::CreateWhile(pred, call_false, call_false, init_value));
auto computation = module->AddEntryComputation(outer.Build());
CallInliner call_inliner;
TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
ASSERT_TRUE(mutated);
EXPECT_THAT(
computation->root_instruction()->while_condition()->root_instruction(),
op::Constant());
EXPECT_THAT(computation->root_instruction()->while_body()->root_instruction(),
op::Constant());
}
} // namespace
} // namespace xla

View File

@ -216,8 +216,7 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault {
Status HandleCall(HloInstruction* call) override {
TF_RETURN_IF_ERROR(DefaultAction(call));
CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_);
TF_RETURN_IF_ERROR(
call->to_apply()->root_instruction()->Accept(&candidates_for_call));
TF_RETURN_IF_ERROR(call->to_apply()->Accept(&candidates_for_call));
return Status::OK();
}

View File

@ -45,8 +45,7 @@ string HloExecutionProfile::ToString(
const HloComputation& computation,
const DeviceDescription& device_description,
HloCostAnalysis* cost_analysis) const {
tensorflow::Status analysis_status =
computation.root_instruction()->Accept(cost_analysis);
tensorflow::Status analysis_status = computation.Accept(cost_analysis);
if (!analysis_status.ok()) {
return "";
}

View File

@ -1179,8 +1179,7 @@ tensorflow::Status Service::GetComputationStats(
HloCostAnalysis analysis(
execute_backend_->compiler()->ShapeSizeBytesFunction());
TF_RETURN_IF_ERROR(
module->entry_computation()->root_instruction()->Accept(&analysis));
TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&analysis));
ComputationStats stats;
stats.set_flop_count(analysis.flop_count());

View File

@ -151,19 +151,6 @@ XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) {
ComputeAndCompareR0<int32>(&builder, -3, {});
}
XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) {
ComputationBuilder builder(client_, TestName());
auto a = builder.Parameter(0, ShapeUtil::MakeShape(S64, {}), "a");
builder.ConvertElementType(a, F32);
int64 value = 3LL << 32;
std::unique_ptr<Literal> a_literal = Literal::CreateR0<int64>(value);
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
ComputeAndCompareR0<float>(&builder, static_cast<float>(value),
{a_data.get()});
}
XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) {
ComputationBuilder builder(client_, TestName());
builder.Mul(builder.Mul(builder.ConstantR0<float>(2.1f),

View File

@ -94,7 +94,7 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
OperationDumper dumper(arg);
for (auto& computation : module.computations()) {
TF_CHECK_OK(computation->root_instruction()->Accept(&dumper));
TF_CHECK_OK(computation->Accept(&dumper));
}
}
}

View File

@ -113,6 +113,7 @@ py_test(
srcs_version = "PY2AND3",
tags = [
"nomac", # b/63258195
"notsan", # b/62863147
],
deps = [
":gbdt_batch",

View File

@ -19,6 +19,14 @@ set(GIFLIB_INCLUDES
"lib/gif_lib.h"
)
if (WIN32)
# Suppress warnings to reduce build log size.
add_definitions(/wd4267 /wd4244 /wd4800 /wd4503 /wd4554 /wd4996 /wd4348 /wd4018)
add_definitions(/wd4099 /wd4146 /wd4267 /wd4305 /wd4307)
add_definitions(/wd4715 /wd4722 /wd4723 /wd4838 /wd4309 /wd4334)
add_definitions(/wd4003 /wd4244 /wd4267 /wd4503 /wd4506 /wd4800 /wd4996)
endif()
include_directories("${CMAKE_CURRENT_SOURCE_DIR}/lib")
add_library(giflib ${GIFLIB_SRCS})

View File

@ -62,6 +62,14 @@ set(LIBJPEG_INCLUDES
"jversion.h"
)
if (WIN32)
# Suppress warnings to reduce build log size.
add_definitions(/wd4267 /wd4244 /wd4800 /wd4503 /wd4554 /wd4996 /wd4348 /wd4018)
add_definitions(/wd4099 /wd4146 /wd4267 /wd4305 /wd4307)
add_definitions(/wd4715 /wd4722 /wd4723 /wd4838 /wd4309 /wd4334)
add_definitions(/wd4003 /wd4244 /wd4267 /wd4503 /wd4506 /wd4800 /wd4996)
endif()
include_directories("${CMAKE_CURRENT_SOURCE_DIR}")
add_library(libjpeg ${LIBJPEG_SRCS})

View File

@ -12,6 +12,14 @@ set(LIBLMDB_INCLUDES
"libraries/liblmdb/midl.h"
)
if (WIN32)
# Suppress warnings to reduce build log size.
add_definitions(/wd4267 /wd4244 /wd4800 /wd4503 /wd4554 /wd4996 /wd4348 /wd4018)
add_definitions(/wd4099 /wd4146 /wd4267 /wd4305 /wd4307)
add_definitions(/wd4715 /wd4722 /wd4723 /wd4838 /wd4309 /wd4334)
add_definitions(/wd4003 /wd4244 /wd4267 /wd4503 /wd4506 /wd4800 /wd4996)
endif()
include_directories("${CMAKE_CURRENT_SOURCE_DIR}")
add_library(lmdb ${LIBLMDB_SRCS})

View File

@ -361,6 +361,11 @@ add_python_module("tensorflow/contrib/framework/python")
add_python_module("tensorflow/contrib/framework/python/framework")
add_python_module("tensorflow/contrib/framework/python/ops")
add_python_module("tensorflow/contrib/gan")
add_python_module("tensorflow/contrib/gan/python")
add_python_module("tensorflow/contrib/gan/python/features")
add_python_module("tensorflow/contrib/gan/python/features/python")
add_python_module("tensorflow/contrib/gan/python/losses")
add_python_module("tensorflow/contrib/gan/python/losses/python")
add_python_module("tensorflow/contrib/graph_editor")
add_python_module("tensorflow/contrib/graph_editor/examples")
add_python_module("tensorflow/contrib/graph_editor/tests")
@ -1147,12 +1152,24 @@ add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen
${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/include/unsupported/Eigen)
if(${tensorflow_ENABLE_GPU})
add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/setup.py bdist_wheel --project_name tensorflow_gpu
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tf_python)
if(${tensorflow_TF_NIGHTLY})
if(${tensorflow_ENABLE_GPU})
add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/setup.py bdist_wheel --project_name tf_nightly_gpu
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tf_python)
else()
add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/setup.py bdist_wheel --project_name tf_nightly
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tf_python)
endif(${tensorflow_ENABLE_GPU})
else()
add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/setup.py bdist_wheel
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tf_python)
endif(${tensorflow_ENABLE_GPU})
if(${tensorflow_ENABLE_GPU})
add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/setup.py bdist_wheel --project_name tensorflow_gpu
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tf_python)
else()
add_custom_command(TARGET tf_python_build_pip_package POST_BUILD
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_BINARY_DIR}/tf_python/setup.py bdist_wheel
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tf_python)
endif(${tensorflow_ENABLE_GPU})
endif(${tensorflow_TF_NIGHTLY})

View File

@ -24,6 +24,8 @@ py_test(
"//tensorflow/python:functional_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:script_ops",
"//tensorflow/python:training",
"//third_party/py/numpy",
],

View File

@ -27,10 +27,13 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
@ -420,7 +423,7 @@ class IteratorTest(test.TestCase):
def testRemoteIteratorUsingRemoteCallOpDirectSession(self):
worker_config = config_pb2.ConfigProto()
worker_config.device_count["CPU"] = 2
worker_config.device_count["CPU"] = 3
with ops.device("/job:localhost/replica:0/task:0/cpu:1"):
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
@ -448,12 +451,12 @@ class IteratorTest(test.TestCase):
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
})
self.assertEqual(elem, [1])
# Fails when target is cpu:0 where the resource is not located.
# Fails when target is cpu:2 where the resource is not located.
with self.assertRaises(errors.InvalidArgumentError):
sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
target_placeholder: "/job:localhost/replica:0/task:0/cpu:2"
})
elem = sess.run(
remote_op,
@ -474,6 +477,61 @@ class IteratorTest(test.TestCase):
target_placeholder: "/job:localhost/replica:0/task:0/cpu:1"
})
def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self):
if not test_util.is_gpu_available():
self.skipTest("No GPU available")
with ops.device("/job:localhost/replica:0/task:0/cpu:0"):
dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3])
iterator_3 = dataset_3.make_one_shot_iterator()
iterator_3_handle = iterator_3.string_handle()
def _encode_raw(byte_array):
return "".join([chr(item) for item in byte_array])
@function.Defun(dtypes.uint8)
def _remote_fn(h):
handle = script_ops.py_func(_encode_raw, [h], dtypes.string)
remote_iterator = dataset_ops.Iterator.from_string_handle(
handle, dataset_3.output_types, dataset_3.output_shapes)
return remote_iterator.get_next()
with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"):
target_placeholder = array_ops.placeholder(dtypes.string, shape=[])
iterator_3_handle_uint8 = parsing_ops.decode_raw(
bytes=iterator_3_handle, out_type=dtypes.uint8)
remote_op = functional_ops.remote_call(
args=[iterator_3_handle_uint8],
Tout=[dtypes.int32],
f=_remote_fn,
target=target_placeholder)
with self.test_session() as sess:
elem = sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
})
self.assertEqual(elem, [1])
elem = sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
})
self.assertEqual(elem, [2])
elem = sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
})
self.assertEqual(elem, [3])
with self.assertRaises(errors.OutOfRangeError):
sess.run(
remote_op,
feed_dict={
target_placeholder: "/job:localhost/replica:0/task:0/cpu:0"
})
if __name__ == "__main__":
test.main()

View File

@ -235,7 +235,7 @@ class SloppyInterleaveDatasetTest(test.TestCase):
self.read_coordination_events[expected_element].acquire()
else:
self.write_coordination_events[expected_element].set()
time.sleep(0.01) # Sleep to consistently "avoid" the race condition.
time.sleep(0.1) # Sleep to consistently "avoid" the race condition.
actual_element = sess.run(self.next_element)
if not done_first_event:
done_first_event = True
@ -300,7 +300,7 @@ class SloppyInterleaveDatasetTest(test.TestCase):
self.read_coordination_events[expected_element].acquire()
else:
self.write_coordination_events[expected_element].set()
time.sleep(0.01) # Sleep to consistently "avoid" the race condition.
time.sleep(0.1) # Sleep to consistently "avoid" the race condition.
actual_element = sess.run(self.next_element)
if not done_first_event:
done_first_event = True

View File

@ -49,25 +49,46 @@ class SqlDatasetTest(test.TestCase):
c = conn.cursor()
c.execute("DROP TABLE IF EXISTS students")
c.execute("DROP TABLE IF EXISTS people")
c.execute("DROP TABLE IF EXISTS townspeople")
c.execute(
"CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY,"
" first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100),"
" school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), "
"grade_level INTEGER, income INTEGER, favorite_number INTEGER)")
"CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, "
"first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), "
"school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), "
"desk_number INTEGER, income INTEGER, favorite_number INTEGER, "
"favorite_big_number INTEGER, favorite_negative_number INTEGER, "
"favorite_medium_sized_number INTEGER, brownie_points INTEGER, "
"account_balance INTEGER, registration_complete INTEGER)")
c.executemany(
"INSERT INTO students (first_name, last_name, motto, school_id, "
"favorite_nonsense_word, grade_level, income, favorite_number) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
[("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647),
("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 11, -20000,
-2147483648)])
"favorite_nonsense_word, desk_number, income, favorite_number, "
"favorite_big_number, favorite_negative_number, "
"favorite_medium_sized_number, brownie_points, account_balance, "
"registration_complete) "
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
[("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647,
9223372036854775807, -2, 32767, 0, 0, 1),
("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000,
-2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)])
c.execute(
"CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, "
"first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))")
c.executemany(
"INSERT INTO people (first_name, last_name, state) VALUES (?, ?, ?)",
"INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)",
[("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe",
"California")])
c.execute(
"CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY "
"KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories "
"FLOAT, accolades FLOAT, triumphs FLOAT)")
c.executemany(
"INSERT INTO townspeople (first_name, last_name, victories, "
"accolades, triumphs) VALUES (?, ?, ?, ?, ?)",
[("George", "Washington", 20.00,
1331241.321342132321324589798264627463827647382647382643874,
9007199254740991.0),
("John", "Adams", -19.95,
1331241321342132321324589798264627463827647382647382643874.0,
9007199254740992.0)])
conn.commit()
conn.close()
@ -80,7 +101,6 @@ class SqlDatasetTest(test.TestCase):
sess.run(
init_op,
feed_dict={
self.driver_name: "sqlite",
self.query: "SELECT first_name, last_name, motto FROM students "
"ORDER BY first_name DESC"
})
@ -98,7 +118,6 @@ class SqlDatasetTest(test.TestCase):
sess.run(
init_op,
feed_dict={
self.driver_name: "sqlite",
self.query:
"SELECT students.first_name, state, motto FROM students "
"INNER JOIN people "
@ -118,7 +137,6 @@ class SqlDatasetTest(test.TestCase):
sess.run(
init_op,
feed_dict={
self.driver_name: "sqlite",
self.query:
"SELECT first_name, last_name, favorite_nonsense_word "
"FROM students ORDER BY first_name DESC"
@ -249,20 +267,124 @@ class SqlDatasetTest(test.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in an `int8` tensor.
def testReadResultSetInt8(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int8` tensor.
def testReadResultSetInt8NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8,
dtypes.int8))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, income, favorite_negative_number "
"FROM students "
"WHERE first_name = 'John' ORDER BY first_name DESC"
})
self.assertEqual((b"John", 0, -2), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int8` tensor.
def testReadResultSetInt8MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query:
"SELECT desk_number, favorite_negative_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((9, -2), sess.run(get_next))
# Max and min values of int8
self.assertEqual((127, -128), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in an `int16` tensor.
def testReadResultSetInt16(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int16` tensor.
def testReadResultSetInt16NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16,
dtypes.int16))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, income, favorite_negative_number "
"FROM students "
"WHERE first_name = 'John' ORDER BY first_name DESC"
})
self.assertEqual((b"John", 0, -2), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int16` tensor.
def testReadResultSetInt16MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, favorite_medium_sized_number "
"FROM students ORDER BY first_name DESC"
})
# Max value of int16
self.assertEqual((b"John", 32767), sess.run(get_next))
# Min value of int16
self.assertEqual((b"Jane", -32768), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in an `int32` tensor.
def testReadResultSetInt32(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, grade_level FROM students "
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 11), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
self.assertEqual((b"Jane", 127), sess.run(get_next))
# Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int32` tensor.
def testReadResultSetInt32NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
with self.test_session() as sess:
@ -277,6 +399,8 @@ class SqlDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int32` tensor.
def testReadResultSetInt32MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
with self.test_session() as sess:
@ -286,7 +410,9 @@ class SqlDatasetTest(test.TestCase):
self.query: "SELECT first_name, favorite_number FROM students "
"ORDER BY first_name DESC"
})
# Max value of int32
self.assertEqual((b"John", 2147483647), sess.run(get_next))
# Min value of int32
self.assertEqual((b"Jane", -2147483648), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@ -307,6 +433,224 @@ class SqlDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table
# and place it in an `int64` tensor.
def testReadResultSetInt64(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a negative or 0-valued integer from a
# SQLite database table and place it in an `int64` tensor.
def testReadResultSetInt64NegativeAndZero(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, income FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 0), sess.run(get_next))
self.assertEqual((b"Jane", -20000), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a large (positive or negative) integer from
# a SQLite database table and place it in an `int64` tensor.
def testReadResultSetInt64MaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query:
"SELECT first_name, favorite_big_number FROM students "
"ORDER BY first_name DESC"
})
# Max value of int64
self.assertEqual((b"John", 9223372036854775807), sess.run(get_next))
# Min value of int64
self.assertEqual((b"Jane", -9223372036854775808), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table and
# place it in a `uint8` tensor.
def testReadResultSetUInt8(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read the minimum and maximum uint8 values from a
# SQLite database table and place them in `uint8` tensors.
def testReadResultSetUInt8MinAndMaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, brownie_points FROM students "
"ORDER BY first_name DESC"
})
# Min value of uint8
self.assertEqual((b"John", 0), sess.run(get_next))
# Max value of uint8
self.assertEqual((b"Jane", 255), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read an integer from a SQLite database table
# and place it in a `uint16` tensor.
def testReadResultSetUInt16(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, desk_number FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", 9), sess.run(get_next))
self.assertEqual((b"Jane", 127), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read the minimum and maximum uint16 values from a
# SQLite database table and place them in `uint16` tensors.
def testReadResultSetUInt16MinAndMaxValues(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, account_balance FROM students "
"ORDER BY first_name DESC"
})
# Min value of uint16
self.assertEqual((b"John", 0), sess.run(get_next))
# Max value of uint16
self.assertEqual((b"Jane", 65535), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a 0-valued and 1-valued integer from a
# SQLite database table and place them as `True` and `False` respectively
# in `bool` tensors.
def testReadResultSetBool(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query:
"SELECT first_name, registration_complete FROM students "
"ORDER BY first_name DESC"
})
self.assertEqual((b"John", True), sess.run(get_next))
self.assertEqual((b"Jane", False), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read an integer that is not 0-valued or 1-valued
# from a SQLite database table and place it as `True` in a `bool` tensor.
def testReadResultSetBoolNotZeroOrOne(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query: "SELECT first_name, favorite_medium_sized_number "
"FROM students ORDER BY first_name DESC"
})
self.assertEqual((b"John", True), sess.run(get_next))
self.assertEqual((b"Jane", True), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a float from a SQLite database table
# and place it in a `float64` tensor.
def testReadResultSetFloat64(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.float64))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query:
"SELECT first_name, last_name, victories FROM townspeople "
"ORDER BY first_name"
})
self.assertEqual((b"George", b"Washington", 20.0), sess.run(get_next))
self.assertEqual((b"John", b"Adams", -19.95), sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a float from a SQLite database table beyond
# the precision of 64-bit IEEE, without throwing an error. Test that
# `SqlDataset` identifies such a value as equal to itself.
def testReadResultSetFloat64OverlyPrecise(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.float64))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query:
"SELECT first_name, last_name, accolades FROM townspeople "
"ORDER BY first_name"
})
self.assertEqual(
(b"George", b"Washington",
1331241.321342132321324589798264627463827647382647382643874),
sess.run(get_next))
self.assertEqual(
(b"John", b"Adams",
1331241321342132321324589798264627463827647382647382643874.0),
sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
# Test that `SqlDataset` can read a float from a SQLite database table,
# representing the largest integer representable as a 64-bit IEEE float
# such that the previous integer is also representable as a 64-bit IEEE float.
# Test that `SqlDataset` can distinguish these two numbers.
def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self):
init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
dtypes.float64))
with self.test_session() as sess:
sess.run(
init_op,
feed_dict={
self.query:
"SELECT first_name, last_name, triumphs FROM townspeople "
"ORDER BY first_name"
})
self.assertNotEqual((b"George", b"Washington", 9007199254740992.0),
sess.run(get_next))
self.assertNotEqual((b"John", b"Adams", 9007199254740991.0),
sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
if __name__ == "__main__":
test.main()

View File

@ -2276,6 +2276,23 @@ class SqlDataset(Dataset):
def __init__(self, driver_name, data_source_name, query, output_types):
"""Creates a `SqlDataset`.
`SqlDataset` allows a user to read data from the result set of a SQL query.
For example:
```python
dataset = tf.contrib.data.SqlDataset("sqlite", "/foo/bar.sqlite3",
"SELECT name, age FROM people",
(tf.string, tf.int32))
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
# Prints the rows of the result set of the above query.
while True:
try:
print(sess.run(next_element))
except tf.errors.OutOfRangeError:
break
```
Args:
driver_name: A 0-D `tf.string` tensor containing the database type.
Currently, the only supported value is 'sqlite'.

View File

@ -21,6 +21,16 @@ py_library(
],
)
cuda_py_test(
name = "tfe_test",
srcs = ["tfe_test.py"],
additional_deps = [
":tfe",
"//tensorflow/python:client_testlib",
"//tensorflow/python:platform_test",
],
)
py_library(
name = "saver",
srcs = ["saver.py"],

View File

@ -18,9 +18,9 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice.
To use, at program startup, call `tfe.enable_eager_execution()`.
@@list_devices
@@device
@@list_devices
@@num_gpus
@@defun
@@implicit_gradients
@ -58,9 +58,10 @@ from tensorflow.python.util.all_util import remove_undocumented
from tensorflow.python.eager import backprop
from tensorflow.python.eager.custom_gradient import custom_gradient
from tensorflow.python.eager import function
from tensorflow.python.eager.context import context
from tensorflow.python.eager.context import device
from tensorflow.python.eager.context import enable_eager_execution
from tensorflow.python.eager.context import list_devices
from tensorflow.python.eager.context import num_gpus
from tensorflow.python.eager.context import run
from tensorflow.python.eager.core import enable_tracing
from tensorflow.python.eager.execution_callbacks import add_execution_callback
@ -70,10 +71,6 @@ from tensorflow.python.eager.execution_callbacks import inf_nan_callback
from tensorflow.python.eager.execution_callbacks import nan_callback
from tensorflow.python.eager.execution_callbacks import seterr
def list_devices():
return context().devices()
defun = function.defun
implicit_gradients = backprop.implicit_grad
implicit_value_and_gradients = backprop.implicit_val_and_grad

View File

@ -0,0 +1,36 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfe.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.eager.python import tfe
from tensorflow.python.platform import test
class TFETest(test.TestCase):
def testListDevices(self):
# Expect at least one device.
self.assertTrue(tfe.list_devices())
def testNumGPUs(self):
devices = tfe.list_devices()
self.assertEqual(len(devices) - 1, tfe.num_gpus())
if __name__ == "__main__":
test.main()

View File

@ -26,6 +26,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":extenders",
":head",
],
)
@ -59,3 +60,14 @@ py_test(
"//third_party/py/numpy",
],
)
py_library(
name = "head",
srcs = [
"python/estimator/head.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python/estimator:head",
],
)

View File

@ -20,10 +20,16 @@ from __future__ import print_function
# pylint: disable=unused-import,line-too-long,wildcard-import
from tensorflow.contrib.estimator.python.estimator.extenders import *
from tensorflow.contrib.estimator.python.estimator.head import *
from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
_allowed_symbols = ['add_metrics']
_allowed_symbols = [
'add_metrics',
'binary_classification_head',
'multi_class_head',
'regression_head',
]
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)

View File

@ -0,0 +1,125 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Abstractions for the head(s) of a model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.estimator.canned import head as head_lib
def multi_class_head(n_classes,
weight_column=None,
label_vocabulary=None,
head_name=None):
"""Creates a `_Head` for multi class classification.
Uses `sparse_softmax_cross_entropy` loss.
This head expects to be fed integer labels specifying the class index.
Args:
n_classes: Number of classes, must be greater than 2 (for 2 classes, use
`_BinaryLogisticHeadWithSigmoidCrossEntropyLoss`).
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example.
label_vocabulary: A list of strings represents possible label values. If it
is not given, that means labels are already encoded as integer within
[0, n_classes). If given, labels must be string type and have any value in
`label_vocabulary`. Also there will be errors if vocabulary is not
provided and labels are string.
head_name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + head_name`.
Returns:
An instance of `_Head` for multi class classification.
Raises:
ValueError: if `n_classes`, `metric_class_ids` or `label_keys` is invalid.
"""
return head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint:disable=protected-access
n_classes=n_classes,
weight_column=weight_column,
label_vocabulary=label_vocabulary,
head_name=head_name)
def binary_classification_head(
weight_column=None, thresholds=None, label_vocabulary=None, head_name=None):
"""Creates a `_Head` for single label binary classification.
This head uses `sigmoid_cross_entropy_with_logits` loss.
This head expects to be fed float labels of shape `(batch_size, 1)`.
Args:
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example.
thresholds: Iterable of floats in the range `(0, 1)`. For binary
classification metrics such as precision and recall, an eval metric is
generated for each threshold value. This threshold is applied to the
logistic values to determine the binary classification (i.e., above the
threshold is `true`, below is `false`.
label_vocabulary: A list of strings represents possible label values. If it
is not given, that means labels are already encoded within [0, 1]. If
given, labels must be string type and have any value in
`label_vocabulary`. Also there will be errors if vocabulary is not
provided and labels are string.
head_name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + head_name`.
Returns:
An instance of `_Head` for binary classification.
Raises:
ValueError: if `thresholds` contains a value outside of `(0, 1)`.
"""
return head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint:disable=protected-access
weight_column=weight_column,
thresholds=thresholds,
label_vocabulary=label_vocabulary,
head_name=head_name)
def regression_head(weight_column=None,
label_dimension=1,
head_name=None):
"""Creates a `_Head` for regression using the mean squared loss.
Uses `mean_squared_error` loss.
Args:
weight_column: A string or a `_NumericColumn` created by
`tf.feature_column.numeric_column` defining feature column representing
weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example.
label_dimension: Number of regression labels per example. This is the size
of the last dimension of the labels `Tensor` (typically, this has shape
`[batch_size, label_dimension]`).
head_name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + head_name`.
Returns:
An instance of `_Head` for linear regression.
"""
return head_lib._regression_head_with_mean_squared_error_loss( # pylint:disable=protected-access
weight_column=weight_column,
label_dimension=label_dimension,
head_name=head_name)

View File

@ -162,6 +162,7 @@ tf_py_test(
"//tensorflow/python:platform_test",
"//tensorflow/python:variables",
],
tags = ["notsan"], # b/62863147
)
py_library(

View File

@ -1,9 +1,12 @@
# Files for using TFGAN framework.
package(default_visibility = ["//tensorflow:__subpackages__"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
py_library(
name = "gan",
srcs = [
@ -11,6 +14,192 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
":features",
":losses",
],
)
py_library(
name = "losses",
srcs = ["python/losses/__init__.py"],
srcs_version = "PY2AND3",
deps = [
":losses_impl",
":tuple_losses",
"//tensorflow/python:util",
],
)
py_library(
name = "features",
srcs = ["python/features/__init__.py"],
srcs_version = "PY2AND3",
deps = [
":clip_weights",
":conditioning_utils",
":virtual_batchnorm",
"//tensorflow/python:util",
],
)
py_library(
name = "losses_impl",
srcs = ["python/losses/python/losses_impl.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:clip_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:gradients",
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:summary",
"//tensorflow/python:tensor_util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/ops/distributions",
"//tensorflow/python/ops/losses",
"//third_party/py/numpy",
],
)
py_test(
name = "losses_impl_test",
srcs = ["python/losses/python/losses_impl_test.py"],
srcs_version = "PY2AND3",
deps = [
":losses_impl",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:clip_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:random_seed",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/ops/distributions",
"//tensorflow/python/ops/losses",
],
)
py_library(
name = "tuple_losses",
srcs = [
"python/losses/python/losses_wargs.py",
"python/losses/python/tuple_losses.py",
"python/losses/python/tuple_losses_impl.py",
],
srcs_version = "PY2AND3",
deps = [
":losses_impl",
"//tensorflow/python:util",
],
)
py_test(
name = "tuple_losses_test",
srcs = ["python/losses/python/tuple_losses_test.py"],
srcs_version = "PY2AND3",
deps = [
":tuple_losses",
"//tensorflow/python:client_testlib",
"//third_party/py/numpy",
],
)
py_library(
name = "conditioning_utils",
srcs = [
"python/features/python/conditioning_utils.py",
"python/features/python/conditioning_utils_impl.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:embedding_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_util",
"//tensorflow/python:variable_scope",
],
)
py_test(
name = "conditioning_utils_test",
srcs = ["python/features/python/conditioning_utils_test.py"],
srcs_version = "PY2AND3",
deps = [
":conditioning_utils",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
],
)
py_library(
name = "virtual_batchnorm",
srcs = [
"python/features/python/virtual_batchnorm.py",
"python/features/python/virtual_batchnorm_impl.py",
],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
"//tensorflow/python:variable_scope",
],
)
py_test(
name = "virtual_batchnorm_test",
srcs = ["python/features/python/virtual_batchnorm_test.py"],
srcs_version = "PY2AND3",
deps = [
":virtual_batchnorm",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:layers",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn",
"//tensorflow/python:random_ops",
"//tensorflow/python:random_seed",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//third_party/py/numpy",
],
)
py_library(
name = "clip_weights",
srcs = [
"python/features/python/clip_weights.py",
"python/features/python/clip_weights_impl.py",
],
srcs_version = "PY2AND3",
deps = ["//tensorflow/contrib/opt:opt_py"],
)
py_test(
name = "clip_weights_test",
srcs = ["python/features/python/clip_weights_test.py"],
srcs_version = "PY2AND3",
deps = [
":clip_weights",
"//tensorflow/python:client_testlib",
"//tensorflow/python:training",
"//tensorflow/python:variables",
],
)

View File

@ -1,4 +1,4 @@
# Copyright 2017 Google Inc. All Rights Reserved.
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -12,8 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFGAN grouped API."""
"""TFGAN grouped API. Please see README.md for details and usage."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Collapse TFGAN into a tiered namespace.
from tensorflow.contrib.gan.python import features
from tensorflow.contrib.gan.python import losses
del absolute_import
del division
del print_function

View File

@ -0,0 +1,37 @@
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFGAN grouped API. Please see README.md for details and usage."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Collapse features into a single namespace.
# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.gan.python.features.python import clip_weights
from tensorflow.contrib.gan.python.features.python import conditioning_utils
from tensorflow.contrib.gan.python.features.python import virtual_batchnorm
from tensorflow.contrib.gan.python.features.python.clip_weights import *
from tensorflow.contrib.gan.python.features.python.conditioning_utils import *
from tensorflow.contrib.gan.python.features.python.virtual_batchnorm import *
# pylint: enable=unused-import,wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = clip_weights.__all__
_allowed_symbols += conditioning_utils.__all__
_allowed_symbols += virtual_batchnorm.__all__
remove_undocumented(__name__, _allowed_symbols)

View File

@ -0,0 +1,28 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to clip weights."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.gan.python.features.python import clip_weights_impl
# pylint: disable=wildcard-import
from tensorflow.contrib.gan.python.features.python.clip_weights_impl import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
__all__ = clip_weights_impl.__all__
remove_undocumented(__name__, __all__)

View File

@ -0,0 +1,80 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities to clip weights.
This is useful in the original formulation of the Wasserstein loss, which
requires that the discriminator be K-Lipschitz. See
https://arxiv.org/pdf/1701.07875 for more details.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.opt.python.training import variable_clipping_optimizer
__all__ = [
'clip_variables',
'clip_discriminator_weights',
]
def clip_discriminator_weights(optimizer, model, weight_clip):
"""Modifies an optimizer so it clips weights to a certain value.
Args:
optimizer: An optimizer to perform variable weight clipping.
model: A GANModel namedtuple.
weight_clip: Positive python float to clip discriminator weights. Used to
enforce a K-lipschitz condition, which is useful for some GAN training
schemes (ex WGAN: https://arxiv.org/pdf/1701.07875).
Returns:
An optimizer to perform weight clipping after updates.
Raises:
ValueError: If `weight_clip` is less than 0.
"""
return clip_variables(optimizer, model.discriminator_variables, weight_clip)
def clip_variables(optimizer, variables, weight_clip):
"""Modifies an optimizer so it clips weights to a certain value.
Args:
optimizer: An optimizer to perform variable weight clipping.
variables: A list of TensorFlow variables.
weight_clip: Positive python float to clip discriminator weights. Used to
enforce a K-lipschitz condition, which is useful for some GAN training
schemes (ex WGAN: https://arxiv.org/pdf/1701.07875).
Returns:
An optimizer to perform weight clipping after updates.
Raises:
ValueError: If `weight_clip` is less than 0.
"""
if weight_clip < 0:
raise ValueError(
'`discriminator_weight_clip` must be positive. Instead, was %s',
weight_clip)
return variable_clipping_optimizer.VariableClippingOptimizer(
opt=optimizer,
# Do no reduction, so clipping happens per-value.
vars_to_clip_dims={var: [] for var in variables},
max_norm=weight_clip,
use_locking=True,
colocate_clip_ops_with_vars=True)

View File

@ -0,0 +1,81 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfgan.python.features.clip_weights."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from tensorflow.contrib.gan.python.features.python import clip_weights_impl as clip_weights
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import training
class ClipWeightsTest(test.TestCase):
"""Tests for `discriminator_weight_clip`."""
def setUp(self):
self.variables = [variables.Variable(2.0)]
self.tuple = collections.namedtuple(
'VarTuple', ['discriminator_variables'])(self.variables)
def _test_weight_clipping_helper(self, use_tuple):
loss = self.variables[0] * 2.0
opt = training.GradientDescentOptimizer(1.0)
if use_tuple:
opt_clip = clip_weights.weight_clip(opt, self.variables, 0.1)
else:
opt_clip = clip_weights.discriminator_weight_clip(opt, self.tuple, 0.1)
train_op1 = opt.minimize(loss, var_list=self.variables)
train_op2 = opt_clip.minimize(loss, var_list=self.variables)
with self.test_session(use_gpu=True) as sess:
sess.run(variables.global_variables_initializer())
self.assertEqual(2.0, self.variables[0].eval())
sess.run(train_op1)
self.assertLess(0.1, self.variables[0].eval())
with self.test_session(use_gpu=True) as sess:
sess.run(variables.global_variables_initializer())
self.assertEqual(2.0, self.variables[0].eval())
sess.run(train_op2)
self.assertNear(0.1, self.variables[0].eval(), 1e-7)
def test_weight_clipping_argsonly(self):
self._test_weight_clipping_helper(False)
def test_weight_clipping_ganmodel(self):
self._test_weight_clipping_helper(True)
def _test_incorrect_weight_clip_value_helper(self, use_tuple):
opt = training.GradientDescentOptimizer(1.0)
if use_tuple:
with self.assertRaisesRegexp(ValueError, 'must be positive'):
clip_weights.clip_discriminator_weights(opt, self.tuple, weight_clip=-1)
else:
with self.assertRaisesRegexp(ValueError, 'must be positive'):
clip_weights.clip_weights(opt, self.variables, weight_clip=-1)
def test_incorrect_weight_clip_value_argsonly(self):
self._test_incorrect_weight_clip_value_helper(False)
def test_incorrect_weight_clip_value_tuple(self):
self._test_incorrect_weight_clip_value_helper(True)

View File

@ -0,0 +1,28 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Miscellanous utilities for TFGAN code and examples."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.gan.python.features.python import conditioning_utils_impl
# pylint: disable=wildcard-import
from tensorflow.contrib.gan.python.features.python.conditioning_utils_impl import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
__all__ = conditioning_utils_impl.__all__
remove_undocumented(__name__, __all__)

View File

@ -0,0 +1,112 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Miscellanous utilities for TFGAN code and examples.
Includes:
1) Conditioning the value of a Tensor, based on techniques from
https://arxiv.org/abs/1609.03499.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
__all__ = [
'condition_tensor',
'condition_tensor_from_onehot',
]
def _get_shape(tensor):
tensor_shape = array_ops.shape(tensor)
static_tensor_shape = tensor_util.constant_value(tensor_shape)
return (static_tensor_shape if static_tensor_shape is not None else
tensor_shape)
def condition_tensor(tensor, conditioning):
"""Condition the value of a tensor.
Conditioning scheme based on https://arxiv.org/abs/1609.03499.
Args:
tensor: A minibatch tensor to be conditioned.
conditioning: A minibatch Tensor of to condition on. Must be 2D, with first
dimension the same as `tensor`.
Returns:
`tensor` conditioned on `conditioning`.
Raises:
ValueError: If the non-batch dimensions of `tensor` aren't fully defined.
ValueError: If `conditioning` isn't at least 2D.
ValueError: If the batch dimension for the input Tensors don't match.
"""
tensor.shape[1:].assert_is_fully_defined()
num_features = tensor.shape[1:].num_elements()
mapped_conditioning = layers.linear(
layers.flatten(conditioning), num_features)
if not mapped_conditioning.shape.is_compatible_with(tensor.shape):
mapped_conditioning = array_ops.reshape(
mapped_conditioning, _get_shape(tensor))
return tensor + mapped_conditioning
def _one_hot_to_embedding(one_hot, embedding_size):
"""Get a dense embedding vector from a one-hot encoding."""
num_tokens = one_hot.shape[1]
label_id = math_ops.argmax(one_hot, axis=1)
embedding = variable_scope.get_variable(
'embedding', [num_tokens, embedding_size])
return embedding_ops.embedding_lookup(
embedding, label_id, name='token_to_embedding')
def _validate_onehot(one_hot_labels):
one_hot_labels.shape.assert_has_rank(2)
one_hot_labels.shape[1:].assert_is_fully_defined()
def condition_tensor_from_onehot(tensor, one_hot_labels, embedding_size=256):
"""Condition a tensor based on a one-hot tensor.
Conditioning scheme based on https://arxiv.org/abs/1609.03499.
Args:
tensor: Tensor to be conditioned.
one_hot_labels: A Tensor of one-hot labels. Shape is
[batch_size, num_classes].
embedding_size: The size of the class embedding.
Returns:
`tensor` conditioned on `one_hot_labels`.
Raises:
ValueError: `one_hot_labels` isn't 2D, if non-batch dimensions aren't
fully defined, or if batch sizes don't match.
"""
_validate_onehot(one_hot_labels)
conditioning = _one_hot_to_embedding(one_hot_labels, embedding_size)
return condition_tensor(tensor, conditioning)

View File

@ -0,0 +1,76 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfgan.python.features.conditioning_utils."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.gan.python.features.python import conditioning_utils_impl as conditioning_utils
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
class ConditioningUtilsTest(test.TestCase):
def test_condition_tensor_multiple_shapes(self):
for tensor_shape in [(4, 1), (4, 2), (4, 2, 6), (None, 5, 3)]:
for conditioning_shape in [(4, 1), (4, 8), (4, 5, 3)]:
conditioning_utils.condition_tensor(
array_ops.placeholder(dtypes.float32, tensor_shape),
array_ops.placeholder(dtypes.float32, conditioning_shape))
def test_condition_tensor_asserts(self):
with self.assertRaisesRegexp(ValueError, 'Cannot reshape'):
conditioning_utils.condition_tensor(
array_ops.placeholder(dtypes.float32, (4, 1)),
array_ops.placeholder(dtypes.float32, (5, 1)))
with self.assertRaisesRegexp(ValueError, 'Shape .* is not fully defined'):
conditioning_utils.condition_tensor(
array_ops.placeholder(dtypes.float32, (5, None)),
array_ops.placeholder(dtypes.float32, (5, 1)))
with self.assertRaisesRegexp(ValueError, 'must have a least 2 dimensions.'):
conditioning_utils.condition_tensor(
array_ops.placeholder(dtypes.float32, (5, 2)),
array_ops.placeholder(dtypes.float32, (5)))
def test_condition_tensor_from_onehot(self):
conditioning_utils.condition_tensor_from_onehot(
array_ops.placeholder(dtypes.float32, (5, 4, 1)),
array_ops.placeholder(dtypes.float32, (5, 10)))
def test_condition_tensor_from_onehot_asserts(self):
with self.assertRaisesRegexp(ValueError, 'Shape .* must have rank 2'):
conditioning_utils.condition_tensor_from_onehot(
array_ops.placeholder(dtypes.float32, (5, 1)),
array_ops.placeholder(dtypes.float32, (5)))
with self.assertRaisesRegexp(ValueError, 'Shape .* is not fully defined'):
conditioning_utils.condition_tensor_from_onehot(
array_ops.placeholder(dtypes.float32, (5, 1)),
array_ops.placeholder(dtypes.float32, (5, None)))
with self.assertRaisesRegexp(ValueError, 'Cannot reshape a tensor'):
conditioning_utils.condition_tensor_from_onehot(
array_ops.placeholder(dtypes.float32, (5, 1)),
array_ops.placeholder(dtypes.float32, (4, 6)))
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,27 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Virtual batch normalization."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.gan.python.features.python import virtual_batchnorm_impl
# pylint: disable=wildcard-import
from tensorflow.contrib.gan.python.features.python.virtual_batchnorm_impl import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
__all__ = virtual_batchnorm_impl.__all__
remove_undocumented(__name__, __all__)

View File

@ -0,0 +1,306 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Virtual batch normalization.
This technique was first introduced in `Improved Techniques for Training GANs`
(Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch
normalization on a minibatch, it fixes a reference subset of the data to use for
calculating normalization statistics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import variable_scope
__all__ = [
'VBN',
]
def _static_or_dynamic_batch_size(tensor, batch_axis):
"""Returns the static or dynamic batch size."""
batch_size = array_ops.shape(tensor)[batch_axis]
static_batch_size = tensor_util.constant_value(batch_size)
return static_batch_size or batch_size
def _statistics(x, axes):
"""Calculate the mean and mean square of `x`.
Modified from the implementation of `tf.nn.moments`.
Args:
x: A `Tensor`.
axes: Array of ints. Axes along which to compute mean and
variance.
Returns:
Two `Tensor` objects: `mean` and `square mean`.
"""
# The dynamic range of fp16 is too limited to support the collection of
# sufficient statistics. As a workaround we simply perform the operations
# on 32-bit floats before converting the mean and variance back to fp16
y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
# Compute true mean while keeping the dims for proper broadcasting.
shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keep_dims=True))
shifted_mean = math_ops.reduce_mean(y - shift, axes, keep_dims=True)
mean = shifted_mean + shift
mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keep_dims=True)
mean = array_ops.squeeze(mean, axes)
mean_squared = array_ops.squeeze(mean_squared, axes)
if x.dtype == dtypes.float16:
return (math_ops.cast(mean, dtypes.float16),
math_ops.cast(mean_squared, dtypes.float16))
else:
return (mean, mean_squared)
def _validate_init_input_and_get_axis(reference_batch, axis):
"""Validate input and return the used axis value."""
if reference_batch.shape.ndims is None:
raise ValueError('`reference_batch` has unknown dimensions.')
ndims = reference_batch.shape.ndims
if axis < 0:
used_axis = ndims + axis
else:
used_axis = axis
if used_axis < 0 or used_axis >= ndims:
raise ValueError('Value of `axis` argument ' + str(used_axis) +
' is out of range for input with rank ' + str(ndims))
return used_axis
def _validate_call_input(tensor_list, batch_dim):
"""Verifies that tensor shapes are compatible, except for `batch_dim`."""
def _get_shape(tensor):
shape = tensor.shape.as_list()
del shape[batch_dim]
return shape
base_shape = tensor_shape.TensorShape(_get_shape(tensor_list[0]))
for tensor in tensor_list:
base_shape.assert_is_compatible_with(_get_shape(tensor))
class VBN(object):
"""A class to perform virtual batch normalization.
This technique was first introduced in `Improved Techniques for Training GANs`
(Salimans et al, https://arxiv.org/abs/1606.03498). Instead of using batch
normalization on a minibatch, it fixes a reference subset of the data to use
for calculating normalization statistics.
To do this, we calculate the reference batch mean and mean square, and modify
those statistics for each example. We use mean square instead of variance,
since it is linear.
Note that if `center` or `scale` variables are created, they are shared
between all calls to this object.
The `__init__` API is intended to mimic `tf.layers.batch_normalization` as
closely as possible.
"""
def __init__(self,
reference_batch,
axis=-1,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer=init_ops.zeros_initializer(),
gamma_initializer=init_ops.ones_initializer(),
beta_regularizer=None,
gamma_regularizer=None,
trainable=True,
name=None,
batch_axis=0):
"""Initialize virtual batch normalization object.
We precompute the 'mean' and 'mean squared' of the reference batch, so that
`__call__` is efficient. This means that the axis must be supplied when the
object is created, not when it is called.
We precompute 'square mean' instead of 'variance', because the square mean
can be easily adjusted on a per-example basis.
Args:
reference_batch: A minibatch tensors. This will form the reference data
from which the normalization statistics are calculated. See
https://arxiv.org/abs/1606.03498 for more details.
axis: Integer, the axis that should be normalized (typically the features
axis). For instance, after a `Convolution2D` layer with
`data_format="channels_first"`, set `axis=1` in `BatchNormalization`.
epsilon: Small float added to variance to avoid dividing by zero.
center: If True, add offset of `beta` to normalized tensor. If False,
`beta` is ignored.
scale: If True, multiply by `gamma`. If False, `gamma` is
not used. When the next layer is linear (also e.g. `nn.relu`), this can
be disabled since the scaling can be done by the next layer.
beta_initializer: Initializer for the beta weight.
gamma_initializer: Initializer for the gamma weight.
beta_regularizer: Optional regularizer for the beta weight.
gamma_regularizer: Optional regularizer for the gamma weight.
trainable: Boolean, if `True` also add variables to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
name: String, the name of the ops.
batch_axis: The axis of the batch dimension. This dimension is treated
differently in `virtual batch normalization` vs `batch normalization`.
Raises:
ValueError: If `reference_batch` has unknown dimensions at graph
construction.
ValueError: If `batch_axis` is the same as `axis`.
"""
axis = _validate_init_input_and_get_axis(reference_batch, axis)
self._epsilon = epsilon
self._beta = 0
self._gamma = 1
self._batch_axis = _validate_init_input_and_get_axis(
reference_batch, batch_axis)
if axis == self._batch_axis:
raise ValueError('`axis` and `batch_axis` cannot be the same.')
with variable_scope.variable_scope(name, 'VBN',
values=[reference_batch]) as self._vs:
self._reference_batch = reference_batch
# Calculate important shapes:
# 1) Reduction axes for the reference batch
# 2) Broadcast shape, if necessary
# 3) Reduction axes for the virtual batchnormed batch
# 4) Shape for optional parameters
input_shape = self._reference_batch.shape
ndims = input_shape.ndims
reduction_axes = list(range(ndims))
del reduction_axes[axis]
self._broadcast_shape = [1] * len(input_shape)
self._broadcast_shape[axis] = input_shape[axis].value
self._example_reduction_axes = list(range(ndims))
del self._example_reduction_axes[max(axis, self._batch_axis)]
del self._example_reduction_axes[min(axis, self._batch_axis)]
params_shape = self._reference_batch.shape[axis]
# Determines whether broadcasting is needed. This is slightly different
# than in the `nn.batch_normalization` case, due to `batch_dim`.
self._needs_broadcasting = (
sorted(self._example_reduction_axes) != list(range(ndims))[:-2])
# Calculate the sufficient statistics for the reference batch in a way
# that can be easily modified by additional examples.
self._ref_mean, self._ref_mean_squares = _statistics(
self._reference_batch, reduction_axes)
self._ref_variance = (self._ref_mean_squares -
math_ops.square(self._ref_mean))
# Virtual batch normalization uses a weighted average between example
# statistics and the reference batch statistics.
ref_batch_size = _static_or_dynamic_batch_size(
self._reference_batch, self._batch_axis)
self._example_weight = 1. / (math_ops.to_float(ref_batch_size) + 1.)
self._ref_weight = 1. - self._example_weight
# Make the variables, if necessary.
if center:
self._beta = variable_scope.get_variable(
name='beta',
shape=(params_shape,),
initializer=beta_initializer,
regularizer=beta_regularizer,
trainable=trainable)
if scale:
self._gamma = variable_scope.get_variable(
name='gamma',
shape=(params_shape,),
initializer=gamma_initializer,
regularizer=gamma_regularizer,
trainable=trainable)
def _virtual_statistics(self, inputs, reduction_axes):
"""Compute the statistics needed for virtual batch normalization."""
cur_mean, cur_mean_sq = _statistics(inputs, reduction_axes)
vb_mean = (self._example_weight * cur_mean +
self._ref_weight * self._ref_mean)
vb_mean_sq = (self._example_weight * cur_mean_sq +
self._ref_weight * self._ref_mean_squares)
return (vb_mean, vb_mean_sq)
def _broadcast(self, v, broadcast_shape=None):
# The exact broadcast shape depends on the current batch, not the reference
# batch, unless we're calculating the batch normalization of the reference
# batch.
b_shape = broadcast_shape or self._broadcast_shape
if self._needs_broadcasting and v is not None:
return array_ops.reshape(v, b_shape)
return v
def reference_batch_normalization(self):
"""Return the reference batch, but batch normalized."""
with ops.name_scope(self._vs.name):
return nn.batch_normalization(self._reference_batch,
self._broadcast(self._ref_mean),
self._broadcast(self._ref_variance),
self._broadcast(self._beta),
self._broadcast(self._gamma),
self._epsilon)
def __call__(self, inputs):
"""Run virtual batch normalization on inputs.
Args:
inputs: Tensor input.
Returns:
A virtual batch normalized version of `inputs`.
Raises:
ValueError: If `inputs` shape isn't compatible with the reference batch.
"""
_validate_call_input([inputs, self._reference_batch], self._batch_axis)
with ops.name_scope(self._vs.name, values=[inputs, self._reference_batch]):
# Calculate the statistics on the current input on a per-example basis.
vb_mean, vb_mean_sq = self._virtual_statistics(
inputs, self._example_reduction_axes)
vb_variance = vb_mean_sq - math_ops.square(vb_mean)
# The exact broadcast shape of the input statistic Tensors depends on the
# current batch, not the reference batch. The parameter broadcast shape
# is independent of the shape of the input statistic Tensor dimensions.
b_shape = self._broadcast_shape[:] # deep copy
b_shape[self._batch_axis] = _static_or_dynamic_batch_size(
inputs, self._batch_axis)
return nn.batch_normalization(
inputs,
self._broadcast(vb_mean, b_shape),
self._broadcast(vb_variance, b_shape),
self._broadcast(self._beta, self._broadcast_shape),
self._broadcast(self._gamma, self._broadcast_shape),
self._epsilon)

View File

@ -0,0 +1,267 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfgan.python.features.virtual_batchnorm."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.framework.python.ops import variables as contrib_variables_lib
from tensorflow.contrib.gan.python.features.python import virtual_batchnorm_impl as virtual_batchnorm
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import random_seed
from tensorflow.python.layers import normalization
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
class VirtualBatchnormTest(test.TestCase):
def test_syntax(self):
reference_batch = array_ops.zeros([5, 3, 16, 9, 15])
vbn = virtual_batchnorm.VBN(reference_batch, batch_axis=1)
vbn(array_ops.ones([5, 7, 16, 9, 15]))
def test_no_broadcast_needed(self):
"""When `axis` and `batch_axis` are at the end, no broadcast is needed."""
reference_batch = array_ops.zeros([5, 3, 16, 9, 15])
minibatch = array_ops.zeros([5, 3, 16, 3, 15])
vbn = virtual_batchnorm.VBN(reference_batch, axis=-1, batch_axis=-2)
vbn(minibatch)
def test_statistics(self):
"""Check that `_statistics` gives the same result as `nn.moments`."""
random_seed.set_random_seed(1234)
tensors = random_ops.random_normal([4, 5, 7, 3])
for axes in [(3), (0, 2), (1, 2, 3)]:
vb_mean, mean_sq = virtual_batchnorm._statistics(tensors, axes)
mom_mean, mom_var = nn.moments(tensors, axes)
vb_var = mean_sq - math_ops.square(vb_mean)
with self.test_session(use_gpu=True) as sess:
vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run([
vb_mean, vb_var, mom_mean, mom_var])
self.assertAllClose(mom_mean_np, vb_mean_np)
self.assertAllClose(mom_var_np, vb_var_np)
def test_virtual_statistics(self):
"""Check that `_virtual_statistics` gives same result as `nn.moments`."""
random_seed.set_random_seed(1234)
batch_axis = 0
partial_batch = random_ops.random_normal([4, 5, 7, 3])
single_example = random_ops.random_normal([1, 5, 7, 3])
full_batch = array_ops.concat([partial_batch, single_example], axis=0)
for reduction_axis in range(1, 4):
# Get `nn.moments` on the full batch.
reduction_axes = list(range(4))
del reduction_axes[reduction_axis]
mom_mean, mom_variance = nn.moments(full_batch, reduction_axes)
# Get virtual batch statistics.
vb_reduction_axes = list(range(4))
del vb_reduction_axes[reduction_axis]
del vb_reduction_axes[batch_axis]
vbn = virtual_batchnorm.VBN(partial_batch, reduction_axis)
vb_mean, mean_sq = vbn._virtual_statistics(
single_example, vb_reduction_axes)
vb_variance = mean_sq - math_ops.square(vb_mean)
# Remove singleton batch dim for easy comparisons.
vb_mean = array_ops.squeeze(vb_mean, batch_axis)
vb_variance = array_ops.squeeze(vb_variance, batch_axis)
with self.test_session(use_gpu=True) as sess:
vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run([
vb_mean, vb_variance, mom_mean, mom_variance])
self.assertAllClose(mom_mean_np, vb_mean_np)
self.assertAllClose(mom_var_np, vb_var_np)
def test_reference_batch_normalization(self):
"""Check that batch norm from VBN agrees with opensource implementation."""
random_seed.set_random_seed(1234)
batch = random_ops.random_normal([6, 5, 7, 3, 3])
for axis in range(5):
# Get `layers` batchnorm result.
bn_normalized = normalization.batch_normalization(
batch, axis, training=True)
# Get VBN's batch normalization on reference batch.
batch_axis = 0 if axis is not 0 else 1 # axis and batch_axis can't same
vbn = virtual_batchnorm.VBN(batch, axis, batch_axis=batch_axis)
vbn_normalized = vbn.reference_batch_normalization()
with self.test_session(use_gpu=True) as sess:
variables_lib.global_variables_initializer().run()
bn_normalized_np, vbn_normalized_np = sess.run(
[bn_normalized, vbn_normalized])
self.assertAllClose(bn_normalized_np, vbn_normalized_np)
def test_same_as_batchnorm(self):
"""Check that batch norm on set X is the same as ref of X / y on `y`."""
random_seed.set_random_seed(1234)
num_examples = 4
examples = [random_ops.random_normal([5, 7, 3]) for _ in
range(num_examples)]
# Get the result of the opensource batch normalization.
batch_normalized = normalization.batch_normalization(
array_ops.stack(examples), training=True)
for i in range(num_examples):
examples_except_i = array_ops.stack(examples[:i] + examples[i+1:])
# Get the result of VBN's batch normalization.
vbn = virtual_batchnorm.VBN(examples_except_i)
vb_normed = array_ops.squeeze(
vbn(array_ops.expand_dims(examples[i], [0])), [0])
with self.test_session(use_gpu=True) as sess:
variables_lib.global_variables_initializer().run()
bn_np, vb_np = sess.run([batch_normalized, vb_normed])
self.assertAllClose(bn_np[i, ...], vb_np)
def test_minibatch_independent(self):
"""Test that virtual batch normalized exampels are independent.
Unlike batch normalization, virtual batch normalization has the property
that the virtual batch normalized value of an example is independent of the
other examples in the minibatch. In this test, we verify this property.
"""
random_seed.set_random_seed(1234)
# These can be random, but must be the same for all session calls.
reference_batch = constant_op.constant(
np.random.normal(size=[4, 7, 3]), dtype=dtypes.float32)
fixed_example = constant_op.constant(np.random.normal(size=[7, 3]),
dtype=dtypes.float32)
# Get the VBN object and the virtual batch normalized value for
# `fixed_example`.
vbn = virtual_batchnorm.VBN(reference_batch)
vbn_fixed_example = array_ops.squeeze(
vbn(array_ops.expand_dims(fixed_example, 0)), 0)
with self.test_session(use_gpu=True):
variables_lib.global_variables_initializer().run()
vbn_fixed_example_np = vbn_fixed_example.eval()
# Check that the value is the same for different minibatches, and different
# sized minibatches.
for minibatch_size in range(1, 6):
examples = [random_ops.random_normal([7, 3]) for _ in
range(minibatch_size)]
minibatch = array_ops.stack([fixed_example] + examples)
vbn_minibatch = vbn(minibatch)
cur_vbn_fixed_example = vbn_minibatch[0, ...]
with self.test_session(use_gpu=True):
variables_lib.global_variables_initializer().run()
cur_vbn_fixed_example_np = cur_vbn_fixed_example.eval()
self.assertAllClose(vbn_fixed_example_np, cur_vbn_fixed_example_np)
def test_variable_reuse(self):
"""Test that variable scopes work and inference on a real-ish case."""
tensor1_ref = array_ops.zeros([6, 5, 7, 3, 3])
tensor1_examples = array_ops.zeros([4, 5, 7, 3, 3])
tensor2_ref = array_ops.zeros([4, 2, 3])
tensor2_examples = array_ops.zeros([2, 2, 3])
with variable_scope.variable_scope('dummy_scope', reuse=True):
with self.assertRaisesRegexp(
ValueError, 'does not exist, or was not created with '
'tf.get_variable()'):
virtual_batchnorm.VBN(tensor1_ref)
vbn1 = virtual_batchnorm.VBN(tensor1_ref, name='vbn1')
vbn2 = virtual_batchnorm.VBN(tensor2_ref, name='vbn2')
# Fetch reference and examples after virtual batch normalization. Also
# fetch in variable reuse case.
to_fetch = []
to_fetch.append(vbn1.reference_batch_normalization())
to_fetch.append(vbn2.reference_batch_normalization())
to_fetch.append(vbn1(tensor1_examples))
to_fetch.append(vbn2(tensor2_examples))
variable_scope.get_variable_scope().reuse_variables()
to_fetch.append(vbn1.reference_batch_normalization())
to_fetch.append(vbn2.reference_batch_normalization())
to_fetch.append(vbn1(tensor1_examples))
to_fetch.append(vbn2(tensor2_examples))
self.assertEqual(4, len(contrib_variables_lib.get_variables()))
with self.test_session(use_gpu=True) as sess:
variables_lib.global_variables_initializer().run()
sess.run(to_fetch)
def test_invalid_input(self):
# Reference batch has unknown dimensions.
with self.assertRaisesRegexp(
ValueError, '`reference_batch` has unknown dimensions.'):
virtual_batchnorm.VBN(array_ops.placeholder(dtypes.float32), name='vbn1')
# Axis too negative.
with self.assertRaisesRegexp(
ValueError, 'Value of `axis` argument .* is out of range'):
virtual_batchnorm.VBN(array_ops.zeros([1, 2]), axis=-3, name='vbn2')
# Axis too large.
with self.assertRaisesRegexp(
ValueError, 'Value of `axis` argument .* is out of range'):
virtual_batchnorm.VBN(array_ops.zeros([1, 2]), axis=2, name='vbn3')
# Batch axis too negative.
with self.assertRaisesRegexp(
ValueError, 'Value of `axis` argument .* is out of range'):
virtual_batchnorm.VBN(array_ops.zeros([1, 2]), name='vbn4', batch_axis=-3)
# Batch axis too large.
with self.assertRaisesRegexp(
ValueError, 'Value of `axis` argument .* is out of range'):
virtual_batchnorm.VBN(array_ops.zeros([1, 2]), name='vbn5', batch_axis=2)
# Axis and batch axis are the same.
with self.assertRaisesRegexp(
ValueError, '`axis` and `batch_axis` cannot be the same.'):
virtual_batchnorm.VBN(array_ops.zeros(
[1, 2]), axis=1, name='vbn6', batch_axis=1)
# Reference Tensor and example Tensor have incompatible shapes.
tensor_ref = array_ops.zeros([5, 2, 3])
tensor_examples = array_ops.zeros([3, 2, 3])
vbn = virtual_batchnorm.VBN(tensor_ref, name='vbn7', batch_axis=1)
with self.assertRaisesRegexp(ValueError, 'Shapes .* are incompatible'):
vbn(tensor_examples)
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,32 @@
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFGAN grouped API. Please see README.md for details and usage."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Collapse losses into a single namespace.
from tensorflow.contrib.gan.python.losses.python import losses_wargs as wargs
from tensorflow.contrib.gan.python.losses.python import tuple_losses
# pylint: disable=wildcard-import
from tensorflow.contrib.gan.python.losses.python.tuple_losses import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = ['wargs'] + tuple_losses.__all__
remove_undocumented(__name__, _allowed_symbols)

View File

@ -0,0 +1,887 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Losses that are useful for training GANs.
The losses belong to two main groups, but there are others that do not:
1) xxxxx_generator_loss
2) xxxxx_discriminator_loss
Example:
1) wasserstein_generator_loss
2) wasserstein_discriminator_loss
Other example:
wasserstein_gradient_penalty
All losses must be able to accept 1D or 2D Tensors, so as to be compatible with
patchGAN style losses (https://arxiv.org/abs/1611.07004).
To make these losses usable in the TFGAN framework, please create a tuple
version of the losses with `losses_utils.py`.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.framework.python.ops import variables as contrib_variables_lib
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.distributions import distribution as ds
from tensorflow.python.ops.losses import losses
from tensorflow.python.ops.losses import util
from tensorflow.python.summary import summary
__all__ = [
'acgan_discriminator_loss',
'acgan_generator_loss',
'least_squares_discriminator_loss',
'least_squares_generator_loss',
'modified_discriminator_loss',
'modified_generator_loss',
'minimax_discriminator_loss',
'minimax_generator_loss',
'wasserstein_discriminator_loss',
'wasserstein_generator_loss',
'wasserstein_gradient_penalty',
'mutual_information_penalty',
'combine_adversarial_loss',
]
# Wasserstein losses from `Wasserstein GAN` (https://arxiv.org/abs/1701.07875).
def wasserstein_generator_loss(
discriminator_gen_outputs,
weights=1.0,
scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False):
"""Wasserstein generator loss for GANs.
See `Wasserstein GAN` (https://arxiv.org/abs/1701.07875) for more details.
Args:
discriminator_gen_outputs: Discriminator output on generated data. Expected
to be in the range of (-inf, inf).
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
be either `1`, or the same as the corresponding `losses` dimension).
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
reduction: A `tf.losses.Reduction` to apply to loss.
add_summaries: Whether or not to add detailed summaries for the loss.
Returns:
A loss Tensor. The shape depends on `reduction`.
"""
with ops.name_scope(scope, 'generator_wasserstein_loss', (
discriminator_gen_outputs, weights)) as scope:
discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs)
loss = - discriminator_gen_outputs
loss = losses.compute_weighted_loss(
loss, weights, scope, loss_collection, reduction)
if add_summaries:
summary.scalar('generator_wass_loss', loss)
return loss
def wasserstein_discriminator_loss(
discriminator_real_outputs,
discriminator_gen_outputs,
real_weights=1.0,
generated_weights=1.0,
scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False):
"""Wasserstein discriminator loss for GANs.
See `Wasserstein GAN` (https://arxiv.org/abs/1701.07875) for more details.
Args:
discriminator_real_outputs: Discriminator output on real data.
discriminator_gen_outputs: Discriminator output on generated data. Expected
to be in the range of (-inf, inf).
real_weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale
the real loss.
generated_weights: A scalar or a `Tensor` of size [batch_size, K] used to
rescale the generated loss.
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
reduction: A `tf.losses.Reduction` to apply to loss.
add_summaries: Whether or not to add summaries for the loss.
Returns:
A loss Tensor. The shape depends on `reduction`.
"""
with ops.name_scope(scope, 'discriminator_wasserstein_loss', (
discriminator_real_outputs, discriminator_gen_outputs, real_weights,
generated_weights)) as scope:
discriminator_real_outputs = math_ops.to_float(discriminator_real_outputs)
discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs)
discriminator_real_outputs.shape.assert_is_compatible_with(
discriminator_gen_outputs.shape)
loss_on_generated = losses.compute_weighted_loss(
discriminator_gen_outputs, generated_weights, scope,
loss_collection=None, reduction=reduction)
loss_on_real = losses.compute_weighted_loss(
discriminator_real_outputs, real_weights, scope, loss_collection=None,
reduction=reduction)
loss = loss_on_generated - loss_on_real
util.add_loss(loss, loss_collection)
if add_summaries:
summary.scalar('discriminator_gen_wass_loss', loss_on_generated)
summary.scalar('discriminator_real_wass_loss', loss_on_real)
summary.scalar('discriminator_wass_loss', loss)
return loss
# ACGAN losses from `Conditional Image Synthesis With Auxiliary Classifier GANs`
# (https://arxiv.org/abs/1610.09585).
def acgan_discriminator_loss(
discriminator_gen_classification_logits,
discriminator_real_classification_logits,
one_hot_labels,
label_smoothing=0.0,
real_weights=1.0,
generated_weights=1.0,
scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False):
"""ACGAN loss for the discriminator.
The ACGAN loss adds a classification loss to the conditional discriminator.
Therefore, the discriminator must output a tuple consisting of
(1) the real/fake prediction and
(2) the logits for the classification (usually the last conv layer,
flattened).
For more details:
ACGAN: https://arxiv.org/abs/1610.09585
Args:
discriminator_gen_classification_logits: Classification logits for generated
data.
discriminator_real_classification_logits: Classification logits for real
data.
one_hot_labels: A Tensor holding one-hot labels for the batch.
label_smoothing: A float in [0, 1]. If greater than 0, smooth the labels for
"discriminator on real data" as suggested in
https://arxiv.org/pdf/1701.00160
real_weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale
the real loss.
generated_weights: A scalar or a `Tensor` of size [batch_size, K] used to
rescale the generated loss.
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
reduction: A `tf.losses.Reduction` to apply to loss.
add_summaries: Whether or not to add summaries for the loss.
Returns:
A loss Tensor. Shape depends on `reduction`.
Raises:
TypeError: If the discriminator does not output a tuple.
"""
loss_on_generated = losses.softmax_cross_entropy(
one_hot_labels, discriminator_gen_classification_logits,
weights=generated_weights, scope=scope, loss_collection=None,
reduction=reduction)
loss_on_real = losses.softmax_cross_entropy(
one_hot_labels, discriminator_real_classification_logits,
weights=real_weights, label_smoothing=label_smoothing, scope=scope,
loss_collection=None, reduction=reduction)
loss = loss_on_generated + loss_on_real
util.add_loss(loss, loss_collection)
if add_summaries:
summary.scalar('discriminator_gen_ac_loss', loss_on_generated)
summary.scalar('discriminator_real_ac_loss', loss_on_real)
summary.scalar('discriminator_ac_loss', loss)
return loss
def acgan_generator_loss(
discriminator_gen_classification_logits,
one_hot_labels,
weights=1.0,
scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False):
"""ACGAN loss for the generator.
The ACGAN loss adds a classification loss to the conditional discriminator.
Therefore, the discriminator must output a tuple consisting of
(1) the real/fake prediction and
(2) the logits for the classification (usually the last conv layer,
flattened).
For more details:
ACGAN: https://arxiv.org/abs/1610.09585
Args:
discriminator_gen_classification_logits: Classification logits for generated
data.
one_hot_labels: A Tensor holding one-hot labels for the batch.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
be either `1`, or the same as the corresponding `losses` dimension).
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
reduction: A `tf.losses.Reduction` to apply to loss.
add_summaries: Whether or not to add summaries for the loss.
Returns:
A loss Tensor. Shape depends on `reduction`.
Raises:
ValueError: if arg module not either `generator` or `discriminator`
TypeError: if the discriminator does not output a tuple.
"""
loss = losses.softmax_cross_entropy(
one_hot_labels, discriminator_gen_classification_logits, weights=weights,
scope=scope, loss_collection=loss_collection, reduction=reduction)
if add_summaries:
summary.scalar('generator_ac_loss', loss)
return loss
# Wasserstein Gradient Penalty losses from `Improved Training of Wasserstein
# GANs` (https://arxiv.org/abs/1704.00028).
# TODO(joelshor): Figure out why this function can't be inside a name scope.
def wasserstein_gradient_penalty(
generated_data,
real_data,
generator_inputs,
discriminator_fn,
discriminator_scope,
epsilon=1e-10,
weights=1.0,
scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False):
"""The gradient penalty for the Wasserstein discriminator loss.
See `Improved Training of Wasserstein GANs`
(https://arxiv.org/abs/1704.00028) for more details.
Args:
generated_data: Output of the generator.
real_data: Real data.
generator_inputs: Exact argument to pass to the generator, which is used
as optional conditioning to the discriminator.
discriminator_fn: A discriminator function that conforms to TFGAN API.
discriminator_scope: If not `None`, reuse discriminators from this scope.
epsilon: A small positive number added for numerical stability when
computing the gradient norm.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
be either `1`, or the same as the corresponding `losses` dimension).
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
reduction: A `tf.losses.Reduction` to apply to loss.
add_summaries: Whether or not to add summaries for the loss.
Returns:
A loss Tensor. The shape depends on `reduction`.
Raises:
ValueError: If the rank of data Tensors is unknown.
"""
if generated_data.shape.ndims is None:
raise ValueError('`generated_data` can\'t have unknown rank.')
if real_data.shape.ndims is None:
raise ValueError('`real_data` can\'t have unknown rank.')
differences = generated_data - real_data
batch_size = differences.shape[0].value or array_ops.shape(differences)[0]
alpha_shape = [batch_size] + [1] * (differences.shape.ndims - 1)
alpha = random_ops.random_uniform(shape=alpha_shape)
interpolates = real_data + (alpha * differences)
# Reuse variables if a discriminator scope already exists.
reuse = False if discriminator_scope is None else True
with variable_scope.variable_scope(discriminator_scope, 'gpenalty_dscope',
reuse=reuse):
disc_interpolates = discriminator_fn(interpolates, generator_inputs)
if isinstance(disc_interpolates, tuple):
# ACGAN case: disc outputs more than one tensor
disc_interpolates = disc_interpolates[0]
gradients = gradients_impl.gradients(disc_interpolates, interpolates)[0]
gradient_squares = math_ops.reduce_sum(
math_ops.square(gradients), axis=list(range(1, gradients.shape.ndims)))
# Propagate shape information, if possible.
if isinstance(batch_size, int):
gradient_squares.set_shape([
batch_size] + gradient_squares.shape.as_list()[1:])
# For numerical stability, add epsilon to the sum before taking the square
# root. Note tf.norm does not add epsilon.
slopes = math_ops.sqrt(gradient_squares + epsilon)
penalties = math_ops.square(slopes - 1.0)
penalty = losses.compute_weighted_loss(
penalties, weights, scope=scope, loss_collection=loss_collection,
reduction=reduction)
if add_summaries:
summary.scalar('gradient_penalty_loss', penalty)
return penalty
# Original losses from `Generative Adversarial Nets`
# (https://arxiv.org/abs/1406.2661).
def minimax_discriminator_loss(
discriminator_real_outputs,
discriminator_gen_outputs,
label_smoothing=0.25,
real_weights=1.0,
generated_weights=1.0,
scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False):
"""Original minimax discriminator loss for GANs, with label smoothing.
Note that the authors don't recommend using this loss. A more practically
useful loss is `modified_discriminator_loss`.
L = - real_weights * log(sigmoid(D(x)))
- generated_weights * log(1 - sigmoid(D(G(z))))
See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more
details.
Args:
discriminator_real_outputs: Discriminator output on real data.
discriminator_gen_outputs: Discriminator output on generated data. Expected
to be in the range of (-inf, inf).
label_smoothing: The amount of smoothing for positive labels. This technique
is taken from `Improved Techniques for Training GANs`
(https://arxiv.org/abs/1606.03498). `0.0` means no smoothing.
real_weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale
the real loss.
generated_weights: A scalar or a `Tensor` of size [batch_size, K] used to
rescale the generated loss.
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
reduction: A `tf.losses.Reduction` to apply to loss.
add_summaries: Whether or not to add summaries for the loss.
Returns:
A loss Tensor. The shape depends on `reduction`.
"""
with ops.name_scope(scope, 'discriminator_minimax_loss', (
discriminator_real_outputs, discriminator_gen_outputs, real_weights,
generated_weights, label_smoothing)) as scope:
# -log((1 - label_smoothing) - sigmoid(D(x)))
loss_on_real = losses.sigmoid_cross_entropy(
array_ops.ones_like(discriminator_real_outputs),
discriminator_real_outputs, real_weights, label_smoothing, scope,
loss_collection=None, reduction=reduction)
# -log(- sigmoid(D(G(x))))
loss_on_generated = losses.sigmoid_cross_entropy(
array_ops.zeros_like(discriminator_gen_outputs),
discriminator_gen_outputs, generated_weights, scope=scope,
loss_collection=None, reduction=reduction)
loss = loss_on_real + loss_on_generated
util.add_loss(loss, loss_collection)
if add_summaries:
summary.scalar('discriminator_gen_minimax_loss', loss_on_generated)
summary.scalar('discriminator_real_minimax_loss', loss_on_real)
summary.scalar('discriminator_minimax_loss', loss)
return loss
def minimax_generator_loss(
discriminator_gen_outputs,
label_smoothing=0.0,
weights=1.0,
scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False):
"""Original minimax generator loss for GANs.
Note that the authors don't recommend using this loss. A more practically
useful loss is `modified_generator_loss`.
L = log(sigmoid(D(x))) + log(1 - sigmoid(D(G(z))))
See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more
details.
Args:
discriminator_gen_outputs: Discriminator output on generated data. Expected
to be in the range of (-inf, inf).
label_smoothing: The amount of smoothing for positive labels. This technique
is taken from `Improved Techniques for Training GANs`
(https://arxiv.org/abs/1606.03498). `0.0` means no smoothing.
weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale
the loss.
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
reduction: A `tf.losses.Reduction` to apply to loss.
add_summaries: Whether or not to add summaries for the loss.
Returns:
A loss Tensor. The shape depends on `reduction`.
"""
with ops.name_scope(scope, 'generator_minimax_loss') as scope:
loss = - minimax_discriminator_loss(
array_ops.ones_like(discriminator_gen_outputs),
discriminator_gen_outputs, label_smoothing, weights, weights, scope,
loss_collection, reduction, add_summaries=False)
if add_summaries:
summary.scalar('generator_minimax_loss', loss)
return loss
def modified_discriminator_loss(
discriminator_real_outputs,
discriminator_gen_outputs,
label_smoothing=0.25,
real_weights=1.0,
generated_weights=1.0,
scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False):
"""Same as minimax discriminator loss.
See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more
details.
Args:
discriminator_real_outputs: Discriminator output on real data.
discriminator_gen_outputs: Discriminator output on generated data. Expected
to be in the range of (-inf, inf).
label_smoothing: The amount of smoothing for positive labels. This technique
is taken from `Improved Techniques for Training GANs`
(https://arxiv.org/abs/1606.03498). `0.0` means no smoothing.
real_weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale
the real loss.
generated_weights: A scalar or a `Tensor` of size [batch_size, K] used to
rescale the generated loss.
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
reduction: A `tf.losses.Reduction` to apply to loss.
add_summaries: Whether or not to add summaries for the loss.
Returns:
A loss Tensor. The shape depends on `reduction`.
"""
return minimax_discriminator_loss(
discriminator_real_outputs,
discriminator_gen_outputs,
label_smoothing,
real_weights,
generated_weights,
scope or 'discriminator_modified_loss',
loss_collection,
reduction,
add_summaries)
def modified_generator_loss(
discriminator_gen_outputs,
label_smoothing=0.0,
weights=1.0,
scope='generator_modified_loss',
loss_collection=ops.GraphKeys.LOSSES,
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False):
"""Modified generator loss for GANs.
L = -log(sigmoid(D(G(z))))
This is the trick used in the original paper to avoid vanishing gradients
early in training. See `Generative Adversarial Nets`
(https://arxiv.org/abs/1406.2661) for more details.
Args:
discriminator_gen_outputs: Discriminator output on generated data. Expected
to be in the range of (-inf, inf).
label_smoothing: The amount of smoothing for positive labels. This technique
is taken from `Improved Techniques for Training GANs`
(https://arxiv.org/abs/1606.03498). `0.0` means no smoothing.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
be either `1`, or the same as the corresponding `losses` dimension).
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
reduction: A `tf.losses.Reduction` to apply to loss.
add_summaries: Whether or not to add summaries for the loss.
Returns:
A loss Tensor. The shape depends on `reduction`.
"""
loss = losses.sigmoid_cross_entropy(
array_ops.ones_like(discriminator_gen_outputs), discriminator_gen_outputs,
weights, label_smoothing, scope, loss_collection, reduction)
if add_summaries:
summary.scalar('generator_modified_loss', loss)
return loss
# Least Squares loss from `Least Squares Generative Adversarial Networks`
# (https://arxiv.org/abs/1611.04076).
def least_squares_generator_loss(
discriminator_gen_outputs,
real_label=1,
weights=1.0,
scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False):
"""Least squares generator loss.
This loss comes from `Least Squares Generative Adversarial Networks`
(https://arxiv.org/abs/1611.04076).
L = 1/2 * (D(G(z)) - `real_label`) ** 2
where D(y) are discriminator logits.
Args:
discriminator_gen_outputs: Discriminator output on generated data. Expected
to be in the range of (-inf, inf).
real_label: The value that the generator is trying to get the discriminator
to output on generated data.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
be either `1`, or the same as the corresponding `losses` dimension).
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
reduction: A `tf.losses.Reduction` to apply to loss.
add_summaries: Whether or not to add summaries for the loss.
Returns:
A loss Tensor. The shape depends on `reduction`.
"""
with ops.name_scope(scope, 'lsq_generator_loss',
(discriminator_gen_outputs, real_label)) as scope:
discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs)
loss = math_ops.squared_difference(
discriminator_gen_outputs, real_label) / 2.0
loss = losses.compute_weighted_loss(
loss, weights, scope, loss_collection, reduction)
if add_summaries:
summary.scalar('generator_lsq_loss', loss)
return loss
def least_squares_discriminator_loss(
discriminator_real_outputs,
discriminator_gen_outputs,
real_label=1,
fake_label=0,
real_weights=1.0,
generated_weights=1.0,
scope=None,
loss_collection=ops.GraphKeys.LOSSES,
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False):
"""Least squares generator loss.
This loss comes from `Least Squares Generative Adversarial Networks`
(https://arxiv.org/abs/1611.04076).
L = 1/2 * (D(x) - `real`) ** 2 +
1/2 * (D(G(z)) - `fake_label`) ** 2
where D(y) are discriminator logits.
Args:
discriminator_real_outputs: Discriminator output on real data.
discriminator_gen_outputs: Discriminator output on generated data. Expected
to be in the range of (-inf, inf).
real_label: The value that the discriminator tries to output for real data.
fake_label: The value that the discriminator tries to output for fake data.
real_weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale
the real loss.
generated_weights: A scalar or a `Tensor` of size [batch_size, K] used to
rescale the generated loss.
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
reduction: A `tf.losses.Reduction` to apply to loss.
add_summaries: Whether or not to add summaries for the loss.
Returns:
A loss Tensor. The shape depends on `reduction`.
"""
with ops.name_scope(scope, 'lsq_discriminator_loss',
(discriminator_gen_outputs, real_label)) as scope:
discriminator_real_outputs = math_ops.to_float(discriminator_real_outputs)
discriminator_gen_outputs = math_ops.to_float(discriminator_gen_outputs)
discriminator_real_outputs.shape.assert_is_compatible_with(
discriminator_gen_outputs.shape)
real_losses = math_ops.squared_difference(
discriminator_real_outputs, real_label) / 2.0
fake_losses = math_ops.squared_difference(
discriminator_gen_outputs, fake_label) / 2.0
loss_on_real = losses.compute_weighted_loss(
real_losses, real_weights, scope, loss_collection=None,
reduction=reduction)
loss_on_generated = losses.compute_weighted_loss(
fake_losses, generated_weights, scope, loss_collection=None,
reduction=reduction)
loss = loss_on_real + loss_on_generated
util.add_loss(loss, loss_collection)
if add_summaries:
summary.scalar('discriminator_gen_lsq_loss', loss_on_generated)
summary.scalar('discriminator_real_lsq_loss', loss_on_real)
summary.scalar('discriminator_lsq_loss', loss)
return loss
# InfoGAN loss from `InfoGAN: Interpretable Representation Learning by
# `Information Maximizing Generative Adversarial Nets`
# https://arxiv.org/abs/1606.03657
def _validate_distributions(distributions):
if not isinstance(distributions, (list, tuple)):
raise ValueError('`distributions` must be a list or tuple. Instead, '
'found %s.', type(distributions))
for x in distributions:
if not isinstance(x, ds.Distribution):
raise ValueError('`distributions` must be a list of `Distributions`. '
'Instead, found %s.', type(x))
def _validate_information_penalty_inputs(
structured_generator_inputs, predicted_distributions):
"""Validate input to `mutual_information_penalty`."""
_validate_distributions(predicted_distributions)
if len(structured_generator_inputs) != len(predicted_distributions):
raise ValueError('`structured_generator_inputs` length %i must be the same '
'as `predicted_distributions` length %i.' % (
len(structured_generator_inputs),
len(predicted_distributions)))
def mutual_information_penalty(
structured_generator_inputs,
predicted_distributions,
weights=1.0,
scope='generator_modified_loss',
loss_collection=ops.GraphKeys.LOSSES,
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False):
"""Returns a penalty on the mutual information in an InfoGAN model.
This loss comes from an InfoGAN paper https://arxiv.org/abs/1606.03657.
Args:
structured_generator_inputs: A list of Tensors representing the random noise
that must have high mutual information with the generator output. List
length should match `predicted_distributions`.
predicted_distributions: A list of tf.Distributions. Predicted by the
recognizer, and used to evaluate the likelihood of the structured noise.
List length should match `structured_generator_inputs`.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`labels`, and must be broadcastable to `labels` (i.e., all dimensions must
be either `1`, or the same as the corresponding `losses` dimension).
scope: The scope for the operations performed in computing the loss.
loss_collection: collection to which this loss will be added.
reduction: A `tf.losses.Reduction` to apply to loss.
add_summaries: Whether or not to add summaries for the loss.
Returns:
A scalar Tensor representing the mutual information loss.
"""
_validate_information_penalty_inputs(
structured_generator_inputs, predicted_distributions)
# Calculate the negative log-likelihood of the reconstructed noise.
log_probs = [math_ops.reduce_mean(dist.log_prob(noise)) for dist, noise in
zip(predicted_distributions, structured_generator_inputs)]
loss = -1 * losses.compute_weighted_loss(
log_probs, weights, scope, loss_collection=loss_collection,
reduction=reduction)
if add_summaries:
summary.scalar('mutual_information_penalty', loss)
return loss
def _numerically_stable_global_norm(tensor_list):
"""Compute the global norm of a list of Tensors, with improved stability.
The global norm computation sometimes overflows due to the intermediate L2
step. To avoid this, we divide by a cheap-to-compute max over the
matrix elements.
Args:
tensor_list: A list of tensors, or `None`.
Returns:
A scalar tensor with the global norm.
"""
if np.all([x is None for x in tensor_list]):
return 0.0
list_max = math_ops.reduce_max([math_ops.reduce_max(math_ops.abs(x)) for x in
tensor_list if x is not None])
return list_max * clip_ops.global_norm([x / list_max for x in tensor_list
if x is not None])
def _used_weight(weights_list):
for weight in weights_list:
if weight is not None:
return tensor_util.constant_value(ops.convert_to_tensor(weight))
def _validate_args(losses_list, weight_factor, gradient_ratio):
for loss in losses_list:
loss.shape.assert_is_compatible_with([])
if weight_factor is None and gradient_ratio is None:
raise ValueError(
'`weight_factor` and `gradient_ratio` cannot both be `None.`')
if weight_factor is not None and gradient_ratio is not None:
raise ValueError(
'`weight_factor` and `gradient_ratio` cannot both be specified.')
# TODO(joelshor): Add ability to pass in gradients, to avoid recomputing.
def combine_adversarial_loss(main_loss,
adversarial_loss,
weight_factor=None,
gradient_ratio=None,
gradient_ratio_epsilon=1e-6,
variables=None,
scalar_summaries=True,
gradient_summaries=True,
scope=None):
"""Utility to combine main and adversarial losses.
This utility combines the main and adversarial losses in one of two ways.
1) Fixed coefficient on adversarial loss. Use `weight_factor` in this case.
2) Fixed ratio of gradients. Use `gradient_ratio` in this case. This is often
used to make sure both losses affect weights roughly equally, as in
https://arxiv.org/pdf/1705.05823.
One can optionally also visualize the scalar and gradient behavior of the
losses.
Args:
main_loss: A floating scalar Tensor indicating the main loss.
adversarial_loss: A floating scalar Tensor indication the adversarial loss.
weight_factor: If not `None`, the coefficient by which to multiply the
adversarial loss. Exactly one of this and `gradient_ratio` must be
non-None.
gradient_ratio: If not `None`, the ratio of the magnitude of the gradients.
Specifically,
gradient_ratio = grad_mag(main_loss) / grad_mag(adversarial_loss)
Exactly one of this and `weight_factor` must be non-None.
gradient_ratio_epsilon: An epsilon to add to the adversarial loss
coefficient denominator, to avoid division-by-zero.
variables: List of variables to calculate gradients with respect to. If not
present, defaults to all trainable variables.
scalar_summaries: Create scalar summaries of losses.
gradient_summaries: Create gradient summaries of losses.
scope: Optional name scope.
Returns:
A floating scalar Tensor indicating the desired combined loss.
Raises:
ValueError: Malformed input.
"""
_validate_args([main_loss, adversarial_loss], weight_factor, gradient_ratio)
if variables is None:
variables = contrib_variables_lib.get_trainable_variables()
with ops.name_scope(scope, 'adversarial_loss',
values=[main_loss, adversarial_loss]):
# Compute gradients if we will need them.
if gradient_summaries or gradient_ratio is not None:
main_loss_grad_mag = _numerically_stable_global_norm(
gradients_impl.gradients(main_loss, variables))
adv_loss_grad_mag = _numerically_stable_global_norm(
gradients_impl.gradients(adversarial_loss, variables))
# Add summaries, if applicable.
if scalar_summaries:
summary.scalar('main_loss', main_loss)
summary.scalar('adversarial_loss', adversarial_loss)
if gradient_summaries:
summary.scalar('main_loss_gradients', main_loss_grad_mag)
summary.scalar('adversarial_loss_gradients', adv_loss_grad_mag)
# Combine losses in the appropriate way.
# If `weight_factor` is always `0`, avoid computing the adversarial loss
# tensor entirely.
if _used_weight((weight_factor, gradient_ratio)) == 0:
final_loss = main_loss
elif weight_factor is not None:
final_loss = (main_loss +
array_ops.stop_gradient(weight_factor) * adversarial_loss)
elif gradient_ratio is not None:
grad_mag_ratio = main_loss_grad_mag / (
adv_loss_grad_mag + gradient_ratio_epsilon)
adv_coeff = grad_mag_ratio / gradient_ratio
summary.scalar('adversarial_coefficient', adv_coeff)
final_loss = (main_loss +
array_ops.stop_gradient(adv_coeff) * adversarial_loss)
return final_loss

View File

@ -0,0 +1,606 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TFGAN losses."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.gan.python.losses.python import losses_impl as tfgan_losses
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.distributions import categorical
from tensorflow.python.ops.distributions import normal
from tensorflow.python.ops.losses import losses as tf_losses
from tensorflow.python.platform import test
# TODO(joelshor): Use `parameterized` tests when opensourced.
class _LossesTest(object):
def init_constants(self):
self._discriminator_real_outputs_np = [-5.0, 1.4, 12.5, 2.7]
self._discriminator_gen_outputs_np = [10.0, 4.4, -5.5, 3.6]
self._weights = 2.3
self._discriminator_real_outputs = constant_op.constant(
self._discriminator_real_outputs_np, dtype=dtypes.float32)
self._discriminator_gen_outputs = constant_op.constant(
self._discriminator_gen_outputs_np, dtype=dtypes.float32)
def test_generator_all_correct(self):
loss = self._g_loss_fn(self._discriminator_gen_outputs)
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
self.assertEqual(self._generator_loss_name, loss.op.name)
with self.test_session():
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
def test_discriminator_all_correct(self):
loss = self._d_loss_fn(
self._discriminator_real_outputs, self._discriminator_gen_outputs)
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
self.assertEqual(self._discriminator_loss_name, loss.op.name)
with self.test_session():
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
def test_generator_loss_collection(self):
self.assertEqual(0, len(ops.get_collection('collection')))
self._g_loss_fn(
self._discriminator_gen_outputs, loss_collection='collection')
self.assertEqual(1, len(ops.get_collection('collection')))
def test_discriminator_loss_collection(self):
self.assertEqual(0, len(ops.get_collection('collection')))
self._d_loss_fn(
self._discriminator_real_outputs, self._discriminator_gen_outputs,
loss_collection='collection')
self.assertEqual(1, len(ops.get_collection('collection')))
def test_generator_no_reduction(self):
loss = self._g_loss_fn(
self._discriminator_gen_outputs, reduction=tf_losses.Reduction.NONE)
self.assertAllEqual([4], loss.shape)
def test_discriminator_no_reduction(self):
loss = self._d_loss_fn(
self._discriminator_real_outputs, self._discriminator_gen_outputs,
reduction=tf_losses.Reduction.NONE)
self.assertAllEqual([4], loss.shape)
def test_generator_patch(self):
loss = self._g_loss_fn(
array_ops.reshape(self._discriminator_gen_outputs, [2, 2]))
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
with self.test_session():
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
def test_discriminator_patch(self):
loss = self._d_loss_fn(
array_ops.reshape(self._discriminator_real_outputs, [2, 2]),
array_ops.reshape(self._discriminator_gen_outputs, [2, 2]))
self.assertEqual(self._discriminator_gen_outputs.dtype, loss.dtype)
with self.test_session():
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
def test_generator_loss_with_placeholder_for_logits(self):
logits = array_ops.placeholder(dtypes.float32, shape=(None, 4))
weights = array_ops.ones_like(logits, dtype=dtypes.float32)
loss = self._g_loss_fn(logits, weights=weights)
self.assertEqual(logits.dtype, loss.dtype)
with self.test_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: [[10.0, 4.4, -5.5, 3.6]],
})
self.assertAlmostEqual(self._expected_g_loss, loss, 5)
def test_discriminator_loss_with_placeholder_for_logits(self):
logits = array_ops.placeholder(dtypes.float32, shape=(None, 4))
logits2 = array_ops.placeholder(dtypes.float32, shape=(None, 4))
real_weights = array_ops.ones_like(logits, dtype=dtypes.float32)
generated_weights = array_ops.ones_like(logits, dtype=dtypes.float32)
loss = self._d_loss_fn(
logits, logits2, real_weights=real_weights,
generated_weights=generated_weights)
with self.test_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: [self._discriminator_real_outputs_np],
logits2: [self._discriminator_gen_outputs_np],
})
self.assertAlmostEqual(self._expected_d_loss, loss, 5)
def test_generator_with_python_scalar_weight(self):
loss = self._g_loss_fn(
self._discriminator_gen_outputs, weights=self._weights)
with self.test_session():
self.assertAlmostEqual(self._expected_g_loss * self._weights,
loss.eval(), 4)
def test_discriminator_with_python_scalar_weight(self):
loss = self._d_loss_fn(
self._discriminator_real_outputs, self._discriminator_gen_outputs,
real_weights=self._weights, generated_weights=self._weights)
with self.test_session():
self.assertAlmostEqual(self._expected_d_loss * self._weights,
loss.eval(), 4)
def test_generator_with_scalar_tensor_weight(self):
loss = self._g_loss_fn(self._discriminator_gen_outputs,
weights=constant_op.constant(self._weights))
with self.test_session():
self.assertAlmostEqual(self._expected_g_loss * self._weights,
loss.eval(), 4)
def test_discriminator_with_scalar_tensor_weight(self):
weights = constant_op.constant(self._weights)
loss = self._d_loss_fn(
self._discriminator_real_outputs, self._discriminator_gen_outputs,
real_weights=weights, generated_weights=weights)
with self.test_session():
self.assertAlmostEqual(self._expected_d_loss * self._weights,
loss.eval(), 4)
def test_generator_add_summaries(self):
self.assertEqual(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
self._g_loss_fn(self._discriminator_gen_outputs, add_summaries=True)
self.assertLess(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
def test_discriminator_add_summaries(self):
self.assertEqual(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
self._d_loss_fn(
self._discriminator_real_outputs, self._discriminator_gen_outputs,
add_summaries=True)
self.assertLess(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
class LeastSquaresLossTest(test.TestCase, _LossesTest):
"""Tests for least_squares_xxx_loss."""
def setUp(self):
super(LeastSquaresLossTest, self).setUp()
self.init_constants()
self._expected_g_loss = 17.69625
self._expected_d_loss = 41.73375
self._generator_loss_name = 'lsq_generator_loss/value'
self._discriminator_loss_name = 'lsq_discriminator_loss/add'
self._g_loss_fn = tfgan_losses.least_squares_generator_loss
self._d_loss_fn = tfgan_losses.least_squares_discriminator_loss
class ModifiedLossTest(test.TestCase, _LossesTest):
"""Tests for modified_xxx_loss."""
def setUp(self):
super(ModifiedLossTest, self).setUp()
self.init_constants()
self._expected_g_loss = 1.38582
self._expected_d_loss = 6.19637
self._generator_loss_name = 'generator_modified_loss/value'
self._discriminator_loss_name = 'discriminator_modified_loss/add_1'
self._g_loss_fn = tfgan_losses.modified_generator_loss
self._d_loss_fn = tfgan_losses.modified_discriminator_loss
class MinimaxLossTest(test.TestCase, _LossesTest):
"""Tests for minimax_xxx_loss."""
def setUp(self):
super(MinimaxLossTest, self).setUp()
self.init_constants()
self._expected_g_loss = -4.82408
self._expected_d_loss = 6.19637
self._generator_loss_name = 'generator_minimax_loss/Neg'
self._discriminator_loss_name = 'discriminator_minimax_loss/add_1'
self._g_loss_fn = tfgan_losses.minimax_generator_loss
self._d_loss_fn = tfgan_losses.minimax_discriminator_loss
class WassersteinLossTest(test.TestCase, _LossesTest):
"""Tests for wasserstein_xxx_loss."""
def setUp(self):
super(WassersteinLossTest, self).setUp()
self.init_constants()
self._expected_g_loss = -3.12500
self._expected_d_loss = 0.22500
self._generator_loss_name = 'generator_wasserstein_loss/value'
self._discriminator_loss_name = 'discriminator_wasserstein_loss/sub'
self._g_loss_fn = tfgan_losses.wasserstein_generator_loss
self._d_loss_fn = tfgan_losses.wasserstein_discriminator_loss
# TODO(joelshor): Use `parameterized` tests when opensourced.
# TODO(joelshor): Refactor this test to use the same code as the other losses.
class ACGANLossTest(test.TestCase):
"""Tests for wasserstein_xxx_loss."""
def setUp(self):
super(ACGANLossTest, self).setUp()
self._g_loss_fn = tfgan_losses.acgan_generator_loss
self._d_loss_fn = tfgan_losses.acgan_discriminator_loss
self._discriminator_gen_classification_logits_np = [[10.0, 4.4, -5.5, 3.6],
[-4.0, 4.4, 5.2, 4.6],
[1.1, 2.4, -3.5, 5.6],
[1.1, 2.4, -3.5, 5.6]]
self._discriminator_real_classification_logits_np = [[-2.0, 0.4, 12.5, 2.7],
[-1.2, 1.9, 12.3, 2.6],
[-2.4, -1.7, 2.5, 2.7],
[1.1, 2.4, -3.5, 5.6]]
self._one_hot_labels_np = [[0, 1, 0, 0],
[0, 0, 1, 0],
[1, 0, 0, 0],
[1, 0, 0, 0]]
self._weights = 2.3
self._discriminator_gen_classification_logits = constant_op.constant(
self._discriminator_gen_classification_logits_np, dtype=dtypes.float32)
self._discriminator_real_classification_logits = constant_op.constant(
self._discriminator_real_classification_logits_np, dtype=dtypes.float32)
self._one_hot_labels = constant_op.constant(
self._one_hot_labels_np, dtype=dtypes.float32)
self._generator_kwargs = {
'discriminator_gen_classification_logits':
self._discriminator_gen_classification_logits,
'one_hot_labels': self._one_hot_labels,
}
self._discriminator_kwargs = {
'discriminator_gen_classification_logits':
self._discriminator_gen_classification_logits,
'discriminator_real_classification_logits':
self._discriminator_real_classification_logits,
'one_hot_labels': self._one_hot_labels,
}
self._generator_loss_name = 'softmax_cross_entropy_loss/value'
self._discriminator_loss_name = 'add'
self._expected_g_loss = 3.84974
self._expected_d_loss = 9.43950
def test_generator_all_correct(self):
loss = self._g_loss_fn(**self._generator_kwargs)
self.assertEqual(
self._discriminator_gen_classification_logits.dtype, loss.dtype)
self.assertEqual(self._generator_loss_name, loss.op.name)
with self.test_session():
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
def test_discriminator_all_correct(self):
loss = self._d_loss_fn(**self._discriminator_kwargs)
self.assertEqual(
self._discriminator_gen_classification_logits.dtype, loss.dtype)
self.assertEqual(self._discriminator_loss_name, loss.op.name)
with self.test_session():
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
def test_generator_loss_collection(self):
self.assertEqual(0, len(ops.get_collection('collection')))
self._g_loss_fn(loss_collection='collection', **self._generator_kwargs)
self.assertEqual(1, len(ops.get_collection('collection')))
def test_discriminator_loss_collection(self):
self.assertEqual(0, len(ops.get_collection('collection')))
self._d_loss_fn(loss_collection='collection', **self._discriminator_kwargs)
self.assertEqual(1, len(ops.get_collection('collection')))
def test_generator_no_reduction(self):
loss = self._g_loss_fn(
reduction=tf_losses.Reduction.NONE, **self._generator_kwargs)
self.assertAllEqual([4], loss.shape)
def test_discriminator_no_reduction(self):
loss = self._d_loss_fn(
reduction=tf_losses.Reduction.NONE, **self._discriminator_kwargs)
self.assertAllEqual([4], loss.shape)
def test_generator_patch(self):
patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in
self._generator_kwargs.items()}
loss = self._g_loss_fn(**patch_args)
with self.test_session():
self.assertAlmostEqual(self._expected_g_loss, loss.eval(), 5)
def test_discriminator_patch(self):
patch_args = {x: array_ops.reshape(y, [2, 2, 4]) for x, y in
self._discriminator_kwargs.items()}
loss = self._d_loss_fn(**patch_args)
with self.test_session():
self.assertAlmostEqual(self._expected_d_loss, loss.eval(), 5)
def test_generator_loss_with_placeholder_for_logits(self):
gen_logits = array_ops.placeholder(dtypes.float32, shape=(None, 4))
one_hot_labels = array_ops.placeholder(dtypes.int32, shape=(None, 4))
loss = self._g_loss_fn(gen_logits, one_hot_labels)
with self.test_session() as sess:
loss = sess.run(
loss, feed_dict={
gen_logits: self._discriminator_gen_classification_logits_np,
one_hot_labels: self._one_hot_labels_np,
})
self.assertAlmostEqual(self._expected_g_loss, loss, 5)
def test_discriminator_loss_with_placeholder_for_logits_and_weights(self):
gen_logits = array_ops.placeholder(dtypes.float32, shape=(None, 4))
real_logits = array_ops.placeholder(dtypes.float32, shape=(None, 4))
one_hot_labels = array_ops.placeholder(dtypes.int32, shape=(None, 4))
loss = self._d_loss_fn(gen_logits, real_logits, one_hot_labels)
with self.test_session() as sess:
loss = sess.run(
loss, feed_dict={
gen_logits: self._discriminator_gen_classification_logits_np,
real_logits: self._discriminator_real_classification_logits_np,
one_hot_labels: self._one_hot_labels_np,
})
self.assertAlmostEqual(self._expected_d_loss, loss, 5)
def test_generator_with_python_scalar_weight(self):
loss = self._g_loss_fn(weights=self._weights, **self._generator_kwargs)
with self.test_session():
self.assertAlmostEqual(self._expected_g_loss * self._weights,
loss.eval(), 4)
def test_discriminator_with_python_scalar_weight(self):
loss = self._d_loss_fn(
real_weights=self._weights, generated_weights=self._weights,
**self._discriminator_kwargs)
with self.test_session():
self.assertAlmostEqual(self._expected_d_loss * self._weights,
loss.eval(), 4)
def test_generator_with_scalar_tensor_weight(self):
loss = self._g_loss_fn(
weights=constant_op.constant(self._weights), **self._generator_kwargs)
with self.test_session():
self.assertAlmostEqual(self._expected_g_loss * self._weights,
loss.eval(), 4)
def test_discriminator_with_scalar_tensor_weight(self):
weights = constant_op.constant(self._weights)
loss = self._d_loss_fn(real_weights=weights, generated_weights=weights,
**self._discriminator_kwargs)
with self.test_session():
self.assertAlmostEqual(self._expected_d_loss * self._weights,
loss.eval(), 4)
def test_generator_add_summaries(self):
self.assertEqual(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
self._g_loss_fn(add_summaries=True, **self._generator_kwargs)
self.assertLess(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
def test_discriminator_add_summaries(self):
self.assertEqual(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
self._d_loss_fn(add_summaries=True, **self._discriminator_kwargs)
self.assertLess(0, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
class _PenaltyTest(object):
def test_all_correct(self):
loss = self._penalty_fn(**self._kwargs)
self.assertEqual(self._expected_dtype, loss.dtype)
self.assertEqual(self._expected_op_name, loss.op.name)
with self.test_session():
variables.global_variables_initializer().run()
self.assertAlmostEqual(self._expected_loss, loss.eval(), 6)
def test_loss_collection(self):
self.assertEqual(0, len(ops.get_collection('collection')))
self._penalty_fn(loss_collection='collection', **self._kwargs)
self.assertEqual(1, len(ops.get_collection('collection')))
def test_no_reduction(self):
loss = self._penalty_fn(reduction=tf_losses.Reduction.NONE, **self._kwargs)
self.assertAllEqual([self._batch_size], loss.shape)
def test_python_scalar_weight(self):
loss = self._penalty_fn(weights=2.3, **self._kwargs)
with self.test_session():
variables.global_variables_initializer().run()
self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3)
def test_scalar_tensor_weight(self):
loss = self._penalty_fn(weights=constant_op.constant(2.3), **self._kwargs)
with self.test_session():
variables.global_variables_initializer().run()
self.assertAlmostEqual(self._expected_loss * 2.3, loss.eval(), 3)
class GradientPenaltyTest(test.TestCase, _PenaltyTest):
"""Tests for wasserstein_gradient_penalty."""
def setUp(self):
super(GradientPenaltyTest, self).setUp()
self._penalty_fn = tfgan_losses.wasserstein_gradient_penalty
self._generated_data_np = [[3.1, 2.3, -12.3, 32.1]]
self._real_data_np = [[-12.3, 23.2, 16.3, -43.2]]
self._expected_dtype = dtypes.float32
with variable_scope.variable_scope('fake_scope') as self._scope:
self._discriminator_fn(0.0, 0.0)
self._kwargs = {
'generated_data': constant_op.constant(
self._generated_data_np, dtype=self._expected_dtype),
'real_data': constant_op.constant(
self._real_data_np, dtype=self._expected_dtype),
'generator_inputs': None,
'discriminator_fn': self._discriminator_fn,
'discriminator_scope': self._scope,
}
self._expected_loss = 9.00000
self._expected_op_name = 'weighted_loss/value'
self._batch_size = 1
def _discriminator_fn(self, inputs, _):
return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs
def test_loss_with_placeholder(self):
generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None))
real_data = array_ops.placeholder(dtypes.float32, shape=(None, None))
loss = tfgan_losses.wasserstein_gradient_penalty(
generated_data,
real_data,
self._kwargs['generator_inputs'],
self._kwargs['discriminator_fn'],
self._kwargs['discriminator_scope'])
self.assertEqual(generated_data.dtype, loss.dtype)
with self.test_session() as sess:
variables.global_variables_initializer().run()
loss = sess.run(loss,
feed_dict={
generated_data: self._generated_data_np,
real_data: self._real_data_np,
})
self.assertAlmostEqual(self._expected_loss, loss, 5)
def test_reuses_scope(self):
"""Test that gradient penalty reuses discriminator scope."""
num_vars = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
tfgan_losses.wasserstein_gradient_penalty(**self._kwargs)
self.assertEqual(
num_vars, len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)))
class MutualInformationPenaltyTest(test.TestCase, _PenaltyTest):
"""Tests for mutual_information_penalty."""
def setUp(self):
super(MutualInformationPenaltyTest, self).setUp()
self._penalty_fn = tfgan_losses.mutual_information_penalty
self._structured_generator_inputs = [1.0, 2.0]
self._predicted_distributions = [categorical.Categorical(logits=[1.0, 2.0]),
normal.Normal([0.0], [1.0])]
self._expected_dtype = dtypes.float32
self._kwargs = {
'structured_generator_inputs': self._structured_generator_inputs,
'predicted_distributions': self._predicted_distributions,
}
self._expected_loss = 1.61610
self._expected_op_name = 'mul'
self._batch_size = 2
class CombineAdversarialLossTest(test.TestCase):
"""Tests for combine_adversarial_loss."""
def setUp(self):
super(CombineAdversarialLossTest, self).setUp()
self._generated_data_np = [[3.1, 2.3, -12.3, 32.1]]
self._real_data_np = [[-12.3, 23.2, 16.3, -43.2]]
self._generated_data = constant_op.constant(
self._generated_data_np, dtype=dtypes.float32)
self._real_data = constant_op.constant(
self._real_data_np, dtype=dtypes.float32)
self._generated_inputs = None
self._expected_loss = 9.00000
def _test_correct_helper(self, use_weight_factor):
variable_list = [variables.Variable(1.0)]
main_loss = variable_list[0] * 2
adversarial_loss = variable_list[0] * 3
gradient_ratio_epsilon = 1e-6
if use_weight_factor:
weight_factor = constant_op.constant(2.0)
gradient_ratio = None
adv_coeff = 2.0
expected_loss = 1.0 * 2 + adv_coeff * 1.0 * 3
else:
weight_factor = None
gradient_ratio = constant_op.constant(0.5)
adv_coeff = 2.0 / (3 * 0.5 + gradient_ratio_epsilon)
expected_loss = 1.0 * 2 + adv_coeff * 1.0 * 3
combined_loss = tfgan_losses.combine_adversarial_loss(
main_loss,
adversarial_loss,
weight_factor=weight_factor,
gradient_ratio=gradient_ratio,
gradient_ratio_epsilon=gradient_ratio_epsilon,
variables=variable_list)
with self.test_session(use_gpu=True):
variables.global_variables_initializer().run()
self.assertNear(expected_loss, combined_loss.eval(), 1e-5)
def test_correct_useweightfactor(self):
self._test_correct_helper(True)
def test_correct_nouseweightfactor(self):
self._test_correct_helper(False)
def _test_no_weight_skips_adversarial_loss_helper(self, use_weight_factor):
"""Test the 0 adversarial weight or grad ratio skips adversarial loss."""
main_loss = constant_op.constant(1.0)
adversarial_loss = constant_op.constant(1.0)
weight_factor = 0.0 if use_weight_factor else None
gradient_ratio = None if use_weight_factor else 0.0
combined_loss = tfgan_losses.combine_adversarial_loss(
main_loss,
adversarial_loss,
weight_factor=weight_factor,
gradient_ratio=gradient_ratio,
gradient_summaries=False)
with self.test_session(use_gpu=True):
self.assertEqual(1.0, combined_loss.eval())
def test_no_weight_skips_adversarial_loss_useweightfactor(self):
self._test_no_weight_skips_adversarial_loss_helper(True)
def test_no_weight_skips_adversarial_loss_nouseweightfactor(self):
self._test_no_weight_skips_adversarial_loss_helper(False)
def test_stable_global_norm_avoids_overflow(self):
tensors = [array_ops.ones([4]), array_ops.ones([4, 4]) * 1e19, None]
gnorm_is_inf = math_ops.is_inf(clip_ops.global_norm(tensors))
stable_gnorm_is_inf = math_ops.is_inf(
tfgan_losses._numerically_stable_global_norm(tensors))
with self.test_session(use_gpu=True):
self.assertTrue(gnorm_is_inf.eval())
self.assertFalse(stable_gnorm_is_inf.eval())
def test_stable_global_norm_unchanged(self):
"""Test that preconditioning doesn't change global norm value."""
random_seed.set_random_seed(1234)
tensors = [random_ops.random_uniform([3]*i, -10.0, 10.0) for i in range(6)]
gnorm = clip_ops.global_norm(tensors)
precond_gnorm = tfgan_losses._numerically_stable_global_norm(tensors)
with self.test_session(use_gpu=True) as sess:
for _ in range(10): # spot check closeness on more than one sample.
gnorm_np, precond_gnorm_np = sess.run([gnorm, precond_gnorm])
self.assertNear(gnorm_np, precond_gnorm_np, 1e-5)
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,27 @@
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFGAN grouped API. Please see README.md for details and usage."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.contrib.gan.python.losses.python import losses_impl
from tensorflow.contrib.gan.python.losses.python.losses_impl import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
remove_undocumented(__name__, losses_impl.__all__)

View File

@ -0,0 +1,27 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFGAN utilities for loss functions that accept GANModel namedtuples."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.contrib.gan.python.losses.python import tuple_losses_impl
from tensorflow.contrib.gan.python.losses.python.tuple_losses_impl import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
__all__ = tuple_losses_impl.__all__
remove_undocumented(__name__, __all__)

View File

@ -0,0 +1,203 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFGAN utilities for loss functions that accept GANModel namedtuples.
Example:
```python
# `tfgan.losses.args` losses take individual arguments.
w_loss = tfgan.losses.args.wasserstein_discriminator_loss(
discriminator_real_outputs,
discriminator_gen_outputs)
# `tfgan.losses` losses take GANModel namedtuples.
w_loss2 = tfgan.losses.wasserstein_discriminator_loss(gan_model)
```
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.gan.python.losses.python import losses_impl
from tensorflow.python.util import tf_inspect
__all__ = [
'acgan_discriminator_loss',
'acgan_generator_loss',
'least_squares_discriminator_loss',
'least_squares_generator_loss',
'modified_discriminator_loss',
'modified_generator_loss',
'minimax_discriminator_loss',
'minimax_generator_loss',
'wasserstein_discriminator_loss',
'wasserstein_generator_loss',
'wasserstein_gradient_penalty',
'mutual_information_penalty',
'combine_adversarial_loss',
]
def _args_to_gan_model(loss_fn):
"""Converts a loss taking individual args to one taking a GANModel namedtuple.
The new function has the same name as the original one.
Args:
loss_fn: A python function taking a `GANModel` object and returning a loss
Tensor calculated from that object. The shape of the loss depends on
`reduction`.
Returns:
A new function that takes a GANModel namedtuples and returns the same loss.
"""
# Match arguments in `loss_fn` to elements of `namedtuple`.
# TODO(joelshor): Properly handle `varargs` and `keywords`.
argspec = tf_inspect.getargspec(loss_fn)
defaults = argspec.defaults or []
required_args = set(argspec.args[:-len(defaults)])
args_with_defaults = argspec.args[-len(defaults):]
default_args_dict = dict(zip(args_with_defaults, defaults))
def new_loss_fn(gan_model, **kwargs): # pylint:disable=missing-docstring
gan_model_dict = gan_model._asdict()
# Make sure non-tuple required args are supplied.
args_from_tuple = set(argspec.args).intersection(set(gan_model._fields))
required_args_not_from_tuple = required_args - args_from_tuple
for arg in required_args_not_from_tuple:
if arg not in kwargs:
raise ValueError('`%s` must be supplied to %s loss function.' % (
arg, loss_fn.__name__))
# Make sure tuple args aren't also supplied as keyword args.
ambiguous_args = set(gan_model._fields).intersection(set(kwargs.keys()))
if ambiguous_args:
raise ValueError(
'The following args are present in both the tuple and keyword args '
'for %s: %s' % (loss_fn.__name__, ambiguous_args))
# Add required args to arg dictionary.
required_args_from_tuple = required_args.intersection(args_from_tuple)
for arg in required_args_from_tuple:
assert arg not in kwargs
kwargs[arg] = gan_model_dict[arg]
# Add arguments that have defaults.
for arg in default_args_dict:
val_from_tuple = gan_model_dict[arg] if arg in gan_model_dict else None
val_from_kwargs = kwargs[arg] if arg in kwargs else None
assert not (val_from_tuple is not None and val_from_kwargs is not None)
kwargs[arg] = (val_from_tuple if val_from_tuple is not None else
val_from_kwargs if val_from_kwargs is not None else
default_args_dict[arg])
return loss_fn(**kwargs)
new_docstring = """The gan_model version of %s.""" % loss_fn.__name__
new_loss_fn.__docstring__ = new_docstring
new_loss_fn.__name__ = loss_fn.__name__
new_loss_fn.__module__ = loss_fn.__module__
return new_loss_fn
# Wasserstein losses from `Wasserstein GAN` (https://arxiv.org/abs/1701.07875).
wasserstein_generator_loss = _args_to_gan_model(
losses_impl.wasserstein_generator_loss)
wasserstein_discriminator_loss = _args_to_gan_model(
losses_impl.wasserstein_discriminator_loss)
wasserstein_gradient_penalty = _args_to_gan_model(
losses_impl.wasserstein_gradient_penalty)
# ACGAN losses from `Conditional Image Synthesis With Auxiliary Classifier GANs`
# (https://arxiv.org/abs/1610.09585).
acgan_discriminator_loss = _args_to_gan_model(
losses_impl.acgan_discriminator_loss)
acgan_generator_loss = _args_to_gan_model(
losses_impl.acgan_generator_loss)
# Original losses from `Generative Adversarial Nets`
# (https://arxiv.org/abs/1406.2661).
minimax_discriminator_loss = _args_to_gan_model(
losses_impl.minimax_discriminator_loss)
minimax_generator_loss = _args_to_gan_model(
losses_impl.minimax_generator_loss)
modified_discriminator_loss = _args_to_gan_model(
losses_impl.modified_discriminator_loss)
modified_generator_loss = _args_to_gan_model(
losses_impl.modified_generator_loss)
# Least Squares loss from `Least Squares Generative Adversarial Networks`
# (https://arxiv.org/abs/1611.04076).
least_squares_generator_loss = _args_to_gan_model(
losses_impl.least_squares_generator_loss)
least_squares_discriminator_loss = _args_to_gan_model(
losses_impl.least_squares_discriminator_loss)
# InfoGAN loss from `InfoGAN: Interpretable Representation Learning by
# `Information Maximizing Generative Adversarial Nets`
# https://arxiv.org/abs/1606.03657
mutual_information_penalty = _args_to_gan_model(
losses_impl.mutual_information_penalty)
def combine_adversarial_loss(gan_loss,
gan_model,
non_adversarial_loss,
weight_factor=None,
gradient_ratio=None,
gradient_ratio_epsilon=1e-6,
scalar_summaries=True,
gradient_summaries=True):
"""Combine adversarial loss and main loss.
Uses `combine_adversarial_loss` to combine the losses, and returns
a modified GANLoss namedtuple.
Args:
gan_loss: A GANLoss namedtuple. Assume the GANLoss.generator_loss is the
adversarial loss.
gan_model: A GANModel namedtuple. Used to access the generator's variables.
non_adversarial_loss: Same as `main_loss` from
`combine_adversarial_loss`.
weight_factor: Same as `weight_factor` from
`combine_adversarial_loss`.
gradient_ratio: Same as `gradient_ratio` from
`combine_adversarial_loss`.
gradient_ratio_epsilon: Same as `gradient_ratio_epsilon` from
`combine_adversarial_loss`.
scalar_summaries: Same as `scalar_summaries` from
`combine_adversarial_loss`.
gradient_summaries: Same as `gradient_summaries` from
`combine_adversarial_loss`.
Returns:
A modified GANLoss namedtuple, with `non_adversarial_loss` included
appropriately.
"""
combined_loss = losses_impl.combine_adversarial_loss(
non_adversarial_loss,
gan_loss.generator_loss,
weight_factor,
gradient_ratio,
gradient_ratio_epsilon,
gan_model.generator_variables,
scalar_summaries,
gradient_summaries)
return gan_loss._replace(generator_loss=combined_loss)

View File

@ -0,0 +1,134 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for contrib.gan.python.losses."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import numpy as np
from tensorflow.contrib.gan.python.losses.python import tuple_losses_impl as tfgan_losses
from tensorflow.python.platform import test
class ArgsToGanModelTest(test.TestCase):
def test_args_to_gan_model(self):
"""Test `_args_to_gan_model`."""
tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg3'])
def args_loss(arg1, arg2, arg3=3, arg4=4):
return arg1 + arg2 + arg3 + arg4
gan_model_loss = tfgan_losses._args_to_gan_model(args_loss)
# Value is correct.
self.assertEqual(1 + 2 + 5 + 6,
gan_model_loss(tuple_type(1, 2), arg2=5, arg4=6))
# Uses tuple argument with defaults.
self.assertEqual(1 + 5 + 3 + 7,
gan_model_loss(tuple_type(1, None), arg2=5, arg4=7))
# Uses non-tuple argument with defaults.
self.assertEqual(1 + 5 + 2 + 4,
gan_model_loss(tuple_type(1, 2), arg2=5))
# Requires non-tuple, non-default arguments.
with self.assertRaisesRegexp(ValueError, '`arg2` must be supplied'):
gan_model_loss(tuple_type(1, 2))
# Can't pass tuple argument outside tuple.
with self.assertRaisesRegexp(
ValueError, 'present in both the tuple and keyword args'):
gan_model_loss(tuple_type(1, 2), arg2=1, arg3=5)
def test_args_to_gan_model_name(self):
"""Test that `_args_to_gan_model` produces correctly named functions."""
def loss_fn(x):
return x
new_loss_fn = tfgan_losses._args_to_gan_model(loss_fn)
self.assertEqual('loss_fn', new_loss_fn.__name__)
self.assertTrue('The gan_model version of' in new_loss_fn.__docstring__)
def test_tuple_respects_optional_args(self):
"""Test that optional args can be changed with tuple losses."""
tuple_type = collections.namedtuple('fake_type', ['arg1', 'arg2'])
def args_loss(arg1, arg2, arg3=3):
return arg1 + 2 * arg2 + 3 * arg3
loss_fn = tfgan_losses._args_to_gan_model(args_loss)
loss = loss_fn(tuple_type(arg1=-1, arg2=2), arg3=4)
# If `arg3` were not set properly, this value would be different.
self.assertEqual(-1 + 2 * 2 + 3 * 4, loss)
class ConsistentLossesTest(test.TestCase):
pass
def _tuple_from_dict(args_dict):
return collections.namedtuple('Tuple', args_dict.keys())(**args_dict)
def add_loss_consistency_test(test_class, loss_name_str, loss_args):
tuple_loss = getattr(tfgan_losses, loss_name_str)
arg_loss = getattr(tfgan_losses.losses_impl, loss_name_str)
def consistency_test(self):
self.assertEqual(arg_loss.__name__, tuple_loss.__name__)
with self.test_session():
self.assertEqual(arg_loss(**loss_args).eval(),
tuple_loss(_tuple_from_dict(loss_args)).eval())
test_name = 'test_loss_consistency_%s' % loss_name_str
setattr(test_class, test_name, consistency_test)
# A list of consistency tests which need to be manually written.
manual_tests = [
'acgan_discriminator_loss',
'acgan_generator_loss',
'combine_adversarial_loss',
'mutual_information_penalty',
'wasserstein_gradient_penalty',
]
discriminator_keyword_args = {
'discriminator_real_outputs': np.array([[3.4, 2.3, -2.3],
[6.3, -2.1, 0.2]]),
'discriminator_gen_outputs': np.array([[6.2, -1.5, 2.3],
[-2.9, -5.1, 0.1]]),
}
generator_keyword_args = {
'discriminator_gen_outputs': np.array([[6.2, -1.5, 2.3],
[-2.9, -5.1, 0.1]]),
}
if __name__ == '__main__':
for loss_name in tfgan_losses.__all__:
if loss_name in manual_tests: continue
keyword_args = (generator_keyword_args if 'generator' in loss_name else
discriminator_keyword_args)
add_loss_consistency_test(ConsistentLossesTest, loss_name, keyword_args)
test.main()

View File

@ -127,7 +127,6 @@ py_test(
name = "sdca_estimator_test",
srcs = ["python/sdca_estimator_test.py"],
srcs_version = "PY2AND3",
tags = ["notsan"],
deps = [
":sdca_estimator_py",
"//tensorflow/contrib/layers:layers_py",

View File

@ -61,6 +61,7 @@ tf_kernel_library(
srcs = ["kernels/hyperplane_lsh_probes.cc"],
deps = [
":hyperplane_lsh_probes",
":nearest_neighbor_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//third_party/eigen3",

View File

@ -20,6 +20,7 @@
@@ARModel
@@CSVReader
@@NumpyReader
@@RandomWindowInputFn
@@WholeDatasetInputFn
@@predict_continuation_input_fn

View File

@ -1788,6 +1788,7 @@ tf_cuda_library(
"common_runtime/process_util.cc",
"common_runtime/renamed_device.cc",
"common_runtime/rendezvous_mgr.cc",
"common_runtime/rendezvous_util.cc",
"common_runtime/resource_variable_read_optimizer.cc",
"common_runtime/session.cc",
"common_runtime/session_factory.cc",
@ -1831,6 +1832,7 @@ tf_cuda_library(
"common_runtime/profile_handler.h",
"common_runtime/renamed_device.h",
"common_runtime/rendezvous_mgr.h",
"common_runtime/rendezvous_util.h",
"common_runtime/session_factory.h",
"common_runtime/graph_execution_state.h",
"common_runtime/placer.h",
@ -2675,29 +2677,29 @@ tf_cc_test(
srcs = ["common_runtime/process_function_library_runtime_test.cc"],
linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":core",
":core_cpu",
":core_cpu_internal",
":direct_session_internal",
":framework",
":framework_internal",
":lib",
":lib_internal",
":ops",
":protos_all_cc",
":test",
":test_main",
":testlib",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
"//tensorflow/cc:functional_ops",
"//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:matmul_op",
"//tensorflow/core/kernels:shape_ops",
"//third_party/eigen3",
],
)
tf_cc_test(
name = "common_runtime_rendezvous_util_test",
size = "small",
srcs = ["common_runtime/rendezvous_util_test.cc"],
linkstatic = tf_kernel_tests_linkstatic(),
deps = [
":core_cpu_internal",
":lib",
":test",
":test_main",
],
)

View File

@ -213,6 +213,9 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
FunctionBody** g_body);
bool IsLocalTarget(const AttrSlice& attrs);
AttrValueMap FixAttrs(const AttrSlice& attrs);
void RunRemote(const Options& opts, Handle handle,
gtl::ArraySlice<Tensor> args, std::vector<Tensor>* rets,
Executor::Args* exec_args, Item* item, DoneCallback done);
TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryRuntimeImpl);
};
@ -557,52 +560,130 @@ Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
return Status::OK();
}
void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets,
Executor::Args* exec_args,
Item* item, DoneCallback done) {
FunctionCallFrame* frame = exec_args->call_frame;
string target_device = parent_->GetDeviceName(handle);
string source_device = opts.source_device;
Rendezvous* rendezvous = opts.rendezvous;
// TODO(rohanj): Handle alloc_attrs in Rendezvous::Args.
Rendezvous::Args rendez_args;
Status s =
parent_->GetDeviceContext(target_device, &rendez_args.device_context);
if (!s.ok()) {
delete frame;
delete exec_args;
done(s);
return;
}
// The ProcFLR sends the arguments to the function from the source_device to
// the target_device. So here we receive those arguments. Similarly, when the
// computation is done and stored in *rets, we send the return values back
// to the source_device (caller) so that the ProcFLR can receive them later.
std::vector<Tensor>* remote_args = new std::vector<Tensor>;
ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
source_device, target_device, "arg_", args.size(), rendez_args,
rendezvous, remote_args,
[frame, remote_args, item, source_device, target_device, rendezvous,
rendez_args, rets, done, exec_args](const Status& status) {
Status s = status;
s = frame->SetArgs(*remote_args);
if (!s.ok()) {
delete frame;
delete remote_args;
delete exec_args;
done(s);
return;
}
item->exec->RunAsync(
*exec_args,
[item, frame, rets, done, source_device, target_device, rendezvous,
rendez_args, remote_args, exec_args](const Status& status) {
item->Unref();
Status s = status;
if (s.ok()) {
s = frame->ConsumeRetvals(rets);
}
delete frame;
if (!s.ok()) {
delete remote_args;
delete exec_args;
done(s);
return;
}
s = ProcessFunctionLibraryRuntime::SendTensors(
target_device, source_device, "ret_", *rets, rendez_args,
rendezvous);
delete remote_args;
delete exec_args;
done(s);
});
});
}
void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets,
DoneCallback done) {
if (opts.cancellation_manager && opts.cancellation_manager->IsCancelled()) {
return done(errors::Cancelled(""));
done(errors::Cancelled(""));
return;
}
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
return parent_->Run(opts, handle, args, rets, done);
parent_->Run(opts, handle, args, rets, done);
return;
}
const FunctionBody* fbody = GetFunctionBody(handle);
FunctionCallFrame* frame =
new FunctionCallFrame(fbody->arg_types, fbody->ret_types);
Status s = frame->SetArgs(args);
if (!s.ok()) {
delete frame;
return done(s);
}
Item* item = nullptr;
s = GetOrCreateItem(handle, &item);
Status s = GetOrCreateItem(handle, &item);
if (!s.ok()) {
delete frame;
return done(s);
done(s);
return;
}
DCHECK(opts.runner != nullptr);
Executor::Args exec_args;
Executor::Args* exec_args = new Executor::Args;
// Inherit the step_id from the caller.
exec_args.step_id = opts.step_id;
exec_args.rendezvous = opts.rendezvous;
exec_args.stats_collector = opts.stats_collector;
exec_args.call_frame = frame;
exec_args.cancellation_manager = opts.cancellation_manager;
exec_args.step_container = opts.step_container;
exec_args.runner = *opts.runner;
exec_args->step_id = opts.step_id;
exec_args->rendezvous = opts.rendezvous;
exec_args->stats_collector = opts.stats_collector;
exec_args->call_frame = frame;
exec_args->cancellation_manager = opts.cancellation_manager;
exec_args->step_container = opts.step_container;
exec_args->runner = *opts.runner;
if (opts.remote_execution) {
RunRemote(opts, handle, args, rets, exec_args, item, done);
return;
}
s = frame->SetArgs(args);
if (!s.ok()) {
delete frame;
delete exec_args;
done(s);
return;
}
item->exec->RunAsync(
// Executor args
exec_args,
*exec_args,
// Done callback.
[item, frame, rets, done](const Status& status) {
[item, frame, rets, done, exec_args](const Status& status) {
item->Unref();
Status s = status;
if (s.ok()) {
s = frame->GetRetvals(rets);
s = frame->ConsumeRetvals(rets);
}
delete frame;
delete exec_args;
done(s);
});
}

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/executor.h"
#include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/op.h"
@ -155,6 +156,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
}
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
FunctionLibraryRuntime::Options opts,
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
std::atomic<int32> call_count(0);
std::function<void(std::function<void()>)> runner =
@ -164,7 +166,6 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
};
Notification done;
FunctionLibraryRuntime::Options opts;
opts.runner = &runner;
std::vector<Tensor> out;
Status status;
@ -205,7 +206,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
if (!status.ok()) {
return status;
}
return Run(flr, handle, args, std::move(rets));
FunctionLibraryRuntime::Options opts;
return Run(flr, handle, opts, args, std::move(rets));
}
std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr,
@ -963,15 +965,21 @@ TEST_F(FunctionLibraryRuntimeTest, CrossDevice) {
{{"_target", "/job:localhost/replica:0/task:0/cpu:1"}}, &handle));
Tensor y;
FunctionLibraryRuntime::Options opts;
opts.rendezvous = new IntraProcessRendezvous(device_mgr_.get());
opts.source_device = "/device:CPU:1";
// Run on flr1_, flr2_ and make sure that the device it ran on was cpu:1.
TF_CHECK_OK(Run(flr1_, handle, {}, {&y}));
TF_CHECK_OK(Run(flr1_, handle, opts, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"},
TensorShape({})));
TF_CHECK_OK(Run(flr2_, handle, {}, {&y}));
opts.remote_execution = true;
opts.source_device = "/job:localhost/replica:0/task:0/cpu:2";
TF_CHECK_OK(Run(flr2_, handle, opts, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:localhost/replica:0/task:0/cpu:1"},
TensorShape({})));
opts.rendezvous->Unref();
}
namespace {

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/rendezvous_util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
namespace tensorflow {
@ -57,6 +58,7 @@ ProcessFunctionLibraryRuntime::ProcessFunctionLibraryRuntime(
}
}
/* static */
string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
const AttrSlice& attrs) {
const AttrValue* value;
@ -66,6 +68,63 @@ string ProcessFunctionLibraryRuntime::ObtainFunctionTarget(
return value->s();
}
/* static */
Status ProcessFunctionLibraryRuntime::SendTensors(
const string& source_device, const string& target_device,
const string& key_prefix, gtl::ArraySlice<Tensor> tensors_to_send,
const Rendezvous::Args& args, Rendezvous* rendezvous) {
std::vector<string> keys;
for (int i = 0; i < tensors_to_send.size(); ++i) {
string name = strings::StrCat(key_prefix, i);
string key = Rendezvous::CreateKey(source_device, i, target_device, name,
FrameAndIter(0, 0));
keys.push_back(key);
}
TF_RETURN_IF_ERROR(
SendTensorsToRendezvous(rendezvous, args, keys, tensors_to_send));
return Status::OK();
}
/* static */
void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
const string& source_device, const string& target_device,
const string& key_prefix, int64 num_tensors, const Rendezvous::Args& args,
Rendezvous* rendezvous, std::vector<Tensor>* received_tensors,
const StatusCallback& done) {
std::vector<string> keys;
for (int64 i = 0; i < num_tensors; ++i) {
string name = strings::StrCat(key_prefix, i);
string key = Rendezvous::CreateKey(source_device, i, target_device, name,
FrameAndIter(0, 0));
keys.push_back(key);
}
RecvOutputsFromRendezvousAsync(
rendezvous, args, keys, received_tensors,
[done](const Status& status) { done(status); });
}
Status ProcessFunctionLibraryRuntime::GetDeviceContext(
const string& device_name, DeviceContext** device_context) {
*device_context = nullptr;
FunctionLibraryRuntime* flr = GetFLR(device_name);
if (flr == nullptr) {
return errors::InvalidArgument("Device name: ", device_name, " not found.");
}
Device* device = flr->device();
string device_type = device->parsed_name().type;
if (device_type == "CPU") return Status::OK();
if (device_type == "GPU") {
auto* dev_info = flr->device()->tensorflow_gpu_device_info();
if (dev_info) {
*device_context = dev_info->default_context;
return Status::OK();
}
}
return errors::Internal("Device type: ", device_type,
" is currently unsupported for remote ",
"function executions");
}
FunctionLibraryRuntime* ProcessFunctionLibraryRuntime::GetFLR(
const string& device_name) {
if (flr_map_.find(device_name) == flr_map_.end()) {
@ -105,6 +164,7 @@ FunctionLibraryRuntime::LocalHandle
ProcessFunctionLibraryRuntime::GetHandleOnDevice(
const string& device_name, FunctionLibraryRuntime::Handle handle) {
mutex_lock l(mu_);
CHECK_LE(handle, function_data_.size());
std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
function_data_[handle];
if (p.first != device_name) {
@ -113,6 +173,15 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice(
return p.second;
}
string ProcessFunctionLibraryRuntime::GetDeviceName(
FunctionLibraryRuntime::Handle handle) {
mutex_lock l(mu_);
CHECK_LE(handle, function_data_.size());
std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
function_data_[handle];
return p.first;
}
Status ProcessFunctionLibraryRuntime::Instantiate(
const string& function_name, AttrSlice attrs,
FunctionLibraryRuntime::Handle* handle) {
@ -129,15 +198,58 @@ void ProcessFunctionLibraryRuntime::Run(
const FunctionLibraryRuntime::Options& opts,
FunctionLibraryRuntime::Handle handle, gtl::ArraySlice<Tensor> args,
std::vector<Tensor>* rets, FunctionLibraryRuntime::DoneCallback done) {
if (!opts.remote_execution) {
done(errors::InvalidArgument(
"ProcessFunctionLibraryRuntime::Run should only be called when there ",
"is a remote execution."));
return;
}
FunctionLibraryRuntime* flr = nullptr;
string target_device;
{
mutex_lock l(mu_);
CHECK_LE(handle, function_data_.size());
std::pair<string, FunctionLibraryRuntime::LocalHandle> p =
function_data_[handle];
target_device = p.first;
flr = GetFLR(p.first);
}
if (flr != nullptr) {
return flr->Run(opts, handle, args, rets, std::move(done));
auto rendezvous = opts.rendezvous;
string source_device = opts.source_device;
Rendezvous::Args rendez_args;
Status s = GetDeviceContext(source_device, &rendez_args.device_context);
if (!s.ok()) {
done(s);
return;
}
// Send the args over to the target device.
s = SendTensors(source_device, target_device, "arg_", args, rendez_args,
rendezvous);
if (!s.ok()) {
done(s);
return;
}
std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
flr->Run(opts, handle, args, remote_rets,
[source_device, target_device, rendezvous, remote_rets, rets, done,
rendez_args](const Status& status) {
if (!status.ok()) {
delete remote_rets;
done(status);
return;
}
int64 num_returns = remote_rets->size();
delete remote_rets;
// Now receive the return values from the target.
ReceiveTensorsAsync(target_device, source_device, "ret_",
num_returns, rendez_args, rendezvous, rets,
done);
});
} else {
done(errors::Internal("Could not find device"));
return;
}
}

View File

@ -45,6 +45,31 @@ class ProcessFunctionLibraryRuntime {
// attribute, returns "". Canonicalizes the device name.
static string ObtainFunctionTarget(const AttrSlice& attrs);
// Sends `tensors_to_send` from `source_device` to `target_device` using
// `rendezvous`. `key_prefix` is used as a prefix for the keys sent to the
// Rendezvous. Method takes references on each of the `tensors_to_send`.
// Method doesn't block.
static Status SendTensors(const string& source_device,
const string& target_device,
const string& key_prefix,
gtl::ArraySlice<Tensor> tensors_to_send,
const Rendezvous::Args& args,
Rendezvous* rendezvous);
typedef std::function<void(const Status&)> StatusCallback;
// Receives `received_tensors` from `target_device` (originally sent from
// `source_device`) using `rendezvous`. Uses `key_prefix` to construct the
// keys to be retrieved. Method doesn't block and calls `done` when
// `num_tensors` are fetched.
static void ReceiveTensorsAsync(const string& source_device,
const string& target_device,
const string& key_prefix, int64 num_tensors,
const Rendezvous::Args& args,
Rendezvous* rendezvous,
std::vector<Tensor>* received_tensors,
const StatusCallback& done);
static const char kDefaultFLRDevice[];
// Returns the FunctionLibraryRuntime for the corresponding device_name.
FunctionLibraryRuntime* GetFLR(const string& device_name);
@ -85,6 +110,17 @@ class ProcessFunctionLibraryRuntime {
FunctionLibraryRuntime::DoneCallback done);
private:
// For a given device_name, returns a DeviceContext for copying
// tensors to/from the device.
Status GetDeviceContext(const string& device_name,
DeviceContext** device_context);
// Looks up the information for the given `handle` and returns the name
// of the device where the function is registered.
string GetDeviceName(FunctionLibraryRuntime::Handle handle);
friend class FunctionLibraryRuntimeImpl;
mutable mutex mu_;
// Holds all the function invocations here.

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/function_testlib.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/platform/test.h"
@ -43,10 +44,12 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
proc_flr_.reset(new ProcessFunctionLibraryRuntime(
device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
opts));
rendezvous_ = new IntraProcessRendezvous(device_mgr_.get());
}
Status Run(const string& name, test::function::Attrs attrs,
const std::vector<Tensor>& args, std::vector<Tensor*> rets) {
Status Run(const string& name, FunctionLibraryRuntime::Options opts,
test::function::Attrs attrs, const std::vector<Tensor>& args,
std::vector<Tensor*> rets) {
FunctionLibraryRuntime::Handle handle;
Status status = proc_flr_->Instantiate(name, attrs, &handle);
if (!status.ok()) {
@ -61,7 +64,6 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
};
Notification done;
FunctionLibraryRuntime::Options opts;
opts.runner = &runner;
std::vector<Tensor> out;
proc_flr_->Run(opts, handle, args, &out, [&status, &done](const Status& s) {
@ -86,6 +88,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<FunctionLibraryDefinition> lib_def_;
std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr_;
IntraProcessRendezvous* rendezvous_;
};
TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
@ -99,6 +102,7 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, Basic) {
EXPECT_EQ(flr->device(), devices_[1]);
flr = proc_flr_->GetFLR("abc");
EXPECT_EQ(flr, nullptr);
rendezvous_->Unref();
}
TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) {
@ -118,69 +122,94 @@ TEST_F(ProcessFunctionLibraryRuntimeTest, ObtainFunctionTarget) {
TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCall) {
Init({test::function::XTimesTwo()});
FunctionLibraryRuntime::Options opts;
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
auto x = test::AsTensor<float>({1, 2, 3, 4});
Tensor y;
TF_CHECK_OK(
Run("XTimesTwo",
Run("XTimesTwo", opts,
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
{&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
rendezvous_->Unref();
}
TEST_F(ProcessFunctionLibraryRuntimeTest, SingleCallFindDevice) {
Init({test::function::FindDevice()});
FunctionLibraryRuntime::Options opts;
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:0"}},
{}, {&y}));
TF_CHECK_OK(Run("FindDevice", opts,
{{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:0"},
TensorShape({})));
rendezvous_->Unref();
}
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceXTimes) {
Init({test::function::XTimesTwo(), test::function::XTimesFour()});
auto x = test::AsTensor<float>({1, 2, 3, 4});
FunctionLibraryRuntime::Options opts;
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
TF_CHECK_OK(
Run("XTimesTwo",
Run("XTimesTwo", opts,
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
{&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({2, 4, 6, 8}));
TF_CHECK_OK(
Run("XTimesFour",
Run("XTimesFour", opts,
{{"T", DT_FLOAT}, {"_target", "/job:a/replica:0/task:0/cpu:0"}}, {x},
{&y}));
test::ExpectTensorEqual<float>(y, test::AsTensor<float>({4, 8, 12, 16}));
rendezvous_->Unref();
}
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsSameDeviceFindDevice) {
Init({test::function::FindDevice()});
FunctionLibraryRuntime::Options opts;
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}},
{}, {&y}));
TF_CHECK_OK(Run("FindDevice", opts,
{{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
TensorShape({})));
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}},
{}, {&y}));
TF_CHECK_OK(Run("FindDevice", opts,
{{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
TensorShape({})));
rendezvous_->Unref();
}
TEST_F(ProcessFunctionLibraryRuntimeTest, MultipleCallsDiffDeviceFindDevice) {
Init({test::function::FindDevice()});
FunctionLibraryRuntime::Options opts;
opts.source_device = "/job:a/replica:0/task:0/cpu:0";
opts.rendezvous = rendezvous_;
opts.remote_execution = true;
Tensor y;
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:0"}},
{}, {&y}));
TF_CHECK_OK(Run("FindDevice", opts,
{{"_target", "/job:a/replica:0/task:0/cpu:0"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:0"},
TensorShape({})));
TF_CHECK_OK(Run("FindDevice", {{"_target", "/job:a/replica:0/task:0/cpu:1"}},
{}, {&y}));
TF_CHECK_OK(Run("FindDevice", opts,
{{"_target", "/job:a/replica:0/task:0/cpu:1"}}, {}, {&y}));
test::ExpectTensorEqual<string>(
y, test::AsTensor<string>({"/job:a/replica:0/task:0/cpu:1"},
TensorShape({})));
rendezvous_->Unref();
}
} // anonymous namespace

View File

@ -0,0 +1,119 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/rendezvous_util.h"
namespace tensorflow {
Status SendTensorsToRendezvous(Rendezvous* rendezvous,
const Rendezvous::Args& args,
const std::vector<string>& keys,
gtl::ArraySlice<Tensor> tensors_to_send) {
if (keys.size() != tensors_to_send.size()) {
return errors::InvalidArgument(
"keys and tensors_to_send are not the same size. keys.size() = ",
keys.size(), "; tensors_to_send.size() = ", tensors_to_send.size());
}
Rendezvous::ParsedKey parsed;
for (int i = 0; i < keys.size(); ++i) {
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(keys[i], &parsed));
TF_RETURN_IF_ERROR(
rendezvous->Send(parsed, args, tensors_to_send[i], false));
}
return Status::OK();
}
void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
const Rendezvous::Args& args,
const std::vector<string>& keys,
std::vector<Tensor>* received_tensors,
const StatusCallback& done) {
if (keys.empty()) {
done(Status::OK());
return;
}
received_tensors->reserve(keys.size());
std::vector<std::tuple<string, Tensor*, Rendezvous::ParsedKey>> arguments;
for (int i = 0; i < keys.size(); ++i) {
Rendezvous::ParsedKey parsed;
Status s = Rendezvous::ParseKey(keys[i], &parsed);
received_tensors->push_back(Tensor());
if (!s.ok()) {
done(s);
return;
}
arguments.push_back(
std::make_tuple(keys[i], &((*received_tensors)[i]), parsed));
}
typedef struct {
mutex mu;
int64 done_counter;
Status shared_status = Status::OK();
} CallState;
CallState* call_state = new CallState;
call_state->done_counter = keys.size();
for (auto& p : arguments) {
const string& key = std::get<0>(p);
Tensor* val = std::get<1>(p);
Rendezvous::ParsedKey parsed = std::get<2>(p);
rendezvous->RecvAsync(
parsed, args,
[val, done, key, call_state](const Status& s,
const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& v, const bool is_dead) {
Status status = s;
if (status.ok()) {
*val = v;
if (is_dead) {
status = errors::InvalidArgument("The tensor returned for ", key,
" was not valid.");
}
}
call_state->mu.lock();
call_state->shared_status.Update(status);
call_state->done_counter--;
// If we are the last async call to return, call the done callback.
if (call_state->done_counter == 0) {
const Status& final_status = call_state->shared_status;
call_state->mu.unlock();
done(final_status);
delete call_state;
return;
}
call_state->mu.unlock();
});
}
}
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
const Rendezvous::Args& args) {
// Receives values requested by the caller.
Rendezvous::ParsedKey parsed;
for (auto& p : *out) {
const string& key = p.first;
Tensor* val = &p.second;
bool is_dead = false;
TF_RETURN_IF_ERROR(Rendezvous::ParseKey(key, &parsed));
TF_RETURN_IF_ERROR(rendezvous->Recv(parsed, args, val, &is_dead));
if (is_dead) {
return errors::InvalidArgument("The tensor returned for ", key,
" was not valid.");
}
}
return Status::OK();
}
} // namespace tensorflow

View File

@ -0,0 +1,44 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
#define THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_
#include <map>
#include "tensorflow/core/framework/rendezvous.h"
namespace tensorflow {
typedef std::map<string, Tensor> NamedTensors;
typedef std::function<void(const Status&)> StatusCallback;
// Uses `rendezvous` to send tensors in `in`.
Status SendTensorsToRendezvous(Rendezvous* rendezvous,
const Rendezvous::Args& args,
const std::vector<string>& keys,
gtl::ArraySlice<Tensor> tensors_to_send);
void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
const Rendezvous::Args& args,
const std::vector<string>& keys,
std::vector<Tensor>* received_tensors,
const StatusCallback& done);
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
const Rendezvous::Args& args);
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_UTIL_H_

View File

@ -0,0 +1,94 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/rendezvous_util.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class RendezvousUtilTest : public ::testing::Test {
public:
RendezvousUtilTest() { rendez_ = NewLocalRendezvous(); }
~RendezvousUtilTest() override { rendez_->Unref(); }
Rendezvous* rendez_;
};
// string -> Tensor<string>
Tensor V(const string& content) {
Tensor tensor(DT_STRING, TensorShape({}));
tensor.scalar<string>()() = content;
return tensor;
}
// Tensor<string> -> string
string V(const Tensor& tensor) {
CHECK_EQ(tensor.dtype(), DT_STRING);
CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
return tensor.scalar<string>()();
}
string MakeStringKey(const string& name) {
return Rendezvous::CreateKey(
"/job:localhost/replica:0/task:0/device:CPU:0", 0,
"/job:localhost/replica:0/task:0/device:GPU:0", name, FrameAndIter(0, 0));
}
TEST_F(RendezvousUtilTest, SendBeforeRecv) {
// Fire off sends before receive the tensors.
Rendezvous::Args args;
TF_ASSERT_OK(SendTensorsToRendezvous(
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
{V("hello1"), V("hello2")}));
Notification n;
std::vector<Tensor> received_keys;
RecvOutputsFromRendezvousAsync(
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
&received_keys, [&n](const Status& status) { n.Notify(); });
n.WaitForNotification();
EXPECT_EQ(2, received_keys.size());
EXPECT_EQ("hello1", V(received_keys[0]));
EXPECT_EQ("hello2", V(received_keys[1]));
}
TEST_F(RendezvousUtilTest, RecvBeforeSend) {
// Fire off recvs, wait for a notification in the callback.
Rendezvous::Args args;
Notification n;
std::vector<Tensor> received_keys;
RecvOutputsFromRendezvousAsync(
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
&received_keys, [&n](const Status& status) { n.Notify(); });
TF_ASSERT_OK(SendTensorsToRendezvous(
rendez_, args, {MakeStringKey("hello1"), MakeStringKey("hello2")},
{V("hello1"), V("hello2")}));
n.WaitForNotification();
EXPECT_EQ(2, received_keys.size());
EXPECT_EQ("hello1", V(received_keys[0]));
EXPECT_EQ("hello2", V(received_keys[1]));
}
} // namespace
} // namespace tensorflow

View File

@ -26,12 +26,13 @@ limitations under the License.
#include "grpc++/create_channel.h"
#else
// winsock2.h is used in grpc, so Ws2_32.lib is needed
#pragma comment(lib,"Ws2_32.lib")
#pragma comment(lib, "Ws2_32.lib")
#endif // #ifndef PLATFORM_WINDOWS
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/io/path.h"

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/memory_types.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/rendezvous_util.h"
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/framework/cancellation.h"
@ -321,116 +322,25 @@ Status GraphMgr::DeregisterAll() {
return Status::OK();
}
Status GraphMgr::SendInputsToRendezvous(Rendezvous* rendezvous,
const NamedTensors& in) {
Rendezvous::ParsedKey parsed;
for (const auto& p : in) {
const string& key = p.first;
const Tensor& val = p.second;
Status s = Rendezvous::ParseKey(key, &parsed);
if (s.ok()) {
s = rendezvous->Send(parsed, Rendezvous::Args(), val, false);
}
if (!s.ok()) {
return s;
}
}
return Status::OK();
}
Status GraphMgr::RecvOutputsFromRendezvous(Rendezvous* rendezvous,
NamedTensors* out) {
// Receives values requested by the caller.
Rendezvous::ParsedKey parsed;
for (auto& p : *out) {
const string& key = p.first;
Tensor* val = &p.second;
bool is_dead = false;
Status s = Rendezvous::ParseKey(key, &parsed);
if (s.ok()) {
s = rendezvous->Recv(parsed, Rendezvous::Args(), val, &is_dead);
}
if (is_dead) {
s = errors::InvalidArgument("The tensor returned for ", key,
" was not valid.");
}
if (!s.ok()) return s;
}
return Status::OK();
}
void GraphMgr::RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous,
NamedTensors* out,
const StatusCallback& done) {
if (out->empty()) {
done(Status::OK());
return;
}
// We compute the args before calling RecvAsync because we need to ensure that
// out isn't being iterated over after done is called, since done deletes out.
std::vector<std::tuple<string, Tensor*, Rendezvous::ParsedKey>> args;
for (auto& p : *out) {
Rendezvous::ParsedKey parsed;
Status s = Rendezvous::ParseKey(p.first, &parsed);
if (!s.ok()) {
done(s);
return;
}
args.push_back(std::make_tuple(p.first, &p.second, parsed));
}
typedef struct {
mutex mu;
int done_counter;
Status shared_status = Status::OK();
} CallState;
CallState* call_state = new CallState;
call_state->done_counter = out->size();
for (auto& p : args) {
const string& key = std::get<0>(p);
Tensor* val = std::get<1>(p);
Rendezvous::ParsedKey parsed = std::get<2>(p);
rendezvous->RecvAsync(
parsed, Rendezvous::Args(),
[val, done, key, call_state](const Status& s,
const Rendezvous::Args& send_args,
const Rendezvous::Args& recv_args,
const Tensor& v, const bool is_dead) {
Status status = s;
if (status.ok()) {
*val = v;
if (is_dead) {
status = errors::InvalidArgument("The tensor returned for ", key,
" was not valid.");
}
}
call_state->mu.lock();
call_state->shared_status.Update(status);
call_state->done_counter--;
// If we are the last async call to return, call the done callback.
if (call_state->done_counter == 0) {
const Status& final_status = call_state->shared_status;
call_state->mu.unlock();
done(final_status);
delete call_state;
return;
}
call_state->mu.unlock();
});
}
}
Status GraphMgr::SendInputs(const int64 step_id, const NamedTensors& in) {
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = SendInputsToRendezvous(rendezvous, in);
std::vector<string> keys;
std::vector<Tensor> tensors_to_send;
keys.reserve(in.size());
tensors_to_send.reserve(in.size());
for (const auto& p : in) {
keys.push_back(p.first);
tensors_to_send.push_back(p.second);
}
Status s = SendTensorsToRendezvous(rendezvous, Rendezvous::Args(), keys,
tensors_to_send);
rendezvous->Unref();
return s;
}
Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
Status s = RecvOutputsFromRendezvous(rendezvous, out);
Status s = RecvOutputsFromRendezvous(rendezvous, out, Rendezvous::Args());
rendezvous->Unref();
return s;
}
@ -438,11 +348,24 @@ Status GraphMgr::RecvOutputs(const int64 step_id, NamedTensors* out) {
void GraphMgr::RecvOutputsAsync(const int64 step_id, NamedTensors* out,
StatusCallback done) {
Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
RecvOutputsFromRendezvousAsync(rendezvous, out,
[done, rendezvous](const Status s) {
rendezvous->Unref();
done(s);
});
std::vector<string> keys;
std::vector<Tensor>* received_keys = new std::vector<Tensor>;
keys.reserve(out->size());
received_keys->reserve(out->size());
for (const auto& p : *out) {
keys.push_back(p.first);
received_keys->push_back(p.second);
}
RecvOutputsFromRendezvousAsync(
rendezvous, Rendezvous::Args(), keys, received_keys,
[done, rendezvous, received_keys, out, keys](const Status s) {
rendezvous->Unref();
for (int i = 0; i < keys.size(); ++i) {
(*out)[keys[i]] = (*received_keys)[i];
}
delete received_keys;
done(s);
});
}
void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
@ -484,7 +407,16 @@ void GraphMgr::ExecuteAsync(const string& handle, const int64 step_id,
// Sends values specified by the caller.
if (s.ok()) {
s = SendInputsToRendezvous(rendezvous, in);
std::vector<string> keys;
std::vector<Tensor> tensors_to_send;
keys.reserve(in.size());
tensors_to_send.reserve(in.size());
for (auto& p : in) {
keys.push_back(p.first);
tensors_to_send.push_back(p.second);
}
s = SendTensorsToRendezvous(rendezvous, Rendezvous::Args(), keys,
tensors_to_send);
}
if (!s.ok()) {

View File

@ -169,11 +169,6 @@ class GraphMgr {
void BuildCostModel(Item* item, StepStatsCollector* collector,
CostGraphDef* cost_graph);
Status SendInputsToRendezvous(Rendezvous* rendezvous, const NamedTensors& in);
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out);
void RecvOutputsFromRendezvousAsync(Rendezvous* rendezvous, NamedTensors* out,
const StatusCallback& done);
Status InitItem(const string& session, const GraphDef& gdef,
const GraphOptions& graph_options,
const DebugOptions& debug_options, Item* item);

View File

@ -465,7 +465,6 @@ tf_cuda_cc_test(
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + [
"no_oss", # b/62956105: port conflicts.
"noguitar", # b/64805119
],
deps = [
":grpc_channel",

View File

@ -426,6 +426,10 @@ class FunctionLibraryRuntime {
StepStatsCollector* stats_collector = nullptr;
std::function<void(std::function<void()>)>* runner = nullptr;
// Parameters for remote function execution.
bool remote_execution = false;
string source_device = ""; // Fully specified device name.
};
typedef std::function<void(const Status&)> DoneCallback;
virtual void Run(const Options& opts, Handle handle,

View File

@ -110,6 +110,37 @@ bool ConsumeAttrNumber(StringPiece* sp, int64* out) {
} \
} while (false)
bool ConsumeCompoundAttrType(StringPiece* sp, StringPiece* out) {
auto capture_begin = sp->begin();
if (sp->Consume("numbertype") || sp->Consume("numerictype") ||
sp->Consume("quantizedtype") || sp->Consume("realnumbertype") ||
sp->Consume("realnumberictype")) {
*out = StringPiece(capture_begin, sp->begin() - capture_begin);
return true;
}
return false;
}
bool ProcessCompoundType(const StringPiece type_string, AttrValue* allowed) {
if (type_string == "numbertype" || type_string == "numerictype") {
for (DataType dt : NumberTypes()) {
allowed->mutable_list()->add_type(dt);
}
} else if (type_string == "quantizedtype") {
for (DataType dt : QuantizedTypes()) {
allowed->mutable_list()->add_type(dt);
}
} else if (type_string == "realnumbertype" ||
type_string == "realnumerictype") {
for (DataType dt : RealNumberTypes()) {
allowed->mutable_list()->add_type(dt);
}
} else {
return false;
}
return true;
}
void FinalizeAttr(StringPiece spec, OpDef* op_def,
std::vector<string>* errors) {
OpDef::AttrDef* attr = op_def->add_attr();
@ -123,6 +154,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
// Read "<type>" or "list(<type>)".
bool is_list = ConsumeListPrefix(&spec);
string type;
StringPiece type_string; // Used if type == "type"
if (spec.Consume("string")) {
type = "string";
} else if (spec.Consume("int")) {
@ -139,29 +171,15 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
type = "tensor";
} else if (spec.Consume("func")) {
type = "func";
} else if (spec.Consume("numbertype") || spec.Consume("numerictype")) {
} else if (ConsumeCompoundAttrType(&spec, &type_string)) {
type = "type";
AttrValue* allowed = attr->mutable_allowed_values();
for (DataType dt : NumberTypes()) {
allowed->mutable_list()->add_type(dt);
}
} else if (spec.Consume("quantizedtype")) {
type = "type";
AttrValue* allowed = attr->mutable_allowed_values();
for (DataType dt : QuantizedTypes()) {
allowed->mutable_list()->add_type(dt);
}
} else if (spec.Consume("realnumbertype") ||
spec.Consume("realnumerictype")) {
type = "type";
AttrValue* allowed = attr->mutable_allowed_values();
for (DataType dt : RealNumberTypes()) {
allowed->mutable_list()->add_type(dt);
}
VERIFY(ProcessCompoundType(type_string, allowed),
"Expected to see a compound type, saw: ", type_string);
} else if (spec.Consume("{")) {
// e.g. "{ int32, float, bool }" or "{ \"foo\", \"bar\" }"
str_util::RemoveLeadingWhitespace(&spec);
AttrValue* allowed = attr->mutable_allowed_values();
str_util::RemoveLeadingWhitespace(&spec);
if (spec.starts_with("\"") || spec.starts_with("'")) {
type = "string"; // "{ \"foo\", \"bar\" }" or "{ 'foo', 'bar' }"
while (true) {
@ -172,8 +190,8 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
string unescaped;
string error;
VERIFY(str_util::CUnescape(escaped_string, &unescaped, &error),
"Trouble unescaping \"", escaped_string, "\", got error: ",
error);
"Trouble unescaping \"", escaped_string,
"\", got error: ", error);
allowed->mutable_list()->add_s(unescaped);
if (spec.Consume(",")) {
str_util::RemoveLeadingWhitespace(&spec);
@ -184,16 +202,19 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
break;
}
}
} else { // "{ int32, float, bool }"
} else { // "{ bool, numbertype, string }"
type = "type";
while (true) {
StringPiece type_string;
VERIFY(ConsumeAttrType(&spec, &type_string),
"Trouble parsing type string at '", spec, "'");
DataType dt;
VERIFY(DataTypeFromString(type_string, &dt),
"Unrecognized type string '", type_string, "'");
allowed->mutable_list()->add_type(dt);
if (ProcessCompoundType(type_string, allowed)) {
// Processed a compound type.
} else {
DataType dt;
VERIFY(DataTypeFromString(type_string, &dt),
"Unrecognized type string '", type_string, "'");
allowed->mutable_list()->add_type(dt);
}
if (spec.Consume(",")) {
str_util::RemoveLeadingWhitespace(&spec);
if (spec.Consume("}")) break; // Allow ending with ", }".
@ -204,7 +225,7 @@ void FinalizeAttr(StringPiece spec, OpDef* op_def,
}
}
}
} else {
} else { // if spec.Consume("{")
VERIFY(false, "Trouble parsing type string at '", spec, "'");
}
str_util::RemoveLeadingWhitespace(&spec);

View File

@ -57,8 +57,10 @@ class OpDefBuilder {
// (by convention only using capital letters for attrs that can be inferred)
// <type> can be:
// "string", "int", "float", "bool", "type", "shape", or "tensor"
// "numbertype", "realnumbertype", "quantizedtype", "{int32,int64}"
// "numbertype", "realnumbertype", "quantizedtype"
// (meaning "type" with a restriction on valid values)
// "{int32,int64}" or {realnumbertype,quantizedtype,string}"
// (meaning "type" with a restriction containing unions of value types)
// "{\"foo\", \"bar\n baz\"}", or "{'foo', 'bar\n baz'}"
// (meaning "string" with a restriction on valid values)
// "list(string)", ..., "list(tensor)", "list(numbertype)", ...

View File

@ -125,13 +125,27 @@ TEST_F(OpDefBuilderTest, AttrWithRestrictions) {
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
"DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, "
"DT_QINT32] } } }");
ExpectSuccess(
b().Attr("a:{numbertype, variant}"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
"DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, "
"DT_QINT32, DT_VARIANT] } } }");
ExpectSuccess(b().Attr("a:realnumbertype"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, "
"DT_INT16, DT_UINT16, DT_INT8] } } }");
ExpectSuccess(b().Attr("a:{realnumbertype, variant , string, }"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, "
"DT_INT16, DT_UINT16, DT_INT8, DT_VARIANT, DT_STRING] } } }");
ExpectSuccess(b().Attr("a:quantizedtype"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16]} } }");
ExpectSuccess(b().Attr("a:{quantizedtype ,string}"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16, "
"DT_STRING]} } }");
ExpectSuccess(b().Attr("a:{string,int32}"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_STRING, DT_INT32] } } }");
@ -202,6 +216,11 @@ TEST_F(OpDefBuilderTest, AttrListOfRestricted) {
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
"[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
"DT_UINT16, DT_INT8, DT_HALF] } } }");
ExpectSuccess(
b().Attr("a:list({realnumbertype, variant})"),
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
"[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
"DT_UINT16, DT_INT8, DT_HALF, DT_VARIANT] } } }");
ExpectSuccess(
b().Attr("a:list(quantizedtype)"),
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "

View File

@ -24,6 +24,12 @@ limitations under the License.
namespace tensorflow {
std::unordered_set<string>* UnaryVariantOpRegistry::PersistentStringStorage() {
static std::unordered_set<string>* string_storage =
new std::unordered_set<string>();
return string_storage;
}
// static
UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() {
static UnaryVariantOpRegistry* global_unary_variant_op_registry =
@ -32,7 +38,7 @@ UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() {
}
UnaryVariantOpRegistry::VariantShapeFn* UnaryVariantOpRegistry::GetShapeFn(
const string& type_name) {
StringPiece type_name) {
auto found = shape_fns.find(type_name);
if (found == shape_fns.end()) return nullptr;
return &found->second;
@ -45,7 +51,8 @@ void UnaryVariantOpRegistry::RegisterShapeFn(const string& type_name,
CHECK_EQ(existing, nullptr)
<< "Unary VariantShapeFn for type_name: " << type_name
<< " already registered";
shape_fns.insert(std::pair<string, VariantShapeFn>(type_name, shape_fn));
shape_fns.insert(std::pair<StringPiece, VariantShapeFn>(
GetPersistentStringPiece(type_name), shape_fn));
}
Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) {
@ -65,8 +72,29 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) {
return (*shape_fn)(v, shape);
}
// Add some basic registrations for use by others, e.g., for testing.
namespace {
template <typename T>
Status ScalarShape(const T&, TensorShape* shape) {
*shape = TensorShape({});
return Status::OK();
}
} // namespace
#define REGISTER_VARIANT_SHAPE_TYPE(T) \
REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, TF_STR(T), ScalarShape<T>);
// No encode/shape registered for std::complex<> and Eigen::half
// objects yet.
REGISTER_VARIANT_SHAPE_TYPE(int);
REGISTER_VARIANT_SHAPE_TYPE(float);
REGISTER_VARIANT_SHAPE_TYPE(bool);
REGISTER_VARIANT_SHAPE_TYPE(double);
#undef REGISTER_VARIANT_SHAPE_TYPE
UnaryVariantOpRegistry::VariantDecodeFn* UnaryVariantOpRegistry::GetDecodeFn(
const string& type_name) {
StringPiece type_name) {
auto found = decode_fns.find(type_name);
if (found == decode_fns.end()) return nullptr;
return &found->second;
@ -79,7 +107,8 @@ void UnaryVariantOpRegistry::RegisterDecodeFn(
CHECK_EQ(existing, nullptr)
<< "Unary VariantDecodeFn for type_name: " << type_name
<< " already registered";
decode_fns.insert(std::pair<string, VariantDecodeFn>(type_name, decode_fn));
decode_fns.insert(std::pair<StringPiece, VariantDecodeFn>(
GetPersistentStringPiece(type_name), decode_fn));
}
bool DecodeUnaryVariant(Variant* variant) {
@ -103,13 +132,6 @@ bool DecodeUnaryVariant(Variant* variant) {
// Add some basic registrations for use by others, e.g., for testing.
namespace {
string MaybeRemoveTFPrefix(const StringPiece& str) {
return str.starts_with("::tensorflow::") ? str.substr(14).ToString()
: str.ToString();
}
} // namespace
#define REGISTER_VARIANT_DECODE_TYPE(T) \
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, TF_STR(T));
@ -122,30 +144,31 @@ REGISTER_VARIANT_DECODE_TYPE(double);
#undef REGISTER_VARIANT_DECODE_TYPE
// Special casing ZerosLikeFn per device.
UnaryVariantOpRegistry::VariantZerosLikeFn*
UnaryVariantOpRegistry::GetZerosLikeFn(const string& device,
const string& type_name) {
auto found = zeros_like_fns.find(std::make_pair(device, type_name));
if (found == zeros_like_fns.end()) return nullptr;
// Special casing UnaryOpFn per op and per device.
UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn(
VariantUnaryOp op, StringPiece device, StringPiece type_name) {
auto found = unary_op_fns.find(std::make_tuple(op, device, type_name));
if (found == unary_op_fns.end()) return nullptr;
return &found->second;
}
void UnaryVariantOpRegistry::RegisterZerosLikeFn(
const string& device, const string& type_name,
const VariantZerosLikeFn& zeros_like_fn) {
CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantZerosLike";
VariantZerosLikeFn* existing = GetZerosLikeFn(device, type_name);
void UnaryVariantOpRegistry::RegisterUnaryOpFn(
VariantUnaryOp op, const string& device, const string& type_name,
const VariantUnaryOpFn& unary_op_fn) {
CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantUnaryOp";
VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_name);
CHECK_EQ(existing, nullptr)
<< "Unary VariantZerosLikeFn for type_name: " << type_name
<< "Unary VariantUnaryOpFn for type_name: " << type_name
<< " already registered for device type: " << device;
zeros_like_fns.insert(
std::pair<std::pair<string, string>, VariantZerosLikeFn>(
std::make_pair(device, type_name), zeros_like_fn));
unary_op_fns.insert(
std::pair<std::tuple<VariantUnaryOp, StringPiece, StringPiece>,
VariantUnaryOpFn>(
std::make_tuple(op, GetPersistentStringPiece(device),
GetPersistentStringPiece(type_name)),
unary_op_fn));
}
namespace {
template <typename T>
Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t,
T* t_out) {
@ -154,9 +177,10 @@ Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t,
}
} // namespace
#define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \
REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION( \
DEVICE_CPU, T, TF_STR(T), ZerosLikeVariantPrimitiveType<T>);
#define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \
DEVICE_CPU, T, TF_STR(T), \
ZerosLikeVariantPrimitiveType<T>);
// No zeros_like registered for std::complex<> or Eigen::half objects yet.
REGISTER_VARIANT_ZEROS_LIKE_TYPE(int);
@ -166,4 +190,51 @@ REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool);
#undef REGISTER_VARIANT_ZEROS_LIKE_TYPE
// Special casing BinaryOpFn per op and per device.
UnaryVariantOpRegistry::VariantBinaryOpFn*
UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
StringPiece type_name) {
auto found = binary_op_fns.find(std::make_tuple(op, device, type_name));
if (found == binary_op_fns.end()) return nullptr;
return &found->second;
}
void UnaryVariantOpRegistry::RegisterBinaryOpFn(
VariantBinaryOp op, const string& device, const string& type_name,
const VariantBinaryOpFn& add_fn) {
CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantBinaryOp";
VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_name);
CHECK_EQ(existing, nullptr)
<< "Unary VariantBinaryOpFn for type_name: " << type_name
<< " already registered for device type: " << device;
binary_op_fns.insert(
std::pair<std::tuple<VariantBinaryOp, StringPiece, StringPiece>,
VariantBinaryOpFn>(
std::make_tuple(op, GetPersistentStringPiece(device),
GetPersistentStringPiece(type_name)),
add_fn));
}
namespace {
template <typename T>
Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b,
T* out) {
*out = a + b;
return Status::OK();
}
} // namespace
#define REGISTER_VARIANT_ADD_TYPE(T) \
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \
T, TF_STR(T), \
AddVariantPrimitiveType<T>);
// No add registered for std::complex<> or Eigen::half objects yet.
REGISTER_VARIANT_ADD_TYPE(int);
REGISTER_VARIANT_ADD_TYPE(float);
REGISTER_VARIANT_ADD_TYPE(double);
REGISTER_VARIANT_ADD_TYPE(bool);
#undef REGISTER_VARIANT_ADD_TYPE
} // namespace tensorflow

View File

@ -17,11 +17,13 @@ limitations under the License.
#define TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_
#include <string>
#include <unordered_set>
#include <vector>
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/lib/hash/hash.h"
namespace tensorflow {
@ -30,49 +32,110 @@ class OpKernelContext;
// for different variant types. To be used by ShapeOp, RankOp, and
// SizeOp, decoding, etc.
enum VariantUnaryOp {
INVALID_VARIANT_UNARY_OP = 0,
ZEROS_LIKE_VARIANT_UNARY_OP = 1,
};
enum VariantBinaryOp {
INVALID_VARIANT_BINARY_OP = 0,
ADD_VARIANT_BINARY_OP = 1,
};
class UnaryVariantOpRegistry {
public:
typedef std::function<Status(const Variant& v, TensorShape*)> VariantShapeFn;
typedef std::function<bool(Variant*)> VariantDecodeFn;
typedef std::function<Status(OpKernelContext*, const Variant&, Variant*)>
VariantZerosLikeFn;
VariantUnaryOpFn;
typedef std::function<Status(OpKernelContext*, const Variant&, const Variant&,
Variant*)>
VariantBinaryOpFn;
// Add a shape lookup function to the registry.
void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn);
// Returns nullptr if no shape function was found for the given TypeName.
VariantShapeFn* GetShapeFn(const string& type_name);
VariantShapeFn* GetShapeFn(StringPiece type_name);
// Add a decode function to the registry.
void RegisterDecodeFn(const string& type_name,
const VariantDecodeFn& decode_fn);
// Returns nullptr if no decode function was found for the given TypeName.
VariantDecodeFn* GetDecodeFn(const string& type_name);
VariantDecodeFn* GetDecodeFn(StringPiece type_name);
// Add a zeros-like function to the registry.
void RegisterZerosLikeFn(const string& device, const string& type_name,
const VariantZerosLikeFn& zeros_like_fn);
// Add a unary op function to the registry.
void RegisterUnaryOpFn(VariantUnaryOp op, const string& device,
const string& type_name,
const VariantUnaryOpFn& unary_op_fn);
// Returns nullptr if no zeros-like function was found for the given
// device and TypeName.
VariantZerosLikeFn* GetZerosLikeFn(const string& device,
const string& type_name);
// Returns nullptr if no unary op function was found for the given
// op, device, and TypeName.
VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device,
StringPiece type_name);
// Add a binary op function to the registry.
void RegisterBinaryOpFn(VariantBinaryOp op, const string& device,
const string& type_name,
const VariantBinaryOpFn& add_fn);
// Returns nullptr if no binary op function was found for the given
// op, device and TypeName.
VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
StringPiece type_name);
// Get a pointer to a global UnaryVariantOpRegistry object
static UnaryVariantOpRegistry* Global();
// Get a pointer to a global persistent string storage object.
// ISO/IEC C++ working draft N4296 clarifies that insertion into an
// std::unordered_set does not invalidate memory locations of
// *values* inside the set (though it may invalidate existing
// iterators). In other words, one may safely point a StringPiece to
// a value in the set without that StringPiece being invalidated by
// future insertions.
static std::unordered_set<string>* PersistentStringStorage();
private:
std::unordered_map<string, VariantShapeFn> shape_fns;
std::unordered_map<string, VariantDecodeFn> decode_fns;
// Map std::pair<device, type_name> to function.
struct PairHash {
template <typename T, typename U>
std::size_t operator()(const std::pair<T, U>& x) const {
return std::hash<T>()(x.first) ^ std::hash<U>()(x.second);
std::unordered_map<StringPiece, VariantShapeFn, StringPiece::Hasher>
shape_fns;
std::unordered_map<StringPiece, VariantDecodeFn, StringPiece::Hasher>
decode_fns;
// Map std::tuple<Op, device, type_name> to function.
struct TupleHash {
template <typename Op>
std::size_t operator()(
const std::tuple<Op, StringPiece, StringPiece>& x) const {
// The hash of an enum is just its value as a std::size_t.
std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
StringPiece::Hasher sp_hasher;
ret = Hash64Combine(ret, sp_hasher(std::get<1>(x)));
ret = Hash64Combine(ret, sp_hasher(std::get<2>(x)));
return ret;
}
};
std::unordered_map<std::pair<string, string>, VariantZerosLikeFn, PairHash>
zeros_like_fns;
std::unordered_map<std::tuple<VariantUnaryOp, StringPiece, StringPiece>,
VariantUnaryOpFn, TupleHash>
unary_op_fns;
std::unordered_map<std::tuple<VariantBinaryOp, StringPiece, StringPiece>,
VariantBinaryOpFn, TupleHash>
binary_op_fns;
// Find or insert a string into a persistent string storage
// container; return the StringPiece pointing to the permanent
// string location.
static StringPiece GetPersistentStringPiece(const string& str) {
const auto string_storage = PersistentStringStorage();
auto found = string_storage->find(str);
if (found == string_storage->end()) {
auto inserted = string_storage->insert(str);
return StringPiece(*inserted.first);
} else {
return StringPiece(*found);
}
}
};
// Gets a TensorShape from a Tensor containing a scalar Variant.
@ -94,26 +157,57 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape);
//
bool DecodeUnaryVariant(Variant* variant);
// Sets *z_out = zeros_like(v). The variant v must have a registered
// ZerosLike function for the given Device. Returns an Internal error
// if v does not have a registered zeros_like function for this device, or if
// ZerosLike fails.
// Sets *v_out = unary_op(v). The variant v must have a registered
// UnaryOp function for the given Device. Returns an Internal error
// if v does not have a registered unary_op function for this device, or if
// UnaryOp fails.
//
// REQUIRES:
// v_out is not null.
//
template <typename Device>
Status CreateZerosLikeVariant(OpKernelContext* ctx, const Variant& v,
Variant* v_out) {
Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
Variant* v_out) {
const string& device = DeviceName<Device>::value;
UnaryVariantOpRegistry::VariantZerosLikeFn* zeros_like_fn =
UnaryVariantOpRegistry::Global()->GetZerosLikeFn(device, v.TypeName());
if (zeros_like_fn == nullptr) {
UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeName());
if (unary_op_fn == nullptr) {
return errors::Internal(
"No unary variant zeros_like function found for Variant type_name: ",
v.TypeName(), " for device type: ", device);
"No unary variant unary_op function found for unary variant op enum: ",
op, " Variant type_name: ", v.TypeName(), " for device type: ", device);
}
return (*zeros_like_fn)(ctx, v, v_out);
return (*unary_op_fn)(ctx, v, v_out);
}
// Sets *out = binary_op(a, b). The variants a and b must be the same type
// and have a registered binary_op function for the given Device. Returns an
// Internal error if a and b are not the same type_name or if
// if a does not have a registered op function for this device, or if
// BinaryOp fails.
//
// REQUIRES:
// out is not null.
//
template <typename Device>
Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
const Variant& a, const Variant& b, Variant* out) {
if (a.TypeName() != b.TypeName()) {
return errors::Internal(
"BianryOpVariants: Variants a and b have different "
"type names: '",
a.TypeName(), "' vs. '", b.TypeName(), "'");
}
const string& device = DeviceName<Device>::value;
UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeName());
if (binary_op_fn == nullptr) {
return errors::Internal(
"No unary variant binary_op function found for binary variant op "
"enum: ",
op, " Variant type_name: '", a.TypeName(),
"' for device type: ", device);
}
return (*binary_op_fn)(ctx, a, b, out);
}
namespace variant_op_registry_fn_registration {
@ -165,30 +259,65 @@ class UnaryVariantDecodeRegistration {
};
template <typename T>
class UnaryVariantZerosLikeRegistration {
class UnaryVariantUnaryOpRegistration {
typedef std::function<Status(OpKernelContext* ctx, const T& t, T* t_out)>
LocalVariantZerosLikeFn;
LocalVariantUnaryOpFn;
public:
UnaryVariantZerosLikeRegistration(
const string& device, const string& type_name,
const LocalVariantZerosLikeFn& zeros_like_fn) {
auto wrapped_fn = [type_name, zeros_like_fn](OpKernelContext* ctx,
const Variant& v,
Variant* v_out) -> Status {
UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device,
const string& type_name,
const LocalVariantUnaryOpFn& unary_op_fn) {
auto wrapped_fn = [type_name, unary_op_fn](OpKernelContext* ctx,
const Variant& v,
Variant* v_out) -> Status {
CHECK_NOTNULL(v_out);
*v_out = T();
if (v.get<T>() == nullptr) {
return errors::Internal(
"VariantZerosLikeFn: Could not access object, type_name: ",
"VariantUnaryOpFn: Could not access object, type_name: ",
type_name);
}
const T& t = *v.get<T>();
T* t_out = v_out->get<T>();
return zeros_like_fn(ctx, t, t_out);
return unary_op_fn(ctx, t, t_out);
};
UnaryVariantOpRegistry::Global()->RegisterZerosLikeFn(device, type_name,
wrapped_fn);
UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(op, device, type_name,
wrapped_fn);
}
};
template <typename T>
class UnaryVariantBinaryOpRegistration {
typedef std::function<Status(OpKernelContext* ctx, const T& a, const T& b,
T* out)>
LocalVariantBinaryOpFn;
public:
UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device,
const string& type_name,
const LocalVariantBinaryOpFn& binary_op_fn) {
auto wrapped_fn = [type_name, binary_op_fn](
OpKernelContext* ctx, const Variant& a,
const Variant& b, Variant* out) -> Status {
CHECK_NOTNULL(out);
*out = T();
if (a.get<T>() == nullptr) {
return errors::Internal(
"VariantBinaryOpFn: Could not access object 'a', type_name: ",
type_name);
}
if (b.get<T>() == nullptr) {
return errors::Internal(
"VariantBinaryOpFn: Could not access object 'b', type_name: ",
type_name);
}
const T& t_a = *a.get<T>();
const T& t_b = *b.get<T>();
T* t_out = out->get<T>();
return binary_op_fn(ctx, t_a, t_b, t_out);
};
UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(op, device, type_name,
wrapped_fn);
}
};
@ -223,25 +352,47 @@ class UnaryVariantZerosLikeRegistration {
T> \
register_unary_variant_op_decoder_fn_##ctr(type_name)
// Register a unary zeros_like variant function with the signature:
// Status ZerosLikeFn(OpKernelContext* ctx, const T& t, T* t_out);
// to Variants having TypeName type_name, for device string device.
#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION(device, T, type_name, \
zeros_like_function) \
REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ_HELPER( \
__COUNTER__, device, T, type_name, zeros_like_function)
// Register a unary unary_op variant function with the signature:
// Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out);
// to Variants having TypeName type_name, for device string device,
// for UnaryVariantOp enum op.
#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, type_name, \
unary_op_function) \
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
__COUNTER__, op, device, T, type_name, unary_op_function)
#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ_HELPER( \
ctr, device, T, type_name, zeros_like_function) \
REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ(ctr, device, T, type_name, \
zeros_like_function)
#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
ctr, op, device, T, type_name, unary_op_function) \
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, type_name, \
unary_op_function)
#define REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION_UNIQ( \
ctr, device, T, type_name, zeros_like_function) \
static variant_op_registry_fn_registration:: \
UnaryVariantZerosLikeRegistration<T> \
register_unary_variant_op_decoder_fn_##ctr(device, type_name, \
zeros_like_function)
#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \
ctr, op, device, T, type_name, unary_op_function) \
static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \
T> \
register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \
unary_op_function)
// Register a binary_op variant function with the signature:
// Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out);
// to Variants having TypeName type_name, for device string device,
// for BinaryVariantOp enum OP.
#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, type_name, \
binary_op_function) \
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
__COUNTER__, op, device, T, type_name, binary_op_function)
#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
ctr, op, device, T, type_name, binary_op_function) \
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
ctr, op, device, T, type_name, binary_op_function)
#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
ctr, op, device, T, type_name, binary_op_function) \
static variant_op_registry_fn_registration:: \
UnaryVariantBinaryOpRegistration<T> \
register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \
binary_op_function)
} // end namespace tensorflow

View File

@ -50,7 +50,7 @@ struct VariantValue {
if (v.early_exit) {
return errors::InvalidArgument("early exit zeros_like!");
}
v_out->zeros_like_set = 1; // CPU
v_out->value = 1; // CPU
return Status::OK();
}
static Status GPUZerosLikeFn(OpKernelContext* ctx, const VariantValue& v,
@ -58,11 +58,27 @@ struct VariantValue {
if (v.early_exit) {
return errors::InvalidArgument("early exit zeros_like!");
}
v_out->zeros_like_set = 2; // GPU
v_out->value = 2; // GPU
return Status::OK();
}
static Status CPUAddFn(OpKernelContext* ctx, const VariantValue& a,
const VariantValue& b, VariantValue* out) {
if (a.early_exit) {
return errors::InvalidArgument("early exit add!");
}
out->value = a.value + b.value; // CPU
return Status::OK();
}
static Status GPUAddFn(OpKernelContext* ctx, const VariantValue& a,
const VariantValue& b, VariantValue* out) {
if (a.early_exit) {
return errors::InvalidArgument("early exit add!");
}
out->value = -(a.value + b.value); // GPU
return Status::OK();
}
bool early_exit;
int zeros_like_set;
int value;
};
REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue",
@ -70,13 +86,23 @@ REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue",
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue");
REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION(DEVICE_CPU, VariantValue,
"TEST VariantValue",
VariantValue::CPUZerosLikeFn);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_CPU, VariantValue,
"TEST VariantValue",
VariantValue::CPUZerosLikeFn);
REGISTER_UNARY_VARIANT_ZEROS_LIKE_FUNCTION(DEVICE_GPU, VariantValue,
"TEST VariantValue",
VariantValue::GPUZerosLikeFn);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_GPU, VariantValue,
"TEST VariantValue",
VariantValue::GPUZerosLikeFn);
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
VariantValue, "TEST VariantValue",
VariantValue::CPUAddFn);
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
VariantValue, "TEST VariantValue",
VariantValue::GPUAddFn);
} // namespace
@ -104,8 +130,9 @@ TEST(VariantOpShapeRegistryTest, TestBasic) {
TEST(VariantOpShapeRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantShapeFn f;
registry.RegisterShapeFn("fjfjfj", f);
EXPECT_DEATH(registry.RegisterShapeFn("fjfjfj", f),
string kTypeName = "fjfjfj";
registry.RegisterShapeFn(kTypeName, f);
EXPECT_DEATH(registry.RegisterShapeFn(kTypeName, f),
"fjfjfj already registered");
}
@ -133,71 +160,146 @@ TEST(VariantOpDecodeRegistryTest, TestBasic) {
TEST(VariantOpDecodeRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantDecodeFn f;
registry.RegisterDecodeFn("fjfjfj", f);
EXPECT_DEATH(registry.RegisterDecodeFn("fjfjfj", f),
string kTypeName = "fjfjfj";
registry.RegisterDecodeFn(kTypeName, f);
EXPECT_DEATH(registry.RegisterDecodeFn(kTypeName, f),
"fjfjfj already registered");
}
TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetZerosLikeFn(
DEVICE_CPU, "YOU SHALL NOT PASS"),
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 0 /* zeros_like_set */};
VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
Variant v = vv_early_exit;
Variant v_out = VariantValue();
OpKernelContext* null_context_pointer = nullptr;
Status s0 =
CreateZerosLikeVariant<CPUDevice>(null_context_pointer, v, &v_out);
Status s0 = UnaryOpVariant<CPUDevice>(null_context_pointer,
ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out);
EXPECT_FALSE(s0.ok());
EXPECT_TRUE(
StringPiece(s0.error_message()).contains("early exit zeros_like"));
VariantValue vv_ok{false /* early_exit */, 0 /* zeros_like_set */};
VariantValue vv_ok{false /* early_exit */, 0 /* value */};
v = vv_ok;
TF_EXPECT_OK(
CreateZerosLikeVariant<CPUDevice>(null_context_pointer, v, &v_out));
TF_EXPECT_OK(UnaryOpVariant<CPUDevice>(
null_context_pointer, ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out));
VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
EXPECT_EQ(vv_out->zeros_like_set, 1); // CPU
EXPECT_EQ(vv_out->value, 1); // CPU
}
#if GOOGLE_CUDA
TEST(VariantOpZerosLikeRegistryTest, TestBasicGPU) {
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetZerosLikeFn(
DEVICE_GPU, "YOU SHALL NOT PASS"),
TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 0 /* zeros_like_set */};
VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
Variant v = vv_early_exit;
Variant v_out = VariantValue();
OpKernelContext* null_context_pointer = nullptr;
Status s0 =
CreateZerosLikeVariant<GPUDevice>(null_context_pointer, v, &v_out);
Status s0 = UnaryOpVariant<GPUDevice>(null_context_pointer,
ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out);
EXPECT_FALSE(s0.ok());
EXPECT_TRUE(
StringPiece(s0.error_message()).contains("early exit zeros_like"));
VariantValue vv_ok{false /* early_exit */, 0 /* zeros_like_set */};
VariantValue vv_ok{false /* early_exit */, 0 /* value */};
v = vv_ok;
TF_EXPECT_OK(
CreateZerosLikeVariant<GPUDevice>(null_context_pointer, v, &v_out));
TF_EXPECT_OK(UnaryOpVariant<GPUDevice>(
null_context_pointer, ZEROS_LIKE_VARIANT_UNARY_OP, v, &v_out));
VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
EXPECT_EQ(vv_out->zeros_like_set, 2); // GPU
EXPECT_EQ(vv_out->value, 2); // GPU
}
#endif // GOOGLE_CUDA
TEST(VariantOpZerosLikeRegistryTest, TestDuplicate) {
TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantZerosLikeFn f;
UnaryVariantOpRegistry::VariantUnaryOpFn f;
string kTypeName = "fjfjfj";
registry.RegisterZerosLikeFn(DEVICE_CPU, "fjfjfj", f);
EXPECT_DEATH(registry.RegisterZerosLikeFn(DEVICE_CPU, "fjfjfj", f),
registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, kTypeName,
f);
EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_CPU, kTypeName, f),
"fjfjfj already registered");
registry.RegisterZerosLikeFn(DEVICE_GPU, "fjfjfj", f);
EXPECT_DEATH(registry.RegisterZerosLikeFn(DEVICE_GPU, "fjfjfj", f),
registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, kTypeName,
f);
EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_GPU, kTypeName, f),
"fjfjfj already registered");
}
TEST(VariantOpAddRegistryTest, TestBasicCPU) {
return;
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
ADD_VARIANT_BINARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
VariantValue vv_other{true /* early_exit */, 4 /* value */};
Variant v_a = vv_early_exit;
Variant v_b = vv_other;
Variant v_out = VariantValue();
OpKernelContext* null_context_pointer = nullptr;
Status s0 = BinaryOpVariants<CPUDevice>(
null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out);
EXPECT_FALSE(s0.ok());
EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit add"));
VariantValue vv_ok{false /* early_exit */, 3 /* value */};
v_a = vv_ok;
TF_EXPECT_OK(BinaryOpVariants<CPUDevice>(
null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out));
VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
EXPECT_EQ(vv_out->value, 7); // CPU
}
#if GOOGLE_CUDA
TEST(VariantOpAddRegistryTest, TestBasicGPU) {
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
ADD_VARIANT_BINARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
VariantValue vv_other{true /* early_exit */, 4 /* value */};
Variant v_a = vv_early_exit;
Variant v_b = vv_other;
Variant v_out = VariantValue();
OpKernelContext* null_context_pointer = nullptr;
Status s0 = BinaryOpVariants<GPUDevice>(
null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out);
EXPECT_FALSE(s0.ok());
EXPECT_TRUE(StringPiece(s0.error_message()).contains("early exit add"));
VariantValue vv_ok{false /* early_exit */, 3 /* value */};
v_a = vv_ok;
TF_EXPECT_OK(BinaryOpVariants<GPUDevice>(
null_context_pointer, ADD_VARIANT_BINARY_OP, v_a, v_b, &v_out));
VariantValue* vv_out = CHECK_NOTNULL(v_out.get<VariantValue>());
EXPECT_EQ(vv_out->value, -7); // GPU
}
#endif // GOOGLE_CUDA
TEST(VariantOpAddRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantBinaryOpFn f;
string kTypeName = "fjfjfj";
registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeName, f);
EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
kTypeName, f),
"fjfjfj already registered");
registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeName, f);
EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
kTypeName, f),
"fjfjfj already registered");
}

View File

@ -34,8 +34,8 @@ TensorId ParseTensorName(StringPiece name) {
// whole name string forms the first part of the tensor name.
const char* base = name.data();
const char* p = base + name.size() - 1;
int index = 0;
int mul = 1;
unsigned int index = 0;
unsigned int mul = 1;
while (p > base && (*p >= '0' && *p <= '9')) {
index += ((*p - '0') * mul);
mul *= 10;

View File

@ -24,6 +24,9 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_op_registry.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/logging.h"
@ -33,7 +36,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename T>
class AddNOp : public OpKernel {
@ -150,6 +153,65 @@ class AddNOp : public OpKernel {
}
};
template <typename Device>
class AddNOp<Device, Variant> : public OpKernel {
public:
explicit AddNOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* ctx) override {
if (!ctx->ValidateInputsAreSameShape(this)) return;
const Tensor& input0 = ctx->input(0);
const int num = ctx->num_inputs();
if (num == 1) {
ctx->set_output(0, input0);
return;
}
for (int i = 0; i < num; ++i) {
// Step 1: ensure unary variants.
OP_REQUIRES(
ctx, ctx->input(i).dims() == 0,
errors::InvalidArgument(
"AddN of non-scalar Tensor with dtype=DT_VARIANT is not "
"supported; inputs[",
i, " has shape: ", ctx->input(i).shape().DebugString(), "."));
}
TensorShape common_shape;
OP_REQUIRES_OK(ctx, GetUnaryVariantShape(ctx->input(0), &common_shape));
// Step 2: access all variants and ensure shapes match.
for (int i = 1; i < num; ++i) {
TensorShape check_shape;
OP_REQUIRES_OK(ctx, GetUnaryVariantShape(ctx->input(i), &check_shape));
OP_REQUIRES(ctx, common_shape == check_shape,
errors::InvalidArgument(
"AddN of Variants of differing shapes; inputs[0] shape: ",
common_shape.DebugString(), ", inputs[", i,
"] shape: ", check_shape.DebugString()));
}
// Step 3: attempt to add using
// BinaryOpVariants(ADD_VARIANT_BINARY_OP, ...)
// For the output create a default-constructed variant object.
// TODO(ebrevdo): Perform summation in a tree-structure.
Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
Variant* v_out = &(out.scalar<Variant>()());
OP_REQUIRES_OK(
ctx, BinaryOpVariants<Device>(
ctx, ADD_VARIANT_BINARY_OP, ctx->input(0).scalar<Variant>()(),
ctx->input(1).scalar<Variant>()(), v_out));
for (int i = 2; i < num; ++i) {
const Variant tmp = std::move(*v_out);
const Variant& inp = ctx->input(i).scalar<Variant>()();
OP_REQUIRES_OK(ctx, BinaryOpVariants<Device>(ctx, ADD_VARIANT_BINARY_OP,
inp, tmp, v_out));
}
ctx->set_output(0, out);
}
};
#define REGISTER_ADDN(type, dev) \
REGISTER_KERNEL_BUILDER( \
Name("AddN").Device(DEVICE_##dev).TypeConstraint<type>("T"), \
@ -158,6 +220,8 @@ class AddNOp : public OpKernel {
#define REGISTER_ADDN_CPU(type) REGISTER_ADDN(type, CPU)
TF_CALL_NUMBER_TYPES(REGISTER_ADDN_CPU);
REGISTER_ADDN_CPU(Variant);
#undef REGISTER_ADDN_CPU
#if GOOGLE_CUDA
@ -176,6 +240,16 @@ REGISTER_KERNEL_BUILDER(Name("AddN")
.HostMemory("inputs")
.HostMemory("sum"),
AddNOp<CPUDevice, int32>);
// TODO(ebrevdo): Once rendezvous has been properly set up for
// Variants, we'll no longer need a HostMemory attribute for this case.
REGISTER_KERNEL_BUILDER(Name("AddN")
.Device(DEVICE_GPU)
.TypeConstraint<Variant>("T")
.HostMemory("inputs")
.HostMemory("sum"),
AddNOp<GPUDevice, Variant>);
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
@ -191,7 +265,7 @@ REGISTER_KERNEL_BUILDER(Name("AddN")
.HostMemory("inputs")
.HostMemory("sum"),
AddNOp<CPUDevice, int32>);
#endif // TENSORFLOW_USE_SYCL
#endif // TENSORFLOW_USE_SYCL
#undef REGISTER_ADDN

View File

@ -279,13 +279,15 @@ class ZerosLikeOp : public OpKernel {
const Tensor& input = ctx->input(0);
const Device& d = ctx->eigen_device<Device>();
if (std::is_same<T, Variant>::value) {
OP_REQUIRES(ctx, input.dims() == 0,
errors::InvalidArgument(
"ZerosLike of non-unary Variant not supported."));
OP_REQUIRES(
ctx, input.dims() == 0,
errors::InvalidArgument("ZerosLike non-scalar Tensor with "
"dtype=DT_VARIANT is not supported."));
const Variant& v = input.scalar<Variant>()();
Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
Variant* out_v = &(out.scalar<Variant>()());
OP_REQUIRES_OK(ctx, CreateZerosLikeVariant<Device>(ctx, v, out_v));
OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(
ctx, ZEROS_LIKE_VARIANT_UNARY_OP, v, out_v));
ctx->set_output(0, out);
} else {
Tensor* out = nullptr;

View File

@ -292,7 +292,8 @@ class RemoteCallOp : public AsyncOpKernel {
OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
AttrValueMap attr_values = func_->attr();
AttrValue v;
v.set_s(target->scalar<string>()());
const string& target_device = target->scalar<string>()();
v.set_s(target_device);
AddAttr("_target", v, &attr_values);
FunctionLibraryRuntime* lib = ctx->function_library();
@ -310,6 +311,11 @@ class RemoteCallOp : public AsyncOpKernel {
FunctionLibraryRuntime::Options opts;
opts.step_id = ctx->step_id();
opts.runner = ctx->runner();
opts.source_device = lib->device()->name();
if (opts.source_device != target_device) {
opts.remote_execution = true;
}
opts.rendezvous = ctx->rendezvous();
std::vector<Tensor> args;
args.reserve(arguments.size());
for (const Tensor& argument : arguments) {
@ -334,10 +340,13 @@ class RemoteCallOp : public AsyncOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(RemoteCallOp);
};
REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_CPU), RemoteCallOp);
REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_GPU), RemoteCallOp);
REGISTER_KERNEL_BUILDER(
Name("RemoteCall").Device(DEVICE_CPU).HostMemory("target"), RemoteCallOp);
REGISTER_KERNEL_BUILDER(
Name("RemoteCall").Device(DEVICE_GPU).HostMemory("target"), RemoteCallOp);
#if TENSORFLOW_USE_SYCL
REGISTER_KERNEL_BUILDER(Name("RemoteCall").Device(DEVICE_SYCL), RemoteCallOp);
REGISTER_KERNEL_BUILDER(
Name("RemoteCall").Device(DEVICE_SYCL).HostMemory("target"), RemoteCallOp);
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow

View File

@ -920,6 +920,13 @@ class MaxPoolingGradWithArgmaxOp : public OpKernel {
public:
explicit MaxPoolingGradWithArgmaxOp(OpKernelConstruction* context)
: OpKernel(context) {
string data_format_str;
auto status = context->GetAttr("data_format", &data_format_str);
if (status.ok()) {
OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
errors::InvalidArgument("Invalid data format"));
}
OP_REQUIRES_OK(context, context->GetAttr("ksize", &ksize_));
OP_REQUIRES(context, ksize_.size() == 4,
errors::InvalidArgument("Sliding window ksize field must "
@ -959,6 +966,7 @@ class MaxPoolingGradWithArgmaxOp : public OpKernel {
std::vector<int32> ksize_;
std::vector<int32> stride_;
Padding padding_;
TensorFormat data_format_;
};
template <typename Device, typename T>
@ -1051,17 +1059,36 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
TensorShape out_shape =
ShapeFromFormat(data_format_, params.tensor_in_batch, params.out_height,
params.out_width, params.depth);
if (use_dnn_ && data_format_ == FORMAT_NCHW) {
// Assuming qint8 <--> NCHW_VECT_C (int8x4) here.
constexpr bool is_int8x4 = std::is_same<T, qint8>::value;
OP_REQUIRES(context, (is_int8x4 == (data_format_ == FORMAT_NCHW_VECT_C)),
errors::InvalidArgument(
"qint8 should be used with data_format NCHW_VECT_C."));
// These is_int8x4 checks avoid linker errors for missing qint8 kernels.
if (!is_int8x4 && use_dnn_ && data_format_ == FORMAT_NCHW) {
DnnPoolingOp<T>::Compute(
context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize_,
stride_, padding_, data_format_, tensor_in, out_shape);
} else {
CHECK(data_format_ == FORMAT_NHWC)
<< "Non-Cudnn MaxPool only supports NHWC format";
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in,
output);
if (is_int8x4) {
LaunchMaxPoolingNoMask_NCHW_VECT_C<Device>::launch(context, params,
tensor_in, output);
} else if (data_format_ == FORMAT_NHWC) {
LaunchMaxPoolingNoMask<Device, T>::launch(context, params, tensor_in,
output);
} else {
LOG(FATAL) << "MaxPool currently only supports the following (layout, "
"type) combinations: (NHWC, non-qint8), "
"(NCHW, non-qint8) or (NCHW_VECT_C, qint8). The "
"requested combination ("
<< ToString(data_format_) << ", "
<< DataTypeString(DataTypeToEnum<T>::v())
<< ") is not supported.";
}
}
}
@ -1346,6 +1373,26 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_MAX_POOL_KERNELS);
.TypeConstraint<int64>("Targmax"), \
MaxPoolingGradGradWithArgmaxOp<GPUDevice, T>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_ONLY_POOL_KERNELS);
REGISTER_KERNEL_BUILDER(
Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<qint8>("T"),
MaxPoolingNoMaskOp<GPUDevice, qint8>);
REGISTER_KERNEL_BUILDER(Name("MaxPoolV2")
.Device(DEVICE_GPU)
.HostMemory("ksize")
.HostMemory("strides")
.TypeConstraint<qint8>("T"),
MaxPoolingV2Op<GPUDevice, qint8>);
REGISTER_KERNEL_BUILDER(Name("MaxPoolV2")
.Device(DEVICE_GPU)
.HostMemory("ksize")
.HostMemory("strides")
.TypeConstraint<qint8>("T")
.Label("eigen_tensor"),
MaxPoolingV2Op<GPUDevice, qint8>);
#undef REGISTER_GPU_ONLY_POOL_KERNELS
#endif // GOOGLE_CUDA

View File

@ -17,7 +17,9 @@ limitations under the License.
#define TENSORFLOW_KERNELS_MAXPOOLING_OP_H_
// Functor definition for MaxPoolingOp, must be compilable by nvcc.
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/kernels/eigen_pooling.h"
#include "tensorflow/core/platform/types.h"
@ -37,6 +39,14 @@ struct SpatialMaxPooling {
}
};
template <typename Device>
struct SpatialMaxPooling<Device, qint8> {
void operator()(const Device& d, typename TTypes<qint8, 4>::Tensor output,
typename TTypes<qint8, 4>::ConstTensor input, int window_rows,
int window_cols, int row_stride, int col_stride,
const Eigen::PaddingType& padding) {}
};
} // namespace functor
} // namespace tensorflow

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/kernels/maxpooling_op.h"
#include "tensorflow/core/kernels/maxpooling_op_gpu.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
@ -89,6 +90,42 @@ __global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data,
}
}
// The parameters for MaxPoolForwardNoMaskKernel_NCHW_VECT_C are the same as for
// MaxPoolForwardNCHW above, except that mask is not supported, and each
// element of the input and output contains 4 adjacent channel values for
// the same X, y coordinate.
// (so channels = outer_channels, output_size = real output size / 4).
__global__ void MaxPoolForwardNoMaskKernel_NCHW_VECT_C(
const int nthreads, const int32* bottom_data, const int height,
const int width, const int channels, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
int32* top_data) {
// TODO(pauldonnelly): Implement a better optimized version of this kernel.
const int32 kMinINT8X4 = 0x80808080;
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int pw = index % pooled_width;
int ph = (index / pooled_width) % pooled_height;
int c = (index / pooled_width / pooled_height) % channels;
int n = index / pooled_width / pooled_height / channels;
int hstart = ph * stride_h - pad_t;
int wstart = pw * stride_w - pad_l;
int hend = min(hstart + kernel_h, height);
int wend = min(wstart + kernel_w, width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
int32 maxval = kMinINT8X4;
const int32* bottom_data_n = bottom_data + n * channels * height * width;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
int idx = (c * height + h) * width + w;
maxval = __vmaxs4(maxval, bottom_data_n[idx]);
}
}
top_data[index] = maxval;
}
}
template <typename dtype>
__global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data,
const int height, const int width,
@ -328,6 +365,25 @@ __global__ void MaxPoolGradBackward(const int nthreads, const dtype* top_diff,
namespace functor {
// Note: channels is the outer channels (dim 1) which has already been
// divided by 4.
bool MaxPoolForwardNoMask_NCHW_VECT_C::operator()(
const int32* bottom_data, const int batch, const int height,
const int width, int channels, const int pooled_height,
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
int32* top_data, const Eigen::GpuDevice& d) {
const int kThreadsPerBlock = 1024;
const int output_size = batch * channels * pooled_height * pooled_width;
MaxPoolForwardNoMaskKernel_NCHW_VECT_C<<<
(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock,
0, d.stream()>>>(output_size, bottom_data, height, width, channels,
pooled_height, pooled_width, kernel_h, kernel_w,
stride_h, stride_w, pad_t, pad_l, top_data);
d.synchronize();
return d.ok();
}
template <typename T>
bool MaxPoolForwardWithOptionalArgmax<T>::operator()(
const T* bottom_data, const int batch, const int height, const int width,

View File

@ -42,6 +42,15 @@ struct MaxPoolForwardWithOptionalArgmax {
const Eigen::GpuDevice& d);
};
struct MaxPoolForwardNoMask_NCHW_VECT_C {
bool operator()(const int32* bottom_data, const int batch, const int height,
const int width, int channels, const int pooled_height,
const int pooled_width, const int kernel_h,
const int kernel_w, const int stride_h, const int stride_w,
const int pad_t, const int pad_l, int32* top_data,
const Eigen::GpuDevice& d);
};
template <typename T>
struct MaxPoolBackwardWithArgmax {
bool operator()(const int output_size, const int input_size,

View File

@ -22,7 +22,6 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/conv_2d.h"
#include "tensorflow/core/kernels/maxpooling_op_gpu.h"
#include "tensorflow/core/kernels/pooling_ops_common_gpu.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
@ -34,12 +33,18 @@ PoolParameters::PoolParameters(OpKernelContext* context,
const std::vector<int32>& stride,
Padding padding, TensorFormat data_format,
const TensorShape& tensor_in_shape) {
// For maxpooling, tensor_in should have 4 dimensions.
OP_REQUIRES(context, tensor_in_shape.dims() == 4,
errors::InvalidArgument("tensor_in must be 4-dimensional"));
// For maxpooling, tensor_in should have 2 spatial dimensions.
// Note: the total number of dimensions could be 4 for NHWC, NCHW,
// or 5 for NCHW_VECT_C.
OP_REQUIRES(context,
GetTensorSpatialDims(tensor_in_shape.dims(), data_format) == 2,
errors::InvalidArgument(
"tensor_in_shape must have 2 spatial dimensions. ",
tensor_in_shape.dims(), " ", data_format));
this->data_format = data_format;
depth = GetTensorDim(tensor_in_shape, data_format, 'C');
depth = GetTensorDim(tensor_in_shape, data_format, 'C') *
(data_format == FORMAT_NCHW_VECT_C ? 4 : 1);
tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, 'W');
tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, 'H');
tensor_in_batch = GetTensorDim(tensor_in_shape, data_format, 'N');

View File

@ -29,6 +29,10 @@ limitations under the License.
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
#include "tensorflow/core/kernels/maxpooling_op_gpu.h"
#endif // GOOGLE_CUDA
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
@ -256,6 +260,30 @@ class MaxPoolingOp : public OpKernel {
TensorFormat data_format_;
};
template <typename Device>
struct LaunchMaxPoolingNoMask_NCHW_VECT_C;
#ifdef GOOGLE_CUDA
template <>
struct LaunchMaxPoolingNoMask_NCHW_VECT_C<Eigen::GpuDevice> {
static void launch(OpKernelContext* context, const PoolParameters& params,
const Tensor& input, Tensor* output) {
bool status = functor::MaxPoolForwardNoMask_NCHW_VECT_C()(
reinterpret_cast<const int32*>(input.flat<qint8>().data()),
params.tensor_in_batch, params.tensor_in_rows, params.tensor_in_cols,
params.depth, params.out_height, params.out_width, params.window_rows,
params.window_cols, params.row_stride, params.col_stride,
params.pad_rows, params.pad_cols,
reinterpret_cast<int32*>(output->flat<qint8>().data()),
context->eigen_gpu_device());
if (!status) {
context->SetStatus(errors::Internal(
"Failed launching LaunchMaxPoolingNoMask_NCHW_VECT_C"));
}
}
};
#endif
template <typename Device, typename T>
class MaxPoolingV2Op : public OpKernel {
public:
@ -266,8 +294,11 @@ class MaxPoolingV2Op : public OpKernel {
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
errors::InvalidArgument("Invalid data format"));
OP_REQUIRES(
context, data_format_ == FORMAT_NHWC,
errors::InvalidArgument("Default MaxPoolingOp only supports NHWC."));
context,
data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW_VECT_C,
errors::InvalidArgument(
"MaxPoolingV2Op only supports NHWC or NCHW_VECT_C. Got: ",
data_format));
} else {
data_format_ = FORMAT_NHWC;
}
@ -315,8 +346,8 @@ class MaxPoolingV2Op : public OpKernel {
errors::Unimplemented(
"Pooling is not yet supported on the batch dimension."));
PoolParameters params{context, ksize, stride,
padding_, FORMAT_NHWC, tensor_in.shape()};
PoolParameters params{context, ksize, stride,
padding_, data_format_, tensor_in.shape()};
if (!context->status().ok()) {
return;
}
@ -368,13 +399,21 @@ class MaxPoolingV2Op : public OpKernel {
// Spatial MaxPooling implementation.
//
// TODO(vrv): Remove this once we no longer need it.
#ifdef GOOGLE_CUDA
if (std::is_same<Device, GPUDevice>::value) {
Eigen::PaddingType pt = BrainPadding2EigenPadding(padding);
functor::SpatialMaxPooling<Device, T>()(
context->eigen_device<Device>(), output->tensor<T, 4>(),
tensor_in.tensor<T, 4>(), params.window_rows, params.window_cols,
params.row_stride, params.col_stride, pt);
} else {
if (std::is_same<T, qint8>::value) {
LaunchMaxPoolingNoMask_NCHW_VECT_C<GPUDevice>::launch(
context, params, tensor_in, output);
} else {
functor::SpatialMaxPooling<Device, T>()(
context->eigen_device<Device>(), output->tensor<T, 4>(),
tensor_in.tensor<T, 4>(), params.window_rows, params.window_cols,
params.row_stride, params.col_stride, pt);
}
} else
#endif
{
typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>
ConstEigenMatrixMap;
typedef Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic>>

View File

@ -82,8 +82,6 @@ Status SqliteQueryConnection::GetNext(std::vector<Tensor>* out_tensors,
int rc = sqlite3_step(stmt_);
if (rc == SQLITE_ROW) {
for (int i = 0; i < column_count_; i++) {
// TODO(b/64276939) Support other tensorflow types. Interpret columns as
// the types that the client specifies.
DataType dt = output_types_[i];
Tensor tensor(cpu_allocator(), dt, {});
FillTensorWithResultSetEntry(dt, i, &tensor);
@ -125,11 +123,46 @@ void SqliteQueryConnection::FillTensorWithResultSetEntry(
tensor->scalar<string>()() = value;
break;
}
case DT_INT8: {
int8 value = sqlite3_column_int(stmt_, column_index);
tensor->scalar<int8>()() = value;
break;
}
case DT_INT16: {
int16 value = sqlite3_column_int(stmt_, column_index);
tensor->scalar<int16>()() = value;
break;
}
case DT_INT32: {
int32 value = sqlite3_column_int(stmt_, column_index);
tensor->scalar<int32>()() = value;
break;
}
case DT_INT64: {
int64 value = sqlite3_column_int64(stmt_, column_index);
tensor->scalar<int64>()() = value;
break;
}
case DT_UINT8: {
uint8 value = sqlite3_column_int(stmt_, column_index);
tensor->scalar<uint8>()() = value;
break;
}
case DT_UINT16: {
uint16 value = sqlite3_column_int(stmt_, column_index);
tensor->scalar<uint16>()() = value;
break;
}
case DT_BOOL: {
int value = sqlite3_column_int(stmt_, column_index);
tensor->scalar<bool>()() = value ? true : false;
break;
}
case DT_DOUBLE: {
double value = sqlite3_column_double(stmt_, column_index);
tensor->scalar<double>()() = value;
break;
}
// Error preemptively thrown by SqlDatasetOp::MakeDataset in this case.
default: {
LOG(FATAL)

View File

@ -34,13 +34,15 @@ class SqlDatasetOp : public DatasetOpKernel {
explicit SqlDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
// TODO(b/64276939) Remove this check when we add support for other
// tensorflow types.
for (const DataType& dt : output_types_) {
OP_REQUIRES(
ctx, dt == DT_STRING || dt == DT_INT32,
errors::InvalidArgument(
"Each element of `output_types_` must be DT_STRING or DT_INT32"));
OP_REQUIRES(ctx,
dt == DT_STRING || dt == DT_INT8 || dt == DT_INT16 ||
dt == DT_INT32 || dt == DT_INT64 || dt == DT_UINT8 ||
dt == DT_UINT16 || dt == DT_BOOL || dt == DT_DOUBLE,
errors::InvalidArgument(
"Each element of `output_types_` must be one of: "
"DT_STRING, DT_INT8, DT_INT16, DT_INT32, DT_INT64, "
"DT_UINT8, DT_UINT16, DT_BOOL, DT_DOUBLE "));
}
for (const PartialTensorShape& pts : output_shapes_) {
OP_REQUIRES(ctx, pts.dims() == 0,

View File

@ -303,6 +303,49 @@ op {
is_aggregate: true
is_commutative: true
}
op {
name: "AddN"
input_arg {
name: "inputs"
type_attr: "T"
number_attr: "N"
}
output_arg {
name: "sum"
type_attr: "T"
}
attr {
name: "N"
type: "int"
has_minimum: true
minimum: 1
}
attr {
name: "T"
type: "type"
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT64
type: DT_INT32
type: DT_UINT8
type: DT_UINT16
type: DT_INT16
type: DT_INT8
type: DT_COMPLEX64
type: DT_COMPLEX128
type: DT_QINT8
type: DT_QUINT8
type: DT_QINT32
type: DT_HALF
type: DT_VARIANT
}
}
}
is_aggregate: true
is_commutative: true
}
op {
name: "AddSparseToTensorsMap"
input_arg {
@ -13473,6 +13516,74 @@ op {
}
}
}
op {
name: "MaxPool"
input_arg {
name: "input"
type_attr: "T"
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "T"
type: "type"
default_value {
type: DT_FLOAT
}
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
type: DT_UINT8
type: DT_INT16
type: DT_INT8
type: DT_UINT16
type: DT_HALF
type: DT_QINT8
}
}
}
attr {
name: "ksize"
type: "list(int)"
has_minimum: true
minimum: 4
}
attr {
name: "strides"
type: "list(int)"
has_minimum: true
minimum: 4
}
attr {
name: "padding"
type: "string"
allowed_values {
list {
s: "SAME"
s: "VALID"
}
}
}
attr {
name: "data_format"
type: "string"
default_value {
s: "NHWC"
}
allowed_values {
list {
s: "NHWC"
s: "NCHW"
s: "NCHW_VECT_C"
}
}
}
}
op {
name: "MaxPool3D"
input_arg {
@ -14435,6 +14546,70 @@ op {
}
}
}
op {
name: "MaxPoolV2"
input_arg {
name: "input"
type_attr: "T"
}
input_arg {
name: "ksize"
type: DT_INT32
}
input_arg {
name: "strides"
type: DT_INT32
}
output_arg {
name: "output"
type_attr: "T"
}
attr {
name: "T"
type: "type"
default_value {
type: DT_FLOAT
}
allowed_values {
list {
type: DT_FLOAT
type: DT_DOUBLE
type: DT_INT32
type: DT_INT64
type: DT_UINT8
type: DT_INT16
type: DT_INT8
type: DT_UINT16
type: DT_HALF
type: DT_QINT8
}
}
}
attr {
name: "padding"
type: "string"
allowed_values {
list {
s: "SAME"
s: "VALID"
}
}
}
attr {
name: "data_format"
type: "string"
default_value {
s: "NHWC"
}
allowed_values {
list {
s: "NHWC"
s: "NCHW"
s: "NCHW_VECT_C"
}
}
}
}
op {
name: "MaxPoolWithArgmax"
input_arg {

View File

@ -28,7 +28,7 @@ REGISTER_OP("AddN")
.Input("inputs: N * T")
.Output("sum: T")
.Attr("N: int >= 1")
.Attr("T: numbertype")
.Attr("T: {numbertype, variant}")
.SetIsCommutative()
.SetIsAggregate()
.SetShapeFn([](InferenceContext* c) {

View File

@ -1344,11 +1344,13 @@ output: The gradients for LRN.
// --------------------------------------------------------------------------
REGISTER_OP("MaxPool")
.Attr("T: realnumbertype = DT_FLOAT")
.Attr(
"T: {float, double, int32, int64, uint8, int16, int8, uint16, "
"half, qint8} = DT_FLOAT")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
.Input("input: T")
.Output("output: T")
.SetShapeFn(shape_inference::MaxPoolShape)
@ -1369,9 +1371,11 @@ output: The max pooled output tensor.
)doc");
REGISTER_OP("MaxPoolV2")
.Attr("T: realnumbertype = DT_FLOAT")
.Attr(
"T: {float, double, int32, int64, uint8, int16, int8, uint16, "
"half, qint8} = DT_FLOAT")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
.Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
.Input("input: T")
.Input("ksize: int32")
.Input("strides: int32")

View File

@ -334,6 +334,7 @@ op {
type: DT_QUINT8
type: DT_QINT32
type: DT_HALF
type: DT_VARIANT
}
}
}
@ -12628,6 +12629,7 @@ op {
type: DT_INT8
type: DT_UINT16
type: DT_HALF
type: DT_QINT8
}
}
}
@ -12667,6 +12669,7 @@ op {
list {
s: "NHWC"
s: "NCHW"
s: "NCHW_VECT_C"
}
}
}
@ -13401,6 +13404,7 @@ op {
type: DT_INT8
type: DT_UINT16
type: DT_HALF
type: DT_QINT8
}
}
}
@ -13426,6 +13430,7 @@ op {
list {
s: "NHWC"
s: "NCHW"
s: "NCHW_VECT_C"
}
}
}

View File

@ -216,7 +216,7 @@ seq2seq_attention_model.py:363:build_graph:self._add_train_o..., cpu: 1.28sec, a
```shell
# The following example generates a timeline.
tfprof> graph -step 0 -max_depth 100000 -output timeline:outfile=<filename>
tfprof> graph -step -1 -max_depth 100000 -output timeline:outfile=<filename>
generating trace file.

View File

@ -14,7 +14,12 @@
### Command Line Inputs
tfprof command line tool uses the following inputs:
tfprof command line tool uses the following input:
<b>--profile_path:</b> A ProfileProto binary proto file.
See QuickStart on generating the file.
<b>THE OLD WAY BELOW IS DEPRECATED:</b>
<b>--graph_path:</b> GraphDef proto file (required). Used to build in-memory
data structure of the model. For example, graph.pbtxt written by tf.Supervisor

View File

@ -84,7 +84,6 @@ string RunProfile(const string& command, const string& options,
} // namespace
bool NewProfiler(const string* graph, const string* op_log) {
CHECK(!tf_stat) << "Currently only 1 living tfprof profiler is allowed";
CHECK(graph) << "graph mustn't be null";
std::unique_ptr<GraphDef> graph_ptr(new GraphDef());
if (!graph_ptr->ParseFromString(*graph)) {

View File

@ -175,22 +175,22 @@ class ExecStep {
std::map<int32, std::pair<int64, uint64>> output_memory_;
};
#define GRAPH_NODE_BYTES(type) \
do { \
if (execs_.empty()) { \
return 0; \
} \
if (step >= 0) { \
auto exec = execs_.find(step); \
CHECK(exec != execs_.end()) << "unknown step " << step; \
return exec->second.type##_bytes(); \
} \
\
int64 bytes = 0; \
for (const auto& exec : execs_) { \
bytes += exec.second.type##_bytes(); \
} \
return bytes / execs_.size(); \
#define GRAPH_NODE_BYTES(type) \
do { \
if (execs_.empty()) { \
return 0; \
} \
if (step >= 0) { \
auto exec = execs_.find(step); \
if (exec == execs_.end()) return 0; \
return exec->second.type##_bytes(); \
} \
\
int64 bytes = 0; \
for (const auto& exec : execs_) { \
bytes += exec.second.type##_bytes(); \
} \
return bytes / execs_.size(); \
} while (0)
class TFGraphNode {
@ -372,7 +372,9 @@ class TFGraphNode {
}
if (step >= 0) {
auto exec = execs_.find(step);
CHECK(exec != execs_.end());
if (exec == execs_.end()) {
return 0;
}
return exec->second.run_count();
}
int64 total_run_count = 0;
@ -390,7 +392,9 @@ class TFGraphNode {
}
if (step >= 0) {
auto exec = execs_.find(step);
CHECK(exec != execs_.end());
if (exec == execs_.end()) {
return 0;
}
return exec->second.exec_micros();
}
@ -410,7 +414,9 @@ class TFGraphNode {
}
if (step >= 0) {
auto exec = execs_.find(step);
CHECK(exec != execs_.end());
if (exec == execs_.end()) {
return 0;
}
return exec->second.accelerator_exec_micros();
}
@ -430,7 +436,9 @@ class TFGraphNode {
}
if (step >= 0) {
auto exec = execs_.find(step);
CHECK(exec != execs_.end());
if (exec == execs_.end()) {
return 0;
}
return exec->second.cpu_exec_micros();
}
@ -448,20 +456,26 @@ class TFGraphNode {
int64 all_start_micros(int64 step) const {
auto exec = execs_.find(step);
CHECK(exec != execs_.end()) << "unknown step " << step;
if (exec == execs_.end()) {
return 0;
}
return exec->second.all_start_micros();
}
int64 latest_end_micros(int64 step) const {
auto exec = execs_.find(step);
CHECK(exec != execs_.end()) << "unknown step " << step;
if (exec == execs_.end()) {
return 0;
}
return exec->second.latest_end_micros();
}
const std::map<string, std::vector<std::pair<int64, int64>>>& op_execs(
int64 step) const {
auto exec = execs_.find(step);
CHECK(exec != execs_.end()) << "unknown step " << step;
if (exec == execs_.end()) {
return empty_op_execs_;
}
return exec->second.op_execs();
}
@ -469,33 +483,45 @@ class TFGraphNode {
int64 accelerator_temp_bytes(int64 step) const {
auto exec = execs_.find(step);
CHECK(exec != execs_.end()) << "unknown step " << step;
if (exec == execs_.end()) {
return 0;
}
return exec->second.accelerator_temp_bytes();
}
int64 host_temp_bytes(int64 step) const {
auto exec = execs_.find(step);
CHECK(exec != execs_.end()) << "unknown step " << step;
if (exec == execs_.end()) {
return 0;
}
return exec->second.host_temp_bytes();
}
int64 accelerator_persistent_bytes(int64 step) const {
auto exec = execs_.find(step);
CHECK(exec != execs_.end()) << "unknown step " << step;
if (exec == execs_.end()) {
return 0;
}
return exec->second.accelerator_persistent_bytes();
}
int64 host_persistent_bytes(int64 step) const {
auto exec = execs_.find(step);
CHECK(exec != execs_.end()) << "unknown step " << step;
if (exec == execs_.end()) {
return 0;
}
return exec->second.host_persistent_bytes();
}
const std::map<int32, std::pair<int64, uint64>>& output_memory(
int64 step) const {
auto exec = execs_.find(step);
CHECK(exec != execs_.end()) << "unknown step " << step;
if (exec == execs_.end()) {
return empty_output_memory_;
}
return exec->second.output_memory();
}
int64 allocator_bytes_in_use(int64 step) const {
auto exec = execs_.find(step);
CHECK(exec != execs_.end()) << "unknown step " << step;
if (exec == execs_.end()) {
return 0;
}
return exec->second.allocator_bytes_in_use();
}
@ -566,6 +592,9 @@ class TFGraphNode {
std::set<string> op_types_;
std::map<int64, ExecStep> execs_;
std::map<int32, std::pair<int64, uint64>> empty_output_memory_;
std::map<string, std::vector<std::pair<int64, int64>>> empty_op_execs_;
};
class TFMultiGraphNode {

View File

@ -88,6 +88,9 @@ TFStats::TFStats(const string& filename,
node_pb.second.name(), std::move(node)));
}
has_code_traces_ = profile.has_trace();
for (int64 s : profile.steps()) {
steps_.insert(s);
}
}
void TFStats::BuildView(const string& cmd) {
@ -136,6 +139,14 @@ const GraphNodeProto& TFStats::ShowGraphNode(const string& cmd,
if (cmd == kCmds[0]) {
return scope_view_->Show(opts);
} else if (cmd == kCmds[1]) {
if (opts.step < 0 && opts.output_type == kOutput[0]) {
for (int64 step : steps_) {
Options nopts = opts;
nopts.step = step;
graph_view_->Show(nopts);
}
return empty_graph_node_;
}
return graph_view_->Show(opts);
} else {
fprintf(stderr, "Unknown command: %s\n", cmd.c_str());
@ -148,7 +159,11 @@ const MultiGraphNodeProto& TFStats::ShowMultiGraphNode(
if (!Validate(opts)) {
return empty_multi_graph_node_;
}
if (cmd == kCmds[2] && has_code_traces()) {
if (cmd == kCmds[2]) {
if (!has_code_traces()) {
fprintf(stderr, "No code trace information\n");
return empty_multi_graph_node_;
}
return code_view_->Show(opts);
} else if (cmd == kCmds[3]) {
return op_view_->Show(opts);
@ -212,7 +227,9 @@ void TFStats::AddOpLogProto(std::unique_ptr<OpLogProto> op_log) {
}
if (entry.has_code_def()) {
has_code_traces_ = true;
node->second->AddCode(entry.code_def());
if (node->second->code().traces_size() == 0) {
node->second->AddCode(entry.code_def());
}
}
}
}
@ -258,9 +275,11 @@ void TFStats::WriteProfile(const string& filename) {
}
(*profile.mutable_nodes())[it->second->id()].MergeFrom(
it->second->ToProto(nodes_map_));
if (it->second->code().traces_size() > 0) {
profile.set_has_trace(true);
}
}
profile.set_has_trace(has_code_traces_);
for (int64 s : steps_) {
profile.add_steps(s);
}
Status s =
WriteStringToFile(Env::Default(), filename, profile.SerializeAsString());
@ -271,7 +290,12 @@ void TFStats::WriteProfile(const string& filename) {
bool TFStats::Validate(const Options& opts) const {
if (opts.step >= 0 && steps_.find(opts.step) == steps_.end()) {
fprintf(stderr, "Options -step=%lld not found\n", opts.step);
fprintf(stderr,
"Options -step=%lld not found.\nAvailable steps: ", opts.step);
for (int64 s : steps_) {
fprintf(stderr, "%lld ", s);
}
fprintf(stderr, "\n");
return false;
}
return true;

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/profiler/internal/tfprof_utils.h"
namespace tensorflow {
@ -303,11 +304,12 @@ void Timeline::GenerateCodeTimeline(const CodeNode* node) {
}
void Timeline::OutputTimeline() {
string outfile = strings::Printf("%s_%lld", outfile_.c_str(), step());
Status s =
WriteStringToFile(Env::Default(), outfile_, chrome_formatter_.Format());
WriteStringToFile(Env::Default(), outfile, chrome_formatter_.Format());
if (!s.ok()) {
fprintf(stderr, "Failed to write timeline file: %s\nError: %s\n",
outfile_.c_str(), s.ToString().c_str());
outfile.c_str(), s.ToString().c_str());
return;
}
fprintf(stdout, "\n******************************************************\n");
@ -315,7 +317,7 @@ void Timeline::OutputTimeline() {
"Timeline file is written to %s.\n"
"Open a Chrome browser, enter URL chrome://tracing and "
"load the timeline file.",
outfile_.c_str());
outfile.c_str());
fprintf(stdout, "\n******************************************************\n");
fflush(stdout);
}

View File

@ -70,7 +70,7 @@ TEST_F(TFProfTimelineTest, GraphView) {
tf_stats_->ShowGraphNode("graph", opts);
string dump_str;
TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file, &dump_str));
TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file + "_0", &dump_str));
EXPECT_EQ(1754536562981488144ull, Hash64(dump_str));
}
@ -84,7 +84,7 @@ TEST_F(TFProfTimelineTest, ScopeView) {
tf_stats_->ShowGraphNode("scope", opts);
string dump_str;
TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file, &dump_str));
TF_CHECK_OK(ReadFileToString(Env::Default(), dump_file + "_0", &dump_str));
EXPECT_EQ(17545174915963890413ull, Hash64(dump_str));
}

View File

@ -42,6 +42,8 @@ message ProfileProto {
map<int64, ProfileNode> nodes = 1;
// Whether or not has code traces.
bool has_trace = 2;
// Traced steps.
repeated int64 steps = 3;
}
message ProfileNode {

View File

@ -632,6 +632,22 @@ define an attr with constraints, you can use the following `<attr-type-expr>`s:
tf.number_type(t=tf.bool) # Invalid
```
Lists can be combined with other lists and single types. The following
op allows attr `t` to be any of the numberic types, or the bool type:
```c++
REGISTER_OP("NumberOrBooleanType")
.Attr("t: {numbertype, bool}");
```
For this op:
```python
tf.number_or_boolean_type(t=tf.int32) # Valid
tf.number_or_boolean_type(t=tf.bool) # Valid
tf.number_or_boolean_type(t=tf.string) # Invalid
```
* `int >= <n>`: The value must be an int whose value is greater than or equal to
`<n>`, where `<n>` is a natural number.

View File

@ -1,5 +1,13 @@
# Installing TensorFlow
We've built and tested TensorFlow on the following 64-bit laptop/desktop
operating systems:
* MacOS X 10.11 (El Capitan) or later.
* Ubuntu 14.04 or later
* Windows 7 or later.
Although you might be able to install TensorFlow on other laptop or desktop
systems, we only support (and only fix issues in) the preceding configurations.
The following guides explain how to install a version of TensorFlow
that enables you to write applications in Python:

View File

@ -1,43 +1,182 @@
# Performance Guide
This guide contains a collection of best practices for optimizing your
TensorFlow code. The best practices apply to both new and experienced
Tensorflow users. As a complement to the best practices in this document, the
@{$performance_models$High-Performance Models} document links to example code
and details for creating models that scale on a variety of hardware.
This guide contains a collection of best practices for optimizing TensorFlow
code. The guide is divided into a few sections:
## Best Practices
While optimizing implementations of different types of models can be different,
the topics below cover best practices to get the most performance from
TensorFlow. Although these suggestions focus on image-based models, we will
regularly add tips for all kinds of models. The following list highlights key
best practices:
* [General best practices](#general_best_practices) covers topics that are
common across a variety of model types and hardware.
* [Optimizing for GPU](#optimizing_for_gpu) details tips specifically relevant
to GPUs.
* [Optimizing for CPU](#optimizing_for_cpu) details CPU specific information.
* Build and install from source
* Utilize queues for reading data
* Preprocessing on the CPU
* Use `NCHW` image data format
* Place shared parameters on the GPU
* Use fused batch norm
## General best practices
The following sections detail the preceding suggestions.
The sections below cover best practices that are relevant to a variety of
hardware and models. The best practices section is broken down into the
following sections:
### Build and install from source
* [Input pipeline optimizations](#input-pipeline-optimization)
* [Data formats](#data-formats)
* [Common fused Ops](#common-fused-ops)
* [Building and installing from source](#building-and-installing-from-source)
To install the most optimized version of TensorFlow, build and install
TensorFlow from source by following [Installing TensorFlow from Source](../install/install_sources).
Building from source with compiler optimizations for the target hardware and
ensuring the latest CUDA platform and cuDNN libraries are installed results in
the highest performing installs.
### Input pipeline optimization
For the most stable experience, build from the [latest release](https://github.com/tensorflow/tensorflow/releases)
branch. To get the latest performance changes and accept some stability risk,
build from [master](https://github.com/tensorflow/tensorflow).
Typical models retrieve data from disk and preprocess it before sending the data
through the network. For example, models that process JPEG images will follow
this flow: load image from disk, decode JPEG into a tensor, crop and pad,
possibly flip and distort, and then batch. This flow is referred to as the input
pipeline. As GPUs and other hardware accelerators get faster, preprocessing of
data can be a bottleneck.
If there is a need to build TensorFlow on a platform that has different hardware
than the target, then cross-compile with the highest optimizations for the target
platform. The following command is an example of telling `bazel` to compile for
a specific platform:
Determining if the input pipeline is the bottleneck can be complicated. One of
the most straightforward methods is to reduce the model to a single operation
(trivial model) after the input pipeline and measure the examples per second. If
the difference in examples per second for the full model and the trivial model
is minimal then the input pipeline is likely a bottleneck. Below are some other
approaches to identifying issues:
* Check if a GPU is underutilized by running `watch -n 2 nvidia-smi`. If GPU
utilization is not approaching 80-100%, then the input pipeline may be the
bottleneck.
* Generate a timeline and look for large blocks of white space (waiting). An
example of generating a timeline exists as part of the @{$jit$XLA JIT}
tutorial.
* Check CPU usage. It is possible to have an optimized input pipeline and lack
the CPU cycles to process the pipeline.
* Estimate the throughput needed and verify the disk used is capable of that
level of throughput. Some cloud solutions have network attached disks that
start as low as 50 MB/sec, which is slower than spinning disks (150 MB/sec),
SATA SSDs (500 MB/sec), and PCIe SSDs (2,000+ MB/sec).
#### Preprocessing on the CPU
Placing input pipeline operations on the CPU can significantly improve
performance. Utilizing the CPU for the input pipeline frees the GPU to focus on
training. To ensure preprocessing is on the CPU, wrap the preprocessing
operations as shown below:
```python
with tf.device('/cpu:0'):
# function to get and process images or data.
distorted_inputs = load_and_distort_images()
```
If using `tf.estimator.Estimator` the input function is automatically placed on
the CPU.
#### Using the Dataset API
The @{$datasets$Dataset API} is replacing `queue_runner` as the recommended API
for building input pipelines. The API was added to contrib as part of TensorFlow
1.2 and will move to core in the near future. This
[ResNet example](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10_estimator/cifar10_main.py)
([arXiv:1512.03385](https://arxiv.org/abs/1512.03385))
training CIFAR-10 illustrates the use of the Dataset API along with
`tf.estimator.Estimator`. The Dataset API utilizes C++ multi-threading and has a
much lower overhead than the Python-based `queue_runner` that is limited by
Python's multi-threading performance.
While feeding data using a `feed_dict` offers a high level of flexibility, in
most instances using `feed_dict` does not scale optimally. However, in instances
where only a single GPU is being used the difference can be negligible. Using
the Dataset API is still strongly recommended. Try to avoid the following:
```python
# feed_dict often results in suboptimal performance when using large inputs.
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
```
#### Use large files
Reading large numbers of small files significantly impacts I/O performance.
One approach to get maximum I/O throughput is to preprocess input data into
larger (~100MB) `TFRecord` files. For smaller data sets (200MB-1GB), the best
approach is often to load the entire data set into memory. The document
[Downloading and converting to TFRecord format](https://github.com/tensorflow/models/tree/master/slim#Data)
includes information and scripts for creating `TFRecords` and this
[script](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10_estimator/generate_cifar10_tfrecords.py)
converts the CIFAR-10 data set into `TFRecords`.
### Data formats
Data formats refers to the structure of the Tensor passed to a given Op. The
discussion below is specifically about 4D Tensors representing images. In
TensorFlow the parts of the 4D tensor are often referred to by the following
letters:
* N refers to the number of images in a batch.
* H refers to the number of pixels in the vertical (height) dimension.
* W refers to the number of pixels in the horizontal (width) dimension.
* C refers to the channels. For example, 1 for black and white or grayscale
and 3 for RGB.
Within TensorFlow there are two naming conventions representing the two most
common data formats:
* `NCHW` or `channels_first`
* `NHWC` or `channels_last`
`NHWC` is the TensorFlow default and `NCHW` is the optimal format to use when
training on NVIDIA GPUs using [cuDNN](https://developer.nvidia.com/cudnn).
The best practice is to build models that work with both data formats. This
simplifies training on GPUs and then running inference on CPUs. If TensorFlow is
compiled with the [Intel MKL](#tensorflow_with_intel_mkl-dnn) optimizations,
many operations, especially those related to CNN based models, will be optimized
and support `NCHW`. If not using the MKL, some operations are not supported on
CPU when using `NCHW`.
The brief history of these two formats is that TensorFlow started by using
`NHWC` because it was a little faster on CPUs. In the long term, we are working
on tools to auto rewrite graphs to make switching between the formats
transparent and take advantages of micro optimizations where a GPU Op may be
faster using `NHWC` than the normally most efficient `NCHW`.
### Common fused Ops
Fused Ops combine multiple operations into a single kernel for improved
performance. There are many fused Ops within TensorFlow and @{$xla$XLA} will
create fused Ops when possible to automatically improve performance. Collected
below are select fused Ops that can greatly improve performance and may be
overlooked.
#### Fused batch norm
Fused batch norm combines the multiple operations needed to do batch
normalization into a single kernel. Batch norm is an expensive process that for
some models makes up a large percentage of the operation time. Using fused batch
norm can result in a 12%-30% speedup.
There are two commonly used batch norms and both support fusing. The core
@{tf.layers.batch_normalization} added fused starting in TensorFlow 1.3.
```python
bn = tf.layers.batch_normalization(
input_layer, fused=True, data_format='NCHW')
```
The contrib @{tf.contrib.layers.batch_norm} method has had fused as an option
since before TensorFlow 1.0.
```python
bn = tf.contrib.layers.batch_norm(input_layer, fused=True, data_format='NCHW')
```
### Building and installing from source
The default TensorFlow binaries target the broadest range of hardware to make
TensorFlow accessible to everyone. If using CPUs for training or inference, it
is recommended to compile TensorFlow with all of the optimizations available for
the CPU in use. Speedups for training and inference on CPU are documented below
in [Comparing compiler optimizations](#comparing-compiler-optimizations).
To install the most optimized version of TensorFlow,
@{$install_sources$build and install} from source. If there is a need to build
TensorFlow on a platform that has different hardware than the target, then
cross-compile with the highest optimizations for the target platform. The
following command is an example of using `bazel` to compile for a specific
platform:
```python
# This command optimizes for Intels Broadwell processor
@ -47,106 +186,467 @@ bazel build -c opt --copt=-march="broadwell" --config=cuda //tensorflow/tools/pi
#### Environment, build, and install tips
* Compile with the highest level of compute the [GPU
supports](http://developer.nvidia.com/cuda-gpus), e.g. P100: 6.0, Titan X
(pascal): 6.2, Titan X (maxwell): 5.2, and K80: 3.7.
* Install the latest CUDA platform and cuDNN libraries.
* Make sure to use a version of gcc that supports all of the optimizations of
the target CPU. The recommended minimum gcc version is 4.8.3. On OS X upgrade
to the latest Xcode version and use the version of clang that comes with Xcode.
* TensorFlow checks on startup whether it has been compiled with the
optimizations available on the CPU. If the optimizations are not included,
TensorFlow will emit warnings, e.g. AVX, AVX2, and FMA instructions not
included.
* `./configure` asks which compute capability to include in the build. This
does not impact overall performance but does impact initial startup. After
running TensorFlow once, the compiled kernels are cached by CUDA. If using
a docker container, the data is not cached and the penalty is paid each time
TensorFlow starts. The best practice is to include the
[compute capabilities](http://developer.nvidia.com/cuda-gpus)
of the GPUs that will be used, e.g. P100: 6.0, Titan X (Pascal): 6.1, Titan
X (Maxwell): 5.2, and K80: 3.7.
* Use a version of gcc that supports all of the optimizations of the target
CPU. The recommended minimum gcc version is 4.8.3. On OS X, upgrade to the
latest Xcode version and use the version of clang that comes with Xcode.
* Install the latest stable CUDA platform and cuDNN libraries supported by
TensorFlow.
### Utilize queues for reading data
## Optimizing for GPU
One common cause of poor performance is underutilizing GPUs, or essentially
"starving" them of data by not setting up an efficient pipeline. Make sure to
set up an input pipeline to utilize queues and stream data effectively. Review
the @{$reading_data#reading_from_files$Reading Data guide} for implementation
details. One way to identify a "starved" GPU is to generate and review
timelines. A detailed tutorial for timelines does not exist, but a quick example
of generating a timeline exists as part of the @{$jit$XLA JIT} tutorial. Another
simple way to check if a GPU is underutilized is to run `watch nvidia-smi`, and
if GPU utilization is not approaching 100% then the GPU is not getting data fast
enough.
This section contains GPU-specific tips that are not covered in the
[General best practices](#general-best-practices). Obtaining optimal performance
on multi-GPUs is a challenge. A common approach is to use data parallelism.
Scaling through the use of data parallelism involves making multiple copies of
the model, which are referred to as "towers", and then placing one tower on each
of the GPUs. Each tower operates on a different mini-batch of data and then
updates variables, also known as parameters, that need to be shared between
each of the towers. How each tower gets the updated variables and how the
gradients are applied has an impact on the performance, scaling, and convergence
of the model. The rest of this section provides an overview of variable
placement and the towering of a model on multiple GPUs.
@{$performance_models$High-Performance Models} gets into more details regarding
more complex methods that can be used to share and update variables between
towers.
Unless for a special circumstance or for example code, do not feed data
into the session from Python variables, e.g. `dictionary`.
The best approach to handling variable updates depends on the model, hardware,
and even how the hardware has been configured. An example of this, is that two
systems can be built with NVIDIA Tesla P100s but one may be using PCIe and the
other [NVLink](http://www.nvidia.com/object/nvlink.html). In that scenario, the
optimal solution for each system may be different. For real world examples, read
the @{$benchmarks$benchmark} page which details the settings that were optimal
for a variety of platforms. Below is a summary of what was learned from
benchmarking various platforms and configurations:
* **Tesla K80**: If the GPUs are on the same PCI Express root complex and are
able to use [NVIDIA GPUDirect](https://developer.nvidia.com/gpudirect) Peer
to Peer, then placing the variables equally across the GPUs used for
training is the best approach. If the GPUs cannot use GPUDirect, then
placing the variables on the CPU is the best option.
* **Titan X (Maxwell and Pascal), M40, P100, and similar**: For models like
ResNet and InceptionV3, placing variables on the CPU is the optimal setting,
but for models with a lot of variables like AlexNet and VGG, using GPUs with
`NCCL` is better.
A common approach to managing where variables are placed, is to create a method
to determine where each Op is to be placed and use that method in place of a
specific device name when calling `with tf.device():`. Consider a scenario where
a model is being trained on 2 GPUs and the variables are to be placed on the
CPU. There would be a loop for creating and placing the "towers" on each of the
2 GPUs. A custom device placement method would be created that watches for Ops
of type `Variable`, `VariableV2`, and `VarHandleOp` and indicates that they are
to be placed on the CPU. All other Ops would be placed on the target GPU.
The building of the graph would proceed as follows:
* On the first loop a "tower" of the model would be created for `gpu:0`.
During the placement of the Ops, the custom device placement method would
indicate that variables are to be placed on `cpu:0` and all other Ops on
`gpu:0`.
* On the second loop, `reuse` is set to `True` to indicate that variables are
to be reused and then the "tower" is created on `gpu:1`. During the
placement of the Ops associated with the "tower", the variables that were
placed on `cpu:0` are reused and all other Ops are created and placed on
`gpu:1`.
The final result is all of the variables are placed on the CPU with each GPU
having a copy of all of the computational Ops associated with the model.
The code snippet below illustrates two different approaches for variable
placement: one is placing variables on the CPU; the other is placing variables
equally across the GPUs.
```python
# Using feed_dict often results in suboptimal performance when using large inputs.
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
class GpuParamServerDeviceSetter(object):
"""Used with tf.device() to place variables on the least loaded GPU.
A common use for this class is to pass a list of GPU devices, e.g. ['gpu:0',
'gpu:1','gpu:2'], as ps_devices. When each variable is placed, it will be
placed on the least loaded gpu. All other Ops, which will be the computation
Ops, will be placed on the worker_device.
"""
def __init__(self, worker_device, ps_devices):
"""Initializer for GpuParamServerDeviceSetter.
Args:
worker_device: the device to use for computation Ops.
ps_devices: a list of devices to use for Variable Ops. Each variable is
assigned to the least loaded device.
"""
self.ps_devices = ps_devices
self.worker_device = worker_device
self.ps_sizes = [0] * len(self.ps_devices)
def __call__(self, op):
if op.device:
return op.device
if op.type not in ['Variable', 'VariableV2', 'VarHandleOp']:
return self.worker_device
# Gets the least loaded ps_device
device_index, _ = min(enumerate(self.ps_sizes), key=operator.itemgetter(1))
device_name = self.ps_devices[device_index]
var_size = op.outputs[0].get_shape().num_elements()
self.ps_sizes[device_index] += var_size
return device_name
def _create_device_setter(is_cpu_ps, worker, num_gpus):
"""Create device setter object."""
if is_cpu_ps:
# tf.train.replica_device_setter supports placing variables on the CPU, all
# on one GPU, or on ps_servers defined in a cluster_spec.
return tf.train.replica_device_setter(
worker_device=worker, ps_device='/cpu:0', ps_tasks=1)
else:
gpus = ['/gpu:%d' % i for i in range(num_gpus)]
return ParamServerDeviceSetter(worker, gpus)
# The method below is a modified snippet from the full example.
def _resnet_model_fn():
# When set to False, variables are placed on the least loaded GPU. If set
# to True, the variables will be placed on the CPU.
is_cpu_ps = False
# Loops over the number of GPUs and creates a copy ("tower") of the model on
# each GPU.
for i in range(num_gpus):
worker = '/gpu:%d' % i
# Creates a device setter used to determine where Ops are to be placed.
device_setter = _create_device_setter(is_cpu_ps, worker, FLAGS.num_gpus)
# Creates variables on the first loop. On subsequent loops reuse is set
# to True, which results in the "towers" sharing variables.
with tf.variable_scope('resnet', reuse=bool(i != 0)):
with tf.name_scope('tower_%d' % i) as name_scope:
# tf.device calls the device_setter for each Op that is created.
# device_setter returns the device the Op is to be placed on.
with tf.device(device_setter):
# Creates the "tower".
_tower_fn(is_training, weight_decay, tower_features[i],
tower_labels[i], tower_losses, tower_gradvars,
tower_preds, False)
```
### Preprocessing on the CPU
In the near future the above code will be for illustration purposes only as
there will be easy to use high level methods to support a wide range of popular
approaches. This
[example](https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10_estimator)
will continue to get updated as the API expands and evolves to address multi-GPU
scenarios.
Placing preprocessing operations on the CPU can significantly improve
performance. When preprocessing occurs on the GPU the flow of data is
CPU -> GPU (preprocessing) -> CPU -> GPU (training). The data is bounced back
and forth between the CPU and GPU. When preprocessing is placed on the CPU,
the data flow is CPU (preprocessing) -> GPU (training). Another benefit is
preprocessing on the CPU frees GPU time to focus on training.
## Optimizing for CPU
Placing preprocessing on the CPU can result in a 6X+ increase in samples/sec
processed, which could lead to training in 1/6th of the time. To ensure
preprocessing is on the CPU, wrap the preprocessing operations as shown below:
CPUs, which includes Intel® Xeon Phi™, achieve optimal performance when
TensorFlow is @{$install_sources$built from source} with all of the instructions
supported by the target CPU.
Beyond using the latest instruction sets, Intel® has added support for the
Intel® Math Kernel Library for Deep Neural Networks (Intel® MKL-DNN) to
TensorFlow. While the name is not completely accurate, these optimizations are
often simply referred to as 'MKL' or 'TensorFlow with MKL'. [TensorFlow
with Intel® MKL-DNN](#tensorflow_with_intel_mkl_dnn) contains details on the
MKL optimizations.
The two configurations listed below are used to optimize CPU performance by
adjusting the thread pools.
* `intra_op_parallelism_threads`: Nodes that can use multiple threads to
parallelize their execution will schedule the individual pieces into this
pool.
* `inter_op_parallelism_threads`: All ready nodes are scheduled in this pool.
These configurations are set via the `tf.ConfigProto` and passed to `tf.Session`
in the `config` attribute as shown in the snippet below. For both configuration
options, if they are unset or set to 0, will default to the number of logical
CPU cores. Testing has shown that the default is effective for systems ranging
from one CPU with 4 cores to multiple CPUs with 70+ combined logical cores.
A common alternative optimization is to set the number of threads in both pools
equal to the number of physical cores rather than logical cores.
```python
with tf.device('/cpu:0'):
# function to get and process images or data.
distorted_inputs = load_and_distort_images()
config = tf.ConfigProto()
config.intra_op_parallelism_threads = 44
config.inter_op_parallelism_threads = 44
tf.session(config=config)
```
### Use large files
The [Comparing compiler optimizations](#comparing-compiler-optimizations)
section contains the results of tests that used different compiler
optimizations.
Under some circumstances, both the CPU and GPU can be starved for data by the
I/O system. If you are using many small files to form your input data set, you
may be limited by the speed of your filesystem. If your training loop runs
faster when using SSDs vs HDDs for storing your input data, you could be
I/O bottlenecked.
### TensorFlow with Intel® MKL DNN
If this is the case, you should pre-process your input data, creating a few
large TFRecord files.
Intel® has added optimizations to TensorFlow for Intel® Xeon® and Intel® Xeon
Phi™ though the use of Intel® Math Kernel Library for Deep Neural Networks
(Intel® MKL-DNN) optimized primitives. The optimizations also provide speedups
for the consumer line of processors, e.g. i5 and i7 Intel processors. The Intel
published paper
[TensorFlow* Optimizations on Modern Intel® Architecture](https://software.intel.com/en-us/articles/tensorflow-optimizations-on-modern-intel-architecture)
contains additional details on the implementation.
### Use NCHW image data format
> Note: MKL was added as of TensorFlow 1.2 and currently only works on Linux. It
> also does not work when also using `--config=cuda`.
Image data format refers to the representation of batches of images. TensorFlow
supports `NHWC` (TensorFlow default) and `NCHW` (cuDNN default). N refers to the
number of images in a batch, H refers to the number of pixels in the vertical
dimension, W refers to the number of pixels in the horizontal dimension, and C
refers to the channels (e.g. 1 for black and white, 3 for RGB, etc.) Although
cuDNN can operate on both formats, it is faster to operate in its default
format.
In addition to providing significant performance improvements for training CNN
based models, compiling with the MKL creates a binary that is optimized for AVX
and AVX2. The result is a single binary that is optimized and compatible with
most modern (post-2011) processors.
The best practice is to build models that work with both `NCHW` and `NHWC` as it
is common to train using `NCHW` on GPU, and then do inference with `NHWC` on CPU.
TensorFlow can be compiled with the MKL optimizations using the following
commands that depending on the version of the TensorFlow source used.
There are edge cases where `NCHW` can be slower on GPU than `NHWC`. One
[case](https://github.com/tensorflow/tensorflow/issues/7551#issuecomment-280421351)
is using non-fused batch norm on WRN-16-4 without dropout. In that case using
fused batch norm, which is also recommended, is the optimal solution.
For TensorFlow source versions after 1.3.0:
The very brief history of these two formats is that TensorFlow started by using
`NHWC` because it was a little faster on CPUs. Then the TensorFlow team
discovered that `NCHW` performs better when using the NVIDIA cuDNN library. The
current recommendation is that users support both formats in their models. In
the long term, we plan to rewrite graphs to make switching between the formats
transparent.
```bash
./configure
# Pick the desired options
bazel build --config=mkl -c opt //tensorflow/tools/pip_package:build_pip_package
### Use fused batch norm
```
When using batch norm
@{tf.contrib.layers.batch_norm} set the attribute `fused=True`:
For TensorFlow versions 1.2.0 through 1.3.0:
```bash
./configure
Do you wish to build TensorFlow with MKL support? [y/N] Y
Do you wish to download MKL LIB from the web? [Y/n] Y
# Select the defaults for the rest of the options.
bazel build --config=mkl --copt="-DEIGEN_USE_VML" -c opt //tensorflow/tools/pip_package:build_pip_package
```
#### Tuning MKL for the best performance
This section details the different configurations and environment variables that
can be used to tune the MKL to get optimal performance. Before tweaking various
environment variables make sure the model is using the `NCHW` (`channels_first`)
[data format](#data-formats). The MKL is optimized for `NCHW` and Intel is
working to get near performance parity when using `NHWC`.
MKL uses the following environment variables to tune performance:
* KMP_BLOCKTIME - Sets the time, in milliseconds, that a thread should wait,
after completing the execution of a parallel region, before sleeping.
* KMP_AFFINITY - Enables the run-time library to bind threads to physical
processing units.
* KMP_SETTINGS - Enables (true) or disables (false) the printing of OpenMP*
run-time library environment variables during program execution.
* OMP_NUM_THREADS - Specifies the number of threads to use.
More details on the KMP variables are on
[Intel's](https://software.intel.com/en-us/node/522775) site and the OMP
variables on
[gnu.org](https://gcc.gnu.org/onlinedocs/libgomp/Environment-Variables.html)
While there can be substantial gains from adjusting the environment variables,
which is discussed below, the simplified advice is to set the
`inter_op_parallelism_threads` equal to the number of physical CPUs and to set
the following environment variables:
* KMP_BLOCKTIME=0
* KMP_AFFINITY=granularity=fine,verbose,compact,1,0
Example setting MKL variables with command-line arguments:
```bash
KMP_BLOCKTIME=0 KMP_AFFINITY=granularity=fine,verbose,compact,1,0 \
KMP_SETTINGS=1 python your_python_script.py
```
Example setting MKL variables with python `os.environ`:
```python
bn = tf.contrib.layers.batch_norm(
input_layer, fused=True, data_format='NCHW'
scope=scope, **kwargs)
os.environ["KMP_BLOCKTIME"] = str(FLAGS.kmp_blocktime)
os.environ["KMP_SETTINGS"] = str(FLAGS.kmp_settings)
os.environ["KMP_AFFINITY"]= FLAGS.kmp_affinity
if FLAGS.num_intra_threads > 0:
os.environ["OMP_NUM_THREADS"]= str(FLAGS.num_intra_threads)
```
The non-fused batch norm does computations using several individual Ops. Fused
batch norm combines the individual operations into a single kernel, which runs
faster.
There are models and hardware platforms that benefit from different settings.
Each variable that impacts performance is discussed below.
* **KMP_BLOCKTIME**: The MKL default is 200ms, which was not optimal in our
testing. 0 (0ms) was a good default for CNN based models that were tested.
The best performance for AlexNex was achieved at 30ms and both GoogleNet and
VGG11 performed best set at 1ms.
* **KMP_AFFINITY**: The recommended setting is
`granularity=fine,verbose,compact,1,0`.
* **OMP_NUM_THREADS**: This defaults to the number of physical cores.
Adjusting this parameter beyond matching the number of cores can have an
impact when using Intel® Xeon Phi™ (Knights Landing) for some models. See
[TensorFlow* Optimizations on Modern Intel® Architecture](https://software.intel.com/en-us/articles/tensorflow-optimizations-on-modern-intel-architecture)
for optimal settings.
* **intra_op_parallelism_threads**: Setting this equal to the number of
physical cores is recommended. Setting the value to 0, which is the default
and will result in the value being set to the number of logical cores, is an
option to try for some architectures. This value and `OMP_NUM_THREADS`
should be equal.
* **inter_op_parallelism_threads**: Setting this equal to the number of
sockets is recommended. Setting the value to 0, which is the default,
results in the value being set to the number of logical cores.
### Comparing compiler optimizations
Collected below are performance results running training and inference on
different types of CPUs on different platforms with various compiler
optimizations. The models used were ResNet-50
([arXiv:1512.03385](https://arxiv.org/abs/1512.03385)) and
InceptionV3 ([arXiv:1512.00567](https://arxiv.org/abs/1512.00567)).
For each test, when the MKL optimization was used the environment variable
KMP_BLOCKTIME was set to 0 (0ms) and KMP_AFFINITY to
`granularity=fine,verbose,compact,1,0`.
#### Inference InceptionV3
**Environment**
* Instance Type: AWS EC2 m4.xlarge
* CPU: Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz (Broadwell)
* Dataset: ImageNet
* TensorFlow Version: 1.2.0 RC2
* Test Script: [tf_cnn_benchmarks.py](https://github.com/tensorflow/benchmarks/blob/mkl_experiment/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py)
**Batch Size: 1**
Command executed for the MKL test:
```bash
python tf_cnn_benchmarks.py --forward_only=True --device=cpu --mkl=True \
--kmp_blocktime=0 --nodistortions --model=inception3 --data_format=NCHW \
--batch_size=1 --num_inter_threads=1 --num_intra_threads=4 \
--data_dir=<path to ImageNet TFRecords>
```
| Optimization | Data Format | Images/Sec | Intra threads | Inter Threads |
: : : (step time) : : :
| ------------ | ----------- | ------------ | ------------- | ------------- |
| AVX2 | NHWC | 6.8 (147ms) | 4 | 0 |
| MKL | NCHW | 6.6 (151ms) | 4 | 1 |
| MKL | NHWC | 5.95 (168ms) | 4 | 1 |
| AVX | NHWC | 4.7 (211ms) | 4 | 0 |
| SSE3 | NHWC | 2.7 (370ms) | 4 | 0 |
**Batch Size: 32**
Command executed for the MKL test:
```bash
python tf_cnn_benchmarks.py --forward_only=True --device=cpu --mkl=True \
--kmp_blocktime=0 --nodistortions --model=inception3 --data_format=NCHW \
--batch_size=32 --num_inter_threads=1 --num_intra_threads=4 \
--data_dir=<path to ImageNet TFRecords>
```
| Optimization | Data Format | Images/Sec | Intra threads | Inter Threads |
: : : (step time) : : :
| ------------ | ----------- | ------------- | ------------- | ------------- |
| MKL | NCHW | 10.24 | 4 | 1 |
: : : (3125ms) : : :
| MKL | NHWC | 8.9 (3595ms) | 4 | 1 |
| AVX2 | NHWC | 7.3 (4383ms) | 4 | 0 |
| AVX | NHWC | 5.1 (6275ms) | 4 | 0 |
| SSE3 | NHWC | 2.8 (11428ms) | 4 | 0 |
#### Inference ResNet-50
**Environment**
* Instance Type: AWS EC2 m4.xlarge
* CPU: Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz (Broadwell)
* Dataset: ImageNet
* TensorFlow Version: 1.2.0 RC2
* Test Script: [tf_cnn_benchmarks.py](https://github.com/tensorflow/benchmarks/blob/mkl_experiment/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py)
**Batch Size: 1**
Command executed for the MKL test:
```bash
python tf_cnn_benchmarks.py --forward_only=True --device=cpu --mkl=True \
--kmp_blocktime=0 --nodistortions --model=resnet50 --data_format=NCHW \
--batch_size=1 --num_inter_threads=1 --num_intra_threads=4 \
--data_dir=<path to ImageNet TFRecords>
```
| Optimization | Data Format | Images/Sec | Intra threads | Inter Threads |
: : : (step time) : : :
| ------------ | ----------- | ------------ | ------------- | ------------- |
| AVX2 | NHWC | 6.8 (147ms) | 4 | 0 |
| MKL | NCHW | 6.6 (151ms) | 4 | 1 |
| MKL | NHWC | 5.95 (168ms) | 4 | 1 |
| AVX | NHWC | 4.7 (211ms) | 4 | 0 |
| SSE3 | NHWC | 2.7 (370ms) | 4 | 0 |
**Batch Size: 32**
Command executed for the MKL test:
```bash
python tf_cnn_benchmarks.py --forward_only=True --device=cpu --mkl=True \
--kmp_blocktime=0 --nodistortions --model=resnet50 --data_format=NCHW \
--batch_size=32 --num_inter_threads=1 --num_intra_threads=4 \
--data_dir=<path to ImageNet TFRecords>
```
| Optimization | Data Format | Images/Sec | Intra threads | Inter Threads |
: : : (step time) : : :
| ------------ | ----------- | ------------- | ------------- | ------------- |
| MKL | NCHW | 10.24 | 4 | 1 |
: : : (3125ms) : : :
| MKL | NHWC | 8.9 (3595ms) | 4 | 1 |
| AVX2 | NHWC | 7.3 (4383ms) | 4 | 0 |
| AVX | NHWC | 5.1 (6275ms) | 4 | 0 |
| SSE3 | NHWC | 2.8 (11428ms) | 4 | 0 |
#### Training InceptionV3
**Environment**
* Instance Type: Dedicated AWS EC2 r4.16xlarge (Broadwell)
* CPU: Intel Xeon E5-2686 v4 (Broadwell) Processors
* Dataset: ImageNet
* TensorFlow Version: 1.2.0 RC2
* Test Script: [tf_cnn_benchmarks.py](https://github.com/tensorflow/benchmarks/blob/mkl_experiment/scripts/tf_cnn_benchmarks/tf_cnn_benchmarks.py)
Command executed for MKL test:
```bash
python tf_cnn_benchmarks.py --device=cpu --mkl=True --kmp_blocktime=0 \
--nodistortions --model=resnet50 --data_format=NCHW --batch_size=32 \
--num_inter_threads=2 --num_intra_threads=36 \
--data_dir=<path to ImageNet TFRecords>
```
Optimization | Data Format | Images/Sec | Intra threads | Inter Threads
------------ | ----------- | ---------- | ------------- | -------------
MKL | NCHW | 20.8 | 36 | 2
AVX2 | NHWC | 6.2 | 36 | 0
AVX | NHWC | 5.7 | 36 | 0
SSE3 | NHWC | 4.3 | 36 | 0
ResNet and [AlexNet](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf)
were also run on this configuration but in an ad hoc manner. There were not
enough runs executed to publish a coherent table of results. The incomplete
results strongly indicated the final result would be similar to the table above
with MKL providing significant 3x+ gains over AVX2.

Some files were not shown because too many files have changed in this diff Show More