Merge commit for internal changes

This commit is contained in:
Vijay Vasudevan 2016-03-28 18:24:13 -07:00
commit 71320c0909
59 changed files with 1943 additions and 1029 deletions

View File

@ -1,6 +1,6 @@
package(default_visibility = ["//visibility:public"]) package(default_visibility = ["//visibility:public"])
archive_dir = "eigen-eigen-36b0586de49f" archive_dir = "eigen-eigen-3d9f227afae2"
cc_library( cc_library(
name = "eigen", name = "eigen",

View File

@ -619,14 +619,19 @@ ANDROID_TF_COPTS = [
"-std=c++11", "-std=c++11",
"-DMIN_LOG_LEVEL=0", "-DMIN_LOG_LEVEL=0",
"-DTF_LEAN_BINARY", "-DTF_LEAN_BINARY",
"-O2",
] ]
# Native library support for Android applications. # Native library support for Android applications.
# Does not contain operators, use :android_tensorflow_lib if you want full # Does not contain operators, use :android_tensorflow_lib if you want full
# operator support. # operator support.
# Compiles to a trivial library on non-android to prevent irrelevant #
# build errors. # Compiles to a trivial library on non-Android to prevent irrelevant
# build errors. If not building this as part of an android_binary,
# a command such as the following must be used:
# bazel build -c opt tensorflow/core:android_tensorflow_lib \
# --crosstool_top=//third_party/java/android/android_ndk_linux/toolchains:everything \
# --cpu=armeabi-v7a \
# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
cc_library( cc_library(
name = "android_tensorflow_lib_lite", name = "android_tensorflow_lib_lite",
srcs = select({ srcs = select({
@ -643,7 +648,7 @@ cc_library(
"public/session.h", "public/session.h",
], ],
copts = select({ copts = select({
":android": ANDROID_TF_COPTS, ":android": ANDROID_TF_COPTS + ["-Os"],
"//conditions:default": [], "//conditions:default": [],
}), }),
tags = [ tags = [
@ -656,19 +661,23 @@ cc_library(
":protos_cc", ":protos_cc",
"//third_party/eigen3", "//third_party/eigen3",
], ],
alwayslink = 1,
) )
# Full Tensorflow library with operator support. Use this unless reducing # Full Tensorflow library with operator support. Use this unless reducing
# binary size (by packaging a reduced operator set) is a concern. # binary size (by packaging a reduced operator set) is a concern.
cc_library( cc_library(
name = "android_tensorflow_lib", name = "android_tensorflow_lib",
srcs = [ srcs = select({
":android_op_registrations_and_gradients", ":android": [
"//tensorflow/core/kernels:android_core_ops", ":android_op_registrations_and_gradients",
"//tensorflow/core/kernels:android_extended_ops", "//tensorflow/core/kernels:android_core_ops",
], "//tensorflow/core/kernels:android_extended_ops",
],
"//conditions:default": [],
}),
copts = select({ copts = select({
":android": ANDROID_TF_COPTS, ":android": ANDROID_TF_COPTS + ["-O2"],
"//conditions:default": [], "//conditions:default": [],
}), }),
tags = [ tags = [
@ -682,6 +691,7 @@ cc_library(
":protos_cc", ":protos_cc",
"//third_party/eigen3", "//third_party/eigen3",
], ],
alwayslink = 1,
) )
filegroup( filegroup(

View File

@ -124,7 +124,7 @@ class UntypedCall : public core::RefCounted {
} }
private: private:
UntypedCall* call_; // `this` owns one reference. UntypedCall* const call_; // `this` owns one reference.
Callback callback_; Callback callback_;
}; };
}; };
@ -149,30 +149,14 @@ class Call : public UntypedCall<Service> {
Call<Service, GrpcService, RequestMessage, ResponseMessage>*); Call<Service, GrpcService, RequestMessage, ResponseMessage>*);
Call(HandleRequestFunction handle_request_function) Call(HandleRequestFunction handle_request_function)
: handle_request_function_(handle_request_function), : handle_request_function_(handle_request_function), responder_(&ctx_) {}
responder_(&ctx_),
cancel_tag_(new typename UntypedCall<Service>::Tag(
this, &UntypedCall<Service>::RequestCancelled)) {
// The `ctx_` borrows the `cancel_tag_` until
// `this->RequestReceived()` is called.
ctx_.AsyncNotifyWhenDone(cancel_tag_.get());
}
virtual ~Call() {} virtual ~Call() {}
void RequestReceived(Service* service, bool ok) override { void RequestReceived(Service* service, bool ok) override {
if (ok) { if (ok) {
// At this point, the `cancel_tag_` becomes owned by the
// completion queue.
cancel_tag_.release();
this->Ref(); this->Ref();
(service->*handle_request_function_)(this); (service->*handle_request_function_)(this);
} else {
// `!ok` implies we never received a request for this call, and
// the `cancel_tag_` will never be added to the completion
// queue, so we free it here.
cancel_tag_.reset();
} }
} }
@ -190,6 +174,9 @@ class Call : public UntypedCall<Service> {
cancel_callback_(); cancel_callback_();
} }
} }
// NOTE(mrry): This can be called before or after RequestReceived, so we
// release `cancel_tag_` (in order to allow the event loop to free it).
cancel_tag_.release();
} }
// Registers `callback` as the function that should be called if and when this // Registers `callback` as the function that should be called if and when this
@ -213,9 +200,13 @@ class Call : public UntypedCall<Service> {
static void EnqueueRequest(GrpcService* grpc_service, static void EnqueueRequest(GrpcService* grpc_service,
::grpc::ServerCompletionQueue* cq, ::grpc::ServerCompletionQueue* cq,
EnqueueFunction enqueue_function, EnqueueFunction enqueue_function,
HandleRequestFunction handle_request_function) { HandleRequestFunction handle_request_function,
bool supports_cancel) {
auto call = new Call<Service, GrpcService, RequestMessage, ResponseMessage>( auto call = new Call<Service, GrpcService, RequestMessage, ResponseMessage>(
handle_request_function); handle_request_function);
if (supports_cancel) {
call->RegisterCancellationHandler();
}
(grpc_service->*enqueue_function)( (grpc_service->*enqueue_function)(
&call->ctx_, &call->request, &call->responder_, cq, cq, &call->ctx_, &call->request, &call->responder_, cq, cq,
@ -228,6 +219,15 @@ class Call : public UntypedCall<Service> {
ResponseMessage response; ResponseMessage response;
private: private:
// Creates a completion queue tag for handling cancellation by the client.
// NOTE: This method must be called before this call is enqueued on a
// completion queue.
void RegisterCancellationHandler() {
cancel_tag_.reset(new typename UntypedCall<Service>::Tag(
this, &UntypedCall<Service>::RequestCancelled));
ctx_.AsyncNotifyWhenDone(cancel_tag_.get());
}
HandleRequestFunction handle_request_function_; HandleRequestFunction handle_request_function_;
::grpc::ServerContext ctx_; ::grpc::ServerContext ctx_;
::grpc::ServerAsyncResponseWriter<ResponseMessage> responder_; ::grpc::ServerAsyncResponseWriter<ResponseMessage> responder_;

View File

@ -88,7 +88,7 @@ class GrpcMasterService : public AsyncServiceInterface {
// The implementation of the request handler for each RPC method // The implementation of the request handler for each RPC method
// must ensure that it calls ENQUEUE_REQUEST() for that RPC method, // must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
// to keep accepting new requests. // to keep accepting new requests.
#define ENQUEUE_REQUEST(method) \ #define ENQUEUE_REQUEST(method, supports_cancel) \
do { \ do { \
mutex_lock l(mu_); \ mutex_lock l(mu_); \
if (!is_shutdown_) { \ if (!is_shutdown_) { \
@ -96,19 +96,20 @@ class GrpcMasterService : public AsyncServiceInterface {
method##Request, method##Response>:: \ method##Request, method##Response>:: \
EnqueueRequest(&master_service_, cq_, \ EnqueueRequest(&master_service_, cq_, \
&grpc::MasterService::AsyncService::Request##method, \ &grpc::MasterService::AsyncService::Request##method, \
&GrpcMasterService::method##Handler); \ &GrpcMasterService::method##Handler, \
(supports_cancel)); \
} \ } \
} while (0) } while (0)
void HandleRPCsLoop() { void HandleRPCsLoop() {
ENQUEUE_REQUEST(CreateSession); ENQUEUE_REQUEST(CreateSession, true);
ENQUEUE_REQUEST(ExtendSession); ENQUEUE_REQUEST(ExtendSession, false);
for (int i = 0; i < 100; ++i) { for (int i = 0; i < 100; ++i) {
ENQUEUE_REQUEST(RunStep); ENQUEUE_REQUEST(RunStep, true);
} }
ENQUEUE_REQUEST(CloseSession); ENQUEUE_REQUEST(CloseSession, false);
ENQUEUE_REQUEST(ListDevices); ENQUEUE_REQUEST(ListDevices, false);
ENQUEUE_REQUEST(Reset); ENQUEUE_REQUEST(Reset, false);
void* tag; void* tag;
bool ok; bool ok;
@ -146,7 +147,7 @@ class GrpcMasterService : public AsyncServiceInterface {
[call](const Status& status) { [call](const Status& status) {
call->SendResponse(ToGrpcStatus(status)); call->SendResponse(ToGrpcStatus(status));
}); });
ENQUEUE_REQUEST(CreateSession); ENQUEUE_REQUEST(CreateSession, true);
} }
// RPC handler for extending a session. // RPC handler for extending a session.
@ -156,7 +157,7 @@ class GrpcMasterService : public AsyncServiceInterface {
[call](const Status& status) { [call](const Status& status) {
call->SendResponse(ToGrpcStatus(status)); call->SendResponse(ToGrpcStatus(status));
}); });
ENQUEUE_REQUEST(ExtendSession); ENQUEUE_REQUEST(ExtendSession, false);
} }
// RPC handler for running one step in a session. // RPC handler for running one step in a session.
@ -169,7 +170,7 @@ class GrpcMasterService : public AsyncServiceInterface {
delete call_opts; delete call_opts;
call->SendResponse(ToGrpcStatus(status)); call->SendResponse(ToGrpcStatus(status));
}); });
ENQUEUE_REQUEST(RunStep); ENQUEUE_REQUEST(RunStep, true);
} }
// RPC handler for deleting a session. // RPC handler for deleting a session.
@ -179,7 +180,7 @@ class GrpcMasterService : public AsyncServiceInterface {
[call](const Status& status) { [call](const Status& status) {
call->SendResponse(ToGrpcStatus(status)); call->SendResponse(ToGrpcStatus(status));
}); });
ENQUEUE_REQUEST(CloseSession); ENQUEUE_REQUEST(CloseSession, false);
} }
// RPC handler for listing devices. // RPC handler for listing devices.
@ -189,7 +190,7 @@ class GrpcMasterService : public AsyncServiceInterface {
[call](const Status& status) { [call](const Status& status) {
call->SendResponse(ToGrpcStatus(status)); call->SendResponse(ToGrpcStatus(status));
}); });
ENQUEUE_REQUEST(ListDevices); ENQUEUE_REQUEST(ListDevices, false);
} }
// RPC handler for resetting all sessions. // RPC handler for resetting all sessions.
@ -198,7 +199,7 @@ class GrpcMasterService : public AsyncServiceInterface {
[call](const Status& status) { [call](const Status& status) {
call->SendResponse(ToGrpcStatus(status)); call->SendResponse(ToGrpcStatus(status));
}); });
ENQUEUE_REQUEST(Reset); ENQUEUE_REQUEST(Reset, false);
} }
#undef ENQUEUE_REQUEST #undef ENQUEUE_REQUEST

View File

@ -95,7 +95,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
// The implementation of the request handler for each RPC method // The implementation of the request handler for each RPC method
// must ensure that it calls ENQUEUE_REQUEST() for that RPC method, // must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
// to keep accepting new requests. // to keep accepting new requests.
#define ENQUEUE_REQUEST(method) \ #define ENQUEUE_REQUEST(method, supports_cancel) \
do { \ do { \
mutex_lock l(shutdown_mu_); \ mutex_lock l(shutdown_mu_); \
if (!is_shutdown_) { \ if (!is_shutdown_) { \
@ -103,7 +103,8 @@ class GrpcWorkerService : public AsyncServiceInterface {
method##Request, method##Response>:: \ method##Request, method##Response>:: \
EnqueueRequest(&worker_service_, cq_, \ EnqueueRequest(&worker_service_, cq_, \
&grpc::WorkerService::AsyncService::Request##method, \ &grpc::WorkerService::AsyncService::Request##method, \
&GrpcWorkerService::method##Handler); \ &GrpcWorkerService::method##Handler, \
(supports_cancel)); \
} \ } \
} while (0) } while (0)
@ -116,18 +117,18 @@ class GrpcWorkerService : public AsyncServiceInterface {
// method, by re-enqueuing a request before the previous one // method, by re-enqueuing a request before the previous one
// completes, and we may decide to bound some of the request // completes, and we may decide to bound some of the request
// types. // types.
ENQUEUE_REQUEST(GetStatus); ENQUEUE_REQUEST(GetStatus, false);
ENQUEUE_REQUEST(CleanupAll); ENQUEUE_REQUEST(CleanupAll, false);
ENQUEUE_REQUEST(RegisterGraph); ENQUEUE_REQUEST(RegisterGraph, false);
ENQUEUE_REQUEST(DeregisterGraph); ENQUEUE_REQUEST(DeregisterGraph, false);
// TODO(mrry): Consider enqueuing more of these request types. // TODO(mrry): Consider enqueuing more of these request types.
ENQUEUE_REQUEST(RecvTensor); ENQUEUE_REQUEST(RecvTensor, true);
ENQUEUE_REQUEST(RunGraph); ENQUEUE_REQUEST(RunGraph, true);
ENQUEUE_REQUEST(CleanupGraph); ENQUEUE_REQUEST(CleanupGraph, false);
ENQUEUE_REQUEST(Logging); ENQUEUE_REQUEST(Logging, false);
ENQUEUE_REQUEST(Tracing); ENQUEUE_REQUEST(Tracing, false);
void* tag; void* tag;
bool ok; bool ok;
@ -181,7 +182,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
} }
call->SendResponse(::grpc::Status::OK); call->SendResponse(::grpc::Status::OK);
}); });
ENQUEUE_REQUEST(GetStatus); ENQUEUE_REQUEST(GetStatus, false);
} }
void CleanupAllHandler( void CleanupAllHandler(
@ -192,7 +193,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
env_->device_mgr->ClearContainers(containers); env_->device_mgr->ClearContainers(containers);
call->SendResponse(::grpc::Status::OK); call->SendResponse(::grpc::Status::OK);
}); });
ENQUEUE_REQUEST(CleanupAll); ENQUEUE_REQUEST(CleanupAll, false);
} }
void RegisterGraphHandler( void RegisterGraphHandler(
@ -203,7 +204,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
call->request.graph_options(), call->response.mutable_graph_handle()); call->request.graph_options(), call->response.mutable_graph_handle());
call->SendResponse(ToGrpcStatus(s)); call->SendResponse(ToGrpcStatus(s));
}); });
ENQUEUE_REQUEST(RegisterGraph); ENQUEUE_REQUEST(RegisterGraph, false);
} }
void DeregisterGraphHandler( void DeregisterGraphHandler(
@ -212,18 +213,18 @@ class GrpcWorkerService : public AsyncServiceInterface {
Status s = env_->graph_mgr->Deregister(call->request.graph_handle()); Status s = env_->graph_mgr->Deregister(call->request.graph_handle());
call->SendResponse(ToGrpcStatus(s)); call->SendResponse(ToGrpcStatus(s));
}); });
ENQUEUE_REQUEST(DeregisterGraph); ENQUEUE_REQUEST(DeregisterGraph, false);
} }
void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) { void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
env_->compute_pool->Schedule([this, call]() { DoRunGraph(call); }); env_->compute_pool->Schedule([this, call]() { DoRunGraph(call); });
ENQUEUE_REQUEST(RunGraph); ENQUEUE_REQUEST(RunGraph, true);
} }
void RecvTensorHandler( void RecvTensorHandler(
WorkerCall<RecvTensorRequest, RecvTensorResponse>* call) { WorkerCall<RecvTensorRequest, RecvTensorResponse>* call) {
env_->compute_pool->Schedule([this, call]() { DoRecvTensor(call); }); env_->compute_pool->Schedule([this, call]() { DoRecvTensor(call); });
ENQUEUE_REQUEST(RecvTensor); ENQUEUE_REQUEST(RecvTensor, true);
} }
void CleanupGraphHandler( void CleanupGraphHandler(
@ -233,7 +234,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
env_->rendezvous_mgr->Cleanup(step_id); env_->rendezvous_mgr->Cleanup(step_id);
call->SendResponse(::grpc::Status::OK); call->SendResponse(::grpc::Status::OK);
}); });
ENQUEUE_REQUEST(CleanupGraph); ENQUEUE_REQUEST(CleanupGraph, false);
} }
void LoggingHandler(WorkerCall<LoggingRequest, LoggingResponse>* call) { void LoggingHandler(WorkerCall<LoggingRequest, LoggingResponse>* call) {
@ -241,7 +242,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
Status s = DoLogging(call); Status s = DoLogging(call);
call->SendResponse(ToGrpcStatus(s)); call->SendResponse(ToGrpcStatus(s));
}); });
ENQUEUE_REQUEST(Logging); ENQUEUE_REQUEST(Logging, false);
} }
void TracingHandler(WorkerCall<TracingRequest, TracingResponse>* call) { void TracingHandler(WorkerCall<TracingRequest, TracingResponse>* call) {
@ -249,7 +250,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
Status s = DoTracing(call); Status s = DoTracing(call);
call->SendResponse(ToGrpcStatus(s)); call->SendResponse(ToGrpcStatus(s));
}); });
ENQUEUE_REQUEST(Tracing); ENQUEUE_REQUEST(Tracing, false);
} }
#undef ENQUEUE_REQUEST #undef ENQUEUE_REQUEST

View File

@ -89,7 +89,8 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
// Look up the Kernel registered for this node def. // Look up the Kernel registered for this node def.
const KernelDef* kdef = nullptr; const KernelDef* kdef = nullptr;
status = FindKernelDef(device_type, ndef, &kdef); status =
FindKernelDef(device_type, ndef, &kdef, nullptr /* kernel_class_name */);
if (!status.ok() || HasTypeList(*op_def)) { if (!status.ok() || HasTypeList(*op_def)) {
// When there is no kernel def for this op or the op's arg is a // When there is no kernel def for this op or the op's arg is a

View File

@ -594,10 +594,11 @@ Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs,
// OpKernel registration ------------------------------------------------------ // OpKernel registration ------------------------------------------------------
struct KernelRegistration { struct KernelRegistration {
KernelRegistration(const KernelDef& d, KernelRegistration(const KernelDef& d, StringPiece c,
kernel_factory::OpKernelRegistrar::Factory f) kernel_factory::OpKernelRegistrar::Factory f)
: def(d), factory(f) {} : def(d), kernel_class_name(c.ToString()), factory(f) {}
const KernelDef def; const KernelDef def;
const string kernel_class_name;
const kernel_factory::OpKernelRegistrar::Factory factory; const kernel_factory::OpKernelRegistrar::Factory factory;
}; };
@ -624,12 +625,13 @@ static string Key(StringPiece op_type, DeviceType device_type,
namespace kernel_factory { namespace kernel_factory {
void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def, void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
StringPiece kernel_class_name,
Factory factory) { Factory factory) {
const string key = const string key =
Key(kernel_def->op(), DeviceType(kernel_def->device_type()), Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
kernel_def->label()); kernel_def->label());
GlobalKernelRegistryTyped()->insert( GlobalKernelRegistryTyped()->insert(std::make_pair(
std::make_pair(key, KernelRegistration(*kernel_def, factory))); key, KernelRegistration(*kernel_def, kernel_class_name, factory)));
delete kernel_def; delete kernel_def;
} }
@ -724,7 +726,7 @@ Status FindKernelRegistration(DeviceType device_type, const NodeDef& node_def,
} // namespace } // namespace
Status FindKernelDef(DeviceType device_type, const NodeDef& node_def, Status FindKernelDef(DeviceType device_type, const NodeDef& node_def,
const KernelDef** def) { const KernelDef** def, string* kernel_class_name) {
const KernelRegistration* reg = nullptr; const KernelRegistration* reg = nullptr;
TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, node_def, &reg)); TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, node_def, &reg));
if (reg == nullptr) { if (reg == nullptr) {
@ -733,7 +735,8 @@ Status FindKernelDef(DeviceType device_type, const NodeDef& node_def,
" devices compatible with node ", " devices compatible with node ",
SummarizeNodeDef(node_def)); SummarizeNodeDef(node_def));
} }
*def = &reg->def; if (def != nullptr) *def = &reg->def;
if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name;
return Status::OK(); return Status::OK();
} }

View File

@ -1025,7 +1025,6 @@ namespace register_kernel {
typedef ::tensorflow::KernelDefBuilder Name; typedef ::tensorflow::KernelDefBuilder Name;
} // namespace register_kernel } // namespace register_kernel
#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \ #define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__) REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
@ -1035,18 +1034,20 @@ typedef ::tensorflow::KernelDefBuilder Name;
#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \ #define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
static ::tensorflow::kernel_factory::OpKernelRegistrar \ static ::tensorflow::kernel_factory::OpKernelRegistrar \
registrar__body__##ctr##__object( \ registrar__body__##ctr##__object( \
SHOULD_REGISTER_OP_KERNEL(__FILE__) \ SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) \
? ::tensorflow::register_kernel::kernel_builder.Build() \ ? ::tensorflow::register_kernel::kernel_builder.Build() \
: nullptr, \ : nullptr, \
#__VA_ARGS__, \
[](::tensorflow::OpKernelConstruction* context) \ [](::tensorflow::OpKernelConstruction* context) \
-> ::tensorflow::OpKernel* { return new __VA_ARGS__(context); }) -> ::tensorflow::OpKernel* { return new __VA_ARGS__(context); })
void* GlobalKernelRegistry(); void* GlobalKernelRegistry();
// If node_def has a corresponding kernel registered on device_type, // If node_def has a corresponding kernel registered on device_type,
// returns OK and fill in the kernel def. // returns OK and fill in the kernel def and kernel_class_name. <def> and
// <kernel_class_name> may be null.
Status FindKernelDef(DeviceType device_type, const NodeDef& node_def, Status FindKernelDef(DeviceType device_type, const NodeDef& node_def,
const KernelDef** def); const KernelDef** def, string* kernel_class_name);
// Treats 'registry_ptr' as a pointer to KernelRegistry. For each kernel 'k' // Treats 'registry_ptr' as a pointer to KernelRegistry. For each kernel 'k'
// registered with the current library's global kernel registry (obtained by // registered with the current library's global kernel registry (obtained by
@ -1058,16 +1059,19 @@ namespace kernel_factory {
class OpKernelRegistrar { class OpKernelRegistrar {
public: public:
typedef OpKernel* (*Factory)(OpKernelConstruction*); typedef OpKernel* (*Factory)(OpKernelConstruction*);
OpKernelRegistrar(const KernelDef* kernel_def, Factory factory) {
OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
Factory factory) {
// Perform the check in the header to allow compile-time optimization // Perform the check in the header to allow compile-time optimization
// to a no-op, allowing the linker to remove the kernel symbols. // to a no-op, allowing the linker to remove the kernel symbols.
if (kernel_def != nullptr) { if (kernel_def != nullptr) {
InitInternal(kernel_def, factory); InitInternal(kernel_def, kernel_class_name, factory);
} }
} }
private: private:
void InitInternal(const KernelDef* kernel_def, Factory factory); void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
Factory factory);
}; };
} // namespace kernel_factory } // namespace kernel_factory

View File

@ -422,6 +422,27 @@ class OpKernelBuilderTest : public ::testing::Test {
} }
} }
} }
string GetKernelClassName(const string& op_type, DeviceType device_type,
const std::vector<string>& attrs,
DataTypeSlice input_types = {}) {
NodeDef def = CreateNodeDef(op_type, attrs);
for (size_t i = 0; i < input_types.size(); ++i) {
def.add_input("a:0");
}
const KernelDef* kernel_def = nullptr;
string kernel_class_name;
const Status status =
FindKernelDef(device_type, def, &kernel_def, &kernel_class_name);
if (status.ok()) {
return kernel_class_name;
} else if (errors::IsNotFound(status)) {
return "not found";
} else {
return status.ToString();
}
}
}; };
REGISTER_OP("BuildCPU"); REGISTER_OP("BuildCPU");
@ -429,7 +450,9 @@ REGISTER_KERNEL_BUILDER(Name("BuildCPU").Device(DEVICE_CPU), DummyKernel);
TEST_F(OpKernelBuilderTest, BuilderCPU) { TEST_F(OpKernelBuilderTest, BuilderCPU) {
ExpectSuccess("BuildCPU", DEVICE_CPU, {}); ExpectSuccess("BuildCPU", DEVICE_CPU, {});
EXPECT_EQ("DummyKernel", GetKernelClassName("BuildCPU", DEVICE_CPU, {}));
ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND); ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND);
EXPECT_EQ("not found", GetKernelClassName("BuildCPU", DEVICE_GPU, {}));
} }
REGISTER_OP("BuildGPU"); REGISTER_OP("BuildGPU");
@ -472,12 +495,26 @@ REGISTER_KERNEL_BUILDER(Name("BuildTypeListAttr")
TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) { TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) {
ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"}); ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"});
EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
{"T|list(type)|[]"}));
ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"}); ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"});
EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
{"T|list(type)|[]"}));
ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, ExpectSuccess("BuildTypeListAttr", DEVICE_CPU,
{"T|list(type)|[DT_BOOL, DT_BOOL]"}); {"T|list(type)|[DT_BOOL, DT_BOOL]"});
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"}, ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"},
error::NOT_FOUND); error::NOT_FOUND);
EXPECT_EQ("not found", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
{"T|list(type)|[DT_FLOAT]"}));
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT); ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
EXPECT_TRUE(
StringPiece(GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, {}))
.contains("Invalid argument: "));
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"}, ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"},
error::INVALID_ARGUMENT); error::INVALID_ARGUMENT);
} }
@ -776,6 +813,9 @@ TEST_F(LabelTest, Default) {
ExpectSuccess("LabeledKernel", DEVICE_CPU, {}); ExpectSuccess("LabeledKernel", DEVICE_CPU, {});
auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get()); auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
EXPECT_EQ(0, get_labeled_kernel->Which()); EXPECT_EQ(0, get_labeled_kernel->Which());
EXPECT_EQ("LabeledKernel<0>",
GetKernelClassName("LabeledKernel", DEVICE_CPU, {}));
} }
TEST_F(LabelTest, Specified) { TEST_F(LabelTest, Specified) {
@ -783,6 +823,8 @@ TEST_F(LabelTest, Specified) {
ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"}); ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"});
auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get()); auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
EXPECT_EQ(1, get_labeled_kernel->Which()); EXPECT_EQ(1, get_labeled_kernel->Which());
EXPECT_EQ("LabeledKernel<1>", GetKernelClassName("LabeledKernel", DEVICE_CPU,
{"_kernel|string|'one'"}));
} }
TEST_F(LabelTest, Duplicate) { TEST_F(LabelTest, Duplicate) {

View File

@ -34,13 +34,12 @@ limitations under the License.
// out. // out.
#include "ops_to_register.h" #include "ops_to_register.h"
// Files which are not included in the whitelist provided by this // Op kernel classes for which ShouldRegisterOpKernel returns false will not be
// graph-specific header file will not be allowed to register their // registered.
// operator kernels. #define SHOULD_REGISTER_OP_KERNEL(clz) \
#define SHOULD_REGISTER_OP_KERNEL(filename) \ (strstr(kNecessaryOpKernelClasses, "," clz ",") != nullptr)
(strstr(kNecessaryOpFiles, filename) != nullptr)
// Ops for which ShouldRegisterOp return false will no be registered. // Ops for which ShouldRegisterOp returns false will not be registered.
#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op) #define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)
// If kRequiresSymbolicGradients is false, then no gradient ops are registered. // If kRequiresSymbolicGradients is false, then no gradient ops are registered.

View File

@ -260,6 +260,7 @@ tf_kernel_libraries(
"constant_op", "constant_op",
"diag_op", "diag_op",
"edit_distance_op", "edit_distance_op",
"gather_nd_op",
"gather_op", "gather_op",
"identity_op", "identity_op",
"immutable_constant_op", "immutable_constant_op",

View File

@ -27,7 +27,8 @@ namespace tensorflow {
// that 0 <= limit if Index is signed. Intended for use in performance // that 0 <= limit if Index is signed. Intended for use in performance
// critical contexts where 0 <= index < limit is almost always true. // critical contexts where 0 <= index < limit is almost always true.
template <typename Ta, typename Tb> template <typename Ta, typename Tb>
EIGEN_ALWAYS_INLINE bool FastBoundsCheck(const Ta index, const Tb limit) { EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool FastBoundsCheck(const Ta index,
const Tb limit) {
static_assert(std::is_integral<Ta>::value && std::is_integral<Tb>::value, static_assert(std::is_integral<Ta>::value && std::is_integral<Tb>::value,
"FastBoundsCheck can only be used on integer types."); "FastBoundsCheck can only be used on integer types.");
typedef typename std::make_unsigned<decltype(index + limit)>::type UIndex; typedef typename std::make_unsigned<decltype(index + limit)>::type UIndex;

View File

@ -22,22 +22,27 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/tensor_types.h"
// The following functors (sign, tanh, sigmoid, etc.) are not defined
// by Eigen. When their equivalent are added into the Eigen, we can
// replace them using type aliases.
namespace Eigen { namespace Eigen {
namespace internal { namespace internal {
// TODO(rmlarsen): Get rid of fmod2 once fmod is upstreamed to Eigen.
template <typename T> template <typename T>
struct scalar_fmod2_op { struct scalar_fmod2_op {
EIGEN_EMPTY_STRUCT_CTOR(scalar_fmod2_op) EIGEN_EMPTY_STRUCT_CTOR(scalar_fmod2_op)
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a, EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a,
const T& b) const { const T& b) const {
return fmod(a, b); return std::fmod(a, b);
} }
}; };
template <typename T>
struct functor_traits<scalar_fmod2_op<T>> {
enum {
Cost = 13, // Reciprocal throughput of FPREM on Haswell.
PacketAccess = false,
};
};
// scalar_left and scalar_right are template helpers to partially // scalar_left and scalar_right are template helpers to partially
// apply a binary function. // apply a binary function.
// //
@ -489,7 +494,7 @@ template <typename T>
struct fmod : base<T, Eigen::internal::scalar_fmod2_op<T> > {}; struct fmod : base<T, Eigen::internal::scalar_fmod2_op<T> > {};
template <typename T> template <typename T>
struct mod : base<T, Eigen::internal::scalar_mod2_op<T> > {}; struct mod : base<T, Eigen::internal::scalar_mod2_op<T>> {};
template <typename T> template <typename T>
struct pow : base<T, Eigen::internal::scalar_binary_pow_op<T, T> > {}; struct pow : base<T, Eigen::internal::scalar_binary_pow_op<T, T> > {};

View File

@ -36,7 +36,6 @@ namespace Eigen {
* It is possible to swap the order of the width and height dimensions provided that the same order is used in the input, the kernel, and the output. * It is possible to swap the order of the width and height dimensions provided that the same order is used in the input, the kernel, and the output.
* *
*/ */
template <typename OutputBackward, typename Kernel> template <typename OutputBackward, typename Kernel>
EIGEN_ALWAYS_INLINE static const typename internal::conditional< EIGEN_ALWAYS_INLINE static const typename internal::conditional<
internal::traits<OutputBackward>::Layout == ColMajor, internal::traits<OutputBackward>::Layout == ColMajor,
@ -45,14 +44,18 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
internal::traits<OutputBackward>::NumDimensions>, internal::traits<OutputBackward>::NumDimensions>,
const TensorContractionOp< const TensorContractionOp<
const array< const array<
IndexPair<typename internal::traits<OutputBackward>::Index>, 2>, IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
const Eigen::TensorForcedEvalOp<const TensorReshapingOp< const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index, const DSizes<typename internal::traits<OutputBackward>::Index,
3>, 2>,
const TensorReverseOp<const array<bool, 4>, const Kernel> > >, const TensorShufflingOp<
const array<
typename internal::traits<OutputBackward>::Index, 4>,
const TensorReverseOp<const array<bool, 4>,
const Kernel> > > >,
const TensorReshapingOp< const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index, const DSizes<typename internal::traits<OutputBackward>::Index,
3>, 2>,
const TensorImagePatchOp<Dynamic, Dynamic, const TensorImagePatchOp<Dynamic, Dynamic,
const OutputBackward> > > >, const OutputBackward> > > >,
TensorReshapingOp< TensorReshapingOp<
@ -60,17 +63,20 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
internal::traits<OutputBackward>::NumDimensions>, internal::traits<OutputBackward>::NumDimensions>,
const TensorContractionOp< const TensorContractionOp<
const array< const array<
IndexPair<typename internal::traits<OutputBackward>::Index>, 2>, IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
const TensorReshapingOp< const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index, const DSizes<typename internal::traits<OutputBackward>::Index,
3>, 2>,
const TensorImagePatchOp<Dynamic, Dynamic, const TensorImagePatchOp<Dynamic, Dynamic,
const OutputBackward> >, const OutputBackward> >,
const Eigen::TensorForcedEvalOp<const TensorReshapingOp< const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index, const DSizes<typename internal::traits<OutputBackward>::Index,
3>, 2>,
const TensorReverseOp<const array<bool, 4>, const TensorShufflingOp<
const Kernel> > > > > >::type const array<
typename internal::traits<OutputBackward>::Index, 4>,
const TensorReverseOp<const array<bool, 4>,
const Kernel> > > > > > >::type
SpatialConvolutionBackwardInput( SpatialConvolutionBackwardInput(
const Kernel& kernel, const OutputBackward& output_backward, const Kernel& kernel, const OutputBackward& output_backward,
typename internal::traits<OutputBackward>::Index inputRows, typename internal::traits<OutputBackward>::Index inputRows,
@ -134,49 +140,57 @@ SpatialConvolutionBackwardInput(
kernel_reverse[3] = false; kernel_reverse[3] = false;
} }
DSizes<TensorIndex, 3> kernel_dims; // Reorder the dimensions to filters X patch_rows X patch_cols X channels
array<TensorIndex, 4> kernel_shuffle;
if (isColMajor) { if (isColMajor) {
kernel_dims[0] = kernelFilters; kernel_shuffle[0] = 0;
kernel_dims[1] = kernelChannels; kernel_shuffle[1] = 2;
kernel_dims[2] = kernelRows * kernelCols; kernel_shuffle[2] = 3;
kernel_shuffle[3] = 1;
} else { } else {
kernel_dims[0] = kernelRows * kernelCols; kernel_shuffle[0] = 2;
kernel_shuffle[1] = 0;
kernel_shuffle[2] = 1;
kernel_shuffle[3] = 3;
}
// Collapse the dims
DSizes<TensorIndex, 2> kernel_dims;
if (isColMajor) {
kernel_dims[0] = kernelFilters * kernelRows * kernelCols;
kernel_dims[1] = kernelChannels; kernel_dims[1] = kernelChannels;
kernel_dims[2] = kernelFilters; } else {
kernel_dims[1] = kernelFilters * kernelRows * kernelCols;
kernel_dims[0] = kernelChannels;
} }
// The output_backward has dimensions out_depth X out_rows X out_cols X OTHERS // The output_backward has dimensions out_depth X out_rows X out_cols X OTHERS
// When we extract the image patches from output_backward, it will have dimensions // When we extract the image patches from output_backward, it will have dimensions
// out_depth X (patch_rows * patch_cols) X (input_rows * input_cols * OTHERS) // out_depth X (patch_rows * patch_cols) X (input_rows * input_cols * OTHERS)
DSizes<TensorIndex, 3> pre_contract_dims; DSizes<TensorIndex, 2> pre_contract_dims;
if (isColMajor) { if (isColMajor) {
pre_contract_dims[0] = kernelFilters; pre_contract_dims[0] = kernelFilters * kernelRows * kernelCols;
pre_contract_dims[1] = kernelRows * kernelCols; pre_contract_dims[1] = inputRows * inputCols;
pre_contract_dims[2] = inputRows * inputCols;
for (int i = 3; i < NumDims; ++i) { for (int i = 3; i < NumDims; ++i) {
pre_contract_dims[2] *= out.dimension(i); pre_contract_dims[1] *= out.dimension(i);
} }
} else { } else {
pre_contract_dims[2] = kernelFilters; pre_contract_dims[1] = kernelFilters * kernelRows * kernelCols;
pre_contract_dims[1] = kernelRows * kernelCols;
pre_contract_dims[0] = inputRows * inputCols; pre_contract_dims[0] = inputRows * inputCols;
for (int i = 0; i < NumDims - 3; ++i) { for (int i = 0; i < NumDims - 3; ++i) {
pre_contract_dims[0] *= out.dimension(i); pre_contract_dims[0] *= out.dimension(i);
} }
} }
// We will contract along dimensions (0, 2) in kernel and (0, 1) in // We will contract along the fused dimension that contains the kernelFilters,
// output_backward, if this is col-major, and // the kernelRows and the kernelCols.
// dimensions (0, 2) in kernel and (1, 2) in output_backward, if this row-major. array<IndexPair<TensorIndex>, 1> contract_dims;
array<IndexPair<TensorIndex>, 2> contract_dims;
if (isColMajor) { if (isColMajor) {
// col-major: kernel.contract(output.patches) // col-major: kernel.contract(output.patches)
contract_dims[0] = IndexPair<TensorIndex>(0, 0); contract_dims[0] = IndexPair<TensorIndex>(0, 0);
contract_dims[1] = IndexPair<TensorIndex>(2, 1);
} else { } else {
// row-major: output.patches.contract(kernel) // row-major: output.patches.contract(kernel)
contract_dims[0] = IndexPair<TensorIndex>(1, 0); contract_dims[0] = IndexPair<TensorIndex>(1, 1);
contract_dims[1] = IndexPair<TensorIndex>(2, 2);
} }
// Post contraction, the dimensions of the input_backprop is // Post contraction, the dimensions of the input_backprop is
@ -201,6 +215,7 @@ SpatialConvolutionBackwardInput(
return choose( return choose(
Cond<internal::traits<OutputBackward>::Layout == ColMajor>(), Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
kernel.reverse(kernel_reverse) kernel.reverse(kernel_reverse)
.shuffle(kernel_shuffle)
.reshape(kernel_dims) .reshape(kernel_dims)
.eval() .eval()
.contract(output_backward .contract(output_backward
@ -217,7 +232,10 @@ SpatialConvolutionBackwardInput(
padding_bottom, padding_left, padding_right, padding_bottom, padding_left, padding_right,
OutScalar(0)) OutScalar(0))
.reshape(pre_contract_dims) .reshape(pre_contract_dims)
.contract(kernel.reverse(kernel_reverse).reshape(kernel_dims).eval(), .contract(kernel.reverse(kernel_reverse)
.shuffle(kernel_shuffle)
.reshape(kernel_dims)
.eval(),
contract_dims) contract_dims)
.reshape(post_contract_dims)); .reshape(post_contract_dims));
} }
@ -239,15 +257,43 @@ SpatialConvolutionBackwardInput(
* It is possible to swap the order of the width and height dimensions provided that the same order is used in the input, the kernel, and the output. * It is possible to swap the order of the width and height dimensions provided that the same order is used in the input, the kernel, and the output.
* *
*/ */
// TODO(gpapan): Resolve a bug in TensorContractionInputMapper at SpatialConvolutions.h that yangke circumvented by using .reshape().reshape().
// This can significantly accelerate SpatialConvolutionBackwardKernel.
template <typename OutputBackward, typename Input> template <typename OutputBackward, typename Input>
EIGEN_ALWAYS_INLINE EIGEN_ALWAYS_INLINE static const typename internal::conditional<
static const typename internal::conditional<
internal::traits<OutputBackward>::Layout == ColMajor, internal::traits<OutputBackward>::Layout == ColMajor,
const TensorShufflingOp<const array<typename internal::traits<OutputBackward>::Index, 4>, const TensorReverseOp<const array<bool, 4>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorContractionOp<const array<IndexPair<typename internal::traits<Input>::Index>, 2>, const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 3>, const Input>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorImagePatchOp<Dynamic, Dynamic, const OutputBackward> > > > > > >, TensorReshapingOp<
const TensorShufflingOp<const array<typename internal::traits<OutputBackward>::Index, 4>, const TensorReverseOp<const array<bool, 4>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorContractionOp<const array<IndexPair<typename internal::traits<Input>::Index>, 2>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorImagePatchOp<Dynamic, Dynamic, const OutputBackward> > >, const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 3>, const Input> > > > > >::type const DSizes<typename internal::traits<Input>::Index, 4>,
const TensorContractionOp<
const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
const TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 2>,
const OutputBackward>,
const TensorShufflingOp<
const array<typename internal::traits<OutputBackward>::Index, 2>,
const TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 2>,
const TensorImagePatchOp<Dynamic, Dynamic, const Input>
>
>
>
>,
TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 4>,
const TensorContractionOp<
const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
const TensorShufflingOp<
const array<typename internal::traits<OutputBackward>::Index, 2>,
const TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 2>,
const TensorImagePatchOp<Dynamic, Dynamic, const Input>
>
>,
const TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 2>,
const OutputBackward>
>
>
>::type
SpatialConvolutionBackwardKernel(const Input& input, const OutputBackward& output_backward, typename internal::traits<Input>::Index kernelRows, typename internal::traits<Input>::Index kernelCols, const DenseIndex stride = 1, const DenseIndex in_stride = 1) { SpatialConvolutionBackwardKernel(const Input& input, const OutputBackward& output_backward, typename internal::traits<Input>::Index kernelRows, typename internal::traits<Input>::Index kernelCols, const DenseIndex stride = 1, const DenseIndex in_stride = 1) {
typedef typename internal::traits<Input>::Index TensorIndex; typedef typename internal::traits<Input>::Index TensorIndex;
@ -283,127 +329,93 @@ SpatialConvolutionBackwardKernel(const Input& input, const OutputBackward& outpu
const TensorIndex kernelColsEff = kernelCols + (kernelCols - 1) * (in_stride - 1); const TensorIndex kernelColsEff = kernelCols + (kernelCols - 1) * (in_stride - 1);
// Computing the forward padding // Computing the forward padding
const TensorIndex forward_pad_top = ((outputRows - 1) * stride + kernelRowsEff - inputRows) / 2; const TensorIndex padRows = numext::maxi<Index>(
const TensorIndex forward_pad_left = ((outputCols - 1) * stride + kernelColsEff - inputCols) / 2; 0, (outputRows - 1) * stride + kernelRowsEff - inputRows);
const TensorIndex padCols = numext::maxi<Index>(
0, (outputCols - 1) * stride + kernelColsEff - inputCols);
const TensorIndex padding_top = padRows / 2;
const TensorIndex padding_bottom = padRows - padding_top;
const TensorIndex padding_left = padCols / 2;
const TensorIndex padding_right = padCols - padding_left;
// TODO: factor out the padding computation. // Reshaped out
const TensorIndex padding_top = kernelRowsEff - 1 - forward_pad_top; DSizes<TensorIndex, 2> output_dims;
const TensorIndex padding_left = kernelColsEff - 1 - forward_pad_left;
const TensorIndex padding_bottom = inputRows + kernelRowsEff - 1 - (outputRows - 1) * stride - 1 - padding_top;
const TensorIndex padding_right = inputCols + kernelColsEff - 1 - (outputCols - 1) * stride - 1 - padding_left;
eigen_assert(padding_top >= 0);
eigen_assert(padding_left >= 0);
eigen_assert(padding_bottom >= 0);
eigen_assert(padding_right >= 0);
// The output_backward has dimensions out_depth X out_rows X out_cols X OTHERS
// When we extract the image patches from output_backward (with input as the
// kernel), it will have dimensions
// (out_depth) X (input_rows * input_cols) X (kernel_rows * kernel_cols) X OTHERS
DSizes<TensorIndex, 4> pre_contract_dims;
if (isColMajor) { if (isColMajor) {
pre_contract_dims[0] = kernelFilters; output_dims[0] = kernelFilters;
pre_contract_dims[1] = inputRows * inputCols; output_dims[1] = outputRows * outputCols;
pre_contract_dims[2] = kernelRows * kernelCols;
pre_contract_dims[3] = 1;
for (int i = 3; i < NumDims; ++i) { for (int i = 3; i < NumDims; ++i) {
pre_contract_dims[3] *= out.dimension(i); output_dims[1] *= out.dimension(i);
} }
} else { } else {
pre_contract_dims[3] = kernelFilters; output_dims[1] = kernelFilters;
pre_contract_dims[2] = inputRows * inputCols; output_dims[0] = outputCols * outputRows;
pre_contract_dims[1] = kernelRows * kernelCols;
pre_contract_dims[0] = 1;
for (int i = 0; i < NumDims - 3; ++i) { for (int i = 0; i < NumDims - 3; ++i) {
pre_contract_dims[0] *= out.dimension(i); output_dims[0] *= out.dimension(i);
} }
} }
// The input has dimensions in_depth X (input_rows * input_cols) X OTHERS // Reshaped extract_image_patches(in)
DSizes<TensorIndex, 3> input_dims; DSizes<TensorIndex, 2> pre_contract_dims;
if (isColMajor) { if (isColMajor) {
input_dims[0] = kernelChannels; pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols;
input_dims[1] = inputRows * inputCols; pre_contract_dims[1] = outputRows * outputCols;
input_dims[2] = 1;
for (int i = 3; i < NumDims; ++i) { for (int i = 3; i < NumDims; ++i) {
input_dims[2] *= in.dimension(i); pre_contract_dims[1] *= in.dimension(i);
} }
eigen_assert(input_dims[2] == pre_contract_dims[3]); eigen_assert(output_dims[1] == pre_contract_dims[1]);
} else { } else {
input_dims[2] = kernelChannels; pre_contract_dims[1] = kernelCols * kernelRows * kernelChannels;
input_dims[1] = inputRows * inputCols; pre_contract_dims[0] = outputRows * outputCols;
input_dims[0] = 1;
for (int i = 0; i < NumDims - 3; ++i) { for (int i = 0; i < NumDims - 3; ++i) {
input_dims[0] *= in.dimension(i); pre_contract_dims[0] *= in.dimension(i);
} }
eigen_assert(input_dims[0] == pre_contract_dims[0]); eigen_assert(output_dims[0] == pre_contract_dims[0]);
} }
// We will contract along dimensions (1, 2) in in and (1, 3) in out, if array<TensorIndex, 2> shuffle_dims;
// this is col-major. shuffle_dims[0] = 1;
// For row-major, it's dimensions (0, 1) in in and (0, 2) in out. shuffle_dims[1] = 0;
array<IndexPair<TensorIndex>, 2> contract_dims;
if (isColMajor) {
// col-major: in.contract(output.patches)
contract_dims[0] = IndexPair<TensorIndex>(1, 1);
contract_dims[1] = IndexPair<TensorIndex>(2, 3);
} else {
// row-major: output.patches.contract(in)
contract_dims[0] = IndexPair<TensorIndex>(0, 0);
contract_dims[1] = IndexPair<TensorIndex>(2, 1);
}
// After the contraction, the kernel will have dimension array<IndexPair<TensorIndex>, 1> contract_dims;
// in_depth X out_depth X kernel_rows X kernel_cols contract_dims[0] = IndexPair<TensorIndex>(1, 0);
// We will need to shuffle the first two dimensions and reverse the latter
// two dimensions. // After the contraction, the kernel will have the desired shape
// The end shape is
// out_depth X in_shape X kernel_rows X kernel_cols // out_depth X in_shape X kernel_rows X kernel_cols
// This is the shape of the kernel *before* the shuffling.
DSizes<TensorIndex, 4> kernel_dims; DSizes<TensorIndex, 4> kernel_dims;
if (isColMajor) { if (isColMajor) {
kernel_dims[0] = kernelChannels; kernel_dims[0] = kernelFilters;
kernel_dims[1] = kernelFilters; kernel_dims[1] = kernelChannels;
kernel_dims[2] = kernelRows; kernel_dims[2] = kernelRows;
kernel_dims[3] = kernelCols; kernel_dims[3] = kernelCols;
} else { } else {
kernel_dims[0] = kernelCols; kernel_dims[3] = kernelFilters;
kernel_dims[2] = kernelChannels;
kernel_dims[1] = kernelRows; kernel_dims[1] = kernelRows;
kernel_dims[2] = kernelFilters; kernel_dims[0] = kernelCols;
kernel_dims[3] = kernelChannels;
} }
array<TensorIndex, 4> kernel_shuffle; return choose(
if (isColMajor) { Cond<internal::traits<Input>::Layout == ColMajor>(),
kernel_shuffle[0] = 1; output_backward.reshape(output_dims)
kernel_shuffle[1] = 0; .contract(
kernel_shuffle[2] = 2; input.extract_image_patches(
kernel_shuffle[3] = 3; kernelRows, kernelCols, stride, stride,
} else { in_stride, in_stride, 1, 1, padding_top, padding_bottom,
kernel_shuffle[0] = 0; padding_left, padding_right, OutScalar(0))
kernel_shuffle[1] = 1; .reshape(pre_contract_dims)
kernel_shuffle[2] = 3; .shuffle(shuffle_dims),
kernel_shuffle[3] = 2; contract_dims)
} .reshape(kernel_dims),
input.extract_image_patches(
array<bool, 4> kernel_reverse; kernelRows, kernelCols, stride, stride,
if (isColMajor) { in_stride, in_stride, 1, 1, padding_top, padding_bottom,
kernel_reverse[0] = false; padding_left, padding_right, OutScalar(0))
kernel_reverse[1] = false; .reshape(pre_contract_dims)
kernel_reverse[2] = true; .shuffle(shuffle_dims)
kernel_reverse[3] = true; .contract(
} else { output_backward.reshape(output_dims),
kernel_reverse[0] = true; contract_dims)
kernel_reverse[1] = true; .reshape(kernel_dims));
kernel_reverse[2] = false;
kernel_reverse[3] = false;
}
return choose(Cond<internal::traits<Input>::Layout == ColMajor>(),
input.reshape(input_dims).contract(output_backward.extract_image_patches(inputRows, inputCols, in_stride, in_stride, 1, 1, stride, stride, padding_top, padding_bottom, padding_left, padding_right, OutScalar(0)).reshape(pre_contract_dims).reshape(pre_contract_dims), contract_dims).reshape(kernel_dims).reverse(kernel_reverse).shuffle(kernel_shuffle),
output_backward.extract_image_patches(inputRows, inputCols, in_stride, in_stride, 1, 1, stride, stride, padding_top, padding_bottom, padding_left, padding_right, OutScalar(0)).reshape(pre_contract_dims).reshape(pre_contract_dims).contract(input.reshape(input_dims), contract_dims).reshape(kernel_dims).reverse(kernel_reverse).shuffle(kernel_shuffle));
} }
} // end namespace Eigen } // end namespace Eigen

View File

@ -0,0 +1,234 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/array_ops.cc.
#define EIGEN_USE_THREADS
#include <atomic>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/gather_nd_op.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/util.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
template <typename Device, typename T, typename Index>
class GatherNdOp : public OpKernel {
public:
explicit GatherNdOp(OpKernelConstruction* c) : OpKernel(c) {
const DataType dt = DataTypeToEnum<T>::v();
const DataType index_t = DataTypeToEnum<Index>::v();
OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t}, {dt}));
}
void Compute(OpKernelContext* c) override {
const Tensor& params = c->input(0);
const Tensor& indices = c->input(1);
OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
errors::InvalidArgument("params must be at least a vector"));
OP_REQUIRES(c, TensorShapeUtils::IsMatrixOrHigher(indices.shape()),
errors::InvalidArgument("indices must be at least a matrix"));
OP_REQUIRES(
c, indices.dim_size(indices.dims() - 1) == params.dims(),
errors::InvalidArgument(
"index innermost dimension length must equal params rank; saw: ",
indices.dim_size(indices.dims() - 1), " vs. ", params.dims()));
// Check that we have enough index space
const int64 N_big = indices.NumElements() / params.dims();
OP_REQUIRES(c, N_big <= std::numeric_limits<int>::max(),
errors::InvalidArgument(
"indices has too many elements for int indexing: ", N_big,
" > ", std::numeric_limits<int>::max()));
const int N = indices.NumElements() / params.dims();
OP_REQUIRES(
c, params.NumElements() <= std::numeric_limits<Index>::max(),
errors::InvalidArgument("params.NumElements() too large for ",
DataTypeString(DataTypeToEnum<Index>::v()),
" indexing: ", params.NumElements(), " > ",
std::numeric_limits<Index>::max()));
// The result shape is indices.shape[:-1]
TensorShape result_shape(indices.shape());
result_shape.RemoveDim(result_shape.dims() - 1);
Tensor* out = nullptr;
OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
if (N > 0) {
auto indices_mat = indices.flat_inner_dims<Index>();
auto out_flat = out->flat<T>();
Index bad_i = -1;
switch (params.dims()) {
#define PARAMS_CASE(NDIM) \
case NDIM: { \
functor::GatherNd<Device, T, Index, NDIM> functor; \
auto params_tensor = params.tensor<T, NDIM>(); \
bad_i = functor(c->eigen_device<Device>(), params_tensor, indices_mat, \
out_flat); \
} break
PARAMS_CASE(1);
PARAMS_CASE(2);
PARAMS_CASE(3);
PARAMS_CASE(4);
PARAMS_CASE(5);
default:
OP_REQUIRES(c, false,
errors::InvalidArgument(
"Only param tensors with ranks between 1 and 5 "
"are currently supported. Tensor rank: ",
params.dims()));
}
OP_REQUIRES(c, bad_i < 0,
errors::InvalidArgument(
"flat indices[", bad_i, ", :] = [",
str_util::Join(gtl::ArraySlice<Index>(
&indices_mat(bad_i, 0), params.dims()),
", "),
"] does not index into param (shape: ",
params.shape().DebugString(), ")."));
}
}
};
// Specialization of GatherNd to CPU
namespace generator {
template <typename T, typename Index, int NDIM>
class GatherNdGenerator {
public:
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
GatherNdGenerator(typename TTypes<Index>::ConstMatrix Tindices,
typename TTypes<T, NDIM>::ConstTensor Tparams,
Index* error_loc)
: Tindices_(Tindices), Tparams_(Tparams), error_loc_(*error_loc) {}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
operator()(const Eigen::array<Eigen::DenseIndex, 1>& loc_array) const {
Index loc = loc_array[0];
Eigen::array<Eigen::DenseIndex, NDIM> ix;
bool out_of_bounds = false;
for (int i = 0; i < NDIM; ++i) {
Index ix_i = Tindices_(loc, i);
ix[i] = ix_i;
out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i));
}
if (out_of_bounds) {
error_loc_ = loc;
return T(); // Return 0, 0.0, or '', etc.
} else {
return Tparams_(ix);
}
}
private:
typename TTypes<Index>::ConstMatrix Tindices_;
typename TTypes<T, NDIM>::ConstTensor Tparams_;
Index& error_loc_;
};
} // namespace generator
namespace functor {
template <typename T, typename Index, int NDIM>
struct GatherNd<CPUDevice, T, Index, NDIM> {
Index operator()(const CPUDevice& d,
typename TTypes<T, NDIM>::ConstTensor Tparams,
typename TTypes<Index>::ConstMatrix Tindices,
typename TTypes<T>::Flat Tout) {
Index error_loc(-1);
generator::GatherNdGenerator<T, Index, NDIM> gather_nd_generator(Tindices,
Tparams,
&error_loc);
Tout.device(d) = Tout.generate(gather_nd_generator);
// error_loc() returns -1 if there's no out-of-bounds index,
// otherwise it returns the location of an OOB index in Tindices.
return error_loc;
}
};
} // namespace functor
#define REGISTER_GATHER_ND_FULL(dev, type, index_type) \
REGISTER_KERNEL_BUILDER(Name("GatherNd") \
.Device(DEVICE_##dev) \
.TypeConstraint<type>("Tparams") \
.TypeConstraint<index_type>("Tindices"), \
GatherNdOp<dev##Device, type, index_type>)
#define REGISTER_GATHER_ND_ALL_INDICES(dev, type) \
REGISTER_GATHER_ND_FULL(dev, type, int32); \
REGISTER_GATHER_ND_FULL(dev, type, int64)
#define REGISTER_GATHER_ND_CPU(type) REGISTER_GATHER_ND_ALL_INDICES(CPU, type)
TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
#undef REGISTER_GATHER_ND_CPU
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, NDIM) \
template <> \
Index GatherNd<GPUDevice, T, Index, NDIM>::operator()( \
const GPUDevice& d, typename TTypes<T, NDIM>::ConstTensor Tparams, \
typename TTypes<Index>::ConstMatrix Tindices, \
typename TTypes<T>::Flat Tout); \
extern template struct GatherNd<GPUDevice, T, Index, NDIM>
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 1); \
DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 2); \
DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 3); \
DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 4); \
DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 5)
#define DECLARE_GPU_SPECS(T) \
DECLARE_GPU_SPECS_INDEX(T, int32); \
DECLARE_GPU_SPECS_INDEX(T, int64)
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
#undef DECLARE_GPU_SPECS
#undef DECLARE_GPU_SPECS_INDEX
} // namespace functor
// Registration of the GPU implementations.
#define REGISTER_GATHER_ND_GPU(type) REGISTER_GATHER_ND_ALL_INDICES(GPU, type)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU);
#undef REGISTER_GATHER_ND_GPU
#endif // GOOGLE_CUDA
#undef REGISTER_GATHER_ND_ALL_INDICES
#undef REGISTER_GATHER_ND_FULL
} // namespace tensorflow

View File

@ -0,0 +1,43 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_KERNELS_GATHER_ND_OP_H_
#define TENSORFLOW_KERNELS_GATHER_ND_OP_H_
// Functor definition for GatherOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/bounds_check.h"
namespace tensorflow {
class OpKernelContext;
namespace functor {
template <typename Device, typename T, typename Index, int NDIM>
struct GatherNd {
// Performs gather op on (Tparams, Tindices), writing to Tout.
// Returns an index to Tindices if the value at that index is out of range.
// Returns -1 if all values of Tindices are in range.
Index operator()(const Device& d,
typename TTypes<T, NDIM>::ConstTensor Tparams,
typename TTypes<Index>::ConstMatrix Tindices,
typename TTypes<T>::Flat Tout);
};
} // namespace functor
} // namespace tensorflow
#endif // TENSORFLOW_KERNELS_GATHER_ND_OP_H_

View File

@ -0,0 +1,105 @@
/* Copyright 2016 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/gather_nd_op.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/cuda_kernel_helper.h"
namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
namespace generator {
template <typename T, typename Index, int NDIM>
class GatherNdGenerator {
public:
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
GatherNdGenerator(typename TTypes<const Index, 2>::Tensor32Bit Tindices,
typename TTypes<const T, NDIM>::Tensor32Bit Tparams)
: Tindices_(Tindices), Tparams_(Tparams) {}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
operator()(const Eigen::array<int, 1>& loc_array) const {
int loc = loc_array[0];
Eigen::array<int, NDIM> ix;
bool out_of_bounds = false;
for (int i = 0; i < NDIM; ++i) {
int ix_i = Tindices_(loc, i);
ix[i] = ix_i;
out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i));
}
if (out_of_bounds) {
return T(0); // TODO(ebrevdo): Pass error back to host.
} else {
return Tparams_(ix);
}
}
private:
typename TTypes<const Index, 2>::Tensor32Bit Tindices_;
typename TTypes<const T, NDIM>::Tensor32Bit Tparams_;
};
} // namespace generator
namespace functor {
template <typename T, typename Index, int NDIM>
struct GatherNd<GPUDevice, T, Index, NDIM> {
Index operator()(const GPUDevice& d,
typename TTypes<T, NDIM>::ConstTensor Tparams,
typename TTypes<Index>::ConstMatrix Tindices,
typename TTypes<T>::Flat Tout) {
generator::GatherNdGenerator<T, Index, NDIM> gather_nd_generator(
To32Bit(Tindices), To32Bit(Tparams));
To32Bit(Tout).device(d) = To32Bit(Tout).generate(gather_nd_generator);
// TODO(ebrevdo): enable indices validation on GPU.
// Right now checking for indicies out of bound in the kernel would
// require copying code between GPU/CPU, and is too slow.
return -1;
}
};
} // namespace functor
#define DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, NDIM) \
template struct functor::GatherNd<GPUDevice, T, Index, NDIM>;
#define DEFINE_GPU_SPECS_INDEX(T, Index) \
DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 1); \
DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 2); \
DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 3); \
DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 4); \
DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 5);
#define DEFINE_GPU_SPECS(T) \
DEFINE_GPU_SPECS_INDEX(T, int32); \
DEFINE_GPU_SPECS_INDEX(T, int64);
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
#undef DEFINE_GPU_SPECS
#undef DEFINE_GPU_SPECS_INDEX
} // namespace tensorflow
#endif // GOOGLE_CUDA

View File

@ -0,0 +1,133 @@
/* Copyright 2015 Google Inc. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <functional>
#include <memory>
#include <vector>
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace test {
namespace graph {
class Node* GatherNd(Graph* g, class Node* in0, class Node* in1) {
class Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GatherNd")
.Input(in0)
.Input(in1)
.Finalize(g, &ret));
return ret;
}
} // namespace graph
} // namespace test
namespace {
class GatherNdOpTest : public OpsTestBase {
protected:
void MakeOp(DataType index_type) {
TF_ASSERT_OK(NodeDefBuilder("myop", "GatherNd")
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(index_type))
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
}
};
TEST_F(GatherNdOpTest, Simple) {
MakeOp(DT_INT32);
// Feed and run
AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 8, 4});
AddInputFromArray<int32>(TensorShape({2, 1}), {3, 4});
TF_ASSERT_OK(RunOpKernel());
// Check the output.
Tensor expected(allocator(), DT_FLOAT, TensorShape({2}));
test::FillValues<float>(&expected, {8, 4});
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
}
constexpr int kLookups = 2000;
template <typename Index>
static Graph* GatherNd(int dim) {
Graph* g = new Graph(OpRegistry::Global());
// Always use a 512MB buffer.
//const int kRows = ((512 << 20) / sizeof(float)) / dim;
Tensor params(DT_FLOAT, TensorShape({dim, 8, 16, 32}));
params.flat<float>().setRandom();
random::PhiloxRandom philox(301, 17);
random::SimplePhilox rnd(&philox);
Tensor indices(DataTypeToEnum<Index>::value, TensorShape({kLookups, 4}));
auto indices_mat = indices.matrix<Index>();
for (int i = 0; i < kLookups; i++) {
indices_mat(i, 0) = rnd.Uniform(dim);
indices_mat(i, 1) = rnd.Uniform(8);
indices_mat(i, 2) = rnd.Uniform(16);
indices_mat(i, 3) = rnd.Uniform(32);
}
test::graph::GatherNd(g, test::graph::Constant(g, params),
test::graph::Constant(g, indices));
return g;
}
#define BM_GATHER_ND(DEVICE, INDEX) \
static void BM_##DEVICE##_gather_nd_##INDEX(int iters, int dim) { \
const int64 tot = static_cast<int64>(iters) * kLookups * dim; \
testing::ItemsProcessed(tot); \
testing::BytesProcessed(tot * sizeof(float)); \
testing::UseRealTime(); \
test::Benchmark(#DEVICE, GatherNd<INDEX>(dim)).Run(iters); \
} \
BENCHMARK(BM_##DEVICE##_gather_nd_##INDEX) \
->Arg(1) \
->Arg(10) \
->Arg(20) \
->Arg(64) \
->Arg(100) \
->Arg(200) \
->Arg(1000)
BM_GATHER_ND(cpu, int32);
BM_GATHER_ND(gpu, int32);
BM_GATHER_ND(cpu, int64);
BM_GATHER_ND(gpu, int64);
} // namespace
} // namespace tensorflow

View File

@ -25,6 +25,9 @@ Status InitializableLookupTable::Find(const Tensor& keys, Tensor* values,
if (!is_initialized()) { if (!is_initialized()) {
return errors::FailedPrecondition("Table not initialized."); return errors::FailedPrecondition("Table not initialized.");
} }
// Do not let the use migrate before the check; table is used without
// a lock by the readers.
std::atomic_thread_fence(std::memory_order_acquire);
TF_RETURN_IF_ERROR(CheckFindArguments(keys, *values, default_value)); TF_RETURN_IF_ERROR(CheckFindArguments(keys, *values, default_value));
return DoFind(keys, values, default_value); return DoFind(keys, values, default_value);
} }
@ -48,6 +51,10 @@ Status InitializableLookupTable::Initialize(InitTableIterator& iter) {
if (!errors::IsOutOfRange(iter.status())) { if (!errors::IsOutOfRange(iter.status())) {
return iter.status(); return iter.status();
} }
// Prevent compiler/memory reordering of is_initialized and
// the initialization itself.
std::atomic_thread_fence(std::memory_order_release);
is_initialized_ = true; is_initialized_ = true;
return Status::OK(); return Status::OK();
} }

View File

@ -69,7 +69,11 @@ class HashTable : public InitializableLookupTable {
public: public:
size_t size() const override { size_t size() const override {
// return the size of the table only if it's initialized, otherwise 0. // return the size of the table only if it's initialized, otherwise 0.
return table_ && is_initialized_ ? table_->size() : 0; if (!is_initialized_) {
return 0;
}
std::atomic_thread_fence(std::memory_order_acquire);
return table_ ? table_->size() : 0;
} }
DataType key_dtype() const override { return DataTypeToEnum<K>::v(); } DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }

View File

@ -407,6 +407,34 @@ this operation will permute `params` accordingly.
</div> </div>
)doc"); )doc");
// --------------------------------------------------------------------------
REGISTER_OP("GatherNd")
.Input("params: Tparams")
.Input("indices: Tindices")
.Output("output: Tparams")
.Attr("Tparams: type")
.Attr("Tindices: {int32,int64}")
.Doc(R"doc(
Gather values from `params` according to `indices`.
`indices` must be integer tensor, containing indices into `params`.
It must be shape `[d_0, ..., d_N, R]` where `R` is the rank of `params`.
The innermost dimension of `indices` (with length `R`) corresponds to the
indices of `params`.
Produces an output tensor with shape `[d_0, ..., d_{n-1}]` where:
output[i, j, k, ...] = params[indices[i, j, k, ..., :]]
e.g. for `indices` a matrix:
output[i] = params[indices[i, :]]
params: R-D. The tensor from which to gather values.
indices: (N+1)-D. Index tensor having shape `[d_0, ..., d_N, R]`.
output: N-D. Values from `params` gathered from indices given by `indices`.
)doc");
// -------------------------------------------------------------------------- // --------------------------------------------------------------------------
REGISTER_OP("Identity") REGISTER_OP("Identity")
.Input("input: T") .Input("input: T")

View File

@ -6480,6 +6480,35 @@ op {
} }
} }
} }
op {
name: "GatherNd"
input_arg {
name: "params"
type_attr: "Tparams"
}
input_arg {
name: "indices"
type_attr: "Tindices"
}
output_arg {
name: "output"
type_attr: "Tparams"
}
attr {
name: "Tparams"
type: "type"
}
attr {
name: "Tindices"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
}
op { op {
name: "Greater" name: "Greater"
input_arg { input_arg {

View File

@ -4066,6 +4066,40 @@ op {
summary: "Gather slices from `params` according to `indices`." summary: "Gather slices from `params` according to `indices`."
description: "`indices` must be an integer tensor of any dimension (usually 0-D or 1-D).\nProduces an output tensor with shape `indices.shape + params.shape[1:]` where:\n\n # Scalar indices\n output[:, ..., :] = params[indices, :, ... :]\n\n # Vector indices\n output[i, :, ..., :] = params[indices[i], :, ... :]\n\n # Higher rank indices\n output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :]\n\nIf `indices` is a permutation and `len(indices) == params.shape[0]` then\nthis operation will permute `params` accordingly.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/Gather.png\" alt>\n</div>" description: "`indices` must be an integer tensor of any dimension (usually 0-D or 1-D).\nProduces an output tensor with shape `indices.shape + params.shape[1:]` where:\n\n # Scalar indices\n output[:, ..., :] = params[indices, :, ... :]\n\n # Vector indices\n output[i, :, ..., :] = params[indices[i], :, ... :]\n\n # Higher rank indices\n output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :]\n\nIf `indices` is a permutation and `len(indices) == params.shape[0]` then\nthis operation will permute `params` accordingly.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/Gather.png\" alt>\n</div>"
} }
op {
name: "GatherNd"
input_arg {
name: "params"
description: "R-D. The tensor from which to gather values."
type_attr: "Tparams"
}
input_arg {
name: "indices"
description: "(N+1)-D. Index tensor having shape `[d_0, ..., d_N, R]`."
type_attr: "Tindices"
}
output_arg {
name: "output"
description: "N-D. Values from `params` gathered from indices given by `indices`."
type_attr: "Tparams"
}
attr {
name: "Tparams"
type: "type"
}
attr {
name: "Tindices"
type: "type"
allowed_values {
list {
type: DT_INT32
type: DT_INT64
}
}
}
summary: "Gather values from `params` according to `indices`."
description: "`indices` must be integer tensor, containing indices into `params`.\nIt must be shape `[d_0, ..., d_N, R]` where `R` is the rank of `params`.\nThe innermost dimension of `indices` (with length `R`) corresponds to the\nindices of `params`.\n\nProduces an output tensor with shape `[d_0, ..., d_{n-1}]` where:\n\n output[i, j, k, ...] = params[indices[i, j, k, ..., :]]\n\ne.g. for `indices` a matrix:\n\n output[i] = params[indices[i, :]]"
}
op { op {
name: "Greater" name: "Greater"
input_arg { input_arg {

View File

@ -0,0 +1 @@
tensorflow-git-owners

View File

@ -1149,6 +1149,39 @@ this operation will permute `params` accordingly.
A `Tensor`. Has the same type as `params`. A `Tensor`. Has the same type as `params`.
- - -
### `tf.gather_nd(params, indices, name=None)` {#gather_nd}
Gather values from `params` according to `indices`.
`indices` must be integer tensor, containing indices into `params`.
It must be shape `[d_0, ..., d_N, R]` where `R` is the rank of `params`.
The innermost dimension of `indices` (with length `R`) corresponds to the
indices of `params`.
Produces an output tensor with shape `[d_0, ..., d_{n-1}]` where:
output[i, j, k, ...] = params[indices[i, j, k, ..., :]]
e.g. for `indices` a matrix:
output[i] = params[indices[i, :]]
##### Args:
* <b>`params`</b>: A `Tensor`. R-D. The tensor from which to gather values.
* <b>`indices`</b>: A `Tensor`. Must be one of the following types: `int32`, `int64`.
(N+1)-D. Index tensor having shape `[d_0, ..., d_N, R]`.
* <b>`name`</b>: A name for the operation (optional).
##### Returns:
A `Tensor`. Has the same type as `params`.
N-D. Values from `params` gathered from indices given by `indices`.
- - - - - -
### `tf.dynamic_partition(data, partitions, num_partitions, name=None)` {#dynamic_partition} ### `tf.dynamic_partition(data, partitions, num_partitions, name=None)` {#dynamic_partition}

View File

@ -411,6 +411,30 @@ subtraction, it usually shouldn't hurt much either.
* <b>`ValueError`</b>: If `regularizer` does not return a scalar output. * <b>`ValueError`</b>: If `regularizer` does not return a scalar output.
- - -
### `tf.contrib.layers.make_all(module_name, doc_string_modules=None)` {#make_all}
Generate `__all__` from the docstring of one or more modules.
Usage: `make_all(__name__)` or
`make_all(__name__, [sys.modules(__name__), other_module])`. The doc string
modules must each a docstring, and `__all__` will contain all symbols with
`@@` references, where that symbol currently exists in the module named
`module_name`.
##### Args:
* <b>`module_name`</b>: The name of the module (usually `__name__`).
* <b>`doc_string_modules`</b>: a list of modules from which to take docstring.
If None, then a list containing only the module named `module_name` is used.
##### Returns:
A list suitable for use as `__all__`.
- - - - - -
### `tf.contrib.layers.optimize_loss(loss, global_step, learning_rate, optimizer, clip_gradients=None, moving_average_decay=0.9, learning_rate_decay_fn=None, variables=None)` {#optimize_loss} ### `tf.contrib.layers.optimize_loss(loss, global_step, learning_rate, optimizer, clip_gradients=None, moving_average_decay=0.9, learning_rate_decay_fn=None, variables=None)` {#optimize_loss}

View File

@ -260,202 +260,6 @@ Example 2:
## Higher Order Operators
TensorFlow provides several higher order operators to simplify the common
map-reduce programming patterns.
- - -
### `tf.map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#map_fn}
The map operator on the list of tensors resulted from unpacking `elems`
along the first dimension.
This map operator repeatedly applies the callable `fn` to a sequence of
elements from first to last. The elements are made of the tensors unpacked
from `elems`. `dtype` is the data type of the return value of `fn`. Users
must provide `dtype` if it is different from the data type of `elems`.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is `[len(values)] + fn(values[0]).shape`.
##### Args:
* <b>`fn`</b>: The callable to be performed.
* <b>`elems`</b>: A tensor to be unpacked to apply `fn`.
* <b>`dtype`</b>: (optional) The output type of `fn`.
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
in parallel.
* <b>`back_prop`</b>: (optional) True enables back propagation.
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
##### Returns:
A tensor that packs the results of applying `fn` to the list of tensors
unpacked from `elems`, from first to last.
##### Raises:
* <b>`TypeError`</b>: if `fn` is not callable.
##### Example:
```python
elems = [1, 2, 3, 4, 5, 6]
squares = map_fn(lambda x: x * x, elems)
# squares == [1, 4, 9, 16, 25, 36]
```
- - -
### `tf.foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#foldl}
The foldl operator on the list of tensors resulted from unpacking `elems`
along the first dimension.
This foldl operator repeatedly applies the callable `fn` to a sequence
of elements from first to last. The elements are made of the tensors
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
arguments. The first argument is the accumulated value computed from the
preceding invocation of fn. If `initializer` is None, `elems` must contain
at least one element, and its first element is used as the initializer.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is fn(initializer, values[0]).shape`.
##### Args:
* <b>`fn`</b>: The callable to be performed.
* <b>`elems`</b>: A tensor to be unpacked on dimension 0.
* <b>`initializer`</b>: (optional) The initial value for the accumulator.
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
in parallel.
* <b>`back_prop`</b>: (optional) True enables back propagation.
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
##### Returns:
A tensor resulting from applying `fn` consecutively to the list of tensors
unpacked from `elems`, from first to last.
##### Raises:
* <b>`TypeError`</b>: if `fn` is not callable.
##### Example:
```python
elems = [1, 2, 3, 4, 5, 6]
sum = foldl(lambda a, x: a + x, elems)
# sum == 21
```
- - -
### `tf.foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#foldr}
The foldr operator on the list of tensors resulted from unpacking `elems`
along the first dimension.
This foldr operator repeatedly applies the callable `fn` to a sequence
of elements from last to first. The elements are made of the tensors
unpacked from `elems`. The callable fn takes two tensors as arguments.
The first argument is the accumulated value computed from the preceding
invocation of fn. If `initializer` is None, `elems` must contain at least
one element, and its first element is used as the initializer.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is `fn(initializer, values[0]).shape`.
##### Args:
* <b>`fn`</b>: The callable to be performed.
* <b>`elems`</b>: A tensor that is unpacked into a sequence of tensors to apply `fn`.
* <b>`initializer`</b>: (optional) The initial value for the accumulator.
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
in parallel.
* <b>`back_prop`</b>: (optional) True enables back propagation.
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
##### Returns:
A tensor resulting from applying `fn` consecutively to the list of tensors
unpacked from `elems`, from last to first.
##### Raises:
* <b>`TypeError`</b>: if `fn` is not callable.
##### Example:
```python
elems = [1, 2, 3, 4, 5, 6]
sum = foldr(lambda a, x: a + x, elems)
# sum == 21
```
- - -
### `tf.scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#scan}
The scan operator on the list of tensors resulted from unpacking `elems`
along the first dimension.
This scan operator repeatedly applies the callable `fn` to a sequence
of elements from first to last. The elements are made of the tensors
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
arguments. The first argument is the accumulated value computed from the
preceding invocation of fn. If `initializer` is None, `elems` must contain
at least one element, and its first element is used as the initializer.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
##### Args:
* <b>`fn`</b>: The callable to be performed.
* <b>`elems`</b>: A tensor to be unpacked on dimension 0.
* <b>`initializer`</b>: (optional) The initial value for the accumulator.
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
in parallel.
* <b>`back_prop`</b>: (optional) True enables back propagation.
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
##### Returns:
A tensor that packs the results of applying `fn` to the list of tensors
unpacked from `elems`, from first to last.
##### Raises:
* <b>`TypeError`</b>: if `fn` is not callable.
##### Example:
```python
elems = [1, 2, 3, 4, 5, 6]
sum = scan(lambda a, x: a + x, elems)
# sum == [1, 3, 6, 10, 15, 21]
```
## Logical Operators ## Logical Operators
TensorFlow provides several operations that you can use to add logical operators TensorFlow provides several operations that you can use to add logical operators

View File

@ -409,6 +409,31 @@ Returns a list of values in the collection with the given `name`.
collected. collected.
- - -
#### `tf.Graph.get_collection_ref(name)` {#Graph.get_collection_ref}
Returns a list of values in the collection with the given `name`.
If the collection exists, this returns the list itself, which can
be modified in place to change the collection. If the collection does
not exist, it is created as an empty list and the list is returned.
This is different from `get_collection()` which always returns a copy of
the collection list if it exists and never creates an empty collection.
##### Args:
* <b>`name`</b>: The key for the collection. For example, the `GraphKeys` class
contains many standard names for collections.
##### Returns:
The list of values in the collection with the given `name`, or an empty
list if no value has been added to that collection.
- - - - - -
@ -1752,6 +1777,29 @@ for more details.
collected. collected.
- - -
### `tf.get_collection_ref(key)` {#get_collection_ref}
Wrapper for `Graph.get_collection_ref()` using the default graph.
See [`Graph.get_collection_ref()`](../../api_docs/python/framework.md#Graph.get_collection_ref)
for more details.
##### Args:
* <b>`key`</b>: The key for the collection. For example, the `GraphKeys` class
contains many standard names for collections.
##### Returns:
The list of values in the collection with the given `name`, or an empty
list if no value has been added to that collection. Note that this returns
the collection list itself, which can be modified in place to change the
collection.
- - - - - -
### `class tf.GraphKeys` {#GraphKeys} ### `class tf.GraphKeys` {#GraphKeys}

View File

@ -0,0 +1,202 @@
<!-- This file is machine generated: DO NOT EDIT! -->
# Higher Order Functions
Note: Functions taking `Tensor` arguments can also take anything accepted by
[`tf.convert_to_tensor`](framework.md#convert_to_tensor).
[TOC]
Functional operations.
## Higher Order Operators
TensorFlow provides several higher order operators to simplify the common
map-reduce programming patterns.
- - -
### `tf.map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#map_fn}
map on the list of tensors unpacked from `elems` on dimension 0.
This map operator repeatedly applies the callable `fn` to a sequence of
elements from first to last. The elements are made of the tensors unpacked
from `elems`. `dtype` is the data type of the return value of `fn`. Users
must provide `dtype` if it is different from the data type of `elems`.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is `[len(values)] + fn(values[0]).shape`.
##### Args:
* <b>`fn`</b>: The callable to be performed.
* <b>`elems`</b>: A tensor to be unpacked to apply `fn`.
* <b>`dtype`</b>: (optional) The output type of `fn`.
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
in parallel.
* <b>`back_prop`</b>: (optional) True enables back propagation.
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
##### Returns:
A tensor that packs the results of applying `fn` to the list of tensors
unpacked from `elems`, from first to last.
##### Raises:
* <b>`TypeError`</b>: if `fn` is not callable.
##### Example:
```python
elems = [1, 2, 3, 4, 5, 6]
squares = map_fn(lambda x: x * x, elems)
# squares == [1, 4, 9, 16, 25, 36]
```
- - -
### `tf.foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#foldl}
foldl on the list of tensors unpacked from `elems` on dimension 0.
This foldl operator repeatedly applies the callable `fn` to a sequence
of elements from first to last. The elements are made of the tensors
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
arguments. The first argument is the accumulated value computed from the
preceding invocation of fn. If `initializer` is None, `elems` must contain
at least one element, and its first element is used as the initializer.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is fn(initializer, values[0]).shape`.
##### Args:
* <b>`fn`</b>: The callable to be performed.
* <b>`elems`</b>: A tensor to be unpacked on dimension 0.
* <b>`initializer`</b>: (optional) The initial value for the accumulator.
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
in parallel.
* <b>`back_prop`</b>: (optional) True enables back propagation.
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
##### Returns:
A tensor resulting from applying `fn` consecutively to the list of tensors
unpacked from `elems`, from first to last.
##### Raises:
* <b>`TypeError`</b>: if `fn` is not callable.
##### Example:
```python
elems = [1, 2, 3, 4, 5, 6]
sum = foldl(lambda a, x: a + x, elems)
# sum == 21
```
- - -
### `tf.foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#foldr}
foldr on the list of tensors unpacked from `elems` on dimension 0.
This foldr operator repeatedly applies the callable `fn` to a sequence
of elements from last to first. The elements are made of the tensors
unpacked from `elems`. The callable fn takes two tensors as arguments.
The first argument is the accumulated value computed from the preceding
invocation of fn. If `initializer` is None, `elems` must contain at least
one element, and its first element is used as the initializer.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is `fn(initializer, values[0]).shape`.
##### Args:
* <b>`fn`</b>: The callable to be performed.
* <b>`elems`</b>: A tensor that is unpacked into a sequence of tensors to apply `fn`.
* <b>`initializer`</b>: (optional) The initial value for the accumulator.
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
in parallel.
* <b>`back_prop`</b>: (optional) True enables back propagation.
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
##### Returns:
A tensor resulting from applying `fn` consecutively to the list of tensors
unpacked from `elems`, from last to first.
##### Raises:
* <b>`TypeError`</b>: if `fn` is not callable.
##### Example:
```python
elems = [1, 2, 3, 4, 5, 6]
sum = foldr(lambda a, x: a + x, elems)
# sum == 21
```
- - -
### `tf.scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#scan}
scan on the list of tensors unpacked from `elems` on dimension 0.
This scan operator repeatedly applies the callable `fn` to a sequence
of elements from first to last. The elements are made of the tensors
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
arguments. The first argument is the accumulated value computed from the
preceding invocation of fn. If `initializer` is None, `elems` must contain
at least one element, and its first element is used as the initializer.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
##### Args:
* <b>`fn`</b>: The callable to be performed.
* <b>`elems`</b>: A tensor to be unpacked on dimension 0.
* <b>`initializer`</b>: (optional) The initial value for the accumulator.
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
in parallel.
* <b>`back_prop`</b>: (optional) True enables back propagation.
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
##### Returns:
A tensor that packs the results of applying `fn` to the list of tensors
unpacked from `elems`, from first to last.
##### Raises:
* <b>`TypeError`</b>: if `fn` is not callable.
##### Example:
```python
elems = [1, 2, 3, 4, 5, 6]
sum = scan(lambda a, x: a + x, elems)
# sum == [1, 3, 6, 10, 15, 21]
```

View File

@ -1234,29 +1234,3 @@ false and no bounding boxes are supplied, an error is raised.
Provide as input to `tf.image.draw_bounding_boxes`. Provide as input to `tf.image.draw_bounding_boxes`.
## Other Functions and Classes
- - -
### `tf.contrib.layers.make_all(module_name, doc_string_modules=None)` {#make_all}
Generate `__all__` from the docstring of one or more modules.
Usage: `make_all(__name__)` or
`make_all(__name__, [sys.modules(__name__), other_module])`. The doc string
modules must each a docstring, and `__all__` will contain all symbols with
`@@` references, where that symbol currently exists in the module named
`module_name`.
##### Args:
* <b>`module_name`</b>: The name of the module (usually `__name__`).
* <b>`doc_string_modules`</b>: a list of modules from which to take docstring.
If None, then a list containing only the module named `module_name` is used.
##### Returns:
A list suitable for use as `__all__`.

View File

@ -13,6 +13,7 @@
* [`Dimension`](../../api_docs/python/framework.md#Dimension) * [`Dimension`](../../api_docs/python/framework.md#Dimension)
* [`DType`](../../api_docs/python/framework.md#DType) * [`DType`](../../api_docs/python/framework.md#DType)
* [`get_collection`](../../api_docs/python/framework.md#get_collection) * [`get_collection`](../../api_docs/python/framework.md#get_collection)
* [`get_collection_ref`](../../api_docs/python/framework.md#get_collection_ref)
* [`get_default_graph`](../../api_docs/python/framework.md#get_default_graph) * [`get_default_graph`](../../api_docs/python/framework.md#get_default_graph)
* [`get_seed`](../../api_docs/python/framework.md#get_seed) * [`get_seed`](../../api_docs/python/framework.md#get_seed)
* [`Graph`](../../api_docs/python/framework.md#Graph) * [`Graph`](../../api_docs/python/framework.md#Graph)
@ -95,6 +96,7 @@
* [`dynamic_stitch`](../../api_docs/python/array_ops.md#dynamic_stitch) * [`dynamic_stitch`](../../api_docs/python/array_ops.md#dynamic_stitch)
* [`expand_dims`](../../api_docs/python/array_ops.md#expand_dims) * [`expand_dims`](../../api_docs/python/array_ops.md#expand_dims)
* [`gather`](../../api_docs/python/array_ops.md#gather) * [`gather`](../../api_docs/python/array_ops.md#gather)
* [`gather_nd`](../../api_docs/python/array_ops.md#gather_nd)
* [`one_hot`](../../api_docs/python/array_ops.md#one_hot) * [`one_hot`](../../api_docs/python/array_ops.md#one_hot)
* [`pack`](../../api_docs/python/array_ops.md#pack) * [`pack`](../../api_docs/python/array_ops.md#pack)
* [`pad`](../../api_docs/python/array_ops.md#pad) * [`pad`](../../api_docs/python/array_ops.md#pad)
@ -229,8 +231,6 @@
* [`cond`](../../api_docs/python/control_flow_ops.md#cond) * [`cond`](../../api_docs/python/control_flow_ops.md#cond)
* [`count_up_to`](../../api_docs/python/control_flow_ops.md#count_up_to) * [`count_up_to`](../../api_docs/python/control_flow_ops.md#count_up_to)
* [`equal`](../../api_docs/python/control_flow_ops.md#equal) * [`equal`](../../api_docs/python/control_flow_ops.md#equal)
* [`foldl`](../../api_docs/python/control_flow_ops.md#foldl)
* [`foldr`](../../api_docs/python/control_flow_ops.md#foldr)
* [`greater`](../../api_docs/python/control_flow_ops.md#greater) * [`greater`](../../api_docs/python/control_flow_ops.md#greater)
* [`greater_equal`](../../api_docs/python/control_flow_ops.md#greater_equal) * [`greater_equal`](../../api_docs/python/control_flow_ops.md#greater_equal)
* [`group`](../../api_docs/python/control_flow_ops.md#group) * [`group`](../../api_docs/python/control_flow_ops.md#group)
@ -244,16 +244,20 @@
* [`logical_not`](../../api_docs/python/control_flow_ops.md#logical_not) * [`logical_not`](../../api_docs/python/control_flow_ops.md#logical_not)
* [`logical_or`](../../api_docs/python/control_flow_ops.md#logical_or) * [`logical_or`](../../api_docs/python/control_flow_ops.md#logical_or)
* [`logical_xor`](../../api_docs/python/control_flow_ops.md#logical_xor) * [`logical_xor`](../../api_docs/python/control_flow_ops.md#logical_xor)
* [`map_fn`](../../api_docs/python/control_flow_ops.md#map_fn)
* [`no_op`](../../api_docs/python/control_flow_ops.md#no_op) * [`no_op`](../../api_docs/python/control_flow_ops.md#no_op)
* [`not_equal`](../../api_docs/python/control_flow_ops.md#not_equal) * [`not_equal`](../../api_docs/python/control_flow_ops.md#not_equal)
* [`Print`](../../api_docs/python/control_flow_ops.md#Print) * [`Print`](../../api_docs/python/control_flow_ops.md#Print)
* [`scan`](../../api_docs/python/control_flow_ops.md#scan)
* [`select`](../../api_docs/python/control_flow_ops.md#select) * [`select`](../../api_docs/python/control_flow_ops.md#select)
* [`tuple`](../../api_docs/python/control_flow_ops.md#tuple) * [`tuple`](../../api_docs/python/control_flow_ops.md#tuple)
* [`verify_tensor_all_finite`](../../api_docs/python/control_flow_ops.md#verify_tensor_all_finite) * [`verify_tensor_all_finite`](../../api_docs/python/control_flow_ops.md#verify_tensor_all_finite)
* [`where`](../../api_docs/python/control_flow_ops.md#where) * [`where`](../../api_docs/python/control_flow_ops.md#where)
* **[Higher Order Functions](../../api_docs/python/functional_ops.md)**:
* [`foldl`](../../api_docs/python/functional_ops.md#foldl)
* [`foldr`](../../api_docs/python/functional_ops.md#foldr)
* [`map_fn`](../../api_docs/python/functional_ops.md#map_fn)
* [`scan`](../../api_docs/python/functional_ops.md#scan)
* **[Images](../../api_docs/python/image.md)**: * **[Images](../../api_docs/python/image.md)**:
* [`adjust_brightness`](../../api_docs/python/image.md#adjust_brightness) * [`adjust_brightness`](../../api_docs/python/image.md#adjust_brightness)
* [`adjust_contrast`](../../api_docs/python/image.md#adjust_contrast) * [`adjust_contrast`](../../api_docs/python/image.md#adjust_contrast)
@ -272,7 +276,6 @@
* [`flip_up_down`](../../api_docs/python/image.md#flip_up_down) * [`flip_up_down`](../../api_docs/python/image.md#flip_up_down)
* [`grayscale_to_rgb`](../../api_docs/python/image.md#grayscale_to_rgb) * [`grayscale_to_rgb`](../../api_docs/python/image.md#grayscale_to_rgb)
* [`hsv_to_rgb`](../../api_docs/python/image.md#hsv_to_rgb) * [`hsv_to_rgb`](../../api_docs/python/image.md#hsv_to_rgb)
* [`make_all`](../../api_docs/python/image.md#make_all)
* [`pad_to_bounding_box`](../../api_docs/python/image.md#pad_to_bounding_box) * [`pad_to_bounding_box`](../../api_docs/python/image.md#pad_to_bounding_box)
* [`per_image_whitening`](../../api_docs/python/image.md#per_image_whitening) * [`per_image_whitening`](../../api_docs/python/image.md#per_image_whitening)
* [`random_brightness`](../../api_docs/python/image.md#random_brightness) * [`random_brightness`](../../api_docs/python/image.md#random_brightness)
@ -466,6 +469,7 @@
* [`fully_connected`](../../api_docs/python/contrib.layers.md#fully_connected) * [`fully_connected`](../../api_docs/python/contrib.layers.md#fully_connected)
* [`l1_regularizer`](../../api_docs/python/contrib.layers.md#l1_regularizer) * [`l1_regularizer`](../../api_docs/python/contrib.layers.md#l1_regularizer)
* [`l2_regularizer`](../../api_docs/python/contrib.layers.md#l2_regularizer) * [`l2_regularizer`](../../api_docs/python/contrib.layers.md#l2_regularizer)
* [`make_all`](../../api_docs/python/contrib.layers.md#make_all)
* [`optimize_loss`](../../api_docs/python/contrib.layers.md#optimize_loss) * [`optimize_loss`](../../api_docs/python/contrib.layers.md#optimize_loss)
* [`sum_regularizer`](../../api_docs/python/contrib.layers.md#sum_regularizer) * [`sum_regularizer`](../../api_docs/python/contrib.layers.md#sum_regularizer)
* [`summarize_activation`](../../api_docs/python/contrib.layers.md#summarize_activation) * [`summarize_activation`](../../api_docs/python/contrib.layers.md#summarize_activation)

View File

@ -101,6 +101,7 @@ from tensorflow.python.framework import framework_lib
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op from tensorflow.python.ops import constant_op
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import histogram_ops from tensorflow.python.ops import histogram_ops
from tensorflow.python.ops import io_ops from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
@ -118,8 +119,8 @@ _whitelist = set([app, compat, contrib, errors, flags, gfile, image,
# strings of other modules. # strings of other modules.
__all__ = make_all(__name__, __all__ = make_all(__name__,
[framework_lib, array_ops, client_lib, constant_op, [framework_lib, array_ops, client_lib, constant_op,
control_flow_ops, histogram_ops, io_ops, math_ops, nn, control_flow_ops, functional_ops, histogram_ops, io_ops,
script_ops, sparse_ops, state_ops, train]) math_ops, nn, script_ops, sparse_ops, state_ops, train])
# Symbols whitelisted for export without documentation. # Symbols whitelisted for export without documentation.
# TODO(cwhipkey): review these and move to contrib, expose through # TODO(cwhipkey): review these and move to contrib, expose through

View File

@ -484,6 +484,9 @@ class Library(Document):
names = self._members.items() names = self._members.items()
else: else:
names = inspect.getmembers(self._module) names = inspect.getmembers(self._module)
all_names = getattr(self._module, "__all__", None)
if all_names is not None:
names = [(n, m) for n, m in names if n in all_names]
leftovers = [] leftovers = []
for name, _ in names: for name, _ in names:
if name in self._members and name not in self._documented: if name in self._members and name not in self._documented:

View File

@ -43,6 +43,7 @@
@@add_to_collection @@add_to_collection
@@get_collection @@get_collection
@@get_collection_ref
@@GraphKeys @@GraphKeys
## Defining new operations ## Defining new operations
@ -82,6 +83,7 @@ from tensorflow.python.framework.ops import reset_default_graph
from tensorflow.python.framework.ops import GraphKeys from tensorflow.python.framework.ops import GraphKeys
from tensorflow.python.framework.ops import add_to_collection from tensorflow.python.framework.ops import add_to_collection
from tensorflow.python.framework.ops import get_collection from tensorflow.python.framework.ops import get_collection
from tensorflow.python.framework.ops import get_collection_ref
from tensorflow.python.framework.ops import convert_to_tensor from tensorflow.python.framework.ops import convert_to_tensor
from tensorflow.python.framework.ops import convert_to_tensor_or_indexed_slices from tensorflow.python.framework.ops import convert_to_tensor_or_indexed_slices
from tensorflow.python.framework.random_seed import get_seed from tensorflow.python.framework.random_seed import get_seed

View File

@ -83,6 +83,7 @@ def all_libraries(module_to_name, members, documented):
prefix=PREFIX_TEXT), prefix=PREFIX_TEXT),
library("histogram_ops", "Histograms"), library("histogram_ops", "Histograms"),
library("control_flow_ops", "Control Flow", prefix=PREFIX_TEXT), library("control_flow_ops", "Control Flow", prefix=PREFIX_TEXT),
library("functional_ops", "Higher Order Functions", prefix=PREFIX_TEXT),
library("image", "Images", tf.image, exclude_symbols=["ResizeMethod"], library("image", "Images", tf.image, exclude_symbols=["ResizeMethod"],
prefix=PREFIX_TEXT), prefix=PREFIX_TEXT),
library("sparse_ops", "Sparse Tensors", library("sparse_ops", "Sparse Tensors",

View File

@ -1787,6 +1787,7 @@ class Graph(object):
@@add_to_collection @@add_to_collection
@@get_collection @@get_collection
@@get_collection_ref
@@as_graph_element @@as_graph_element
@@get_operation_by_name @@get_operation_by_name
@ -2396,6 +2397,30 @@ class Graph(object):
for name in set(names): for name in set(names):
self.add_to_collection(name, value) self.add_to_collection(name, value)
def get_collection_ref(self, name):
"""Returns a list of values in the collection with the given `name`.
If the collection exists, this returns the list itself, which can
be modified in place to change the collection. If the collection does
not exist, it is created as an empty list and the list is returned.
This is different from `get_collection()` which always returns a copy of
the collection list if it exists and never creates an empty collection.
Args:
name: The key for the collection. For example, the `GraphKeys` class
contains many standard names for collections.
Returns:
The list of values in the collection with the given `name`, or an empty
list if no value has been added to that collection.
"""
coll_list = self._collections.get(name, None)
if coll_list is None:
coll_list = []
self._collections[name] = coll_list
return coll_list
def get_collection(self, name, scope=None): def get_collection(self, name, scope=None):
"""Returns a list of values in the collection with the given `name`. """Returns a list of values in the collection with the given `name`.
@ -2411,11 +2436,14 @@ class Graph(object):
list contains the values in the order under which they were list contains the values in the order under which they were
collected. collected.
""" """
coll_list = self._collections.get(name, None)
if coll_list is None:
return []
if scope is None: if scope is None:
return self._collections.get(name, list()) return list(coll_list)
else: else:
c = [] c = []
for item in self._collections.get(name, list()): for item in coll_list:
if hasattr(item, "name") and item.name.startswith(scope): if hasattr(item, "name") and item.name.startswith(scope):
c.append(item) c.append(item)
return c return c
@ -3547,6 +3575,25 @@ def add_to_collections(names, value):
get_default_graph().add_to_collections(names, value) get_default_graph().add_to_collections(names, value)
def get_collection_ref(key):
"""Wrapper for `Graph.get_collection_ref()` using the default graph.
See [`Graph.get_collection_ref()`](../../api_docs/python/framework.md#Graph.get_collection_ref)
for more details.
Args:
key: The key for the collection. For example, the `GraphKeys` class
contains many standard names for collections.
Returns:
The list of values in the collection with the given `name`, or an empty
list if no value has been added to that collection. Note that this returns
the collection list itself, which can be modified in place to change the
collection.
"""
return get_default_graph().get_collection_ref(key)
def get_collection(key, scope=None): def get_collection(key, scope=None):
"""Wrapper for `Graph.get_collection()` using the default graph. """Wrapper for `Graph.get_collection()` using the default graph.

View File

@ -751,12 +751,41 @@ class CollectionTest(test_util.TensorFlowTestCase):
blank2 = ObjectWithName("junk/foo") blank2 = ObjectWithName("junk/foo")
g.add_to_collection("blah", blank2) g.add_to_collection("blah", blank2)
self.assertEqual(["foo"], g.get_collection("other"))
self.assertEqual([12, 34], g.get_collection("key")) self.assertEqual([12, 34], g.get_collection("key"))
self.assertEqual([], g.get_collection("nothing")) self.assertEqual([], g.get_collection("nothing"))
self.assertEqual([27, blank1, blank2], g.get_collection("blah")) self.assertEqual([27, blank1, blank2], g.get_collection("blah"))
self.assertEqual([blank1], g.get_collection("blah", "prefix")) self.assertEqual([blank1], g.get_collection("blah", "prefix"))
# Make sure that get_collection() returns a first-level
# copy of the collection, while get_collection_ref() returns
# the original list.
other_collection_snapshot = g.get_collection("other")
other_collection_ref = g.get_collection_ref("other")
self.assertEqual(["foo"], other_collection_snapshot)
self.assertEqual(["foo"], other_collection_ref)
g.add_to_collection("other", "bar")
self.assertEqual(["foo"], other_collection_snapshot)
self.assertEqual(["foo", "bar"], other_collection_ref)
self.assertEqual(["foo", "bar"], g.get_collection("other"))
self.assertTrue(other_collection_ref is g.get_collection_ref("other"))
# Verify that getting an empty collection ref returns a modifiable list.
empty_coll_ref = g.get_collection_ref("empty")
self.assertEqual([], empty_coll_ref)
empty_coll = g.get_collection("empty")
self.assertEqual([], empty_coll)
self.assertFalse(empty_coll is empty_coll_ref)
empty_coll_ref2 = g.get_collection_ref("empty")
self.assertTrue(empty_coll_ref2 is empty_coll_ref)
# Add to the collection.
empty_coll_ref.append("something")
self.assertEqual(["something"], empty_coll_ref)
self.assertEqual(["something"], empty_coll_ref2)
self.assertEqual([], empty_coll)
self.assertEqual(["something"], g.get_collection("empty"))
empty_coll_ref3 = g.get_collection_ref("empty")
self.assertTrue(empty_coll_ref3 is empty_coll_ref)
def testDefaulGraph(self): def testDefaulGraph(self):
with ops.Graph().as_default(): with ops.Graph().as_default():
ops.add_to_collection("key", 90) ops.add_to_collection("key", 90)

View File

@ -20,7 +20,6 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import math import math
import sys
import numpy as np import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin from six.moves import xrange # pylint: disable=redefined-builtin
@ -1182,75 +1181,6 @@ class ControlFlowTest(tf.test.TestCase):
self.assertEqual(0, value_x) self.assertEqual(0, value_x)
self.assertEqual(73, value_x_grad) self.assertEqual(73, value_x_grad)
def testFoldl_Simple(self):
with self.test_session():
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
r = control_flow_ops.foldl(
lambda a, x: tf.mul(tf.add(a, x), 2), elems)
self.assertAllEqual(208, r.eval())
r = control_flow_ops.foldl(
lambda a, x: tf.mul(tf.add(a, x), 2), elems, initializer=10)
self.assertAllEqual(880, r.eval())
def testFoldr_Simple(self):
with self.test_session():
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
r = control_flow_ops.foldr(
lambda a, x: tf.mul(tf.add(a, x), 2), elems)
self.assertAllEqual(450, r.eval())
r = control_flow_ops.foldr(
lambda a, x: tf.mul(tf.add(a, x), 2), elems, initializer=10)
self.assertAllEqual(1282, r.eval())
def testFold_Grad(self):
with self.test_session():
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
v = tf.constant(2.0, name="v")
r = control_flow_ops.foldl(
lambda a, x: tf.mul(a, x), elems, initializer=v)
r = tf.gradients(r, v)[0]
self.assertAllEqual(720.0, r.eval())
r = control_flow_ops.foldr(
lambda a, x: tf.mul(a, x), elems, initializer=v)
r = tf.gradients(r, v)[0]
self.assertAllEqual(720.0, r.eval())
def testMap_Simple(self):
with self.test_session():
nums = [1, 2, 3, 4, 5, 6]
elems = tf.constant(nums, name="data")
r = control_flow_ops.map_fn(
lambda x: tf.mul(tf.add(x, 3), 2), elems)
self.assertAllEqual(np.array([(x + 3) * 2 for x in nums]), r.eval())
def testScan_Simple(self):
with self.test_session():
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
v = tf.constant(2.0, name="v")
r = control_flow_ops.scan(lambda a, x: tf.mul(a, x), elems)
self.assertAllEqual([1., 2., 6., 24., 120., 720.], r.eval())
r = control_flow_ops.scan(
lambda a, x: tf.mul(a, x), elems, initializer=v)
self.assertAllEqual([2., 4., 12., 48., 240., 1440.], r.eval())
def testScan_Grad(self):
with self.test_session():
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
v = tf.constant(2.0, name="v")
r = control_flow_ops.scan(
lambda a, x: tf.mul(a, x), elems, initializer=v)
r = tf.gradients(r, v)[0]
self.assertAllEqual(873.0, r.eval())
def testOneValueCond(self): def testOneValueCond(self):
with self.test_session(): with self.test_session():
c = tf.placeholder(tf.int32, shape=[]) c = tf.placeholder(tf.int32, shape=[])

View File

@ -0,0 +1,94 @@
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.kernels.bcast_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
class FunctionalOpsTest(tf.test.TestCase):
def testFoldl_Simple(self):
with self.test_session():
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
r = tf.foldl(lambda a, x: tf.mul(tf.add(a, x), 2), elems)
self.assertAllEqual(208, r.eval())
r = tf.foldl(
lambda a, x: tf.mul(tf.add(a, x), 2), elems, initializer=10)
self.assertAllEqual(880, r.eval())
def testFoldr_Simple(self):
with self.test_session():
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
r = tf.foldr(lambda a, x: tf.mul(tf.add(a, x), 2), elems)
self.assertAllEqual(450, r.eval())
r = tf.foldr(
lambda a, x: tf.mul(tf.add(a, x), 2), elems, initializer=10)
self.assertAllEqual(1282, r.eval())
def testFold_Grad(self):
with self.test_session():
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
v = tf.constant(2.0, name="v")
r = tf.foldl(
lambda a, x: tf.mul(a, x), elems, initializer=v)
r = tf.gradients(r, v)[0]
self.assertAllEqual(720.0, r.eval())
r = tf.foldr(
lambda a, x: tf.mul(a, x), elems, initializer=v)
r = tf.gradients(r, v)[0]
self.assertAllEqual(720.0, r.eval())
def testMap_Simple(self):
with self.test_session():
nums = [1, 2, 3, 4, 5, 6]
elems = tf.constant(nums, name="data")
r = tf.map_fn(lambda x: tf.mul(tf.add(x, 3), 2), elems)
self.assertAllEqual(np.array([(x + 3) * 2 for x in nums]), r.eval())
def testScan_Simple(self):
with self.test_session():
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
v = tf.constant(2.0, name="v")
r = tf.scan(lambda a, x: tf.mul(a, x), elems)
self.assertAllEqual([1., 2., 6., 24., 120., 720.], r.eval())
r = tf.scan(
lambda a, x: tf.mul(a, x), elems, initializer=v)
self.assertAllEqual([2., 4., 12., 48., 240., 1440.], r.eval())
def testScan_Grad(self):
with self.test_session():
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
v = tf.constant(2.0, name="v")
r = tf.scan(lambda a, x: tf.mul(a, x), elems, initializer=v)
r = tf.gradients(r, v)[0]
self.assertAllEqual(873.0, r.eval())
if __name__ == "__main__":
tf.test.main()

View File

@ -0,0 +1,122 @@
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.tf.gather_nd."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import time
import numpy as np
import tensorflow as tf
class GatherNdTest(tf.test.TestCase):
use_gpu = False
def _testSimpleDtype(self, dtype):
with self.test_session(use_gpu=self.use_gpu):
params = tf.constant(np.array([8, 1, 2, 3, 7, 5], dtype=dtype))
indices = tf.constant([[4], [4], [0]])
gather_nd_t = tf.gather_nd(params, indices)
gather_nd_val = gather_nd_t.eval()
self.assertAllEqual(np.array([7, 7, 8], dtype=dtype), gather_nd_val)
self.assertEqual([3], gather_nd_t.get_shape())
def testSimpleDtype(self):
self._testSimpleDtype(np.float32)
self._testSimpleDtype(np.float64)
self._testSimpleDtype(np.int32)
self._testSimpleDtype(np.int64)
self._testSimpleDtype(np.complex64)
self._testSimpleDtype("|S") # byte strings in python2 + 3
def testHigherRankParams(self):
with self.test_session(use_gpu=self.use_gpu):
shape = (10, 20, 5, 1, 17)
params = np.random.rand(*shape)
indices = np.vstack([
np.random.randint(0, s, size=2000) for s in shape]).T
gather_nd_t = tf.gather_nd(params, indices)
gather_nd_val = gather_nd_t.eval()
expected = params[tuple(indices.T)]
self.assertAllEqual(expected, gather_nd_val)
self.assertEqual([2000], gather_nd_t.get_shape())
def testHigherRankParamsAndIndices(self):
with self.test_session(use_gpu=self.use_gpu):
shape = (10, 20, 5, 1, 17)
params = np.random.rand(*shape)
indices = np.vstack([
np.random.randint(0, s, size=2000) for s in shape]).T
indices_reshaped = indices.reshape([10, 10, 20, 5])
gather_nd_t = tf.gather_nd(params, indices_reshaped)
gather_nd_val = gather_nd_t.eval()
expected = params[tuple(indices.T)]
self.assertAllEqual(expected.reshape([10, 10, 20]), gather_nd_val)
self.assertEqual([10, 10, 20], gather_nd_t.get_shape())
def testUnknownIndices(self):
params = tf.constant([[0, 1, 2]])
indices = tf.placeholder(tf.int32)
gather_nd_t = tf.gather_nd(params, indices)
shape = gather_nd_t.get_shape()
self.assertEqual(shape.ndims, None)
self.assertEqual(shape[0].value, None)
def testBadIndices(self):
with self.test_session(use_gpu=False):
params = [0, 1, 2]
indices = [[[0], [7]]] # Make this one higher rank
gather_nd = tf.gather_nd(params, indices)
with self.assertRaisesOpError(
r"flat indices\[1, :\] = \[7\] does not index into param "
r"\(shape: \[3\]\)"):
gather_nd.eval()
class GatherNdGpuTest(GatherNdTest):
use_gpu = True
class GatherNdOpBenchmark(tf.test.Benchmark):
def benchmark_gather_nd_op(self):
shape = (100, 47, 18, 170, 13)
np.random.seed(127)
params = np.random.rand(*shape)
indices = np.vstack([
np.random.randint(0, s, size=10000) for s in shape]).T
with tf.Session():
t_params = tf.Variable(params)
t_indices = tf.Variable(indices)
gather_op = tf.gather_nd(t_params, t_indices)
tf.initialize_all_variables().run()
for _ in range(10):
gather_op.eval()
t1 = time.time()
for _ in range(1000):
gather_op.eval()
t2 = time.time()
self.report_benchmark(iters=1000, wall_time=(t2-t1)/1000.0)
if __name__ == "__main__":
tf.test.main()

View File

@ -193,6 +193,11 @@ def _GatherGrad(op, grad):
return [ops.IndexedSlices(values, indices, dense_shape), None] return [ops.IndexedSlices(values, indices, dense_shape), None]
@ops.RegisterGradient("GatherNd")
def _GatherNdGrad(unused_op, unused_grad):
raise NotImplementedError("Gradient for gather_nd is not implemented.")
@ops.RegisterGradient("Identity") @ops.RegisterGradient("Identity")
def _IdGrad(_, grad): def _IdGrad(_, grad):
return grad return grad

View File

@ -57,6 +57,7 @@ or join multiple tensors together.
@@space_to_depth @@space_to_depth
@@depth_to_space @@depth_to_space
@@gather @@gather
@@gather_nd
@@dynamic_partition @@dynamic_partition
@@dynamic_stitch @@dynamic_stitch
@@boolean_mask @@boolean_mask
@ -879,6 +880,16 @@ def _GatherShape(op):
return [indices_shape.concatenate(params_shape[1:])] return [indices_shape.concatenate(params_shape[1:])]
@ops.RegisterShape("GatherNd")
def _GatherNdShape(op):
"""Shape function for array_ops.gather_nd."""
params_shape = op.inputs[0].get_shape()
indices_shape = op.inputs[1].get_shape().with_rank_at_least(2)
if indices_shape.ndims is not None:
indices_shape[-1].merge_with(params_shape.ndims)
return [indices_shape[:-1]]
@ops.RegisterShape("Unique") @ops.RegisterShape("Unique")
def _UniqueShape(op): def _UniqueShape(op):
"""Shape function for array_ops.Unique.""" """Shape function for array_ops.Unique."""

View File

@ -26,16 +26,6 @@ the execution of operations and add conditional dependencies to your graph.
@@cond @@cond
@@case @@case
## Higher Order Operators
TensorFlow provides several higher order operators to simplify the common
map-reduce programming patterns.
@@map_fn
@@foldl
@@foldr
@@scan
## Logical Operators ## Logical Operators
TensorFlow provides several operations that you can use to add logical operators TensorFlow provides several operations that you can use to add logical operators
@ -1785,269 +1775,6 @@ def tuple(tensors, name=None, control_inputs=None):
return tpl return tpl
# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
swap_memory=False, name=None):
"""The foldl operator on the list of tensors resulted from unpacking `elems`
along the first dimension.
This foldl operator repeatedly applies the callable `fn` to a sequence
of elements from first to last. The elements are made of the tensors
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
arguments. The first argument is the accumulated value computed from the
preceding invocation of fn. If `initializer` is None, `elems` must contain
at least one element, and its first element is used as the initializer.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is fn(initializer, values[0]).shape`.
Args:
fn: The callable to be performed.
elems: A tensor to be unpacked on dimension 0.
initializer: (optional) The initial value for the accumulator.
parallel_iterations: (optional) The number of iterations allowed to run
in parallel.
back_prop: (optional) True enables back propagation.
swap_memory: (optional) True enables GPU-CPU memory swapping.
name: (optional) Name prefix for the returned tensors.
Returns:
A tensor resulting from applying `fn` consecutively to the list of tensors
unpacked from `elems`, from first to last.
Raises:
TypeError: if `fn` is not callable.
Example:
```python
elems = [1, 2, 3, 4, 5, 6]
sum = foldl(lambda a, x: a + x, elems)
# sum == 21
```
"""
with ops.op_scope([elems], name, "foldl") as name:
if not callable(fn):
raise TypeError("fn must be callable.")
# Convert elems to tensor array.
n = array_ops.shape(elems)[0]
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
dynamic_size=False)
elems_ta = elems_ta.unpack(elems)
if initializer is None:
a = elems_ta.read(0)
i = constant_op.constant(1)
else:
a = ops.convert_to_tensor(initializer)
i = constant_op.constant(0)
def compute(i, a):
a = fn(a, elems_ta.read(i))
return [i + 1, a]
_, r_a = While(lambda i, a: i < n, compute, [i, a],
parallel_iterations=parallel_iterations,
back_prop=back_prop, swap_memory=swap_memory)
return r_a
def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
swap_memory=False, name=None):
"""The foldr operator on the list of tensors resulted from unpacking `elems`
along the first dimension.
This foldr operator repeatedly applies the callable `fn` to a sequence
of elements from last to first. The elements are made of the tensors
unpacked from `elems`. The callable fn takes two tensors as arguments.
The first argument is the accumulated value computed from the preceding
invocation of fn. If `initializer` is None, `elems` must contain at least
one element, and its first element is used as the initializer.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is `fn(initializer, values[0]).shape`.
Args:
fn: The callable to be performed.
elems: A tensor that is unpacked into a sequence of tensors to apply `fn`.
initializer: (optional) The initial value for the accumulator.
parallel_iterations: (optional) The number of iterations allowed to run
in parallel.
back_prop: (optional) True enables back propagation.
swap_memory: (optional) True enables GPU-CPU memory swapping.
name: (optional) Name prefix for the returned tensors.
Returns:
A tensor resulting from applying `fn` consecutively to the list of tensors
unpacked from `elems`, from last to first.
Raises:
TypeError: if `fn` is not callable.
Example:
```python
elems = [1, 2, 3, 4, 5, 6]
sum = foldr(lambda a, x: a + x, elems)
# sum == 21
```
"""
with ops.op_scope([elems], name, "foldr") as name:
if not callable(fn):
raise TypeError("fn must be callable.")
# Convert elems to tensor array.
n = array_ops.shape(elems)[0]
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
dynamic_size=False)
elems_ta = elems_ta.unpack(elems)
if initializer is None:
i = n - 1
a = elems_ta.read(i)
else:
i = n
a = ops.convert_to_tensor(initializer)
def compute(i, a):
i -= 1
a = fn(a, elems_ta.read(i))
return [i, a]
_, r_a = While(lambda i, a: i > 0, compute, [i, a],
parallel_iterations=parallel_iterations,
back_prop=back_prop, swap_memory=swap_memory)
return r_a
def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
swap_memory=False, name=None):
"""The map operator on the list of tensors resulted from unpacking `elems`
along the first dimension.
This map operator repeatedly applies the callable `fn` to a sequence of
elements from first to last. The elements are made of the tensors unpacked
from `elems`. `dtype` is the data type of the return value of `fn`. Users
must provide `dtype` if it is different from the data type of `elems`.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is `[len(values)] + fn(values[0]).shape`.
Args:
fn: The callable to be performed.
elems: A tensor to be unpacked to apply `fn`.
dtype: (optional) The output type of `fn`.
parallel_iterations: (optional) The number of iterations allowed to run
in parallel.
back_prop: (optional) True enables back propagation.
swap_memory: (optional) True enables GPU-CPU memory swapping.
name: (optional) Name prefix for the returned tensors.
Returns:
A tensor that packs the results of applying `fn` to the list of tensors
unpacked from `elems`, from first to last.
Raises:
TypeError: if `fn` is not callable.
Example:
```python
elems = [1, 2, 3, 4, 5, 6]
squares = map_fn(lambda x: x * x, elems)
# squares == [1, 4, 9, 16, 25, 36]
```
"""
with ops.op_scope([elems], name, "map") as name:
if not callable(fn):
raise TypeError("fn must be callable.")
dtype = dtype if dtype else elems.dtype
# Convert elems to tensor array.
n = array_ops.shape(elems)[0]
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
dynamic_size=False)
elems_ta = elems_ta.unpack(elems)
i = constant_op.constant(0)
acc_ta = tensor_array_ops.TensorArray(dtype=dtype, size=n,
dynamic_size=False)
def compute(i, ta):
ta = ta.write(i, fn(elems_ta.read(i)))
return [i + 1, ta]
_, r_a = While(lambda i, a: i < n, compute, [i, acc_ta],
parallel_iterations=parallel_iterations,
back_prop=back_prop, swap_memory=swap_memory)
return r_a.pack()
def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
swap_memory=False, name=None):
"""The scan operator on the list of tensors resulted from unpacking `elems`
along the first dimension.
This scan operator repeatedly applies the callable `fn` to a sequence
of elements from first to last. The elements are made of the tensors
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
arguments. The first argument is the accumulated value computed from the
preceding invocation of fn. If `initializer` is None, `elems` must contain
at least one element, and its first element is used as the initializer.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
Args:
fn: The callable to be performed.
elems: A tensor to be unpacked on dimension 0.
initializer: (optional) The initial value for the accumulator.
parallel_iterations: (optional) The number of iterations allowed to run
in parallel.
back_prop: (optional) True enables back propagation.
swap_memory: (optional) True enables GPU-CPU memory swapping.
name: (optional) Name prefix for the returned tensors.
Returns:
A tensor that packs the results of applying `fn` to the list of tensors
unpacked from `elems`, from first to last.
Raises:
TypeError: if `fn` is not callable.
Example:
```python
elems = [1, 2, 3, 4, 5, 6]
sum = scan(lambda a, x: a + x, elems)
# sum == [1, 3, 6, 10, 15, 21]
```
"""
with ops.op_scope([elems], name, "scan") as name:
if not callable(fn):
raise TypeError("fn must be callable.")
# Convert elems to tensor array.
n = array_ops.shape(elems)[0]
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
dynamic_size=False)
elems_ta = elems_ta.unpack(elems)
if initializer is None:
a = elems_ta.read(0)
i = constant_op.constant(1)
else:
a = ops.convert_to_tensor(initializer)
i = constant_op.constant(0)
# Create a tensor array to store the intermediate values.
acc_ta = tensor_array_ops.TensorArray(dtype=a.dtype, size=n,
dynamic_size=False)
if initializer is None:
acc_ta = acc_ta.write(0, a)
def compute(i, a, ta):
a = fn(a, elems_ta.read(i))
ta = ta.write(i, a)
return [i + 1, a, ta]
_, _, r_a = While(lambda i, a, ta: i < n, compute, [i, a, acc_ta],
parallel_iterations=parallel_iterations,
back_prop=back_prop, swap_memory=swap_memory)
return r_a.pack()
def case(pred_fn_pairs, default, exclusive=False, name="case"): def case(pred_fn_pairs, default, exclusive=False, name="case"):
"""Create a case operation. """Create a case operation.

View File

@ -13,13 +13,28 @@
# limitations under the License. # limitations under the License.
# ============================================================================= # =============================================================================
"""Functional operations.""" """Functional operations.
## Higher Order Operators
TensorFlow provides several higher order operators to simplify the common
map-reduce programming patterns.
@@map_fn
@@foldl
@@foldr
@@scan
"""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import constant_op
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import tensor_array_ops
# pylint: disable=wildcard-import # pylint: disable=wildcard-import
from tensorflow.python.ops.gen_functional_ops import * from tensorflow.python.ops.gen_functional_ops import *
# pylint: enable=wildcard-import # pylint: enable=wildcard-import
@ -28,6 +43,269 @@ from tensorflow.python.ops.gen_functional_ops import _symbolic_gradient
# pylint: enable=unused-import # pylint: enable=unused-import
# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
swap_memory=False, name=None):
"""foldl on the list of tensors unpacked from `elems` on dimension 0.
This foldl operator repeatedly applies the callable `fn` to a sequence
of elements from first to last. The elements are made of the tensors
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
arguments. The first argument is the accumulated value computed from the
preceding invocation of fn. If `initializer` is None, `elems` must contain
at least one element, and its first element is used as the initializer.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is fn(initializer, values[0]).shape`.
Args:
fn: The callable to be performed.
elems: A tensor to be unpacked on dimension 0.
initializer: (optional) The initial value for the accumulator.
parallel_iterations: (optional) The number of iterations allowed to run
in parallel.
back_prop: (optional) True enables back propagation.
swap_memory: (optional) True enables GPU-CPU memory swapping.
name: (optional) Name prefix for the returned tensors.
Returns:
A tensor resulting from applying `fn` consecutively to the list of tensors
unpacked from `elems`, from first to last.
Raises:
TypeError: if `fn` is not callable.
Example:
```python
elems = [1, 2, 3, 4, 5, 6]
sum = foldl(lambda a, x: a + x, elems)
# sum == 21
```
"""
with ops.op_scope([elems], name, "foldl") as name:
if not callable(fn):
raise TypeError("fn must be callable.")
# Convert elems to tensor array.
n = array_ops.shape(elems)[0]
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
dynamic_size=False)
elems_ta = elems_ta.unpack(elems)
if initializer is None:
a = elems_ta.read(0)
i = constant_op.constant(1)
else:
a = ops.convert_to_tensor(initializer)
i = constant_op.constant(0)
def compute(i, a):
a = fn(a, elems_ta.read(i))
return [i + 1, a]
_, r_a = control_flow_ops.While(lambda i, a: i < n, compute, [i, a],
parallel_iterations=parallel_iterations,
back_prop=back_prop,
swap_memory=swap_memory)
return r_a
def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
swap_memory=False, name=None):
"""foldr on the list of tensors unpacked from `elems` on dimension 0.
This foldr operator repeatedly applies the callable `fn` to a sequence
of elements from last to first. The elements are made of the tensors
unpacked from `elems`. The callable fn takes two tensors as arguments.
The first argument is the accumulated value computed from the preceding
invocation of fn. If `initializer` is None, `elems` must contain at least
one element, and its first element is used as the initializer.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is `fn(initializer, values[0]).shape`.
Args:
fn: The callable to be performed.
elems: A tensor that is unpacked into a sequence of tensors to apply `fn`.
initializer: (optional) The initial value for the accumulator.
parallel_iterations: (optional) The number of iterations allowed to run
in parallel.
back_prop: (optional) True enables back propagation.
swap_memory: (optional) True enables GPU-CPU memory swapping.
name: (optional) Name prefix for the returned tensors.
Returns:
A tensor resulting from applying `fn` consecutively to the list of tensors
unpacked from `elems`, from last to first.
Raises:
TypeError: if `fn` is not callable.
Example:
```python
elems = [1, 2, 3, 4, 5, 6]
sum = foldr(lambda a, x: a + x, elems)
# sum == 21
```
"""
with ops.op_scope([elems], name, "foldr") as name:
if not callable(fn):
raise TypeError("fn must be callable.")
# Convert elems to tensor array.
n = array_ops.shape(elems)[0]
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
dynamic_size=False)
elems_ta = elems_ta.unpack(elems)
if initializer is None:
i = n - 1
a = elems_ta.read(i)
else:
i = n
a = ops.convert_to_tensor(initializer)
def compute(i, a):
i -= 1
a = fn(a, elems_ta.read(i))
return [i, a]
_, r_a = control_flow_ops.While(lambda i, a: i > 0, compute, [i, a],
parallel_iterations=parallel_iterations,
back_prop=back_prop,
swap_memory=swap_memory)
return r_a
def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
swap_memory=False, name=None):
"""map on the list of tensors unpacked from `elems` on dimension 0.
This map operator repeatedly applies the callable `fn` to a sequence of
elements from first to last. The elements are made of the tensors unpacked
from `elems`. `dtype` is the data type of the return value of `fn`. Users
must provide `dtype` if it is different from the data type of `elems`.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is `[len(values)] + fn(values[0]).shape`.
Args:
fn: The callable to be performed.
elems: A tensor to be unpacked to apply `fn`.
dtype: (optional) The output type of `fn`.
parallel_iterations: (optional) The number of iterations allowed to run
in parallel.
back_prop: (optional) True enables back propagation.
swap_memory: (optional) True enables GPU-CPU memory swapping.
name: (optional) Name prefix for the returned tensors.
Returns:
A tensor that packs the results of applying `fn` to the list of tensors
unpacked from `elems`, from first to last.
Raises:
TypeError: if `fn` is not callable.
Example:
```python
elems = [1, 2, 3, 4, 5, 6]
squares = map_fn(lambda x: x * x, elems)
# squares == [1, 4, 9, 16, 25, 36]
```
"""
with ops.op_scope([elems], name, "map") as name:
if not callable(fn):
raise TypeError("fn must be callable.")
dtype = dtype if dtype else elems.dtype
# Convert elems to tensor array.
n = array_ops.shape(elems)[0]
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
dynamic_size=False)
elems_ta = elems_ta.unpack(elems)
i = constant_op.constant(0)
acc_ta = tensor_array_ops.TensorArray(dtype=dtype, size=n,
dynamic_size=False)
def compute(i, ta):
ta = ta.write(i, fn(elems_ta.read(i)))
return [i + 1, ta]
_, r_a = control_flow_ops.While(lambda i, a: i < n, compute, [i, acc_ta],
parallel_iterations=parallel_iterations,
back_prop=back_prop,
swap_memory=swap_memory)
return r_a.pack()
def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
swap_memory=False, name=None):
"""scan on the list of tensors unpacked from `elems` on dimension 0.
This scan operator repeatedly applies the callable `fn` to a sequence
of elements from first to last. The elements are made of the tensors
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
arguments. The first argument is the accumulated value computed from the
preceding invocation of fn. If `initializer` is None, `elems` must contain
at least one element, and its first element is used as the initializer.
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
Args:
fn: The callable to be performed.
elems: A tensor to be unpacked on dimension 0.
initializer: (optional) The initial value for the accumulator.
parallel_iterations: (optional) The number of iterations allowed to run
in parallel.
back_prop: (optional) True enables back propagation.
swap_memory: (optional) True enables GPU-CPU memory swapping.
name: (optional) Name prefix for the returned tensors.
Returns:
A tensor that packs the results of applying `fn` to the list of tensors
unpacked from `elems`, from first to last.
Raises:
TypeError: if `fn` is not callable.
Example:
```python
elems = [1, 2, 3, 4, 5, 6]
sum = scan(lambda a, x: a + x, elems)
# sum == [1, 3, 6, 10, 15, 21]
```
"""
with ops.op_scope([elems], name, "scan") as name:
if not callable(fn):
raise TypeError("fn must be callable.")
# Convert elems to tensor array.
n = array_ops.shape(elems)[0]
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
dynamic_size=False)
elems_ta = elems_ta.unpack(elems)
if initializer is None:
a = elems_ta.read(0)
i = constant_op.constant(1)
else:
a = ops.convert_to_tensor(initializer)
i = constant_op.constant(0)
# Create a tensor array to store the intermediate values.
acc_ta = tensor_array_ops.TensorArray(dtype=a.dtype, size=n,
dynamic_size=False)
if initializer is None:
acc_ta = acc_ta.write(0, a)
def compute(i, a, ta):
a = fn(a, elems_ta.read(i))
ta = ta.write(i, a)
return [i + 1, a, ta]
_, _, r_a = control_flow_ops.While(
lambda i, a, ta: i < n, compute, [i, a, acc_ta],
parallel_iterations=parallel_iterations,
back_prop=back_prop, swap_memory=swap_memory)
return r_a.pack()
@ops.RegisterShape("SymbolicGradient") @ops.RegisterShape("SymbolicGradient")
def _symbolic_gradient_shape(op): def _symbolic_gradient_shape(op):
# Say, (u, v) = f(x, y, z), _symbolic_gradient(f) is a function of # Say, (u, v) = f(x, y, z), _symbolic_gradient(f) is a function of

View File

@ -37,11 +37,8 @@ from tensorflow.python.ops.control_flow_ops import no_op
from tensorflow.python.ops.control_flow_ops import tuple from tensorflow.python.ops.control_flow_ops import tuple
from tensorflow.python.ops.control_flow_ops import cond from tensorflow.python.ops.control_flow_ops import cond
from tensorflow.python.ops.control_flow_ops import case from tensorflow.python.ops.control_flow_ops import case
from tensorflow.python.ops.control_flow_ops import foldl
from tensorflow.python.ops.control_flow_ops import foldr
from tensorflow.python.ops.control_flow_ops import map_fn
from tensorflow.python.ops.control_flow_ops import scan
from tensorflow.python.ops.data_flow_ops import * from tensorflow.python.ops.data_flow_ops import *
from tensorflow.python.ops.functional_ops import *
from tensorflow.python.ops.gradients import * from tensorflow.python.ops.gradients import *
from tensorflow.python.ops.histogram_ops import * from tensorflow.python.ops.histogram_ops import *
from tensorflow.python.ops.init_ops import * from tensorflow.python.ops.init_ops import *

View File

@ -359,7 +359,8 @@ def _pure_variable_scope(name_or_scope, reuse=None, initializer=None,
""" """
get_variable_scope() # Ensure that a default exists, then get a pointer. get_variable_scope() # Ensure that a default exists, then get a pointer.
default_varscope = ops.get_collection(_VARSCOPE_KEY) # Get the reference to the collection as we want to modify it in place.
default_varscope = ops.get_collection_ref(_VARSCOPE_KEY)
try: try:
old = default_varscope[0] old = default_varscope[0]
reuse = reuse or old.reuse # Re-using is inherited by sub-scopes. reuse = reuse or old.reuse # Re-using is inherited by sub-scopes.

View File

@ -256,7 +256,7 @@ class SessionManager(object):
self._safe_close(sess) self._safe_close(sess)
logging.info("Waiting for model to be ready: %s", not_ready) logging.info("Waiting for model to be ready: %s", not_ready)
time.sleep(self._recovery_wait_secs) time.sleep(self._recovery_wait_secs)
sess = session.Session(master, graph=self._graph) sess = session.Session(target, graph=self._graph, config=config)
return sess return sess

View File

@ -104,6 +104,13 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
# Now you can call `minimize()` or `compute_gradients()` and # Now you can call `minimize()` or `compute_gradients()` and
# `apply_gradients()` normally # `apply_gradients()` normally
grads = opt.minimize(total_loss, global_step=self.global_step) grads = opt.minimize(total_loss, global_step=self.global_step)
# You can now call get_init_tokens_op() and get_chief_queue_runner().
# Note that get_init_tokens_op() must be called before creating session
# because it modifies the graph.
init_token_op = opt.get_init_tokens_op()
chief_queue_runner = opt.get_chief_queue_runner()
``` ```
In the training program, every worker will run the train_op as if not In the training program, every worker will run the train_op as if not
@ -114,9 +121,9 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
# After the session is created by the superviser and before the main while # After the session is created by the superviser and before the main while
# loop: # loop:
if is_chief and FLAGS.sync_replicas: if is_chief and FLAGS.sync_replicas:
sv.start_queue_runners(sess, [opt.get_chief_queue_runner()]) sv.start_queue_runners(sess, [chief_queue_runner])
# Insert initial tokens to the queue. # Insert initial tokens to the queue.
sess.run(opt.get_init_tokens_op()) sess.run(init_token_op)
``` ```
@@__init__ @@__init__

View File

@ -89,37 +89,41 @@ const DOMAIN_EDGE_WIDTH_SCALE = [1, 5E6];
/** /**
* Parameters that affect how the graph is rendered on the screen. * Parameters that affect how the graph is rendered on the screen.
*/ */
export interface RenderGraphParams { const PARAMS = {
/** /**
* Whether to extract high degree nodes from the core part of the graph. * Whether to extract high degree nodes from the core part of the graph.
*/ */
enableExtraction: boolean; enableExtraction: true,
/** /**
* Maximum in-degree that a node can have without being considered as * Maximum in-degree that a node can have without being considered as
* high in-degree node. * high in-degree node.
*/ */
maxInDegree: number; maxInDegree: 4,
/** /**
* Maximum out-degree that a node can have without being considered as * Maximum out-degree that a node can have without being considered as
* high out-degree node. * high out-degree node.
*/ */
maxOutDegree: number; maxOutDegree: 4,
/** /**
* Maximum number of control edges a node can have before they aren't * Maximum number of control edges a node can have before they aren't
* displayed. * displayed.
*/ */
maxControlDegree: number; maxControlDegree: 4,
/** /**
* Types patterns for predefined out-extract nodes, which are * Types patterns for predefined out-extract nodes, which are
* sink-like nodes that will be extracted from the main graph. * sink-like nodes that will be extracted from the main graph.
*/ */
outExtractTypes: string[]; outExtractTypes: [
"NoOp" // NoOps are sink-like used for managing control dependencies.
],
/** /**
* Types patterns for predefined in-extract nodes, which are * Types patterns for predefined in-extract nodes, which are
* source-like nodes that will be extracted from the main graph. * source-like nodes that will be extracted from the main graph.
*/ */
inExtractTypes: string[]; inExtractTypes: [
"Variable"
],
/** /**
* When removing edges from a high degree node, remove all of its edges if * When removing edges from a high degree node, remove all of its edges if
@ -127,32 +131,33 @@ export interface RenderGraphParams {
* the node has high in-degree, or all out-edges if the node has high * the node has high in-degree, or all out-edges if the node has high
* out-degree. * out-degree.
*/ */
detachAllEdgesForHighDegree: boolean; detachAllEdgesForHighDegree: true,
/** /**
* After extracting high in/out degree nodes and predefined * After extracting high in/out degree nodes and predefined
* source-like/sink-like, extract isolated nodes to the side * source-like/sink-like, extract isolated nodes to the side
* if this extractIsolatedNodesWithAnnotationsOnOneSide is true. * if this extractIsolatedNodesWithAnnotationsOnOneSide is true.
*/ */
extractIsolatedNodesWithAnnotationsOnOneSide: boolean; extractIsolatedNodesWithAnnotationsOnOneSide: true,
/** /**
* Whether to add bridge nodes and edges to the core when building the * Whether to add bridge nodes and edges to the core when building the
* subhierarchy of an expanded metanode. See buildSubhierarchy(). * subhierarchy of an expanded metanode. See buildSubhierarchy().
*/ */
enableBridgegraph: boolean; enableBridgegraph: true,
/** /**
* 2 colors, for the minimum and maximum value respectively, whenever we * 2 colors, for the minimum and maximum value respectively, whenever we
* have a gradient scale. * have a gradient scale.
*/ */
minMaxColors: string[]; minMaxColors: ["#fff5f0", "#fb6a4a"],
/** /**
* Maximum number of annotations to be displayed on a node. * Maximum number of annotations to be displayed on a node before an
* ellipsis is used.
*/ */
maxAnnotations: number; maxAnnotations: 5
} };
/** /**
* Stores the rendering information, such as x and y coordinates, * Stores the rendering information, such as x and y coordinates,
@ -161,7 +166,6 @@ export interface RenderGraphParams {
export class RenderGraphInfo { export class RenderGraphInfo {
private hierarchy: hierarchy.Hierarchy; private hierarchy: hierarchy.Hierarchy;
private index: {[nodeName: string]: RenderNodeInfo}; private index: {[nodeName: string]: RenderNodeInfo};
private params: RenderGraphParams;
private deviceColorMap: d3.scale.Ordinal<string, string>; private deviceColorMap: d3.scale.Ordinal<string, string>;
private memoryUsageScale: d3.scale.Linear<string, string>; private memoryUsageScale: d3.scale.Linear<string, string>;
private computeTimeScale: d3.scale.Linear<string, string>; private computeTimeScale: d3.scale.Linear<string, string>;
@ -173,7 +177,7 @@ export class RenderGraphInfo {
private hasSubhierarchy: {[nodeName: string]: boolean}; private hasSubhierarchy: {[nodeName: string]: boolean};
root: RenderGroupNodeInfo; root: RenderGroupNodeInfo;
constructor(hierarchy: hierarchy.Hierarchy, params: RenderGraphParams) { constructor(hierarchy: hierarchy.Hierarchy) {
this.hierarchy = hierarchy; this.hierarchy = hierarchy;
this.index = {}; this.index = {};
this.deviceColorMap = d3.scale.ordinal<string>() this.deviceColorMap = d3.scale.ordinal<string>()
@ -199,7 +203,7 @@ export class RenderGraphInfo {
}); });
this.memoryUsageScale = d3.scale.linear<string, string>() this.memoryUsageScale = d3.scale.linear<string, string>()
.domain(memoryExtent) .domain(memoryExtent)
.range(params.minMaxColors); .range(PARAMS.minMaxColors);
// Find also the minimum and maximum compute time. // Find also the minimum and maximum compute time.
let computeTimeExtent = d3.extent(topLevelGraph.nodes(), let computeTimeExtent = d3.extent(topLevelGraph.nodes(),
@ -212,12 +216,11 @@ export class RenderGraphInfo {
}); });
this.computeTimeScale = d3.scale.linear<string, string>() this.computeTimeScale = d3.scale.linear<string, string>()
.domain(computeTimeExtent) .domain(computeTimeExtent)
.range(params.minMaxColors); .range(PARAMS.minMaxColors);
// Maps node name to whether the rendering hierarchy was already // Maps node name to whether the rendering hierarchy was already
// constructed. // constructed.
this.hasSubhierarchy = {}; this.hasSubhierarchy = {};
this.params = params;
this.root = new RenderGroupNodeInfo(hierarchy.root); this.root = new RenderGroupNodeInfo(hierarchy.root);
this.index[hierarchy.root.name] = this.root; this.index[hierarchy.root.name] = this.root;
this.buildSubhierarchy(hierarchy.root.name); this.buildSubhierarchy(hierarchy.root.name);
@ -373,13 +376,13 @@ export class RenderGraphInfo {
_.each((<OpNode>childNode).inEmbeddings, embedding => { _.each((<OpNode>childNode).inEmbeddings, embedding => {
let renderMetaedgeInfo = new RenderMetaedgeInfo(null); let renderMetaedgeInfo = new RenderMetaedgeInfo(null);
addInAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo, addInAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo,
AnnotationType.CONSTANT, this.params); AnnotationType.CONSTANT);
this.index[embedding.name] = new RenderNodeInfo(embedding); this.index[embedding.name] = new RenderNodeInfo(embedding);
}); });
_.each((<OpNode>childNode).outEmbeddings, embedding => { _.each((<OpNode>childNode).outEmbeddings, embedding => {
let renderMetaedgeInfo = new RenderMetaedgeInfo(null); let renderMetaedgeInfo = new RenderMetaedgeInfo(null);
addOutAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo, addOutAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo,
AnnotationType.SUMMARY, this.params); AnnotationType.SUMMARY);
this.index[embedding.name] = new RenderNodeInfo(embedding); this.index[embedding.name] = new RenderNodeInfo(embedding);
}); });
} }
@ -393,9 +396,9 @@ export class RenderGraphInfo {
coreGraph.setEdge(edgeObj.v, edgeObj.w, renderMetaedgeInfo); coreGraph.setEdge(edgeObj.v, edgeObj.w, renderMetaedgeInfo);
}); });
if (this.params.enableExtraction && if (PARAMS.enableExtraction &&
renderGroupNodeInfo.node.type === NodeType.META) { renderGroupNodeInfo.node.type === NodeType.META) {
extractHighDegrees(renderGroupNodeInfo, this.params); extractHighDegrees(renderGroupNodeInfo);
} }
// Record that we constructed the rendering hierarchy for this node, so we // Record that we constructed the rendering hierarchy for this node, so we
@ -469,7 +472,7 @@ export class RenderGraphInfo {
// either node is high-degree with respect to control edges. This will // either node is high-degree with respect to control edges. This will
// be a signal to show it as an annotation instead of a bridge edge. // be a signal to show it as an annotation instead of a bridge edge.
let isHighDegreeControlEdge = !bridgeMetaedge.numRegularEdges && let isHighDegreeControlEdge = !bridgeMetaedge.numRegularEdges &&
otherCounts.control[otherName] > this.params.maxControlDegree; otherCounts.control[otherName] > PARAMS.maxControlDegree;
let [, childAnnotations] = let [, childAnnotations] =
inbound ? inbound ?
@ -478,8 +481,8 @@ export class RenderGraphInfo {
let isOtherHighDegree = let isOtherHighDegree =
inbound ? inbound ?
otherCounts.out[otherName] > this.params.maxOutDegree : otherCounts.out[otherName] > PARAMS.maxOutDegree :
otherCounts.in[otherName] > this.params.maxInDegree; otherCounts.in[otherName] > PARAMS.maxInDegree;
// The adjoining render metaedge info from the parent's coreGraph, if any. // The adjoining render metaedge info from the parent's coreGraph, if any.
// It will either be a Metaedge involving this node directly, if it // It will either be a Metaedge involving this node directly, if it
@ -493,7 +496,7 @@ export class RenderGraphInfo {
// - the child is in the core (not extracted for being high-degree), and // - the child is in the core (not extracted for being high-degree), and
// - there's a path (in the traversal sense) between child and other. // - there's a path (in the traversal sense) between child and other.
let canDrawBridgePath = false; let canDrawBridgePath = false;
if (this.params.enableBridgegraph && if (PARAMS.enableBridgegraph &&
!isOtherHighDegree && !isOtherHighDegree &&
!isHighDegreeControlEdge && !isHighDegreeControlEdge &&
childRenderInfo.isInCore()) { childRenderInfo.isInCore()) {
@ -581,7 +584,7 @@ export class RenderGraphInfo {
otherRenderInfo, otherRenderInfo,
new RenderMetaedgeInfo(bridgeMetaedge), new RenderMetaedgeInfo(bridgeMetaedge),
AnnotationType.SHORTCUT, AnnotationType.SHORTCUT,
inbound), this.params); inbound));
return; return;
} }
@ -869,13 +872,13 @@ export class AnnotationList {
* Append an annotation to the list, or a stand-in ellipsis annotation instead * Append an annotation to the list, or a stand-in ellipsis annotation instead
* if this would make it too many. * if this would make it too many.
*/ */
push(annotation: Annotation, params: RenderGraphParams): void { push(annotation: Annotation): void {
if (annotation.node.name in this.nodeNames) { if (annotation.node.name in this.nodeNames) {
return; // Skip duplicate annotation. return; // Skip duplicate annotation.
} }
this.nodeNames[annotation.node.name] = true; this.nodeNames[annotation.node.name] = true;
if (this.list.length < params.maxAnnotations) { if (this.list.length < PARAMS.maxAnnotations) {
this.list.push(annotation); this.list.push(annotation);
return; return;
} }
@ -1101,19 +1104,18 @@ export class RenderMetaedgeInfo {
function addInAnnotation(node: RenderNodeInfo, predecessor: Node, function addInAnnotation(node: RenderNodeInfo, predecessor: Node,
predecessorRenderInfo: RenderNodeInfo, predecessorRenderInfo: RenderNodeInfo,
edge: RenderMetaedgeInfo, type: AnnotationType, edge: RenderMetaedgeInfo, type: AnnotationType): void {
params: RenderGraphParams): void {
let annotation = new Annotation(predecessor, predecessorRenderInfo, edge, let annotation = new Annotation(predecessor, predecessorRenderInfo, edge,
type, true); type, true);
node.inAnnotations.push(annotation, params); node.inAnnotations.push(annotation);
} }
function addOutAnnotation(node: RenderNodeInfo, successor: Node, function addOutAnnotation(node: RenderNodeInfo, successor: Node,
successorRenderInfo: RenderNodeInfo, edge: RenderMetaedgeInfo, successorRenderInfo: RenderNodeInfo, edge: RenderMetaedgeInfo,
type: AnnotationType, params: RenderGraphParams): void { type: AnnotationType): void {
let annotation = new Annotation(successor, successorRenderInfo, edge, let annotation = new Annotation(successor, successorRenderInfo, edge,
type, false); type, false);
node.outAnnotations.push(annotation, params); node.outAnnotations.push(annotation);
} }
function setGraphDepth(graph: graphlib.Graph<RenderNodeInfo, any>, function setGraphDepth(graph: graphlib.Graph<RenderNodeInfo, any>,
@ -1181,7 +1183,7 @@ function setGroupNodeDepth(renderInfo: RenderGroupNodeInfo,
*/ */
function createShortcut( function createShortcut(
graph: graphlib.Graph<RenderNodeInfo, RenderMetaedgeInfo>, graph: graphlib.Graph<RenderNodeInfo, RenderMetaedgeInfo>,
v: string, w: string, params: RenderGraphParams) { v: string, w: string) {
let src = graph.node(v); let src = graph.node(v);
let sink = graph.node(w); let sink = graph.node(w);
let edge = graph.edge(v, w); let edge = graph.edge(v, w);
@ -1197,8 +1199,8 @@ function createShortcut(
} }
// Add each annotation. // Add each annotation.
addOutAnnotation(src, sink.node, sink, edge, AnnotationType.SHORTCUT, params); addOutAnnotation(src, sink.node, sink, edge, AnnotationType.SHORTCUT);
addInAnnotation(sink, src.node, src, edge, AnnotationType.SHORTCUT, params); addInAnnotation(sink, src.node, src, edge, AnnotationType.SHORTCUT);
// Remove the edge from the core graph. // Remove the edge from the core graph.
graph.removeEdge(v, w); graph.removeEdge(v, w);
@ -1212,18 +1214,18 @@ function createShortcut(
* edges. Otherwise, only extract all in-edges. * edges. Otherwise, only extract all in-edges.
*/ */
function makeOutExtract(renderNode: RenderGroupNodeInfo, n: string, function makeOutExtract(renderNode: RenderGroupNodeInfo, n: string,
params: RenderGraphParams, forceDetach?: boolean) { forceDetach?: boolean) {
let graph = renderNode.coreGraph; let graph = renderNode.coreGraph;
let child = graph.node(n); let child = graph.node(n);
child.isOutExtract = true; child.isOutExtract = true;
_.each(graph.predecessors(n), (p, index) => { _.each(graph.predecessors(n), (p, index) => {
createShortcut(graph, p, n, params); createShortcut(graph, p, n);
}); });
if (params.detachAllEdgesForHighDegree || forceDetach) { if (PARAMS.detachAllEdgesForHighDegree || forceDetach) {
_.each(graph.successors(n), (s, index) => { _.each(graph.successors(n), (s, index) => {
createShortcut(graph, n, s, params); createShortcut(graph, n, s);
}); });
} }
@ -1243,18 +1245,18 @@ function makeOutExtract(renderNode: RenderGroupNodeInfo, n: string,
* edges. Otherwise, only remove all out-edges. * edges. Otherwise, only remove all out-edges.
*/ */
export function makeInExtract(renderNode: RenderGroupNodeInfo, n: string, export function makeInExtract(renderNode: RenderGroupNodeInfo, n: string,
params: RenderGraphParams, forceDetach?: boolean) { forceDetach?: boolean) {
let graph = renderNode.coreGraph; let graph = renderNode.coreGraph;
let child = graph.node(n); let child = graph.node(n);
child.isInExtract = true; child.isInExtract = true;
_.each(graph.successors(n), (s, index) => { _.each(graph.successors(n), (s, index) => {
createShortcut(graph, n, s, params); createShortcut(graph, n, s);
}); });
if (params.detachAllEdgesForHighDegree || forceDetach) { if (PARAMS.detachAllEdgesForHighDegree || forceDetach) {
_.each(graph.predecessors(n), (p, index) => { _.each(graph.predecessors(n), (p, index) => {
createShortcut(graph, p, n, params); createShortcut(graph, p, n);
}); });
} }
@ -1289,40 +1291,37 @@ function hasTypeIn(node: Node, types: string[]): boolean {
} }
/** Move nodes that are specified to be excluded out of the core graph. */ /** Move nodes that are specified to be excluded out of the core graph. */
function extractSpecifiedNodes(renderNode: RenderGroupNodeInfo, function extractSpecifiedNodes(renderNode: RenderGroupNodeInfo) {
params: RenderGraphParams) {
let graph = renderNode.coreGraph; let graph = renderNode.coreGraph;
_.each(graph.nodes(), n => { _.each(graph.nodes(), n => {
let renderInfo = graph.node(n); let renderInfo = graph.node(n);
if (renderInfo.node.include === InclusionType.EXCLUDE) { if (renderInfo.node.include === InclusionType.EXCLUDE) {
if (renderNode.coreGraph.outEdges(n).length > if (renderNode.coreGraph.outEdges(n).length >
renderNode.coreGraph.inEdges(n).length) { renderNode.coreGraph.inEdges(n).length) {
makeOutExtract(renderNode, n, params, true); makeOutExtract(renderNode, n, true);
} else { } else {
makeInExtract(renderNode, n, params, true); makeInExtract(renderNode, n, true);
} }
} }
}); });
} }
/** Remove edges from pre-defined out-extract patterns */ /** Remove edges from pre-defined out-extract patterns */
function extractPredefinedSink(renderNode: RenderGroupNodeInfo, function extractPredefinedSink(renderNode: RenderGroupNodeInfo) {
params: RenderGraphParams) {
let graph = renderNode.coreGraph; let graph = renderNode.coreGraph;
_.each(graph.nodes(), n => { _.each(graph.nodes(), n => {
let renderInfo = graph.node(n); let renderInfo = graph.node(n);
if (renderInfo.node.include !== InclusionType.UNSPECIFIED) { if (renderInfo.node.include !== InclusionType.UNSPECIFIED) {
return; return;
} }
if (hasTypeIn(renderInfo.node, params.outExtractTypes)) { if (hasTypeIn(renderInfo.node, PARAMS.outExtractTypes)) {
makeOutExtract(renderNode, n, params); makeOutExtract(renderNode, n);
} }
}); });
} }
/** Remove edges from pre-defined in-extract patterns */ /** Remove edges from pre-defined in-extract patterns */
function extractPredefinedSource(renderNode: RenderGroupNodeInfo, function extractPredefinedSource(renderNode: RenderGroupNodeInfo) {
params: RenderGraphParams) {
let graph = renderNode.coreGraph; let graph = renderNode.coreGraph;
_.each(graph.nodes(), n => { _.each(graph.nodes(), n => {
@ -1330,17 +1329,16 @@ function extractPredefinedSource(renderNode: RenderGroupNodeInfo,
if (renderInfo.node.include !== InclusionType.UNSPECIFIED) { if (renderInfo.node.include !== InclusionType.UNSPECIFIED) {
return; return;
} }
if (hasTypeIn(renderInfo.node, params.inExtractTypes)) { if (hasTypeIn(renderInfo.node, PARAMS.inExtractTypes)) {
makeInExtract(renderNode, n, params); makeInExtract(renderNode, n);
} }
}); });
} }
/** Extract from nodes with in-degree > maxInDegree */ /** Extract from nodes with in-degree > maxInDegree */
function extractHighInDegree(renderNode: RenderGroupNodeInfo, function extractHighInDegree(renderNode: RenderGroupNodeInfo) {
params: RenderGraphParams) {
let graph = renderNode.coreGraph; let graph = renderNode.coreGraph;
let maxInDegree = params.maxInDegree; let maxInDegree = PARAMS.maxInDegree;
// detect first so degrees don't get affected by other removal // detect first so degrees don't get affected by other removal
let highInDegreeNames = _.filter(graph.nodes(), n => { let highInDegreeNames = _.filter(graph.nodes(), n => {
@ -1363,15 +1361,14 @@ function extractHighInDegree(renderNode: RenderGroupNodeInfo,
}); });
_.each(highInDegreeNames, n => { _.each(highInDegreeNames, n => {
makeOutExtract(renderNode, n, params); makeOutExtract(renderNode, n);
}); });
} }
/** Extract nodes with out-degree > maxOutDegree */ /** Extract nodes with out-degree > maxOutDegree */
function extractHighOutDegree(renderNode: RenderGroupNodeInfo, function extractHighOutDegree(renderNode: RenderGroupNodeInfo) {
params: RenderGraphParams) {
let graph = renderNode.coreGraph; let graph = renderNode.coreGraph;
let maxOutDegree = params.maxOutDegree; let maxOutDegree = PARAMS.maxOutDegree;
// detect first so degrees don't get affected by other removal // detect first so degrees don't get affected by other removal
let highOutDegreeNames = _.filter(graph.nodes(), n => { let highOutDegreeNames = _.filter(graph.nodes(), n => {
@ -1394,13 +1391,12 @@ function extractHighOutDegree(renderNode: RenderGroupNodeInfo,
}); });
_.each(highOutDegreeNames, n => { _.each(highOutDegreeNames, n => {
makeInExtract(renderNode, n, params); makeInExtract(renderNode, n);
}); });
} }
/** Remove control edges from nodes that have too many control edges */ /** Remove control edges from nodes that have too many control edges */
function removeControlEdges(renderNode: RenderGroupNodeInfo, function removeControlEdges(renderNode: RenderGroupNodeInfo) {
params: RenderGraphParams) {
let graph = renderNode.coreGraph; let graph = renderNode.coreGraph;
// Collect control edges into a map by node name. // Collect control edges into a map by node name.
@ -1414,8 +1410,8 @@ function removeControlEdges(renderNode: RenderGroupNodeInfo,
// For each node with too many control edges, turn them into annotations. // For each node with too many control edges, turn them into annotations.
_.each(map, (edges, nodeName) => { _.each(map, (edges, nodeName) => {
if (edges.length > params.maxControlDegree) { if (edges.length > PARAMS.maxControlDegree) {
_.each(edges, e => createShortcut(graph, e.v, e.w, params)); _.each(edges, e => createShortcut(graph, e.v, e.w));
} }
}); });
} }
@ -1445,38 +1441,35 @@ export function mapIndexToHue(id: number): number {
* screw up the graph layout. * screw up the graph layout.
* *
* @param {Render.Node} renderNode Node to manipulate. * @param {Render.Node} renderNode Node to manipulate.
* @param {Object} params render Graph construction parameters. See
* <tf-graph-params>'s output
*/ */
function extractHighDegrees(renderNode: RenderGroupNodeInfo, function extractHighDegrees(renderNode: RenderGroupNodeInfo) {
params: RenderGraphParams) {
extractSpecifiedNodes(renderNode, params); extractSpecifiedNodes(renderNode);
if (params.outExtractTypes) { if (PARAMS.outExtractTypes) {
extractPredefinedSink(renderNode, params); extractPredefinedSink(renderNode);
} }
// This has to come before extract high in-degree to protect the core part // This has to come before extract high in-degree to protect the core part
// that takes many variables. // that takes many variables.
if (params.inExtractTypes) { if (PARAMS.inExtractTypes) {
extractPredefinedSource(renderNode, params); extractPredefinedSource(renderNode);
} }
// This has to come before extract high out-degree to protect the core part // This has to come before extract high out-degree to protect the core part
// that output to many places as there are more high-degree sinks than // that output to many places as there are more high-degree sinks than
// sources. // sources.
if (params.maxInDegree) { if (PARAMS.maxInDegree) {
extractHighInDegree(renderNode, params); extractHighInDegree(renderNode);
} }
if (params.maxOutDegree) { if (PARAMS.maxOutDegree) {
extractHighOutDegree(renderNode, params); extractHighOutDegree(renderNode);
} }
if (params.maxControlDegree) { if (PARAMS.maxControlDegree) {
removeControlEdges(renderNode, params); removeControlEdges(renderNode);
} }
// Extract isolated nodes, which can be // Extract isolated nodes, which can be
@ -1513,7 +1506,7 @@ function extractHighDegrees(renderNode: RenderGroupNodeInfo,
renderNode.isolatedOutExtract.push(child); renderNode.isolatedOutExtract.push(child);
child.node.include = InclusionType.EXCLUDE; child.node.include = InclusionType.EXCLUDE;
graph.removeNode(n); graph.removeNode(n);
} else if (params.extractIsolatedNodesWithAnnotationsOnOneSide) { } else if (PARAMS.extractIsolatedNodesWithAnnotationsOnOneSide) {
if (hasOutAnnotations && !hasInAnnotations) { if (hasOutAnnotations && !hasInAnnotations) {
child.isInExtract = true; // for ones with high out-annotations child.isInExtract = true; // for ones with high out-annotations
renderNode.isolatedInExtract.push(child); renderNode.isolatedInExtract.push(child);

View File

@ -1,113 +0,0 @@
<link rel="import" href="../polymer/polymer.html">
<!--
Module for adjusting render graph building parameter.
-->
<dom-module id="tf-graph-params">
</dom-module>
<script>
Polymer({
is: 'tf-graph-params',
properties: {
// PARAMETERS
enableExtraction: {
type: Boolean,
value: true
},
/** Maximum in-degree that a node can have without being considered as
* high in-degree node. */
maxInDegree: {
type: Number,
value: 4
},
/** Maximum out-degree that a node can have without being considered as
* high out-degree node. */
maxOutDegree: {
type: Number,
value: 4
},
/** Maximum number of control edges a node can have before they aren't
* displayed. */
maxControlDegree: {
type: Number,
value: 4
},
/**
* Types patterns for predefined out-extract nodes, which are
* sink-like nodes that will be extracted from the main graph.
*/
outExtractTypes: {
type: Array,
value: function() {
return [
'NoOp' // for "sgd", "momentum" group
];
}
},
/**
* Types patterns for predefined in-extract nodes, which are
* source-like nodes that will be extracted from the main graph.
*/
inExtractTypes: {
type: Array,
value: function() {
return ['Variable'];
}
},
/**
* When removing edges from a high degree node, remove all of its edges if
* detachAllEdgesForHighDegree is true. Otherwise remove all in-edges if
* the node has high in-degree, or all out-edges if the node has high
* out-degree.
*/
detachAllEdgesForHighDegree: {
type: Boolean,
value: true
},
/**
* After extracting high in/out degree nodes and predefined
* source-like/sink-like, extract isolated nodes to the side
* if this extractIsolatedNodesWithAnnotationsOnOneSide is true.
*/
extractIsolatedNodesWithAnnotationsOnOneSide: {
type: Boolean,
value: true
},
/**
* Whether to draw bridge paths inside of expanded group nodes.
*/
enableBridgegraph: {
type: Boolean,
value: true
},
/**
* Colors for the minimum and maximum values whenever we have a gradient
* scale.
*/
minMaxColors: {
type: Array,
value: function() {
return ["#fff5f0", "#fb6a4a"];
}
},
/**
* Maximum number of annotations to be displayed on a node before an
* ellipsis is used.
*/
maxAnnotations: {
type: Number,
value: 5
}
}
});
</script>

View File

@ -6,7 +6,6 @@
<link rel="import" href="../paper-toggle-button/paper-toggle-button.html"> <link rel="import" href="../paper-toggle-button/paper-toggle-button.html">
<link rel="import" href="../tf-graph-common/tf-graph-common.html"> <link rel="import" href="../tf-graph-common/tf-graph-common.html">
<link rel="import" href="tf-graph-scene.html"> <link rel="import" href="tf-graph-scene.html">
<link rel="import" href="tf-graph-params.html">
<dom-module id="tf-graph"> <dom-module id="tf-graph">
<template> <template>
<style> <style>
@ -37,7 +36,6 @@ paper-button {
} }
</style> </style>
<div class="container"> <div class="container">
<tf-graph-params id="graphParams"></tf-graph-params>
<div class="vertical"> <div class="vertical">
<h2>[[title]]</h2> <h2>[[title]]</h2>
<tf-graph-scene id="scene" class="auto" <tf-graph-scene id="scene" class="auto"
@ -91,13 +89,6 @@ Polymer({
readOnly: true, readOnly: true,
notify: true, notify: true,
}, },
// internal properties
_graphParams: {
type: Object,
value: function() {
return this.$.graphParams;
}
},
_renderDepth: { _renderDepth: {
type: Number, type: Number,
value: 1 value: 1
@ -108,9 +99,9 @@ Polymer({
}, },
}, },
observers: [ observers: [
'_buildRenderHierarchy(graphHierarchy, _graphParams)' '_buildRenderHierarchy(graphHierarchy)'
], ],
_buildRenderHierarchy: function(graphHierarchy, params) { _buildRenderHierarchy: function(graphHierarchy) {
tf.time('new tf.graph.render.Hierarchy', function() { tf.time('new tf.graph.render.Hierarchy', function() {
if (graphHierarchy.root.type !== tf.graph.NodeType.META) { if (graphHierarchy.root.type !== tf.graph.NodeType.META) {
// root must be metanode but sometimes Polymer's dom-if has not // root must be metanode but sometimes Polymer's dom-if has not
@ -118,8 +109,7 @@ Polymer({
// and thus mistakenly pass non-metanode to this module. // and thus mistakenly pass non-metanode to this module.
return; return;
} }
var renderGraph = new tf.graph.render.RenderGraphInfo(graphHierarchy, var renderGraph = new tf.graph.render.RenderGraphInfo(graphHierarchy);
params);
// Producing the 'color by' parameters to be consumed // Producing the 'color by' parameters to be consumed
// by the tf-graph-controls panel. It contains information about the // by the tf-graph-controls panel. It contains information about the
// min and max values and their respective colors, as well as list // min and max values and their respective colors, as well as list
@ -252,7 +242,7 @@ Polymer({
} }
// Rebuild the render hierarchy. // Rebuild the render hierarchy.
this._buildRenderHierarchy(this.graphHierarchy, this._graphParams); this._buildRenderHierarchy(this.graphHierarchy);
}, },
_nodeToggleSeriesGroup: function(event) { _nodeToggleSeriesGroup: function(event) {
// Toggle the group setting of the specified node appropriately. // Toggle the group setting of the specified node appropriately.
@ -270,7 +260,7 @@ Polymer({
tf.graph.hierarchy.build(this.basicGraph, this.hierarchyParams, hierarchyTracker) tf.graph.hierarchy.build(this.basicGraph, this.hierarchyParams, hierarchyTracker)
.then(function(graphHierarchy) { .then(function(graphHierarchy) {
this.set('graphHierarchy', graphHierarchy); this.set('graphHierarchy', graphHierarchy);
this._buildRenderHierarchy(this.graphHierarchy, this._graphParams); this._buildRenderHierarchy(this.graphHierarchy);
}.bind(this)); }.bind(this));
}, },
not: function(x) { not: function(x) {

View File

@ -13,8 +13,8 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
native.new_http_archive( native.new_http_archive(
name = "eigen_archive", name = "eigen_archive",
url = "https://bitbucket.org/eigen/eigen/get/36b0586de49f.tar.gz", url = "https://bitbucket.org/eigen/eigen/get/3d9f227afae2.tar.gz",
sha256 = "86da9dd97c91b6587a257add70d9478f4c463d6697d487564c5bfe83c4a0e8e0", sha256 = "bf2638b7e1085de0b430b000c07e090dc71c83dd7f5b934a06f68b7db02676bf",
build_file = path_prefix + "eigen.BUILD", build_file = path_prefix + "eigen.BUILD",
) )

View File

@ -1 +1 @@
#include "eigen-eigen-36b0586de49f/Eigen/Cholesky" #include "eigen-eigen-3d9f227afae2/Eigen/Cholesky"

View File

@ -1 +1 @@
#include "eigen-eigen-36b0586de49f/Eigen/Core" #include "eigen-eigen-3d9f227afae2/Eigen/Core"

View File

@ -1 +1 @@
#include "eigen-eigen-36b0586de49f/Eigen/Eigenvalues" #include "eigen-eigen-3d9f227afae2/Eigen/Eigenvalues"

View File

@ -1 +1 @@
#include "eigen-eigen-36b0586de49f/Eigen/LU" #include "eigen-eigen-3d9f227afae2/Eigen/LU"

View File

@ -1 +1 @@
#include "eigen-eigen-36b0586de49f/Eigen/QR" #include "eigen-eigen-3d9f227afae2/Eigen/QR"

View File

@ -1 +1 @@
#include "eigen-eigen-36b0586de49f/unsupported/Eigen/CXX11/Tensor" #include "eigen-eigen-3d9f227afae2/unsupported/Eigen/CXX11/Tensor"

View File

@ -3,6 +3,8 @@ build:cuda --define=using_cuda=true
build --force_python=py$PYTHON_MAJOR_VERSION build --force_python=py$PYTHON_MAJOR_VERSION
build --python$PYTHON_MAJOR_VERSION_path=$PYTHON_BINARY build --python$PYTHON_MAJOR_VERSION_path=$PYTHON_BINARY
build --define=use_fast_cpp_protos=true
build --define=allow_oversize_protos=true
build --spawn_strategy=standalone build --spawn_strategy=standalone
test --spawn_strategy=standalone test --spawn_strategy=standalone