Merge commit for internal changes
This commit is contained in:
commit
71320c0909
@ -1,6 +1,6 @@
|
|||||||
package(default_visibility = ["//visibility:public"])
|
package(default_visibility = ["//visibility:public"])
|
||||||
|
|
||||||
archive_dir = "eigen-eigen-36b0586de49f"
|
archive_dir = "eigen-eigen-3d9f227afae2"
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "eigen",
|
name = "eigen",
|
||||||
|
@ -619,14 +619,19 @@ ANDROID_TF_COPTS = [
|
|||||||
"-std=c++11",
|
"-std=c++11",
|
||||||
"-DMIN_LOG_LEVEL=0",
|
"-DMIN_LOG_LEVEL=0",
|
||||||
"-DTF_LEAN_BINARY",
|
"-DTF_LEAN_BINARY",
|
||||||
"-O2",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Native library support for Android applications.
|
# Native library support for Android applications.
|
||||||
# Does not contain operators, use :android_tensorflow_lib if you want full
|
# Does not contain operators, use :android_tensorflow_lib if you want full
|
||||||
# operator support.
|
# operator support.
|
||||||
# Compiles to a trivial library on non-android to prevent irrelevant
|
#
|
||||||
# build errors.
|
# Compiles to a trivial library on non-Android to prevent irrelevant
|
||||||
|
# build errors. If not building this as part of an android_binary,
|
||||||
|
# a command such as the following must be used:
|
||||||
|
# bazel build -c opt tensorflow/core:android_tensorflow_lib \
|
||||||
|
# --crosstool_top=//third_party/java/android/android_ndk_linux/toolchains:everything \
|
||||||
|
# --cpu=armeabi-v7a \
|
||||||
|
# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "android_tensorflow_lib_lite",
|
name = "android_tensorflow_lib_lite",
|
||||||
srcs = select({
|
srcs = select({
|
||||||
@ -643,7 +648,7 @@ cc_library(
|
|||||||
"public/session.h",
|
"public/session.h",
|
||||||
],
|
],
|
||||||
copts = select({
|
copts = select({
|
||||||
":android": ANDROID_TF_COPTS,
|
":android": ANDROID_TF_COPTS + ["-Os"],
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
}),
|
}),
|
||||||
tags = [
|
tags = [
|
||||||
@ -656,19 +661,23 @@ cc_library(
|
|||||||
":protos_cc",
|
":protos_cc",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
],
|
],
|
||||||
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Full Tensorflow library with operator support. Use this unless reducing
|
# Full Tensorflow library with operator support. Use this unless reducing
|
||||||
# binary size (by packaging a reduced operator set) is a concern.
|
# binary size (by packaging a reduced operator set) is a concern.
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "android_tensorflow_lib",
|
name = "android_tensorflow_lib",
|
||||||
srcs = [
|
srcs = select({
|
||||||
":android_op_registrations_and_gradients",
|
":android": [
|
||||||
"//tensorflow/core/kernels:android_core_ops",
|
":android_op_registrations_and_gradients",
|
||||||
"//tensorflow/core/kernels:android_extended_ops",
|
"//tensorflow/core/kernels:android_core_ops",
|
||||||
],
|
"//tensorflow/core/kernels:android_extended_ops",
|
||||||
|
],
|
||||||
|
"//conditions:default": [],
|
||||||
|
}),
|
||||||
copts = select({
|
copts = select({
|
||||||
":android": ANDROID_TF_COPTS,
|
":android": ANDROID_TF_COPTS + ["-O2"],
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
}),
|
}),
|
||||||
tags = [
|
tags = [
|
||||||
@ -682,6 +691,7 @@ cc_library(
|
|||||||
":protos_cc",
|
":protos_cc",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
],
|
],
|
||||||
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
|
@ -124,7 +124,7 @@ class UntypedCall : public core::RefCounted {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
UntypedCall* call_; // `this` owns one reference.
|
UntypedCall* const call_; // `this` owns one reference.
|
||||||
Callback callback_;
|
Callback callback_;
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
@ -149,30 +149,14 @@ class Call : public UntypedCall<Service> {
|
|||||||
Call<Service, GrpcService, RequestMessage, ResponseMessage>*);
|
Call<Service, GrpcService, RequestMessage, ResponseMessage>*);
|
||||||
|
|
||||||
Call(HandleRequestFunction handle_request_function)
|
Call(HandleRequestFunction handle_request_function)
|
||||||
: handle_request_function_(handle_request_function),
|
: handle_request_function_(handle_request_function), responder_(&ctx_) {}
|
||||||
responder_(&ctx_),
|
|
||||||
cancel_tag_(new typename UntypedCall<Service>::Tag(
|
|
||||||
this, &UntypedCall<Service>::RequestCancelled)) {
|
|
||||||
// The `ctx_` borrows the `cancel_tag_` until
|
|
||||||
// `this->RequestReceived()` is called.
|
|
||||||
ctx_.AsyncNotifyWhenDone(cancel_tag_.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual ~Call() {}
|
virtual ~Call() {}
|
||||||
|
|
||||||
void RequestReceived(Service* service, bool ok) override {
|
void RequestReceived(Service* service, bool ok) override {
|
||||||
if (ok) {
|
if (ok) {
|
||||||
// At this point, the `cancel_tag_` becomes owned by the
|
|
||||||
// completion queue.
|
|
||||||
cancel_tag_.release();
|
|
||||||
|
|
||||||
this->Ref();
|
this->Ref();
|
||||||
(service->*handle_request_function_)(this);
|
(service->*handle_request_function_)(this);
|
||||||
} else {
|
|
||||||
// `!ok` implies we never received a request for this call, and
|
|
||||||
// the `cancel_tag_` will never be added to the completion
|
|
||||||
// queue, so we free it here.
|
|
||||||
cancel_tag_.reset();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -190,6 +174,9 @@ class Call : public UntypedCall<Service> {
|
|||||||
cancel_callback_();
|
cancel_callback_();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// NOTE(mrry): This can be called before or after RequestReceived, so we
|
||||||
|
// release `cancel_tag_` (in order to allow the event loop to free it).
|
||||||
|
cancel_tag_.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Registers `callback` as the function that should be called if and when this
|
// Registers `callback` as the function that should be called if and when this
|
||||||
@ -213,9 +200,13 @@ class Call : public UntypedCall<Service> {
|
|||||||
static void EnqueueRequest(GrpcService* grpc_service,
|
static void EnqueueRequest(GrpcService* grpc_service,
|
||||||
::grpc::ServerCompletionQueue* cq,
|
::grpc::ServerCompletionQueue* cq,
|
||||||
EnqueueFunction enqueue_function,
|
EnqueueFunction enqueue_function,
|
||||||
HandleRequestFunction handle_request_function) {
|
HandleRequestFunction handle_request_function,
|
||||||
|
bool supports_cancel) {
|
||||||
auto call = new Call<Service, GrpcService, RequestMessage, ResponseMessage>(
|
auto call = new Call<Service, GrpcService, RequestMessage, ResponseMessage>(
|
||||||
handle_request_function);
|
handle_request_function);
|
||||||
|
if (supports_cancel) {
|
||||||
|
call->RegisterCancellationHandler();
|
||||||
|
}
|
||||||
|
|
||||||
(grpc_service->*enqueue_function)(
|
(grpc_service->*enqueue_function)(
|
||||||
&call->ctx_, &call->request, &call->responder_, cq, cq,
|
&call->ctx_, &call->request, &call->responder_, cq, cq,
|
||||||
@ -228,6 +219,15 @@ class Call : public UntypedCall<Service> {
|
|||||||
ResponseMessage response;
|
ResponseMessage response;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
// Creates a completion queue tag for handling cancellation by the client.
|
||||||
|
// NOTE: This method must be called before this call is enqueued on a
|
||||||
|
// completion queue.
|
||||||
|
void RegisterCancellationHandler() {
|
||||||
|
cancel_tag_.reset(new typename UntypedCall<Service>::Tag(
|
||||||
|
this, &UntypedCall<Service>::RequestCancelled));
|
||||||
|
ctx_.AsyncNotifyWhenDone(cancel_tag_.get());
|
||||||
|
}
|
||||||
|
|
||||||
HandleRequestFunction handle_request_function_;
|
HandleRequestFunction handle_request_function_;
|
||||||
::grpc::ServerContext ctx_;
|
::grpc::ServerContext ctx_;
|
||||||
::grpc::ServerAsyncResponseWriter<ResponseMessage> responder_;
|
::grpc::ServerAsyncResponseWriter<ResponseMessage> responder_;
|
||||||
|
@ -88,7 +88,7 @@ class GrpcMasterService : public AsyncServiceInterface {
|
|||||||
// The implementation of the request handler for each RPC method
|
// The implementation of the request handler for each RPC method
|
||||||
// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
|
// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
|
||||||
// to keep accepting new requests.
|
// to keep accepting new requests.
|
||||||
#define ENQUEUE_REQUEST(method) \
|
#define ENQUEUE_REQUEST(method, supports_cancel) \
|
||||||
do { \
|
do { \
|
||||||
mutex_lock l(mu_); \
|
mutex_lock l(mu_); \
|
||||||
if (!is_shutdown_) { \
|
if (!is_shutdown_) { \
|
||||||
@ -96,19 +96,20 @@ class GrpcMasterService : public AsyncServiceInterface {
|
|||||||
method##Request, method##Response>:: \
|
method##Request, method##Response>:: \
|
||||||
EnqueueRequest(&master_service_, cq_, \
|
EnqueueRequest(&master_service_, cq_, \
|
||||||
&grpc::MasterService::AsyncService::Request##method, \
|
&grpc::MasterService::AsyncService::Request##method, \
|
||||||
&GrpcMasterService::method##Handler); \
|
&GrpcMasterService::method##Handler, \
|
||||||
|
(supports_cancel)); \
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
void HandleRPCsLoop() {
|
void HandleRPCsLoop() {
|
||||||
ENQUEUE_REQUEST(CreateSession);
|
ENQUEUE_REQUEST(CreateSession, true);
|
||||||
ENQUEUE_REQUEST(ExtendSession);
|
ENQUEUE_REQUEST(ExtendSession, false);
|
||||||
for (int i = 0; i < 100; ++i) {
|
for (int i = 0; i < 100; ++i) {
|
||||||
ENQUEUE_REQUEST(RunStep);
|
ENQUEUE_REQUEST(RunStep, true);
|
||||||
}
|
}
|
||||||
ENQUEUE_REQUEST(CloseSession);
|
ENQUEUE_REQUEST(CloseSession, false);
|
||||||
ENQUEUE_REQUEST(ListDevices);
|
ENQUEUE_REQUEST(ListDevices, false);
|
||||||
ENQUEUE_REQUEST(Reset);
|
ENQUEUE_REQUEST(Reset, false);
|
||||||
|
|
||||||
void* tag;
|
void* tag;
|
||||||
bool ok;
|
bool ok;
|
||||||
@ -146,7 +147,7 @@ class GrpcMasterService : public AsyncServiceInterface {
|
|||||||
[call](const Status& status) {
|
[call](const Status& status) {
|
||||||
call->SendResponse(ToGrpcStatus(status));
|
call->SendResponse(ToGrpcStatus(status));
|
||||||
});
|
});
|
||||||
ENQUEUE_REQUEST(CreateSession);
|
ENQUEUE_REQUEST(CreateSession, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// RPC handler for extending a session.
|
// RPC handler for extending a session.
|
||||||
@ -156,7 +157,7 @@ class GrpcMasterService : public AsyncServiceInterface {
|
|||||||
[call](const Status& status) {
|
[call](const Status& status) {
|
||||||
call->SendResponse(ToGrpcStatus(status));
|
call->SendResponse(ToGrpcStatus(status));
|
||||||
});
|
});
|
||||||
ENQUEUE_REQUEST(ExtendSession);
|
ENQUEUE_REQUEST(ExtendSession, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// RPC handler for running one step in a session.
|
// RPC handler for running one step in a session.
|
||||||
@ -169,7 +170,7 @@ class GrpcMasterService : public AsyncServiceInterface {
|
|||||||
delete call_opts;
|
delete call_opts;
|
||||||
call->SendResponse(ToGrpcStatus(status));
|
call->SendResponse(ToGrpcStatus(status));
|
||||||
});
|
});
|
||||||
ENQUEUE_REQUEST(RunStep);
|
ENQUEUE_REQUEST(RunStep, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
// RPC handler for deleting a session.
|
// RPC handler for deleting a session.
|
||||||
@ -179,7 +180,7 @@ class GrpcMasterService : public AsyncServiceInterface {
|
|||||||
[call](const Status& status) {
|
[call](const Status& status) {
|
||||||
call->SendResponse(ToGrpcStatus(status));
|
call->SendResponse(ToGrpcStatus(status));
|
||||||
});
|
});
|
||||||
ENQUEUE_REQUEST(CloseSession);
|
ENQUEUE_REQUEST(CloseSession, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// RPC handler for listing devices.
|
// RPC handler for listing devices.
|
||||||
@ -189,7 +190,7 @@ class GrpcMasterService : public AsyncServiceInterface {
|
|||||||
[call](const Status& status) {
|
[call](const Status& status) {
|
||||||
call->SendResponse(ToGrpcStatus(status));
|
call->SendResponse(ToGrpcStatus(status));
|
||||||
});
|
});
|
||||||
ENQUEUE_REQUEST(ListDevices);
|
ENQUEUE_REQUEST(ListDevices, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
// RPC handler for resetting all sessions.
|
// RPC handler for resetting all sessions.
|
||||||
@ -198,7 +199,7 @@ class GrpcMasterService : public AsyncServiceInterface {
|
|||||||
[call](const Status& status) {
|
[call](const Status& status) {
|
||||||
call->SendResponse(ToGrpcStatus(status));
|
call->SendResponse(ToGrpcStatus(status));
|
||||||
});
|
});
|
||||||
ENQUEUE_REQUEST(Reset);
|
ENQUEUE_REQUEST(Reset, false);
|
||||||
}
|
}
|
||||||
#undef ENQUEUE_REQUEST
|
#undef ENQUEUE_REQUEST
|
||||||
|
|
||||||
|
@ -95,7 +95,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
|||||||
// The implementation of the request handler for each RPC method
|
// The implementation of the request handler for each RPC method
|
||||||
// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
|
// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
|
||||||
// to keep accepting new requests.
|
// to keep accepting new requests.
|
||||||
#define ENQUEUE_REQUEST(method) \
|
#define ENQUEUE_REQUEST(method, supports_cancel) \
|
||||||
do { \
|
do { \
|
||||||
mutex_lock l(shutdown_mu_); \
|
mutex_lock l(shutdown_mu_); \
|
||||||
if (!is_shutdown_) { \
|
if (!is_shutdown_) { \
|
||||||
@ -103,7 +103,8 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
|||||||
method##Request, method##Response>:: \
|
method##Request, method##Response>:: \
|
||||||
EnqueueRequest(&worker_service_, cq_, \
|
EnqueueRequest(&worker_service_, cq_, \
|
||||||
&grpc::WorkerService::AsyncService::Request##method, \
|
&grpc::WorkerService::AsyncService::Request##method, \
|
||||||
&GrpcWorkerService::method##Handler); \
|
&GrpcWorkerService::method##Handler, \
|
||||||
|
(supports_cancel)); \
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
@ -116,18 +117,18 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
|||||||
// method, by re-enqueuing a request before the previous one
|
// method, by re-enqueuing a request before the previous one
|
||||||
// completes, and we may decide to bound some of the request
|
// completes, and we may decide to bound some of the request
|
||||||
// types.
|
// types.
|
||||||
ENQUEUE_REQUEST(GetStatus);
|
ENQUEUE_REQUEST(GetStatus, false);
|
||||||
ENQUEUE_REQUEST(CleanupAll);
|
ENQUEUE_REQUEST(CleanupAll, false);
|
||||||
ENQUEUE_REQUEST(RegisterGraph);
|
ENQUEUE_REQUEST(RegisterGraph, false);
|
||||||
ENQUEUE_REQUEST(DeregisterGraph);
|
ENQUEUE_REQUEST(DeregisterGraph, false);
|
||||||
|
|
||||||
// TODO(mrry): Consider enqueuing more of these request types.
|
// TODO(mrry): Consider enqueuing more of these request types.
|
||||||
ENQUEUE_REQUEST(RecvTensor);
|
ENQUEUE_REQUEST(RecvTensor, true);
|
||||||
ENQUEUE_REQUEST(RunGraph);
|
ENQUEUE_REQUEST(RunGraph, true);
|
||||||
|
|
||||||
ENQUEUE_REQUEST(CleanupGraph);
|
ENQUEUE_REQUEST(CleanupGraph, false);
|
||||||
ENQUEUE_REQUEST(Logging);
|
ENQUEUE_REQUEST(Logging, false);
|
||||||
ENQUEUE_REQUEST(Tracing);
|
ENQUEUE_REQUEST(Tracing, false);
|
||||||
|
|
||||||
void* tag;
|
void* tag;
|
||||||
bool ok;
|
bool ok;
|
||||||
@ -181,7 +182,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
|||||||
}
|
}
|
||||||
call->SendResponse(::grpc::Status::OK);
|
call->SendResponse(::grpc::Status::OK);
|
||||||
});
|
});
|
||||||
ENQUEUE_REQUEST(GetStatus);
|
ENQUEUE_REQUEST(GetStatus, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CleanupAllHandler(
|
void CleanupAllHandler(
|
||||||
@ -192,7 +193,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
|||||||
env_->device_mgr->ClearContainers(containers);
|
env_->device_mgr->ClearContainers(containers);
|
||||||
call->SendResponse(::grpc::Status::OK);
|
call->SendResponse(::grpc::Status::OK);
|
||||||
});
|
});
|
||||||
ENQUEUE_REQUEST(CleanupAll);
|
ENQUEUE_REQUEST(CleanupAll, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void RegisterGraphHandler(
|
void RegisterGraphHandler(
|
||||||
@ -203,7 +204,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
|||||||
call->request.graph_options(), call->response.mutable_graph_handle());
|
call->request.graph_options(), call->response.mutable_graph_handle());
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
});
|
});
|
||||||
ENQUEUE_REQUEST(RegisterGraph);
|
ENQUEUE_REQUEST(RegisterGraph, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DeregisterGraphHandler(
|
void DeregisterGraphHandler(
|
||||||
@ -212,18 +213,18 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
|||||||
Status s = env_->graph_mgr->Deregister(call->request.graph_handle());
|
Status s = env_->graph_mgr->Deregister(call->request.graph_handle());
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
});
|
});
|
||||||
ENQUEUE_REQUEST(DeregisterGraph);
|
ENQUEUE_REQUEST(DeregisterGraph, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
|
void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
|
||||||
env_->compute_pool->Schedule([this, call]() { DoRunGraph(call); });
|
env_->compute_pool->Schedule([this, call]() { DoRunGraph(call); });
|
||||||
ENQUEUE_REQUEST(RunGraph);
|
ENQUEUE_REQUEST(RunGraph, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void RecvTensorHandler(
|
void RecvTensorHandler(
|
||||||
WorkerCall<RecvTensorRequest, RecvTensorResponse>* call) {
|
WorkerCall<RecvTensorRequest, RecvTensorResponse>* call) {
|
||||||
env_->compute_pool->Schedule([this, call]() { DoRecvTensor(call); });
|
env_->compute_pool->Schedule([this, call]() { DoRecvTensor(call); });
|
||||||
ENQUEUE_REQUEST(RecvTensor);
|
ENQUEUE_REQUEST(RecvTensor, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CleanupGraphHandler(
|
void CleanupGraphHandler(
|
||||||
@ -233,7 +234,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
|||||||
env_->rendezvous_mgr->Cleanup(step_id);
|
env_->rendezvous_mgr->Cleanup(step_id);
|
||||||
call->SendResponse(::grpc::Status::OK);
|
call->SendResponse(::grpc::Status::OK);
|
||||||
});
|
});
|
||||||
ENQUEUE_REQUEST(CleanupGraph);
|
ENQUEUE_REQUEST(CleanupGraph, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void LoggingHandler(WorkerCall<LoggingRequest, LoggingResponse>* call) {
|
void LoggingHandler(WorkerCall<LoggingRequest, LoggingResponse>* call) {
|
||||||
@ -241,7 +242,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
|||||||
Status s = DoLogging(call);
|
Status s = DoLogging(call);
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
});
|
});
|
||||||
ENQUEUE_REQUEST(Logging);
|
ENQUEUE_REQUEST(Logging, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TracingHandler(WorkerCall<TracingRequest, TracingResponse>* call) {
|
void TracingHandler(WorkerCall<TracingRequest, TracingResponse>* call) {
|
||||||
@ -249,7 +250,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
|||||||
Status s = DoTracing(call);
|
Status s = DoTracing(call);
|
||||||
call->SendResponse(ToGrpcStatus(s));
|
call->SendResponse(ToGrpcStatus(s));
|
||||||
});
|
});
|
||||||
ENQUEUE_REQUEST(Tracing);
|
ENQUEUE_REQUEST(Tracing, false);
|
||||||
}
|
}
|
||||||
#undef ENQUEUE_REQUEST
|
#undef ENQUEUE_REQUEST
|
||||||
|
|
||||||
|
@ -89,7 +89,8 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
|
|||||||
|
|
||||||
// Look up the Kernel registered for this node def.
|
// Look up the Kernel registered for this node def.
|
||||||
const KernelDef* kdef = nullptr;
|
const KernelDef* kdef = nullptr;
|
||||||
status = FindKernelDef(device_type, ndef, &kdef);
|
status =
|
||||||
|
FindKernelDef(device_type, ndef, &kdef, nullptr /* kernel_class_name */);
|
||||||
|
|
||||||
if (!status.ok() || HasTypeList(*op_def)) {
|
if (!status.ok() || HasTypeList(*op_def)) {
|
||||||
// When there is no kernel def for this op or the op's arg is a
|
// When there is no kernel def for this op or the op's arg is a
|
||||||
|
@ -594,10 +594,11 @@ Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs,
|
|||||||
// OpKernel registration ------------------------------------------------------
|
// OpKernel registration ------------------------------------------------------
|
||||||
|
|
||||||
struct KernelRegistration {
|
struct KernelRegistration {
|
||||||
KernelRegistration(const KernelDef& d,
|
KernelRegistration(const KernelDef& d, StringPiece c,
|
||||||
kernel_factory::OpKernelRegistrar::Factory f)
|
kernel_factory::OpKernelRegistrar::Factory f)
|
||||||
: def(d), factory(f) {}
|
: def(d), kernel_class_name(c.ToString()), factory(f) {}
|
||||||
const KernelDef def;
|
const KernelDef def;
|
||||||
|
const string kernel_class_name;
|
||||||
const kernel_factory::OpKernelRegistrar::Factory factory;
|
const kernel_factory::OpKernelRegistrar::Factory factory;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -624,12 +625,13 @@ static string Key(StringPiece op_type, DeviceType device_type,
|
|||||||
namespace kernel_factory {
|
namespace kernel_factory {
|
||||||
|
|
||||||
void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
|
void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
|
||||||
|
StringPiece kernel_class_name,
|
||||||
Factory factory) {
|
Factory factory) {
|
||||||
const string key =
|
const string key =
|
||||||
Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
|
Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
|
||||||
kernel_def->label());
|
kernel_def->label());
|
||||||
GlobalKernelRegistryTyped()->insert(
|
GlobalKernelRegistryTyped()->insert(std::make_pair(
|
||||||
std::make_pair(key, KernelRegistration(*kernel_def, factory)));
|
key, KernelRegistration(*kernel_def, kernel_class_name, factory)));
|
||||||
delete kernel_def;
|
delete kernel_def;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -724,7 +726,7 @@ Status FindKernelRegistration(DeviceType device_type, const NodeDef& node_def,
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status FindKernelDef(DeviceType device_type, const NodeDef& node_def,
|
Status FindKernelDef(DeviceType device_type, const NodeDef& node_def,
|
||||||
const KernelDef** def) {
|
const KernelDef** def, string* kernel_class_name) {
|
||||||
const KernelRegistration* reg = nullptr;
|
const KernelRegistration* reg = nullptr;
|
||||||
TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, node_def, ®));
|
TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, node_def, ®));
|
||||||
if (reg == nullptr) {
|
if (reg == nullptr) {
|
||||||
@ -733,7 +735,8 @@ Status FindKernelDef(DeviceType device_type, const NodeDef& node_def,
|
|||||||
" devices compatible with node ",
|
" devices compatible with node ",
|
||||||
SummarizeNodeDef(node_def));
|
SummarizeNodeDef(node_def));
|
||||||
}
|
}
|
||||||
*def = ®->def;
|
if (def != nullptr) *def = ®->def;
|
||||||
|
if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1025,7 +1025,6 @@ namespace register_kernel {
|
|||||||
typedef ::tensorflow::KernelDefBuilder Name;
|
typedef ::tensorflow::KernelDefBuilder Name;
|
||||||
} // namespace register_kernel
|
} // namespace register_kernel
|
||||||
|
|
||||||
|
|
||||||
#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
|
#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
|
||||||
REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
|
REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
|
||||||
|
|
||||||
@ -1035,18 +1034,20 @@ typedef ::tensorflow::KernelDefBuilder Name;
|
|||||||
#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
|
#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
|
||||||
static ::tensorflow::kernel_factory::OpKernelRegistrar \
|
static ::tensorflow::kernel_factory::OpKernelRegistrar \
|
||||||
registrar__body__##ctr##__object( \
|
registrar__body__##ctr##__object( \
|
||||||
SHOULD_REGISTER_OP_KERNEL(__FILE__) \
|
SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) \
|
||||||
? ::tensorflow::register_kernel::kernel_builder.Build() \
|
? ::tensorflow::register_kernel::kernel_builder.Build() \
|
||||||
: nullptr, \
|
: nullptr, \
|
||||||
|
#__VA_ARGS__, \
|
||||||
[](::tensorflow::OpKernelConstruction* context) \
|
[](::tensorflow::OpKernelConstruction* context) \
|
||||||
-> ::tensorflow::OpKernel* { return new __VA_ARGS__(context); })
|
-> ::tensorflow::OpKernel* { return new __VA_ARGS__(context); })
|
||||||
|
|
||||||
void* GlobalKernelRegistry();
|
void* GlobalKernelRegistry();
|
||||||
|
|
||||||
// If node_def has a corresponding kernel registered on device_type,
|
// If node_def has a corresponding kernel registered on device_type,
|
||||||
// returns OK and fill in the kernel def.
|
// returns OK and fill in the kernel def and kernel_class_name. <def> and
|
||||||
|
// <kernel_class_name> may be null.
|
||||||
Status FindKernelDef(DeviceType device_type, const NodeDef& node_def,
|
Status FindKernelDef(DeviceType device_type, const NodeDef& node_def,
|
||||||
const KernelDef** def);
|
const KernelDef** def, string* kernel_class_name);
|
||||||
|
|
||||||
// Treats 'registry_ptr' as a pointer to KernelRegistry. For each kernel 'k'
|
// Treats 'registry_ptr' as a pointer to KernelRegistry. For each kernel 'k'
|
||||||
// registered with the current library's global kernel registry (obtained by
|
// registered with the current library's global kernel registry (obtained by
|
||||||
@ -1058,16 +1059,19 @@ namespace kernel_factory {
|
|||||||
class OpKernelRegistrar {
|
class OpKernelRegistrar {
|
||||||
public:
|
public:
|
||||||
typedef OpKernel* (*Factory)(OpKernelConstruction*);
|
typedef OpKernel* (*Factory)(OpKernelConstruction*);
|
||||||
OpKernelRegistrar(const KernelDef* kernel_def, Factory factory) {
|
|
||||||
|
OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
|
||||||
|
Factory factory) {
|
||||||
// Perform the check in the header to allow compile-time optimization
|
// Perform the check in the header to allow compile-time optimization
|
||||||
// to a no-op, allowing the linker to remove the kernel symbols.
|
// to a no-op, allowing the linker to remove the kernel symbols.
|
||||||
if (kernel_def != nullptr) {
|
if (kernel_def != nullptr) {
|
||||||
InitInternal(kernel_def, factory);
|
InitInternal(kernel_def, kernel_class_name, factory);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void InitInternal(const KernelDef* kernel_def, Factory factory);
|
void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
|
||||||
|
Factory factory);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace kernel_factory
|
} // namespace kernel_factory
|
||||||
|
@ -422,6 +422,27 @@ class OpKernelBuilderTest : public ::testing::Test {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
string GetKernelClassName(const string& op_type, DeviceType device_type,
|
||||||
|
const std::vector<string>& attrs,
|
||||||
|
DataTypeSlice input_types = {}) {
|
||||||
|
NodeDef def = CreateNodeDef(op_type, attrs);
|
||||||
|
for (size_t i = 0; i < input_types.size(); ++i) {
|
||||||
|
def.add_input("a:0");
|
||||||
|
}
|
||||||
|
|
||||||
|
const KernelDef* kernel_def = nullptr;
|
||||||
|
string kernel_class_name;
|
||||||
|
const Status status =
|
||||||
|
FindKernelDef(device_type, def, &kernel_def, &kernel_class_name);
|
||||||
|
if (status.ok()) {
|
||||||
|
return kernel_class_name;
|
||||||
|
} else if (errors::IsNotFound(status)) {
|
||||||
|
return "not found";
|
||||||
|
} else {
|
||||||
|
return status.ToString();
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
REGISTER_OP("BuildCPU");
|
REGISTER_OP("BuildCPU");
|
||||||
@ -429,7 +450,9 @@ REGISTER_KERNEL_BUILDER(Name("BuildCPU").Device(DEVICE_CPU), DummyKernel);
|
|||||||
|
|
||||||
TEST_F(OpKernelBuilderTest, BuilderCPU) {
|
TEST_F(OpKernelBuilderTest, BuilderCPU) {
|
||||||
ExpectSuccess("BuildCPU", DEVICE_CPU, {});
|
ExpectSuccess("BuildCPU", DEVICE_CPU, {});
|
||||||
|
EXPECT_EQ("DummyKernel", GetKernelClassName("BuildCPU", DEVICE_CPU, {}));
|
||||||
ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND);
|
ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND);
|
||||||
|
EXPECT_EQ("not found", GetKernelClassName("BuildCPU", DEVICE_GPU, {}));
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_OP("BuildGPU");
|
REGISTER_OP("BuildGPU");
|
||||||
@ -472,12 +495,26 @@ REGISTER_KERNEL_BUILDER(Name("BuildTypeListAttr")
|
|||||||
|
|
||||||
TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) {
|
TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) {
|
||||||
ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"});
|
ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"});
|
||||||
|
EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
|
||||||
|
{"T|list(type)|[]"}));
|
||||||
|
|
||||||
ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"});
|
ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"});
|
||||||
|
EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
|
||||||
|
{"T|list(type)|[]"}));
|
||||||
|
|
||||||
ExpectSuccess("BuildTypeListAttr", DEVICE_CPU,
|
ExpectSuccess("BuildTypeListAttr", DEVICE_CPU,
|
||||||
{"T|list(type)|[DT_BOOL, DT_BOOL]"});
|
{"T|list(type)|[DT_BOOL, DT_BOOL]"});
|
||||||
|
|
||||||
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"},
|
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"},
|
||||||
error::NOT_FOUND);
|
error::NOT_FOUND);
|
||||||
|
EXPECT_EQ("not found", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
|
||||||
|
{"T|list(type)|[DT_FLOAT]"}));
|
||||||
|
|
||||||
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
|
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
|
||||||
|
EXPECT_TRUE(
|
||||||
|
StringPiece(GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, {}))
|
||||||
|
.contains("Invalid argument: "));
|
||||||
|
|
||||||
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"},
|
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"},
|
||||||
error::INVALID_ARGUMENT);
|
error::INVALID_ARGUMENT);
|
||||||
}
|
}
|
||||||
@ -776,6 +813,9 @@ TEST_F(LabelTest, Default) {
|
|||||||
ExpectSuccess("LabeledKernel", DEVICE_CPU, {});
|
ExpectSuccess("LabeledKernel", DEVICE_CPU, {});
|
||||||
auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
|
auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
|
||||||
EXPECT_EQ(0, get_labeled_kernel->Which());
|
EXPECT_EQ(0, get_labeled_kernel->Which());
|
||||||
|
|
||||||
|
EXPECT_EQ("LabeledKernel<0>",
|
||||||
|
GetKernelClassName("LabeledKernel", DEVICE_CPU, {}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(LabelTest, Specified) {
|
TEST_F(LabelTest, Specified) {
|
||||||
@ -783,6 +823,8 @@ TEST_F(LabelTest, Specified) {
|
|||||||
ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"});
|
ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"});
|
||||||
auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
|
auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
|
||||||
EXPECT_EQ(1, get_labeled_kernel->Which());
|
EXPECT_EQ(1, get_labeled_kernel->Which());
|
||||||
|
EXPECT_EQ("LabeledKernel<1>", GetKernelClassName("LabeledKernel", DEVICE_CPU,
|
||||||
|
{"_kernel|string|'one'"}));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(LabelTest, Duplicate) {
|
TEST_F(LabelTest, Duplicate) {
|
||||||
|
@ -34,13 +34,12 @@ limitations under the License.
|
|||||||
// out.
|
// out.
|
||||||
#include "ops_to_register.h"
|
#include "ops_to_register.h"
|
||||||
|
|
||||||
// Files which are not included in the whitelist provided by this
|
// Op kernel classes for which ShouldRegisterOpKernel returns false will not be
|
||||||
// graph-specific header file will not be allowed to register their
|
// registered.
|
||||||
// operator kernels.
|
#define SHOULD_REGISTER_OP_KERNEL(clz) \
|
||||||
#define SHOULD_REGISTER_OP_KERNEL(filename) \
|
(strstr(kNecessaryOpKernelClasses, "," clz ",") != nullptr)
|
||||||
(strstr(kNecessaryOpFiles, filename) != nullptr)
|
|
||||||
|
|
||||||
// Ops for which ShouldRegisterOp return false will no be registered.
|
// Ops for which ShouldRegisterOp returns false will not be registered.
|
||||||
#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)
|
#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)
|
||||||
|
|
||||||
// If kRequiresSymbolicGradients is false, then no gradient ops are registered.
|
// If kRequiresSymbolicGradients is false, then no gradient ops are registered.
|
||||||
|
@ -260,6 +260,7 @@ tf_kernel_libraries(
|
|||||||
"constant_op",
|
"constant_op",
|
||||||
"diag_op",
|
"diag_op",
|
||||||
"edit_distance_op",
|
"edit_distance_op",
|
||||||
|
"gather_nd_op",
|
||||||
"gather_op",
|
"gather_op",
|
||||||
"identity_op",
|
"identity_op",
|
||||||
"immutable_constant_op",
|
"immutable_constant_op",
|
||||||
|
@ -27,7 +27,8 @@ namespace tensorflow {
|
|||||||
// that 0 <= limit if Index is signed. Intended for use in performance
|
// that 0 <= limit if Index is signed. Intended for use in performance
|
||||||
// critical contexts where 0 <= index < limit is almost always true.
|
// critical contexts where 0 <= index < limit is almost always true.
|
||||||
template <typename Ta, typename Tb>
|
template <typename Ta, typename Tb>
|
||||||
EIGEN_ALWAYS_INLINE bool FastBoundsCheck(const Ta index, const Tb limit) {
|
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool FastBoundsCheck(const Ta index,
|
||||||
|
const Tb limit) {
|
||||||
static_assert(std::is_integral<Ta>::value && std::is_integral<Tb>::value,
|
static_assert(std::is_integral<Ta>::value && std::is_integral<Tb>::value,
|
||||||
"FastBoundsCheck can only be used on integer types.");
|
"FastBoundsCheck can only be used on integer types.");
|
||||||
typedef typename std::make_unsigned<decltype(index + limit)>::type UIndex;
|
typedef typename std::make_unsigned<decltype(index + limit)>::type UIndex;
|
||||||
|
@ -22,22 +22,27 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/numeric_types.h"
|
#include "tensorflow/core/framework/numeric_types.h"
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
|
||||||
// The following functors (sign, tanh, sigmoid, etc.) are not defined
|
|
||||||
// by Eigen. When their equivalent are added into the Eigen, we can
|
|
||||||
// replace them using type aliases.
|
|
||||||
|
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
|
||||||
|
// TODO(rmlarsen): Get rid of fmod2 once fmod is upstreamed to Eigen.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct scalar_fmod2_op {
|
struct scalar_fmod2_op {
|
||||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_fmod2_op)
|
EIGEN_EMPTY_STRUCT_CTOR(scalar_fmod2_op)
|
||||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a,
|
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a,
|
||||||
const T& b) const {
|
const T& b) const {
|
||||||
return fmod(a, b);
|
return std::fmod(a, b);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
struct functor_traits<scalar_fmod2_op<T>> {
|
||||||
|
enum {
|
||||||
|
Cost = 13, // Reciprocal throughput of FPREM on Haswell.
|
||||||
|
PacketAccess = false,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
// scalar_left and scalar_right are template helpers to partially
|
// scalar_left and scalar_right are template helpers to partially
|
||||||
// apply a binary function.
|
// apply a binary function.
|
||||||
//
|
//
|
||||||
@ -489,7 +494,7 @@ template <typename T>
|
|||||||
struct fmod : base<T, Eigen::internal::scalar_fmod2_op<T> > {};
|
struct fmod : base<T, Eigen::internal::scalar_fmod2_op<T> > {};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct mod : base<T, Eigen::internal::scalar_mod2_op<T> > {};
|
struct mod : base<T, Eigen::internal::scalar_mod2_op<T>> {};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct pow : base<T, Eigen::internal::scalar_binary_pow_op<T, T> > {};
|
struct pow : base<T, Eigen::internal::scalar_binary_pow_op<T, T> > {};
|
||||||
|
@ -36,7 +36,6 @@ namespace Eigen {
|
|||||||
* It is possible to swap the order of the width and height dimensions provided that the same order is used in the input, the kernel, and the output.
|
* It is possible to swap the order of the width and height dimensions provided that the same order is used in the input, the kernel, and the output.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
|
|
||||||
template <typename OutputBackward, typename Kernel>
|
template <typename OutputBackward, typename Kernel>
|
||||||
EIGEN_ALWAYS_INLINE static const typename internal::conditional<
|
EIGEN_ALWAYS_INLINE static const typename internal::conditional<
|
||||||
internal::traits<OutputBackward>::Layout == ColMajor,
|
internal::traits<OutputBackward>::Layout == ColMajor,
|
||||||
@ -45,14 +44,18 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
|
|||||||
internal::traits<OutputBackward>::NumDimensions>,
|
internal::traits<OutputBackward>::NumDimensions>,
|
||||||
const TensorContractionOp<
|
const TensorContractionOp<
|
||||||
const array<
|
const array<
|
||||||
IndexPair<typename internal::traits<OutputBackward>::Index>, 2>,
|
IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
|
||||||
const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
|
const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
|
||||||
const DSizes<typename internal::traits<OutputBackward>::Index,
|
const DSizes<typename internal::traits<OutputBackward>::Index,
|
||||||
3>,
|
2>,
|
||||||
const TensorReverseOp<const array<bool, 4>, const Kernel> > >,
|
const TensorShufflingOp<
|
||||||
|
const array<
|
||||||
|
typename internal::traits<OutputBackward>::Index, 4>,
|
||||||
|
const TensorReverseOp<const array<bool, 4>,
|
||||||
|
const Kernel> > > >,
|
||||||
const TensorReshapingOp<
|
const TensorReshapingOp<
|
||||||
const DSizes<typename internal::traits<OutputBackward>::Index,
|
const DSizes<typename internal::traits<OutputBackward>::Index,
|
||||||
3>,
|
2>,
|
||||||
const TensorImagePatchOp<Dynamic, Dynamic,
|
const TensorImagePatchOp<Dynamic, Dynamic,
|
||||||
const OutputBackward> > > >,
|
const OutputBackward> > > >,
|
||||||
TensorReshapingOp<
|
TensorReshapingOp<
|
||||||
@ -60,17 +63,20 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
|
|||||||
internal::traits<OutputBackward>::NumDimensions>,
|
internal::traits<OutputBackward>::NumDimensions>,
|
||||||
const TensorContractionOp<
|
const TensorContractionOp<
|
||||||
const array<
|
const array<
|
||||||
IndexPair<typename internal::traits<OutputBackward>::Index>, 2>,
|
IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
|
||||||
const TensorReshapingOp<
|
const TensorReshapingOp<
|
||||||
const DSizes<typename internal::traits<OutputBackward>::Index,
|
const DSizes<typename internal::traits<OutputBackward>::Index,
|
||||||
3>,
|
2>,
|
||||||
const TensorImagePatchOp<Dynamic, Dynamic,
|
const TensorImagePatchOp<Dynamic, Dynamic,
|
||||||
const OutputBackward> >,
|
const OutputBackward> >,
|
||||||
const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
|
const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
|
||||||
const DSizes<typename internal::traits<OutputBackward>::Index,
|
const DSizes<typename internal::traits<OutputBackward>::Index,
|
||||||
3>,
|
2>,
|
||||||
const TensorReverseOp<const array<bool, 4>,
|
const TensorShufflingOp<
|
||||||
const Kernel> > > > > >::type
|
const array<
|
||||||
|
typename internal::traits<OutputBackward>::Index, 4>,
|
||||||
|
const TensorReverseOp<const array<bool, 4>,
|
||||||
|
const Kernel> > > > > > >::type
|
||||||
SpatialConvolutionBackwardInput(
|
SpatialConvolutionBackwardInput(
|
||||||
const Kernel& kernel, const OutputBackward& output_backward,
|
const Kernel& kernel, const OutputBackward& output_backward,
|
||||||
typename internal::traits<OutputBackward>::Index inputRows,
|
typename internal::traits<OutputBackward>::Index inputRows,
|
||||||
@ -134,49 +140,57 @@ SpatialConvolutionBackwardInput(
|
|||||||
kernel_reverse[3] = false;
|
kernel_reverse[3] = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
DSizes<TensorIndex, 3> kernel_dims;
|
// Reorder the dimensions to filters X patch_rows X patch_cols X channels
|
||||||
|
array<TensorIndex, 4> kernel_shuffle;
|
||||||
if (isColMajor) {
|
if (isColMajor) {
|
||||||
kernel_dims[0] = kernelFilters;
|
kernel_shuffle[0] = 0;
|
||||||
kernel_dims[1] = kernelChannels;
|
kernel_shuffle[1] = 2;
|
||||||
kernel_dims[2] = kernelRows * kernelCols;
|
kernel_shuffle[2] = 3;
|
||||||
|
kernel_shuffle[3] = 1;
|
||||||
} else {
|
} else {
|
||||||
kernel_dims[0] = kernelRows * kernelCols;
|
kernel_shuffle[0] = 2;
|
||||||
|
kernel_shuffle[1] = 0;
|
||||||
|
kernel_shuffle[2] = 1;
|
||||||
|
kernel_shuffle[3] = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Collapse the dims
|
||||||
|
DSizes<TensorIndex, 2> kernel_dims;
|
||||||
|
if (isColMajor) {
|
||||||
|
kernel_dims[0] = kernelFilters * kernelRows * kernelCols;
|
||||||
kernel_dims[1] = kernelChannels;
|
kernel_dims[1] = kernelChannels;
|
||||||
kernel_dims[2] = kernelFilters;
|
} else {
|
||||||
|
kernel_dims[1] = kernelFilters * kernelRows * kernelCols;
|
||||||
|
kernel_dims[0] = kernelChannels;
|
||||||
}
|
}
|
||||||
|
|
||||||
// The output_backward has dimensions out_depth X out_rows X out_cols X OTHERS
|
// The output_backward has dimensions out_depth X out_rows X out_cols X OTHERS
|
||||||
// When we extract the image patches from output_backward, it will have dimensions
|
// When we extract the image patches from output_backward, it will have dimensions
|
||||||
// out_depth X (patch_rows * patch_cols) X (input_rows * input_cols * OTHERS)
|
// out_depth X (patch_rows * patch_cols) X (input_rows * input_cols * OTHERS)
|
||||||
DSizes<TensorIndex, 3> pre_contract_dims;
|
DSizes<TensorIndex, 2> pre_contract_dims;
|
||||||
if (isColMajor) {
|
if (isColMajor) {
|
||||||
pre_contract_dims[0] = kernelFilters;
|
pre_contract_dims[0] = kernelFilters * kernelRows * kernelCols;
|
||||||
pre_contract_dims[1] = kernelRows * kernelCols;
|
pre_contract_dims[1] = inputRows * inputCols;
|
||||||
pre_contract_dims[2] = inputRows * inputCols;
|
|
||||||
for (int i = 3; i < NumDims; ++i) {
|
for (int i = 3; i < NumDims; ++i) {
|
||||||
pre_contract_dims[2] *= out.dimension(i);
|
pre_contract_dims[1] *= out.dimension(i);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
pre_contract_dims[2] = kernelFilters;
|
pre_contract_dims[1] = kernelFilters * kernelRows * kernelCols;
|
||||||
pre_contract_dims[1] = kernelRows * kernelCols;
|
|
||||||
pre_contract_dims[0] = inputRows * inputCols;
|
pre_contract_dims[0] = inputRows * inputCols;
|
||||||
for (int i = 0; i < NumDims - 3; ++i) {
|
for (int i = 0; i < NumDims - 3; ++i) {
|
||||||
pre_contract_dims[0] *= out.dimension(i);
|
pre_contract_dims[0] *= out.dimension(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// We will contract along dimensions (0, 2) in kernel and (0, 1) in
|
// We will contract along the fused dimension that contains the kernelFilters,
|
||||||
// output_backward, if this is col-major, and
|
// the kernelRows and the kernelCols.
|
||||||
// dimensions (0, 2) in kernel and (1, 2) in output_backward, if this row-major.
|
array<IndexPair<TensorIndex>, 1> contract_dims;
|
||||||
array<IndexPair<TensorIndex>, 2> contract_dims;
|
|
||||||
if (isColMajor) {
|
if (isColMajor) {
|
||||||
// col-major: kernel.contract(output.patches)
|
// col-major: kernel.contract(output.patches)
|
||||||
contract_dims[0] = IndexPair<TensorIndex>(0, 0);
|
contract_dims[0] = IndexPair<TensorIndex>(0, 0);
|
||||||
contract_dims[1] = IndexPair<TensorIndex>(2, 1);
|
|
||||||
} else {
|
} else {
|
||||||
// row-major: output.patches.contract(kernel)
|
// row-major: output.patches.contract(kernel)
|
||||||
contract_dims[0] = IndexPair<TensorIndex>(1, 0);
|
contract_dims[0] = IndexPair<TensorIndex>(1, 1);
|
||||||
contract_dims[1] = IndexPair<TensorIndex>(2, 2);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Post contraction, the dimensions of the input_backprop is
|
// Post contraction, the dimensions of the input_backprop is
|
||||||
@ -201,6 +215,7 @@ SpatialConvolutionBackwardInput(
|
|||||||
return choose(
|
return choose(
|
||||||
Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
|
Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
|
||||||
kernel.reverse(kernel_reverse)
|
kernel.reverse(kernel_reverse)
|
||||||
|
.shuffle(kernel_shuffle)
|
||||||
.reshape(kernel_dims)
|
.reshape(kernel_dims)
|
||||||
.eval()
|
.eval()
|
||||||
.contract(output_backward
|
.contract(output_backward
|
||||||
@ -217,7 +232,10 @@ SpatialConvolutionBackwardInput(
|
|||||||
padding_bottom, padding_left, padding_right,
|
padding_bottom, padding_left, padding_right,
|
||||||
OutScalar(0))
|
OutScalar(0))
|
||||||
.reshape(pre_contract_dims)
|
.reshape(pre_contract_dims)
|
||||||
.contract(kernel.reverse(kernel_reverse).reshape(kernel_dims).eval(),
|
.contract(kernel.reverse(kernel_reverse)
|
||||||
|
.shuffle(kernel_shuffle)
|
||||||
|
.reshape(kernel_dims)
|
||||||
|
.eval(),
|
||||||
contract_dims)
|
contract_dims)
|
||||||
.reshape(post_contract_dims));
|
.reshape(post_contract_dims));
|
||||||
}
|
}
|
||||||
@ -239,15 +257,43 @@ SpatialConvolutionBackwardInput(
|
|||||||
* It is possible to swap the order of the width and height dimensions provided that the same order is used in the input, the kernel, and the output.
|
* It is possible to swap the order of the width and height dimensions provided that the same order is used in the input, the kernel, and the output.
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
// TODO(gpapan): Resolve a bug in TensorContractionInputMapper at SpatialConvolutions.h that yangke circumvented by using .reshape().reshape().
|
|
||||||
// This can significantly accelerate SpatialConvolutionBackwardKernel.
|
|
||||||
|
|
||||||
template <typename OutputBackward, typename Input>
|
template <typename OutputBackward, typename Input>
|
||||||
EIGEN_ALWAYS_INLINE
|
EIGEN_ALWAYS_INLINE static const typename internal::conditional<
|
||||||
static const typename internal::conditional<
|
|
||||||
internal::traits<OutputBackward>::Layout == ColMajor,
|
internal::traits<OutputBackward>::Layout == ColMajor,
|
||||||
const TensorShufflingOp<const array<typename internal::traits<OutputBackward>::Index, 4>, const TensorReverseOp<const array<bool, 4>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorContractionOp<const array<IndexPair<typename internal::traits<Input>::Index>, 2>, const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 3>, const Input>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorImagePatchOp<Dynamic, Dynamic, const OutputBackward> > > > > > >,
|
TensorReshapingOp<
|
||||||
const TensorShufflingOp<const array<typename internal::traits<OutputBackward>::Index, 4>, const TensorReverseOp<const array<bool, 4>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorContractionOp<const array<IndexPair<typename internal::traits<Input>::Index>, 2>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorImagePatchOp<Dynamic, Dynamic, const OutputBackward> > >, const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 3>, const Input> > > > > >::type
|
const DSizes<typename internal::traits<Input>::Index, 4>,
|
||||||
|
const TensorContractionOp<
|
||||||
|
const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
|
||||||
|
const TensorReshapingOp<
|
||||||
|
const DSizes<typename internal::traits<Input>::Index, 2>,
|
||||||
|
const OutputBackward>,
|
||||||
|
const TensorShufflingOp<
|
||||||
|
const array<typename internal::traits<OutputBackward>::Index, 2>,
|
||||||
|
const TensorReshapingOp<
|
||||||
|
const DSizes<typename internal::traits<Input>::Index, 2>,
|
||||||
|
const TensorImagePatchOp<Dynamic, Dynamic, const Input>
|
||||||
|
>
|
||||||
|
>
|
||||||
|
>
|
||||||
|
>,
|
||||||
|
TensorReshapingOp<
|
||||||
|
const DSizes<typename internal::traits<Input>::Index, 4>,
|
||||||
|
const TensorContractionOp<
|
||||||
|
const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
|
||||||
|
const TensorShufflingOp<
|
||||||
|
const array<typename internal::traits<OutputBackward>::Index, 2>,
|
||||||
|
const TensorReshapingOp<
|
||||||
|
const DSizes<typename internal::traits<Input>::Index, 2>,
|
||||||
|
const TensorImagePatchOp<Dynamic, Dynamic, const Input>
|
||||||
|
>
|
||||||
|
>,
|
||||||
|
const TensorReshapingOp<
|
||||||
|
const DSizes<typename internal::traits<Input>::Index, 2>,
|
||||||
|
const OutputBackward>
|
||||||
|
>
|
||||||
|
>
|
||||||
|
>::type
|
||||||
SpatialConvolutionBackwardKernel(const Input& input, const OutputBackward& output_backward, typename internal::traits<Input>::Index kernelRows, typename internal::traits<Input>::Index kernelCols, const DenseIndex stride = 1, const DenseIndex in_stride = 1) {
|
SpatialConvolutionBackwardKernel(const Input& input, const OutputBackward& output_backward, typename internal::traits<Input>::Index kernelRows, typename internal::traits<Input>::Index kernelCols, const DenseIndex stride = 1, const DenseIndex in_stride = 1) {
|
||||||
|
|
||||||
typedef typename internal::traits<Input>::Index TensorIndex;
|
typedef typename internal::traits<Input>::Index TensorIndex;
|
||||||
@ -283,127 +329,93 @@ SpatialConvolutionBackwardKernel(const Input& input, const OutputBackward& outpu
|
|||||||
const TensorIndex kernelColsEff = kernelCols + (kernelCols - 1) * (in_stride - 1);
|
const TensorIndex kernelColsEff = kernelCols + (kernelCols - 1) * (in_stride - 1);
|
||||||
|
|
||||||
// Computing the forward padding
|
// Computing the forward padding
|
||||||
const TensorIndex forward_pad_top = ((outputRows - 1) * stride + kernelRowsEff - inputRows) / 2;
|
const TensorIndex padRows = numext::maxi<Index>(
|
||||||
const TensorIndex forward_pad_left = ((outputCols - 1) * stride + kernelColsEff - inputCols) / 2;
|
0, (outputRows - 1) * stride + kernelRowsEff - inputRows);
|
||||||
|
const TensorIndex padCols = numext::maxi<Index>(
|
||||||
|
0, (outputCols - 1) * stride + kernelColsEff - inputCols);
|
||||||
|
const TensorIndex padding_top = padRows / 2;
|
||||||
|
const TensorIndex padding_bottom = padRows - padding_top;
|
||||||
|
const TensorIndex padding_left = padCols / 2;
|
||||||
|
const TensorIndex padding_right = padCols - padding_left;
|
||||||
|
|
||||||
// TODO: factor out the padding computation.
|
// Reshaped out
|
||||||
const TensorIndex padding_top = kernelRowsEff - 1 - forward_pad_top;
|
DSizes<TensorIndex, 2> output_dims;
|
||||||
const TensorIndex padding_left = kernelColsEff - 1 - forward_pad_left;
|
|
||||||
const TensorIndex padding_bottom = inputRows + kernelRowsEff - 1 - (outputRows - 1) * stride - 1 - padding_top;
|
|
||||||
const TensorIndex padding_right = inputCols + kernelColsEff - 1 - (outputCols - 1) * stride - 1 - padding_left;
|
|
||||||
|
|
||||||
eigen_assert(padding_top >= 0);
|
|
||||||
eigen_assert(padding_left >= 0);
|
|
||||||
eigen_assert(padding_bottom >= 0);
|
|
||||||
eigen_assert(padding_right >= 0);
|
|
||||||
|
|
||||||
// The output_backward has dimensions out_depth X out_rows X out_cols X OTHERS
|
|
||||||
// When we extract the image patches from output_backward (with input as the
|
|
||||||
// kernel), it will have dimensions
|
|
||||||
// (out_depth) X (input_rows * input_cols) X (kernel_rows * kernel_cols) X OTHERS
|
|
||||||
DSizes<TensorIndex, 4> pre_contract_dims;
|
|
||||||
if (isColMajor) {
|
if (isColMajor) {
|
||||||
pre_contract_dims[0] = kernelFilters;
|
output_dims[0] = kernelFilters;
|
||||||
pre_contract_dims[1] = inputRows * inputCols;
|
output_dims[1] = outputRows * outputCols;
|
||||||
pre_contract_dims[2] = kernelRows * kernelCols;
|
|
||||||
pre_contract_dims[3] = 1;
|
|
||||||
for (int i = 3; i < NumDims; ++i) {
|
for (int i = 3; i < NumDims; ++i) {
|
||||||
pre_contract_dims[3] *= out.dimension(i);
|
output_dims[1] *= out.dimension(i);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
pre_contract_dims[3] = kernelFilters;
|
output_dims[1] = kernelFilters;
|
||||||
pre_contract_dims[2] = inputRows * inputCols;
|
output_dims[0] = outputCols * outputRows;
|
||||||
pre_contract_dims[1] = kernelRows * kernelCols;
|
|
||||||
pre_contract_dims[0] = 1;
|
|
||||||
for (int i = 0; i < NumDims - 3; ++i) {
|
for (int i = 0; i < NumDims - 3; ++i) {
|
||||||
pre_contract_dims[0] *= out.dimension(i);
|
output_dims[0] *= out.dimension(i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// The input has dimensions in_depth X (input_rows * input_cols) X OTHERS
|
// Reshaped extract_image_patches(in)
|
||||||
DSizes<TensorIndex, 3> input_dims;
|
DSizes<TensorIndex, 2> pre_contract_dims;
|
||||||
if (isColMajor) {
|
if (isColMajor) {
|
||||||
input_dims[0] = kernelChannels;
|
pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols;
|
||||||
input_dims[1] = inputRows * inputCols;
|
pre_contract_dims[1] = outputRows * outputCols;
|
||||||
input_dims[2] = 1;
|
|
||||||
for (int i = 3; i < NumDims; ++i) {
|
for (int i = 3; i < NumDims; ++i) {
|
||||||
input_dims[2] *= in.dimension(i);
|
pre_contract_dims[1] *= in.dimension(i);
|
||||||
}
|
}
|
||||||
eigen_assert(input_dims[2] == pre_contract_dims[3]);
|
eigen_assert(output_dims[1] == pre_contract_dims[1]);
|
||||||
} else {
|
} else {
|
||||||
input_dims[2] = kernelChannels;
|
pre_contract_dims[1] = kernelCols * kernelRows * kernelChannels;
|
||||||
input_dims[1] = inputRows * inputCols;
|
pre_contract_dims[0] = outputRows * outputCols;
|
||||||
input_dims[0] = 1;
|
|
||||||
for (int i = 0; i < NumDims - 3; ++i) {
|
for (int i = 0; i < NumDims - 3; ++i) {
|
||||||
input_dims[0] *= in.dimension(i);
|
pre_contract_dims[0] *= in.dimension(i);
|
||||||
}
|
}
|
||||||
eigen_assert(input_dims[0] == pre_contract_dims[0]);
|
eigen_assert(output_dims[0] == pre_contract_dims[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// We will contract along dimensions (1, 2) in in and (1, 3) in out, if
|
array<TensorIndex, 2> shuffle_dims;
|
||||||
// this is col-major.
|
shuffle_dims[0] = 1;
|
||||||
// For row-major, it's dimensions (0, 1) in in and (0, 2) in out.
|
shuffle_dims[1] = 0;
|
||||||
array<IndexPair<TensorIndex>, 2> contract_dims;
|
|
||||||
if (isColMajor) {
|
|
||||||
// col-major: in.contract(output.patches)
|
|
||||||
contract_dims[0] = IndexPair<TensorIndex>(1, 1);
|
|
||||||
contract_dims[1] = IndexPair<TensorIndex>(2, 3);
|
|
||||||
} else {
|
|
||||||
// row-major: output.patches.contract(in)
|
|
||||||
contract_dims[0] = IndexPair<TensorIndex>(0, 0);
|
|
||||||
contract_dims[1] = IndexPair<TensorIndex>(2, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
// After the contraction, the kernel will have dimension
|
array<IndexPair<TensorIndex>, 1> contract_dims;
|
||||||
// in_depth X out_depth X kernel_rows X kernel_cols
|
contract_dims[0] = IndexPair<TensorIndex>(1, 0);
|
||||||
// We will need to shuffle the first two dimensions and reverse the latter
|
|
||||||
// two dimensions.
|
// After the contraction, the kernel will have the desired shape
|
||||||
// The end shape is
|
|
||||||
// out_depth X in_shape X kernel_rows X kernel_cols
|
// out_depth X in_shape X kernel_rows X kernel_cols
|
||||||
|
|
||||||
// This is the shape of the kernel *before* the shuffling.
|
|
||||||
DSizes<TensorIndex, 4> kernel_dims;
|
DSizes<TensorIndex, 4> kernel_dims;
|
||||||
if (isColMajor) {
|
if (isColMajor) {
|
||||||
kernel_dims[0] = kernelChannels;
|
kernel_dims[0] = kernelFilters;
|
||||||
kernel_dims[1] = kernelFilters;
|
kernel_dims[1] = kernelChannels;
|
||||||
kernel_dims[2] = kernelRows;
|
kernel_dims[2] = kernelRows;
|
||||||
kernel_dims[3] = kernelCols;
|
kernel_dims[3] = kernelCols;
|
||||||
} else {
|
} else {
|
||||||
kernel_dims[0] = kernelCols;
|
kernel_dims[3] = kernelFilters;
|
||||||
|
kernel_dims[2] = kernelChannels;
|
||||||
kernel_dims[1] = kernelRows;
|
kernel_dims[1] = kernelRows;
|
||||||
kernel_dims[2] = kernelFilters;
|
kernel_dims[0] = kernelCols;
|
||||||
kernel_dims[3] = kernelChannels;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
array<TensorIndex, 4> kernel_shuffle;
|
return choose(
|
||||||
if (isColMajor) {
|
Cond<internal::traits<Input>::Layout == ColMajor>(),
|
||||||
kernel_shuffle[0] = 1;
|
output_backward.reshape(output_dims)
|
||||||
kernel_shuffle[1] = 0;
|
.contract(
|
||||||
kernel_shuffle[2] = 2;
|
input.extract_image_patches(
|
||||||
kernel_shuffle[3] = 3;
|
kernelRows, kernelCols, stride, stride,
|
||||||
} else {
|
in_stride, in_stride, 1, 1, padding_top, padding_bottom,
|
||||||
kernel_shuffle[0] = 0;
|
padding_left, padding_right, OutScalar(0))
|
||||||
kernel_shuffle[1] = 1;
|
.reshape(pre_contract_dims)
|
||||||
kernel_shuffle[2] = 3;
|
.shuffle(shuffle_dims),
|
||||||
kernel_shuffle[3] = 2;
|
contract_dims)
|
||||||
}
|
.reshape(kernel_dims),
|
||||||
|
input.extract_image_patches(
|
||||||
array<bool, 4> kernel_reverse;
|
kernelRows, kernelCols, stride, stride,
|
||||||
if (isColMajor) {
|
in_stride, in_stride, 1, 1, padding_top, padding_bottom,
|
||||||
kernel_reverse[0] = false;
|
padding_left, padding_right, OutScalar(0))
|
||||||
kernel_reverse[1] = false;
|
.reshape(pre_contract_dims)
|
||||||
kernel_reverse[2] = true;
|
.shuffle(shuffle_dims)
|
||||||
kernel_reverse[3] = true;
|
.contract(
|
||||||
} else {
|
output_backward.reshape(output_dims),
|
||||||
kernel_reverse[0] = true;
|
contract_dims)
|
||||||
kernel_reverse[1] = true;
|
.reshape(kernel_dims));
|
||||||
kernel_reverse[2] = false;
|
|
||||||
kernel_reverse[3] = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return choose(Cond<internal::traits<Input>::Layout == ColMajor>(),
|
|
||||||
input.reshape(input_dims).contract(output_backward.extract_image_patches(inputRows, inputCols, in_stride, in_stride, 1, 1, stride, stride, padding_top, padding_bottom, padding_left, padding_right, OutScalar(0)).reshape(pre_contract_dims).reshape(pre_contract_dims), contract_dims).reshape(kernel_dims).reverse(kernel_reverse).shuffle(kernel_shuffle),
|
|
||||||
output_backward.extract_image_patches(inputRows, inputCols, in_stride, in_stride, 1, 1, stride, stride, padding_top, padding_bottom, padding_left, padding_right, OutScalar(0)).reshape(pre_contract_dims).reshape(pre_contract_dims).contract(input.reshape(input_dims), contract_dims).reshape(kernel_dims).reverse(kernel_reverse).shuffle(kernel_shuffle));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // end namespace Eigen
|
} // end namespace Eigen
|
||||||
|
234
tensorflow/core/kernels/gather_nd_op.cc
Normal file
234
tensorflow/core/kernels/gather_nd_op.cc
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// See docs in ../ops/array_ops.cc.
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#include <atomic>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/kernels/bounds_check.h"
|
||||||
|
#include "tensorflow/core/kernels/gather_nd_op.h"
|
||||||
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/mem.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
template <typename Device, typename T, typename Index>
|
||||||
|
class GatherNdOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit GatherNdOp(OpKernelConstruction* c) : OpKernel(c) {
|
||||||
|
const DataType dt = DataTypeToEnum<T>::v();
|
||||||
|
const DataType index_t = DataTypeToEnum<Index>::v();
|
||||||
|
OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t}, {dt}));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* c) override {
|
||||||
|
const Tensor& params = c->input(0);
|
||||||
|
const Tensor& indices = c->input(1);
|
||||||
|
OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
|
||||||
|
errors::InvalidArgument("params must be at least a vector"));
|
||||||
|
OP_REQUIRES(c, TensorShapeUtils::IsMatrixOrHigher(indices.shape()),
|
||||||
|
errors::InvalidArgument("indices must be at least a matrix"));
|
||||||
|
OP_REQUIRES(
|
||||||
|
c, indices.dim_size(indices.dims() - 1) == params.dims(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"index innermost dimension length must equal params rank; saw: ",
|
||||||
|
indices.dim_size(indices.dims() - 1), " vs. ", params.dims()));
|
||||||
|
|
||||||
|
// Check that we have enough index space
|
||||||
|
const int64 N_big = indices.NumElements() / params.dims();
|
||||||
|
OP_REQUIRES(c, N_big <= std::numeric_limits<int>::max(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"indices has too many elements for int indexing: ", N_big,
|
||||||
|
" > ", std::numeric_limits<int>::max()));
|
||||||
|
const int N = indices.NumElements() / params.dims();
|
||||||
|
OP_REQUIRES(
|
||||||
|
c, params.NumElements() <= std::numeric_limits<Index>::max(),
|
||||||
|
errors::InvalidArgument("params.NumElements() too large for ",
|
||||||
|
DataTypeString(DataTypeToEnum<Index>::v()),
|
||||||
|
" indexing: ", params.NumElements(), " > ",
|
||||||
|
std::numeric_limits<Index>::max()));
|
||||||
|
|
||||||
|
// The result shape is indices.shape[:-1]
|
||||||
|
TensorShape result_shape(indices.shape());
|
||||||
|
result_shape.RemoveDim(result_shape.dims() - 1);
|
||||||
|
|
||||||
|
Tensor* out = nullptr;
|
||||||
|
OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
|
||||||
|
if (N > 0) {
|
||||||
|
auto indices_mat = indices.flat_inner_dims<Index>();
|
||||||
|
auto out_flat = out->flat<T>();
|
||||||
|
|
||||||
|
Index bad_i = -1;
|
||||||
|
switch (params.dims()) {
|
||||||
|
#define PARAMS_CASE(NDIM) \
|
||||||
|
case NDIM: { \
|
||||||
|
functor::GatherNd<Device, T, Index, NDIM> functor; \
|
||||||
|
auto params_tensor = params.tensor<T, NDIM>(); \
|
||||||
|
bad_i = functor(c->eigen_device<Device>(), params_tensor, indices_mat, \
|
||||||
|
out_flat); \
|
||||||
|
} break
|
||||||
|
|
||||||
|
PARAMS_CASE(1);
|
||||||
|
PARAMS_CASE(2);
|
||||||
|
PARAMS_CASE(3);
|
||||||
|
PARAMS_CASE(4);
|
||||||
|
PARAMS_CASE(5);
|
||||||
|
default:
|
||||||
|
OP_REQUIRES(c, false,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Only param tensors with ranks between 1 and 5 "
|
||||||
|
"are currently supported. Tensor rank: ",
|
||||||
|
params.dims()));
|
||||||
|
}
|
||||||
|
|
||||||
|
OP_REQUIRES(c, bad_i < 0,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"flat indices[", bad_i, ", :] = [",
|
||||||
|
str_util::Join(gtl::ArraySlice<Index>(
|
||||||
|
&indices_mat(bad_i, 0), params.dims()),
|
||||||
|
", "),
|
||||||
|
"] does not index into param (shape: ",
|
||||||
|
params.shape().DebugString(), ")."));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Specialization of GatherNd to CPU
|
||||||
|
namespace generator {
|
||||||
|
|
||||||
|
template <typename T, typename Index, int NDIM>
|
||||||
|
class GatherNdGenerator {
|
||||||
|
public:
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||||
|
GatherNdGenerator(typename TTypes<Index>::ConstMatrix Tindices,
|
||||||
|
typename TTypes<T, NDIM>::ConstTensor Tparams,
|
||||||
|
Index* error_loc)
|
||||||
|
: Tindices_(Tindices), Tparams_(Tparams), error_loc_(*error_loc) {}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||||
|
operator()(const Eigen::array<Eigen::DenseIndex, 1>& loc_array) const {
|
||||||
|
Index loc = loc_array[0];
|
||||||
|
Eigen::array<Eigen::DenseIndex, NDIM> ix;
|
||||||
|
bool out_of_bounds = false;
|
||||||
|
for (int i = 0; i < NDIM; ++i) {
|
||||||
|
Index ix_i = Tindices_(loc, i);
|
||||||
|
ix[i] = ix_i;
|
||||||
|
out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i));
|
||||||
|
}
|
||||||
|
if (out_of_bounds) {
|
||||||
|
error_loc_ = loc;
|
||||||
|
return T(); // Return 0, 0.0, or '', etc.
|
||||||
|
} else {
|
||||||
|
return Tparams_(ix);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
typename TTypes<Index>::ConstMatrix Tindices_;
|
||||||
|
typename TTypes<T, NDIM>::ConstTensor Tparams_;
|
||||||
|
Index& error_loc_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace generator
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
template <typename T, typename Index, int NDIM>
|
||||||
|
struct GatherNd<CPUDevice, T, Index, NDIM> {
|
||||||
|
Index operator()(const CPUDevice& d,
|
||||||
|
typename TTypes<T, NDIM>::ConstTensor Tparams,
|
||||||
|
typename TTypes<Index>::ConstMatrix Tindices,
|
||||||
|
typename TTypes<T>::Flat Tout) {
|
||||||
|
Index error_loc(-1);
|
||||||
|
generator::GatherNdGenerator<T, Index, NDIM> gather_nd_generator(Tindices,
|
||||||
|
Tparams,
|
||||||
|
&error_loc);
|
||||||
|
Tout.device(d) = Tout.generate(gather_nd_generator);
|
||||||
|
|
||||||
|
// error_loc() returns -1 if there's no out-of-bounds index,
|
||||||
|
// otherwise it returns the location of an OOB index in Tindices.
|
||||||
|
return error_loc;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
#define REGISTER_GATHER_ND_FULL(dev, type, index_type) \
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("GatherNd") \
|
||||||
|
.Device(DEVICE_##dev) \
|
||||||
|
.TypeConstraint<type>("Tparams") \
|
||||||
|
.TypeConstraint<index_type>("Tindices"), \
|
||||||
|
GatherNdOp<dev##Device, type, index_type>)
|
||||||
|
|
||||||
|
#define REGISTER_GATHER_ND_ALL_INDICES(dev, type) \
|
||||||
|
REGISTER_GATHER_ND_FULL(dev, type, int32); \
|
||||||
|
REGISTER_GATHER_ND_FULL(dev, type, int64)
|
||||||
|
|
||||||
|
#define REGISTER_GATHER_ND_CPU(type) REGISTER_GATHER_ND_ALL_INDICES(CPU, type)
|
||||||
|
|
||||||
|
TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
|
||||||
|
|
||||||
|
#undef REGISTER_GATHER_ND_CPU
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
// Forward declarations of the functor specializations for GPU.
|
||||||
|
namespace functor {
|
||||||
|
#define DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, NDIM) \
|
||||||
|
template <> \
|
||||||
|
Index GatherNd<GPUDevice, T, Index, NDIM>::operator()( \
|
||||||
|
const GPUDevice& d, typename TTypes<T, NDIM>::ConstTensor Tparams, \
|
||||||
|
typename TTypes<Index>::ConstMatrix Tindices, \
|
||||||
|
typename TTypes<T>::Flat Tout); \
|
||||||
|
extern template struct GatherNd<GPUDevice, T, Index, NDIM>
|
||||||
|
|
||||||
|
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
|
||||||
|
DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 1); \
|
||||||
|
DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 2); \
|
||||||
|
DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 3); \
|
||||||
|
DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 4); \
|
||||||
|
DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 5)
|
||||||
|
|
||||||
|
#define DECLARE_GPU_SPECS(T) \
|
||||||
|
DECLARE_GPU_SPECS_INDEX(T, int32); \
|
||||||
|
DECLARE_GPU_SPECS_INDEX(T, int64)
|
||||||
|
|
||||||
|
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
|
||||||
|
|
||||||
|
#undef DECLARE_GPU_SPECS
|
||||||
|
#undef DECLARE_GPU_SPECS_INDEX
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
// Registration of the GPU implementations.
|
||||||
|
#define REGISTER_GATHER_ND_GPU(type) REGISTER_GATHER_ND_ALL_INDICES(GPU, type)
|
||||||
|
|
||||||
|
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU);
|
||||||
|
|
||||||
|
#undef REGISTER_GATHER_ND_GPU
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
|
#undef REGISTER_GATHER_ND_ALL_INDICES
|
||||||
|
#undef REGISTER_GATHER_ND_FULL
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
43
tensorflow/core/kernels/gather_nd_op.h
Normal file
43
tensorflow/core/kernels/gather_nd_op.h
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_KERNELS_GATHER_ND_OP_H_
|
||||||
|
#define TENSORFLOW_KERNELS_GATHER_ND_OP_H_
|
||||||
|
// Functor definition for GatherOp, must be compilable by nvcc.
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
|
#include "tensorflow/core/kernels/bounds_check.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class OpKernelContext;
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
template <typename Device, typename T, typename Index, int NDIM>
|
||||||
|
struct GatherNd {
|
||||||
|
// Performs gather op on (Tparams, Tindices), writing to Tout.
|
||||||
|
// Returns an index to Tindices if the value at that index is out of range.
|
||||||
|
// Returns -1 if all values of Tindices are in range.
|
||||||
|
Index operator()(const Device& d,
|
||||||
|
typename TTypes<T, NDIM>::ConstTensor Tparams,
|
||||||
|
typename TTypes<Index>::ConstMatrix Tindices,
|
||||||
|
typename TTypes<T>::Flat Tout);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_KERNELS_GATHER_ND_OP_H_
|
105
tensorflow/core/kernels/gather_nd_op_gpu.cu.cc
Normal file
105
tensorflow/core/kernels/gather_nd_op_gpu.cu.cc
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
|
#define EIGEN_USE_GPU
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/kernels/gather_nd_op.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
|
||||||
|
namespace generator {
|
||||||
|
|
||||||
|
template <typename T, typename Index, int NDIM>
|
||||||
|
class GatherNdGenerator {
|
||||||
|
public:
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||||
|
GatherNdGenerator(typename TTypes<const Index, 2>::Tensor32Bit Tindices,
|
||||||
|
typename TTypes<const T, NDIM>::Tensor32Bit Tparams)
|
||||||
|
: Tindices_(Tindices), Tparams_(Tparams) {}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||||
|
operator()(const Eigen::array<int, 1>& loc_array) const {
|
||||||
|
int loc = loc_array[0];
|
||||||
|
Eigen::array<int, NDIM> ix;
|
||||||
|
bool out_of_bounds = false;
|
||||||
|
for (int i = 0; i < NDIM; ++i) {
|
||||||
|
int ix_i = Tindices_(loc, i);
|
||||||
|
ix[i] = ix_i;
|
||||||
|
out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i));
|
||||||
|
}
|
||||||
|
if (out_of_bounds) {
|
||||||
|
return T(0); // TODO(ebrevdo): Pass error back to host.
|
||||||
|
} else {
|
||||||
|
return Tparams_(ix);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
typename TTypes<const Index, 2>::Tensor32Bit Tindices_;
|
||||||
|
typename TTypes<const T, NDIM>::Tensor32Bit Tparams_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace generator
|
||||||
|
|
||||||
|
namespace functor {
|
||||||
|
|
||||||
|
template <typename T, typename Index, int NDIM>
|
||||||
|
struct GatherNd<GPUDevice, T, Index, NDIM> {
|
||||||
|
Index operator()(const GPUDevice& d,
|
||||||
|
typename TTypes<T, NDIM>::ConstTensor Tparams,
|
||||||
|
typename TTypes<Index>::ConstMatrix Tindices,
|
||||||
|
typename TTypes<T>::Flat Tout) {
|
||||||
|
generator::GatherNdGenerator<T, Index, NDIM> gather_nd_generator(
|
||||||
|
To32Bit(Tindices), To32Bit(Tparams));
|
||||||
|
To32Bit(Tout).device(d) = To32Bit(Tout).generate(gather_nd_generator);
|
||||||
|
|
||||||
|
// TODO(ebrevdo): enable indices validation on GPU.
|
||||||
|
// Right now checking for indicies out of bound in the kernel would
|
||||||
|
// require copying code between GPU/CPU, and is too slow.
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace functor
|
||||||
|
|
||||||
|
#define DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, NDIM) \
|
||||||
|
template struct functor::GatherNd<GPUDevice, T, Index, NDIM>;
|
||||||
|
|
||||||
|
#define DEFINE_GPU_SPECS_INDEX(T, Index) \
|
||||||
|
DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 1); \
|
||||||
|
DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 2); \
|
||||||
|
DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 3); \
|
||||||
|
DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 4); \
|
||||||
|
DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 5);
|
||||||
|
|
||||||
|
#define DEFINE_GPU_SPECS(T) \
|
||||||
|
DEFINE_GPU_SPECS_INDEX(T, int32); \
|
||||||
|
DEFINE_GPU_SPECS_INDEX(T, int64);
|
||||||
|
|
||||||
|
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
|
||||||
|
|
||||||
|
#undef DEFINE_GPU_SPECS
|
||||||
|
#undef DEFINE_GPU_SPECS_INDEX
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // GOOGLE_CUDA
|
133
tensorflow/core/kernels/gather_nd_op_test.cc
Normal file
133
tensorflow/core/kernels/gather_nd_op_test.cc
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||||
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
|
#include "tensorflow/core/framework/fake_input.h"
|
||||||
|
#include "tensorflow/core/framework/graph.pb.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/graph/graph.h"
|
||||||
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/graph/testlib.h"
|
||||||
|
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||||
|
#include "tensorflow/core/kernels/ops_util.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||||
|
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace test {
|
||||||
|
namespace graph {
|
||||||
|
|
||||||
|
class Node* GatherNd(Graph* g, class Node* in0, class Node* in1) {
|
||||||
|
class Node* ret;
|
||||||
|
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GatherNd")
|
||||||
|
.Input(in0)
|
||||||
|
.Input(in1)
|
||||||
|
.Finalize(g, &ret));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace graph
|
||||||
|
} // namespace test
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class GatherNdOpTest : public OpsTestBase {
|
||||||
|
protected:
|
||||||
|
void MakeOp(DataType index_type) {
|
||||||
|
TF_ASSERT_OK(NodeDefBuilder("myop", "GatherNd")
|
||||||
|
.Input(FakeInput(DT_FLOAT))
|
||||||
|
.Input(FakeInput(index_type))
|
||||||
|
.Finalize(node_def()));
|
||||||
|
TF_ASSERT_OK(InitOp());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(GatherNdOpTest, Simple) {
|
||||||
|
MakeOp(DT_INT32);
|
||||||
|
|
||||||
|
// Feed and run
|
||||||
|
AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 8, 4});
|
||||||
|
AddInputFromArray<int32>(TensorShape({2, 1}), {3, 4});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
|
// Check the output.
|
||||||
|
Tensor expected(allocator(), DT_FLOAT, TensorShape({2}));
|
||||||
|
test::FillValues<float>(&expected, {8, 4});
|
||||||
|
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr int kLookups = 2000;
|
||||||
|
|
||||||
|
template <typename Index>
|
||||||
|
static Graph* GatherNd(int dim) {
|
||||||
|
Graph* g = new Graph(OpRegistry::Global());
|
||||||
|
// Always use a 512MB buffer.
|
||||||
|
//const int kRows = ((512 << 20) / sizeof(float)) / dim;
|
||||||
|
Tensor params(DT_FLOAT, TensorShape({dim, 8, 16, 32}));
|
||||||
|
params.flat<float>().setRandom();
|
||||||
|
|
||||||
|
random::PhiloxRandom philox(301, 17);
|
||||||
|
random::SimplePhilox rnd(&philox);
|
||||||
|
Tensor indices(DataTypeToEnum<Index>::value, TensorShape({kLookups, 4}));
|
||||||
|
auto indices_mat = indices.matrix<Index>();
|
||||||
|
for (int i = 0; i < kLookups; i++) {
|
||||||
|
indices_mat(i, 0) = rnd.Uniform(dim);
|
||||||
|
indices_mat(i, 1) = rnd.Uniform(8);
|
||||||
|
indices_mat(i, 2) = rnd.Uniform(16);
|
||||||
|
indices_mat(i, 3) = rnd.Uniform(32);
|
||||||
|
}
|
||||||
|
|
||||||
|
test::graph::GatherNd(g, test::graph::Constant(g, params),
|
||||||
|
test::graph::Constant(g, indices));
|
||||||
|
return g;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define BM_GATHER_ND(DEVICE, INDEX) \
|
||||||
|
static void BM_##DEVICE##_gather_nd_##INDEX(int iters, int dim) { \
|
||||||
|
const int64 tot = static_cast<int64>(iters) * kLookups * dim; \
|
||||||
|
testing::ItemsProcessed(tot); \
|
||||||
|
testing::BytesProcessed(tot * sizeof(float)); \
|
||||||
|
testing::UseRealTime(); \
|
||||||
|
test::Benchmark(#DEVICE, GatherNd<INDEX>(dim)).Run(iters); \
|
||||||
|
} \
|
||||||
|
BENCHMARK(BM_##DEVICE##_gather_nd_##INDEX) \
|
||||||
|
->Arg(1) \
|
||||||
|
->Arg(10) \
|
||||||
|
->Arg(20) \
|
||||||
|
->Arg(64) \
|
||||||
|
->Arg(100) \
|
||||||
|
->Arg(200) \
|
||||||
|
->Arg(1000)
|
||||||
|
|
||||||
|
BM_GATHER_ND(cpu, int32);
|
||||||
|
BM_GATHER_ND(gpu, int32);
|
||||||
|
BM_GATHER_ND(cpu, int64);
|
||||||
|
BM_GATHER_ND(gpu, int64);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -25,6 +25,9 @@ Status InitializableLookupTable::Find(const Tensor& keys, Tensor* values,
|
|||||||
if (!is_initialized()) {
|
if (!is_initialized()) {
|
||||||
return errors::FailedPrecondition("Table not initialized.");
|
return errors::FailedPrecondition("Table not initialized.");
|
||||||
}
|
}
|
||||||
|
// Do not let the use migrate before the check; table is used without
|
||||||
|
// a lock by the readers.
|
||||||
|
std::atomic_thread_fence(std::memory_order_acquire);
|
||||||
TF_RETURN_IF_ERROR(CheckFindArguments(keys, *values, default_value));
|
TF_RETURN_IF_ERROR(CheckFindArguments(keys, *values, default_value));
|
||||||
return DoFind(keys, values, default_value);
|
return DoFind(keys, values, default_value);
|
||||||
}
|
}
|
||||||
@ -48,6 +51,10 @@ Status InitializableLookupTable::Initialize(InitTableIterator& iter) {
|
|||||||
if (!errors::IsOutOfRange(iter.status())) {
|
if (!errors::IsOutOfRange(iter.status())) {
|
||||||
return iter.status();
|
return iter.status();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Prevent compiler/memory reordering of is_initialized and
|
||||||
|
// the initialization itself.
|
||||||
|
std::atomic_thread_fence(std::memory_order_release);
|
||||||
is_initialized_ = true;
|
is_initialized_ = true;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -69,7 +69,11 @@ class HashTable : public InitializableLookupTable {
|
|||||||
public:
|
public:
|
||||||
size_t size() const override {
|
size_t size() const override {
|
||||||
// return the size of the table only if it's initialized, otherwise 0.
|
// return the size of the table only if it's initialized, otherwise 0.
|
||||||
return table_ && is_initialized_ ? table_->size() : 0;
|
if (!is_initialized_) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
std::atomic_thread_fence(std::memory_order_acquire);
|
||||||
|
return table_ ? table_->size() : 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
|
DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
|
||||||
|
@ -407,6 +407,34 @@ this operation will permute `params` accordingly.
|
|||||||
</div>
|
</div>
|
||||||
)doc");
|
)doc");
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
REGISTER_OP("GatherNd")
|
||||||
|
.Input("params: Tparams")
|
||||||
|
.Input("indices: Tindices")
|
||||||
|
.Output("output: Tparams")
|
||||||
|
.Attr("Tparams: type")
|
||||||
|
.Attr("Tindices: {int32,int64}")
|
||||||
|
.Doc(R"doc(
|
||||||
|
Gather values from `params` according to `indices`.
|
||||||
|
|
||||||
|
`indices` must be integer tensor, containing indices into `params`.
|
||||||
|
It must be shape `[d_0, ..., d_N, R]` where `R` is the rank of `params`.
|
||||||
|
The innermost dimension of `indices` (with length `R`) corresponds to the
|
||||||
|
indices of `params`.
|
||||||
|
|
||||||
|
Produces an output tensor with shape `[d_0, ..., d_{n-1}]` where:
|
||||||
|
|
||||||
|
output[i, j, k, ...] = params[indices[i, j, k, ..., :]]
|
||||||
|
|
||||||
|
e.g. for `indices` a matrix:
|
||||||
|
|
||||||
|
output[i] = params[indices[i, :]]
|
||||||
|
|
||||||
|
params: R-D. The tensor from which to gather values.
|
||||||
|
indices: (N+1)-D. Index tensor having shape `[d_0, ..., d_N, R]`.
|
||||||
|
output: N-D. Values from `params` gathered from indices given by `indices`.
|
||||||
|
)doc");
|
||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
REGISTER_OP("Identity")
|
REGISTER_OP("Identity")
|
||||||
.Input("input: T")
|
.Input("input: T")
|
||||||
|
@ -6480,6 +6480,35 @@ op {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "GatherNd"
|
||||||
|
input_arg {
|
||||||
|
name: "params"
|
||||||
|
type_attr: "Tparams"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "indices"
|
||||||
|
type_attr: "Tindices"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output"
|
||||||
|
type_attr: "Tparams"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tparams"
|
||||||
|
type: "type"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tindices"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
op {
|
op {
|
||||||
name: "Greater"
|
name: "Greater"
|
||||||
input_arg {
|
input_arg {
|
||||||
|
@ -4066,6 +4066,40 @@ op {
|
|||||||
summary: "Gather slices from `params` according to `indices`."
|
summary: "Gather slices from `params` according to `indices`."
|
||||||
description: "`indices` must be an integer tensor of any dimension (usually 0-D or 1-D).\nProduces an output tensor with shape `indices.shape + params.shape[1:]` where:\n\n # Scalar indices\n output[:, ..., :] = params[indices, :, ... :]\n\n # Vector indices\n output[i, :, ..., :] = params[indices[i], :, ... :]\n\n # Higher rank indices\n output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :]\n\nIf `indices` is a permutation and `len(indices) == params.shape[0]` then\nthis operation will permute `params` accordingly.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/Gather.png\" alt>\n</div>"
|
description: "`indices` must be an integer tensor of any dimension (usually 0-D or 1-D).\nProduces an output tensor with shape `indices.shape + params.shape[1:]` where:\n\n # Scalar indices\n output[:, ..., :] = params[indices, :, ... :]\n\n # Vector indices\n output[i, :, ..., :] = params[indices[i], :, ... :]\n\n # Higher rank indices\n output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :]\n\nIf `indices` is a permutation and `len(indices) == params.shape[0]` then\nthis operation will permute `params` accordingly.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/Gather.png\" alt>\n</div>"
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "GatherNd"
|
||||||
|
input_arg {
|
||||||
|
name: "params"
|
||||||
|
description: "R-D. The tensor from which to gather values."
|
||||||
|
type_attr: "Tparams"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "indices"
|
||||||
|
description: "(N+1)-D. Index tensor having shape `[d_0, ..., d_N, R]`."
|
||||||
|
type_attr: "Tindices"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "output"
|
||||||
|
description: "N-D. Values from `params` gathered from indices given by `indices`."
|
||||||
|
type_attr: "Tparams"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tparams"
|
||||||
|
type: "type"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "Tindices"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_INT32
|
||||||
|
type: DT_INT64
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
summary: "Gather values from `params` according to `indices`."
|
||||||
|
description: "`indices` must be integer tensor, containing indices into `params`.\nIt must be shape `[d_0, ..., d_N, R]` where `R` is the rank of `params`.\nThe innermost dimension of `indices` (with length `R`) corresponds to the\nindices of `params`.\n\nProduces an output tensor with shape `[d_0, ..., d_{n-1}]` where:\n\n output[i, j, k, ...] = params[indices[i, j, k, ..., :]]\n\ne.g. for `indices` a matrix:\n\n output[i] = params[indices[i, :]]"
|
||||||
|
}
|
||||||
op {
|
op {
|
||||||
name: "Greater"
|
name: "Greater"
|
||||||
input_arg {
|
input_arg {
|
||||||
|
1
tensorflow/g3doc/api_docs/OWNERS
Normal file
1
tensorflow/g3doc/api_docs/OWNERS
Normal file
@ -0,0 +1 @@
|
|||||||
|
tensorflow-git-owners
|
@ -1149,6 +1149,39 @@ this operation will permute `params` accordingly.
|
|||||||
A `Tensor`. Has the same type as `params`.
|
A `Tensor`. Has the same type as `params`.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
### `tf.gather_nd(params, indices, name=None)` {#gather_nd}
|
||||||
|
|
||||||
|
Gather values from `params` according to `indices`.
|
||||||
|
|
||||||
|
`indices` must be integer tensor, containing indices into `params`.
|
||||||
|
It must be shape `[d_0, ..., d_N, R]` where `R` is the rank of `params`.
|
||||||
|
The innermost dimension of `indices` (with length `R`) corresponds to the
|
||||||
|
indices of `params`.
|
||||||
|
|
||||||
|
Produces an output tensor with shape `[d_0, ..., d_{n-1}]` where:
|
||||||
|
|
||||||
|
output[i, j, k, ...] = params[indices[i, j, k, ..., :]]
|
||||||
|
|
||||||
|
e.g. for `indices` a matrix:
|
||||||
|
|
||||||
|
output[i] = params[indices[i, :]]
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`params`</b>: A `Tensor`. R-D. The tensor from which to gather values.
|
||||||
|
* <b>`indices`</b>: A `Tensor`. Must be one of the following types: `int32`, `int64`.
|
||||||
|
(N+1)-D. Index tensor having shape `[d_0, ..., d_N, R]`.
|
||||||
|
* <b>`name`</b>: A name for the operation (optional).
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
A `Tensor`. Has the same type as `params`.
|
||||||
|
N-D. Values from `params` gathered from indices given by `indices`.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
### `tf.dynamic_partition(data, partitions, num_partitions, name=None)` {#dynamic_partition}
|
### `tf.dynamic_partition(data, partitions, num_partitions, name=None)` {#dynamic_partition}
|
||||||
|
@ -411,6 +411,30 @@ subtraction, it usually shouldn't hurt much either.
|
|||||||
* <b>`ValueError`</b>: If `regularizer` does not return a scalar output.
|
* <b>`ValueError`</b>: If `regularizer` does not return a scalar output.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
### `tf.contrib.layers.make_all(module_name, doc_string_modules=None)` {#make_all}
|
||||||
|
|
||||||
|
Generate `__all__` from the docstring of one or more modules.
|
||||||
|
|
||||||
|
Usage: `make_all(__name__)` or
|
||||||
|
`make_all(__name__, [sys.modules(__name__), other_module])`. The doc string
|
||||||
|
modules must each a docstring, and `__all__` will contain all symbols with
|
||||||
|
`@@` references, where that symbol currently exists in the module named
|
||||||
|
`module_name`.
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`module_name`</b>: The name of the module (usually `__name__`).
|
||||||
|
* <b>`doc_string_modules`</b>: a list of modules from which to take docstring.
|
||||||
|
If None, then a list containing only the module named `module_name` is used.
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
A list suitable for use as `__all__`.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
### `tf.contrib.layers.optimize_loss(loss, global_step, learning_rate, optimizer, clip_gradients=None, moving_average_decay=0.9, learning_rate_decay_fn=None, variables=None)` {#optimize_loss}
|
### `tf.contrib.layers.optimize_loss(loss, global_step, learning_rate, optimizer, clip_gradients=None, moving_average_decay=0.9, learning_rate_decay_fn=None, variables=None)` {#optimize_loss}
|
||||||
|
@ -260,202 +260,6 @@ Example 2:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Higher Order Operators
|
|
||||||
|
|
||||||
TensorFlow provides several higher order operators to simplify the common
|
|
||||||
map-reduce programming patterns.
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
### `tf.map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#map_fn}
|
|
||||||
|
|
||||||
The map operator on the list of tensors resulted from unpacking `elems`
|
|
||||||
along the first dimension.
|
|
||||||
|
|
||||||
This map operator repeatedly applies the callable `fn` to a sequence of
|
|
||||||
elements from first to last. The elements are made of the tensors unpacked
|
|
||||||
from `elems`. `dtype` is the data type of the return value of `fn`. Users
|
|
||||||
must provide `dtype` if it is different from the data type of `elems`.
|
|
||||||
|
|
||||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
|
||||||
of the result tensor is `[len(values)] + fn(values[0]).shape`.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`fn`</b>: The callable to be performed.
|
|
||||||
* <b>`elems`</b>: A tensor to be unpacked to apply `fn`.
|
|
||||||
* <b>`dtype`</b>: (optional) The output type of `fn`.
|
|
||||||
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
|
|
||||||
in parallel.
|
|
||||||
* <b>`back_prop`</b>: (optional) True enables back propagation.
|
|
||||||
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
|
|
||||||
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
|
|
||||||
|
|
||||||
##### Returns:
|
|
||||||
|
|
||||||
A tensor that packs the results of applying `fn` to the list of tensors
|
|
||||||
unpacked from `elems`, from first to last.
|
|
||||||
|
|
||||||
##### Raises:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`TypeError`</b>: if `fn` is not callable.
|
|
||||||
|
|
||||||
##### Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
elems = [1, 2, 3, 4, 5, 6]
|
|
||||||
squares = map_fn(lambda x: x * x, elems)
|
|
||||||
# squares == [1, 4, 9, 16, 25, 36]
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
### `tf.foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#foldl}
|
|
||||||
|
|
||||||
The foldl operator on the list of tensors resulted from unpacking `elems`
|
|
||||||
along the first dimension.
|
|
||||||
|
|
||||||
This foldl operator repeatedly applies the callable `fn` to a sequence
|
|
||||||
of elements from first to last. The elements are made of the tensors
|
|
||||||
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
|
|
||||||
arguments. The first argument is the accumulated value computed from the
|
|
||||||
preceding invocation of fn. If `initializer` is None, `elems` must contain
|
|
||||||
at least one element, and its first element is used as the initializer.
|
|
||||||
|
|
||||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
|
||||||
of the result tensor is fn(initializer, values[0]).shape`.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`fn`</b>: The callable to be performed.
|
|
||||||
* <b>`elems`</b>: A tensor to be unpacked on dimension 0.
|
|
||||||
* <b>`initializer`</b>: (optional) The initial value for the accumulator.
|
|
||||||
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
|
|
||||||
in parallel.
|
|
||||||
* <b>`back_prop`</b>: (optional) True enables back propagation.
|
|
||||||
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
|
|
||||||
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
|
|
||||||
|
|
||||||
##### Returns:
|
|
||||||
|
|
||||||
A tensor resulting from applying `fn` consecutively to the list of tensors
|
|
||||||
unpacked from `elems`, from first to last.
|
|
||||||
|
|
||||||
##### Raises:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`TypeError`</b>: if `fn` is not callable.
|
|
||||||
|
|
||||||
##### Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
elems = [1, 2, 3, 4, 5, 6]
|
|
||||||
sum = foldl(lambda a, x: a + x, elems)
|
|
||||||
# sum == 21
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
### `tf.foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#foldr}
|
|
||||||
|
|
||||||
The foldr operator on the list of tensors resulted from unpacking `elems`
|
|
||||||
along the first dimension.
|
|
||||||
|
|
||||||
This foldr operator repeatedly applies the callable `fn` to a sequence
|
|
||||||
of elements from last to first. The elements are made of the tensors
|
|
||||||
unpacked from `elems`. The callable fn takes two tensors as arguments.
|
|
||||||
The first argument is the accumulated value computed from the preceding
|
|
||||||
invocation of fn. If `initializer` is None, `elems` must contain at least
|
|
||||||
one element, and its first element is used as the initializer.
|
|
||||||
|
|
||||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
|
||||||
of the result tensor is `fn(initializer, values[0]).shape`.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`fn`</b>: The callable to be performed.
|
|
||||||
* <b>`elems`</b>: A tensor that is unpacked into a sequence of tensors to apply `fn`.
|
|
||||||
* <b>`initializer`</b>: (optional) The initial value for the accumulator.
|
|
||||||
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
|
|
||||||
in parallel.
|
|
||||||
* <b>`back_prop`</b>: (optional) True enables back propagation.
|
|
||||||
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
|
|
||||||
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
|
|
||||||
|
|
||||||
##### Returns:
|
|
||||||
|
|
||||||
A tensor resulting from applying `fn` consecutively to the list of tensors
|
|
||||||
unpacked from `elems`, from last to first.
|
|
||||||
|
|
||||||
##### Raises:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`TypeError`</b>: if `fn` is not callable.
|
|
||||||
|
|
||||||
##### Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
elems = [1, 2, 3, 4, 5, 6]
|
|
||||||
sum = foldr(lambda a, x: a + x, elems)
|
|
||||||
# sum == 21
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
|
||||||
|
|
||||||
### `tf.scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#scan}
|
|
||||||
|
|
||||||
The scan operator on the list of tensors resulted from unpacking `elems`
|
|
||||||
along the first dimension.
|
|
||||||
|
|
||||||
This scan operator repeatedly applies the callable `fn` to a sequence
|
|
||||||
of elements from first to last. The elements are made of the tensors
|
|
||||||
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
|
|
||||||
arguments. The first argument is the accumulated value computed from the
|
|
||||||
preceding invocation of fn. If `initializer` is None, `elems` must contain
|
|
||||||
at least one element, and its first element is used as the initializer.
|
|
||||||
|
|
||||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
|
||||||
of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`fn`</b>: The callable to be performed.
|
|
||||||
* <b>`elems`</b>: A tensor to be unpacked on dimension 0.
|
|
||||||
* <b>`initializer`</b>: (optional) The initial value for the accumulator.
|
|
||||||
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
|
|
||||||
in parallel.
|
|
||||||
* <b>`back_prop`</b>: (optional) True enables back propagation.
|
|
||||||
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
|
|
||||||
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
|
|
||||||
|
|
||||||
##### Returns:
|
|
||||||
|
|
||||||
A tensor that packs the results of applying `fn` to the list of tensors
|
|
||||||
unpacked from `elems`, from first to last.
|
|
||||||
|
|
||||||
##### Raises:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`TypeError`</b>: if `fn` is not callable.
|
|
||||||
|
|
||||||
##### Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
elems = [1, 2, 3, 4, 5, 6]
|
|
||||||
sum = scan(lambda a, x: a + x, elems)
|
|
||||||
# sum == [1, 3, 6, 10, 15, 21]
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Logical Operators
|
## Logical Operators
|
||||||
|
|
||||||
TensorFlow provides several operations that you can use to add logical operators
|
TensorFlow provides several operations that you can use to add logical operators
|
||||||
|
@ -409,6 +409,31 @@ Returns a list of values in the collection with the given `name`.
|
|||||||
collected.
|
collected.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
#### `tf.Graph.get_collection_ref(name)` {#Graph.get_collection_ref}
|
||||||
|
|
||||||
|
Returns a list of values in the collection with the given `name`.
|
||||||
|
|
||||||
|
If the collection exists, this returns the list itself, which can
|
||||||
|
be modified in place to change the collection. If the collection does
|
||||||
|
not exist, it is created as an empty list and the list is returned.
|
||||||
|
|
||||||
|
This is different from `get_collection()` which always returns a copy of
|
||||||
|
the collection list if it exists and never creates an empty collection.
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`name`</b>: The key for the collection. For example, the `GraphKeys` class
|
||||||
|
contains many standard names for collections.
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
The list of values in the collection with the given `name`, or an empty
|
||||||
|
list if no value has been added to that collection.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
@ -1752,6 +1777,29 @@ for more details.
|
|||||||
collected.
|
collected.
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
### `tf.get_collection_ref(key)` {#get_collection_ref}
|
||||||
|
|
||||||
|
Wrapper for `Graph.get_collection_ref()` using the default graph.
|
||||||
|
|
||||||
|
See [`Graph.get_collection_ref()`](../../api_docs/python/framework.md#Graph.get_collection_ref)
|
||||||
|
for more details.
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`key`</b>: The key for the collection. For example, the `GraphKeys` class
|
||||||
|
contains many standard names for collections.
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
The list of values in the collection with the given `name`, or an empty
|
||||||
|
list if no value has been added to that collection. Note that this returns
|
||||||
|
the collection list itself, which can be modified in place to change the
|
||||||
|
collection.
|
||||||
|
|
||||||
|
|
||||||
- - -
|
- - -
|
||||||
|
|
||||||
### `class tf.GraphKeys` {#GraphKeys}
|
### `class tf.GraphKeys` {#GraphKeys}
|
||||||
|
202
tensorflow/g3doc/api_docs/python/functional_ops.md
Normal file
202
tensorflow/g3doc/api_docs/python/functional_ops.md
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
<!-- This file is machine generated: DO NOT EDIT! -->
|
||||||
|
|
||||||
|
# Higher Order Functions
|
||||||
|
|
||||||
|
Note: Functions taking `Tensor` arguments can also take anything accepted by
|
||||||
|
[`tf.convert_to_tensor`](framework.md#convert_to_tensor).
|
||||||
|
|
||||||
|
[TOC]
|
||||||
|
|
||||||
|
Functional operations.
|
||||||
|
|
||||||
|
## Higher Order Operators
|
||||||
|
|
||||||
|
TensorFlow provides several higher order operators to simplify the common
|
||||||
|
map-reduce programming patterns.
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
### `tf.map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#map_fn}
|
||||||
|
|
||||||
|
map on the list of tensors unpacked from `elems` on dimension 0.
|
||||||
|
|
||||||
|
This map operator repeatedly applies the callable `fn` to a sequence of
|
||||||
|
elements from first to last. The elements are made of the tensors unpacked
|
||||||
|
from `elems`. `dtype` is the data type of the return value of `fn`. Users
|
||||||
|
must provide `dtype` if it is different from the data type of `elems`.
|
||||||
|
|
||||||
|
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||||
|
of the result tensor is `[len(values)] + fn(values[0]).shape`.
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`fn`</b>: The callable to be performed.
|
||||||
|
* <b>`elems`</b>: A tensor to be unpacked to apply `fn`.
|
||||||
|
* <b>`dtype`</b>: (optional) The output type of `fn`.
|
||||||
|
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
|
||||||
|
in parallel.
|
||||||
|
* <b>`back_prop`</b>: (optional) True enables back propagation.
|
||||||
|
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
|
||||||
|
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
A tensor that packs the results of applying `fn` to the list of tensors
|
||||||
|
unpacked from `elems`, from first to last.
|
||||||
|
|
||||||
|
##### Raises:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`TypeError`</b>: if `fn` is not callable.
|
||||||
|
|
||||||
|
##### Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
elems = [1, 2, 3, 4, 5, 6]
|
||||||
|
squares = map_fn(lambda x: x * x, elems)
|
||||||
|
# squares == [1, 4, 9, 16, 25, 36]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
### `tf.foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#foldl}
|
||||||
|
|
||||||
|
foldl on the list of tensors unpacked from `elems` on dimension 0.
|
||||||
|
|
||||||
|
This foldl operator repeatedly applies the callable `fn` to a sequence
|
||||||
|
of elements from first to last. The elements are made of the tensors
|
||||||
|
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
|
||||||
|
arguments. The first argument is the accumulated value computed from the
|
||||||
|
preceding invocation of fn. If `initializer` is None, `elems` must contain
|
||||||
|
at least one element, and its first element is used as the initializer.
|
||||||
|
|
||||||
|
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||||
|
of the result tensor is fn(initializer, values[0]).shape`.
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`fn`</b>: The callable to be performed.
|
||||||
|
* <b>`elems`</b>: A tensor to be unpacked on dimension 0.
|
||||||
|
* <b>`initializer`</b>: (optional) The initial value for the accumulator.
|
||||||
|
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
|
||||||
|
in parallel.
|
||||||
|
* <b>`back_prop`</b>: (optional) True enables back propagation.
|
||||||
|
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
|
||||||
|
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
A tensor resulting from applying `fn` consecutively to the list of tensors
|
||||||
|
unpacked from `elems`, from first to last.
|
||||||
|
|
||||||
|
##### Raises:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`TypeError`</b>: if `fn` is not callable.
|
||||||
|
|
||||||
|
##### Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
elems = [1, 2, 3, 4, 5, 6]
|
||||||
|
sum = foldl(lambda a, x: a + x, elems)
|
||||||
|
# sum == 21
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
### `tf.foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#foldr}
|
||||||
|
|
||||||
|
foldr on the list of tensors unpacked from `elems` on dimension 0.
|
||||||
|
|
||||||
|
This foldr operator repeatedly applies the callable `fn` to a sequence
|
||||||
|
of elements from last to first. The elements are made of the tensors
|
||||||
|
unpacked from `elems`. The callable fn takes two tensors as arguments.
|
||||||
|
The first argument is the accumulated value computed from the preceding
|
||||||
|
invocation of fn. If `initializer` is None, `elems` must contain at least
|
||||||
|
one element, and its first element is used as the initializer.
|
||||||
|
|
||||||
|
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||||
|
of the result tensor is `fn(initializer, values[0]).shape`.
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`fn`</b>: The callable to be performed.
|
||||||
|
* <b>`elems`</b>: A tensor that is unpacked into a sequence of tensors to apply `fn`.
|
||||||
|
* <b>`initializer`</b>: (optional) The initial value for the accumulator.
|
||||||
|
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
|
||||||
|
in parallel.
|
||||||
|
* <b>`back_prop`</b>: (optional) True enables back propagation.
|
||||||
|
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
|
||||||
|
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
A tensor resulting from applying `fn` consecutively to the list of tensors
|
||||||
|
unpacked from `elems`, from last to first.
|
||||||
|
|
||||||
|
##### Raises:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`TypeError`</b>: if `fn` is not callable.
|
||||||
|
|
||||||
|
##### Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
elems = [1, 2, 3, 4, 5, 6]
|
||||||
|
sum = foldr(lambda a, x: a + x, elems)
|
||||||
|
# sum == 21
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
- - -
|
||||||
|
|
||||||
|
### `tf.scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#scan}
|
||||||
|
|
||||||
|
scan on the list of tensors unpacked from `elems` on dimension 0.
|
||||||
|
|
||||||
|
This scan operator repeatedly applies the callable `fn` to a sequence
|
||||||
|
of elements from first to last. The elements are made of the tensors
|
||||||
|
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
|
||||||
|
arguments. The first argument is the accumulated value computed from the
|
||||||
|
preceding invocation of fn. If `initializer` is None, `elems` must contain
|
||||||
|
at least one element, and its first element is used as the initializer.
|
||||||
|
|
||||||
|
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||||
|
of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
|
||||||
|
|
||||||
|
##### Args:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`fn`</b>: The callable to be performed.
|
||||||
|
* <b>`elems`</b>: A tensor to be unpacked on dimension 0.
|
||||||
|
* <b>`initializer`</b>: (optional) The initial value for the accumulator.
|
||||||
|
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
|
||||||
|
in parallel.
|
||||||
|
* <b>`back_prop`</b>: (optional) True enables back propagation.
|
||||||
|
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
|
||||||
|
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
|
||||||
|
|
||||||
|
##### Returns:
|
||||||
|
|
||||||
|
A tensor that packs the results of applying `fn` to the list of tensors
|
||||||
|
unpacked from `elems`, from first to last.
|
||||||
|
|
||||||
|
##### Raises:
|
||||||
|
|
||||||
|
|
||||||
|
* <b>`TypeError`</b>: if `fn` is not callable.
|
||||||
|
|
||||||
|
##### Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
elems = [1, 2, 3, 4, 5, 6]
|
||||||
|
sum = scan(lambda a, x: a + x, elems)
|
||||||
|
# sum == [1, 3, 6, 10, 15, 21]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
@ -1234,29 +1234,3 @@ false and no bounding boxes are supplied, an error is raised.
|
|||||||
Provide as input to `tf.image.draw_bounding_boxes`.
|
Provide as input to `tf.image.draw_bounding_boxes`.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Other Functions and Classes
|
|
||||||
- - -
|
|
||||||
|
|
||||||
### `tf.contrib.layers.make_all(module_name, doc_string_modules=None)` {#make_all}
|
|
||||||
|
|
||||||
Generate `__all__` from the docstring of one or more modules.
|
|
||||||
|
|
||||||
Usage: `make_all(__name__)` or
|
|
||||||
`make_all(__name__, [sys.modules(__name__), other_module])`. The doc string
|
|
||||||
modules must each a docstring, and `__all__` will contain all symbols with
|
|
||||||
`@@` references, where that symbol currently exists in the module named
|
|
||||||
`module_name`.
|
|
||||||
|
|
||||||
##### Args:
|
|
||||||
|
|
||||||
|
|
||||||
* <b>`module_name`</b>: The name of the module (usually `__name__`).
|
|
||||||
* <b>`doc_string_modules`</b>: a list of modules from which to take docstring.
|
|
||||||
If None, then a list containing only the module named `module_name` is used.
|
|
||||||
|
|
||||||
##### Returns:
|
|
||||||
|
|
||||||
A list suitable for use as `__all__`.
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
* [`Dimension`](../../api_docs/python/framework.md#Dimension)
|
* [`Dimension`](../../api_docs/python/framework.md#Dimension)
|
||||||
* [`DType`](../../api_docs/python/framework.md#DType)
|
* [`DType`](../../api_docs/python/framework.md#DType)
|
||||||
* [`get_collection`](../../api_docs/python/framework.md#get_collection)
|
* [`get_collection`](../../api_docs/python/framework.md#get_collection)
|
||||||
|
* [`get_collection_ref`](../../api_docs/python/framework.md#get_collection_ref)
|
||||||
* [`get_default_graph`](../../api_docs/python/framework.md#get_default_graph)
|
* [`get_default_graph`](../../api_docs/python/framework.md#get_default_graph)
|
||||||
* [`get_seed`](../../api_docs/python/framework.md#get_seed)
|
* [`get_seed`](../../api_docs/python/framework.md#get_seed)
|
||||||
* [`Graph`](../../api_docs/python/framework.md#Graph)
|
* [`Graph`](../../api_docs/python/framework.md#Graph)
|
||||||
@ -95,6 +96,7 @@
|
|||||||
* [`dynamic_stitch`](../../api_docs/python/array_ops.md#dynamic_stitch)
|
* [`dynamic_stitch`](../../api_docs/python/array_ops.md#dynamic_stitch)
|
||||||
* [`expand_dims`](../../api_docs/python/array_ops.md#expand_dims)
|
* [`expand_dims`](../../api_docs/python/array_ops.md#expand_dims)
|
||||||
* [`gather`](../../api_docs/python/array_ops.md#gather)
|
* [`gather`](../../api_docs/python/array_ops.md#gather)
|
||||||
|
* [`gather_nd`](../../api_docs/python/array_ops.md#gather_nd)
|
||||||
* [`one_hot`](../../api_docs/python/array_ops.md#one_hot)
|
* [`one_hot`](../../api_docs/python/array_ops.md#one_hot)
|
||||||
* [`pack`](../../api_docs/python/array_ops.md#pack)
|
* [`pack`](../../api_docs/python/array_ops.md#pack)
|
||||||
* [`pad`](../../api_docs/python/array_ops.md#pad)
|
* [`pad`](../../api_docs/python/array_ops.md#pad)
|
||||||
@ -229,8 +231,6 @@
|
|||||||
* [`cond`](../../api_docs/python/control_flow_ops.md#cond)
|
* [`cond`](../../api_docs/python/control_flow_ops.md#cond)
|
||||||
* [`count_up_to`](../../api_docs/python/control_flow_ops.md#count_up_to)
|
* [`count_up_to`](../../api_docs/python/control_flow_ops.md#count_up_to)
|
||||||
* [`equal`](../../api_docs/python/control_flow_ops.md#equal)
|
* [`equal`](../../api_docs/python/control_flow_ops.md#equal)
|
||||||
* [`foldl`](../../api_docs/python/control_flow_ops.md#foldl)
|
|
||||||
* [`foldr`](../../api_docs/python/control_flow_ops.md#foldr)
|
|
||||||
* [`greater`](../../api_docs/python/control_flow_ops.md#greater)
|
* [`greater`](../../api_docs/python/control_flow_ops.md#greater)
|
||||||
* [`greater_equal`](../../api_docs/python/control_flow_ops.md#greater_equal)
|
* [`greater_equal`](../../api_docs/python/control_flow_ops.md#greater_equal)
|
||||||
* [`group`](../../api_docs/python/control_flow_ops.md#group)
|
* [`group`](../../api_docs/python/control_flow_ops.md#group)
|
||||||
@ -244,16 +244,20 @@
|
|||||||
* [`logical_not`](../../api_docs/python/control_flow_ops.md#logical_not)
|
* [`logical_not`](../../api_docs/python/control_flow_ops.md#logical_not)
|
||||||
* [`logical_or`](../../api_docs/python/control_flow_ops.md#logical_or)
|
* [`logical_or`](../../api_docs/python/control_flow_ops.md#logical_or)
|
||||||
* [`logical_xor`](../../api_docs/python/control_flow_ops.md#logical_xor)
|
* [`logical_xor`](../../api_docs/python/control_flow_ops.md#logical_xor)
|
||||||
* [`map_fn`](../../api_docs/python/control_flow_ops.md#map_fn)
|
|
||||||
* [`no_op`](../../api_docs/python/control_flow_ops.md#no_op)
|
* [`no_op`](../../api_docs/python/control_flow_ops.md#no_op)
|
||||||
* [`not_equal`](../../api_docs/python/control_flow_ops.md#not_equal)
|
* [`not_equal`](../../api_docs/python/control_flow_ops.md#not_equal)
|
||||||
* [`Print`](../../api_docs/python/control_flow_ops.md#Print)
|
* [`Print`](../../api_docs/python/control_flow_ops.md#Print)
|
||||||
* [`scan`](../../api_docs/python/control_flow_ops.md#scan)
|
|
||||||
* [`select`](../../api_docs/python/control_flow_ops.md#select)
|
* [`select`](../../api_docs/python/control_flow_ops.md#select)
|
||||||
* [`tuple`](../../api_docs/python/control_flow_ops.md#tuple)
|
* [`tuple`](../../api_docs/python/control_flow_ops.md#tuple)
|
||||||
* [`verify_tensor_all_finite`](../../api_docs/python/control_flow_ops.md#verify_tensor_all_finite)
|
* [`verify_tensor_all_finite`](../../api_docs/python/control_flow_ops.md#verify_tensor_all_finite)
|
||||||
* [`where`](../../api_docs/python/control_flow_ops.md#where)
|
* [`where`](../../api_docs/python/control_flow_ops.md#where)
|
||||||
|
|
||||||
|
* **[Higher Order Functions](../../api_docs/python/functional_ops.md)**:
|
||||||
|
* [`foldl`](../../api_docs/python/functional_ops.md#foldl)
|
||||||
|
* [`foldr`](../../api_docs/python/functional_ops.md#foldr)
|
||||||
|
* [`map_fn`](../../api_docs/python/functional_ops.md#map_fn)
|
||||||
|
* [`scan`](../../api_docs/python/functional_ops.md#scan)
|
||||||
|
|
||||||
* **[Images](../../api_docs/python/image.md)**:
|
* **[Images](../../api_docs/python/image.md)**:
|
||||||
* [`adjust_brightness`](../../api_docs/python/image.md#adjust_brightness)
|
* [`adjust_brightness`](../../api_docs/python/image.md#adjust_brightness)
|
||||||
* [`adjust_contrast`](../../api_docs/python/image.md#adjust_contrast)
|
* [`adjust_contrast`](../../api_docs/python/image.md#adjust_contrast)
|
||||||
@ -272,7 +276,6 @@
|
|||||||
* [`flip_up_down`](../../api_docs/python/image.md#flip_up_down)
|
* [`flip_up_down`](../../api_docs/python/image.md#flip_up_down)
|
||||||
* [`grayscale_to_rgb`](../../api_docs/python/image.md#grayscale_to_rgb)
|
* [`grayscale_to_rgb`](../../api_docs/python/image.md#grayscale_to_rgb)
|
||||||
* [`hsv_to_rgb`](../../api_docs/python/image.md#hsv_to_rgb)
|
* [`hsv_to_rgb`](../../api_docs/python/image.md#hsv_to_rgb)
|
||||||
* [`make_all`](../../api_docs/python/image.md#make_all)
|
|
||||||
* [`pad_to_bounding_box`](../../api_docs/python/image.md#pad_to_bounding_box)
|
* [`pad_to_bounding_box`](../../api_docs/python/image.md#pad_to_bounding_box)
|
||||||
* [`per_image_whitening`](../../api_docs/python/image.md#per_image_whitening)
|
* [`per_image_whitening`](../../api_docs/python/image.md#per_image_whitening)
|
||||||
* [`random_brightness`](../../api_docs/python/image.md#random_brightness)
|
* [`random_brightness`](../../api_docs/python/image.md#random_brightness)
|
||||||
@ -466,6 +469,7 @@
|
|||||||
* [`fully_connected`](../../api_docs/python/contrib.layers.md#fully_connected)
|
* [`fully_connected`](../../api_docs/python/contrib.layers.md#fully_connected)
|
||||||
* [`l1_regularizer`](../../api_docs/python/contrib.layers.md#l1_regularizer)
|
* [`l1_regularizer`](../../api_docs/python/contrib.layers.md#l1_regularizer)
|
||||||
* [`l2_regularizer`](../../api_docs/python/contrib.layers.md#l2_regularizer)
|
* [`l2_regularizer`](../../api_docs/python/contrib.layers.md#l2_regularizer)
|
||||||
|
* [`make_all`](../../api_docs/python/contrib.layers.md#make_all)
|
||||||
* [`optimize_loss`](../../api_docs/python/contrib.layers.md#optimize_loss)
|
* [`optimize_loss`](../../api_docs/python/contrib.layers.md#optimize_loss)
|
||||||
* [`sum_regularizer`](../../api_docs/python/contrib.layers.md#sum_regularizer)
|
* [`sum_regularizer`](../../api_docs/python/contrib.layers.md#sum_regularizer)
|
||||||
* [`summarize_activation`](../../api_docs/python/contrib.layers.md#summarize_activation)
|
* [`summarize_activation`](../../api_docs/python/contrib.layers.md#summarize_activation)
|
||||||
|
@ -101,6 +101,7 @@ from tensorflow.python.framework import framework_lib
|
|||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import constant_op
|
from tensorflow.python.ops import constant_op
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import functional_ops
|
||||||
from tensorflow.python.ops import histogram_ops
|
from tensorflow.python.ops import histogram_ops
|
||||||
from tensorflow.python.ops import io_ops
|
from tensorflow.python.ops import io_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -118,8 +119,8 @@ _whitelist = set([app, compat, contrib, errors, flags, gfile, image,
|
|||||||
# strings of other modules.
|
# strings of other modules.
|
||||||
__all__ = make_all(__name__,
|
__all__ = make_all(__name__,
|
||||||
[framework_lib, array_ops, client_lib, constant_op,
|
[framework_lib, array_ops, client_lib, constant_op,
|
||||||
control_flow_ops, histogram_ops, io_ops, math_ops, nn,
|
control_flow_ops, functional_ops, histogram_ops, io_ops,
|
||||||
script_ops, sparse_ops, state_ops, train])
|
math_ops, nn, script_ops, sparse_ops, state_ops, train])
|
||||||
|
|
||||||
# Symbols whitelisted for export without documentation.
|
# Symbols whitelisted for export without documentation.
|
||||||
# TODO(cwhipkey): review these and move to contrib, expose through
|
# TODO(cwhipkey): review these and move to contrib, expose through
|
||||||
|
@ -484,6 +484,9 @@ class Library(Document):
|
|||||||
names = self._members.items()
|
names = self._members.items()
|
||||||
else:
|
else:
|
||||||
names = inspect.getmembers(self._module)
|
names = inspect.getmembers(self._module)
|
||||||
|
all_names = getattr(self._module, "__all__", None)
|
||||||
|
if all_names is not None:
|
||||||
|
names = [(n, m) for n, m in names if n in all_names]
|
||||||
leftovers = []
|
leftovers = []
|
||||||
for name, _ in names:
|
for name, _ in names:
|
||||||
if name in self._members and name not in self._documented:
|
if name in self._members and name not in self._documented:
|
||||||
|
@ -43,6 +43,7 @@
|
|||||||
|
|
||||||
@@add_to_collection
|
@@add_to_collection
|
||||||
@@get_collection
|
@@get_collection
|
||||||
|
@@get_collection_ref
|
||||||
@@GraphKeys
|
@@GraphKeys
|
||||||
|
|
||||||
## Defining new operations
|
## Defining new operations
|
||||||
@ -82,6 +83,7 @@ from tensorflow.python.framework.ops import reset_default_graph
|
|||||||
from tensorflow.python.framework.ops import GraphKeys
|
from tensorflow.python.framework.ops import GraphKeys
|
||||||
from tensorflow.python.framework.ops import add_to_collection
|
from tensorflow.python.framework.ops import add_to_collection
|
||||||
from tensorflow.python.framework.ops import get_collection
|
from tensorflow.python.framework.ops import get_collection
|
||||||
|
from tensorflow.python.framework.ops import get_collection_ref
|
||||||
from tensorflow.python.framework.ops import convert_to_tensor
|
from tensorflow.python.framework.ops import convert_to_tensor
|
||||||
from tensorflow.python.framework.ops import convert_to_tensor_or_indexed_slices
|
from tensorflow.python.framework.ops import convert_to_tensor_or_indexed_slices
|
||||||
from tensorflow.python.framework.random_seed import get_seed
|
from tensorflow.python.framework.random_seed import get_seed
|
||||||
|
@ -83,6 +83,7 @@ def all_libraries(module_to_name, members, documented):
|
|||||||
prefix=PREFIX_TEXT),
|
prefix=PREFIX_TEXT),
|
||||||
library("histogram_ops", "Histograms"),
|
library("histogram_ops", "Histograms"),
|
||||||
library("control_flow_ops", "Control Flow", prefix=PREFIX_TEXT),
|
library("control_flow_ops", "Control Flow", prefix=PREFIX_TEXT),
|
||||||
|
library("functional_ops", "Higher Order Functions", prefix=PREFIX_TEXT),
|
||||||
library("image", "Images", tf.image, exclude_symbols=["ResizeMethod"],
|
library("image", "Images", tf.image, exclude_symbols=["ResizeMethod"],
|
||||||
prefix=PREFIX_TEXT),
|
prefix=PREFIX_TEXT),
|
||||||
library("sparse_ops", "Sparse Tensors",
|
library("sparse_ops", "Sparse Tensors",
|
||||||
|
@ -1787,6 +1787,7 @@ class Graph(object):
|
|||||||
|
|
||||||
@@add_to_collection
|
@@add_to_collection
|
||||||
@@get_collection
|
@@get_collection
|
||||||
|
@@get_collection_ref
|
||||||
|
|
||||||
@@as_graph_element
|
@@as_graph_element
|
||||||
@@get_operation_by_name
|
@@get_operation_by_name
|
||||||
@ -2396,6 +2397,30 @@ class Graph(object):
|
|||||||
for name in set(names):
|
for name in set(names):
|
||||||
self.add_to_collection(name, value)
|
self.add_to_collection(name, value)
|
||||||
|
|
||||||
|
def get_collection_ref(self, name):
|
||||||
|
"""Returns a list of values in the collection with the given `name`.
|
||||||
|
|
||||||
|
If the collection exists, this returns the list itself, which can
|
||||||
|
be modified in place to change the collection. If the collection does
|
||||||
|
not exist, it is created as an empty list and the list is returned.
|
||||||
|
|
||||||
|
This is different from `get_collection()` which always returns a copy of
|
||||||
|
the collection list if it exists and never creates an empty collection.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The key for the collection. For example, the `GraphKeys` class
|
||||||
|
contains many standard names for collections.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The list of values in the collection with the given `name`, or an empty
|
||||||
|
list if no value has been added to that collection.
|
||||||
|
"""
|
||||||
|
coll_list = self._collections.get(name, None)
|
||||||
|
if coll_list is None:
|
||||||
|
coll_list = []
|
||||||
|
self._collections[name] = coll_list
|
||||||
|
return coll_list
|
||||||
|
|
||||||
def get_collection(self, name, scope=None):
|
def get_collection(self, name, scope=None):
|
||||||
"""Returns a list of values in the collection with the given `name`.
|
"""Returns a list of values in the collection with the given `name`.
|
||||||
|
|
||||||
@ -2411,11 +2436,14 @@ class Graph(object):
|
|||||||
list contains the values in the order under which they were
|
list contains the values in the order under which they were
|
||||||
collected.
|
collected.
|
||||||
"""
|
"""
|
||||||
|
coll_list = self._collections.get(name, None)
|
||||||
|
if coll_list is None:
|
||||||
|
return []
|
||||||
if scope is None:
|
if scope is None:
|
||||||
return self._collections.get(name, list())
|
return list(coll_list)
|
||||||
else:
|
else:
|
||||||
c = []
|
c = []
|
||||||
for item in self._collections.get(name, list()):
|
for item in coll_list:
|
||||||
if hasattr(item, "name") and item.name.startswith(scope):
|
if hasattr(item, "name") and item.name.startswith(scope):
|
||||||
c.append(item)
|
c.append(item)
|
||||||
return c
|
return c
|
||||||
@ -3547,6 +3575,25 @@ def add_to_collections(names, value):
|
|||||||
get_default_graph().add_to_collections(names, value)
|
get_default_graph().add_to_collections(names, value)
|
||||||
|
|
||||||
|
|
||||||
|
def get_collection_ref(key):
|
||||||
|
"""Wrapper for `Graph.get_collection_ref()` using the default graph.
|
||||||
|
|
||||||
|
See [`Graph.get_collection_ref()`](../../api_docs/python/framework.md#Graph.get_collection_ref)
|
||||||
|
for more details.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The key for the collection. For example, the `GraphKeys` class
|
||||||
|
contains many standard names for collections.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The list of values in the collection with the given `name`, or an empty
|
||||||
|
list if no value has been added to that collection. Note that this returns
|
||||||
|
the collection list itself, which can be modified in place to change the
|
||||||
|
collection.
|
||||||
|
"""
|
||||||
|
return get_default_graph().get_collection_ref(key)
|
||||||
|
|
||||||
|
|
||||||
def get_collection(key, scope=None):
|
def get_collection(key, scope=None):
|
||||||
"""Wrapper for `Graph.get_collection()` using the default graph.
|
"""Wrapper for `Graph.get_collection()` using the default graph.
|
||||||
|
|
||||||
|
@ -751,12 +751,41 @@ class CollectionTest(test_util.TensorFlowTestCase):
|
|||||||
blank2 = ObjectWithName("junk/foo")
|
blank2 = ObjectWithName("junk/foo")
|
||||||
g.add_to_collection("blah", blank2)
|
g.add_to_collection("blah", blank2)
|
||||||
|
|
||||||
self.assertEqual(["foo"], g.get_collection("other"))
|
|
||||||
self.assertEqual([12, 34], g.get_collection("key"))
|
self.assertEqual([12, 34], g.get_collection("key"))
|
||||||
self.assertEqual([], g.get_collection("nothing"))
|
self.assertEqual([], g.get_collection("nothing"))
|
||||||
self.assertEqual([27, blank1, blank2], g.get_collection("blah"))
|
self.assertEqual([27, blank1, blank2], g.get_collection("blah"))
|
||||||
self.assertEqual([blank1], g.get_collection("blah", "prefix"))
|
self.assertEqual([blank1], g.get_collection("blah", "prefix"))
|
||||||
|
|
||||||
|
# Make sure that get_collection() returns a first-level
|
||||||
|
# copy of the collection, while get_collection_ref() returns
|
||||||
|
# the original list.
|
||||||
|
other_collection_snapshot = g.get_collection("other")
|
||||||
|
other_collection_ref = g.get_collection_ref("other")
|
||||||
|
self.assertEqual(["foo"], other_collection_snapshot)
|
||||||
|
self.assertEqual(["foo"], other_collection_ref)
|
||||||
|
g.add_to_collection("other", "bar")
|
||||||
|
self.assertEqual(["foo"], other_collection_snapshot)
|
||||||
|
self.assertEqual(["foo", "bar"], other_collection_ref)
|
||||||
|
self.assertEqual(["foo", "bar"], g.get_collection("other"))
|
||||||
|
self.assertTrue(other_collection_ref is g.get_collection_ref("other"))
|
||||||
|
|
||||||
|
# Verify that getting an empty collection ref returns a modifiable list.
|
||||||
|
empty_coll_ref = g.get_collection_ref("empty")
|
||||||
|
self.assertEqual([], empty_coll_ref)
|
||||||
|
empty_coll = g.get_collection("empty")
|
||||||
|
self.assertEqual([], empty_coll)
|
||||||
|
self.assertFalse(empty_coll is empty_coll_ref)
|
||||||
|
empty_coll_ref2 = g.get_collection_ref("empty")
|
||||||
|
self.assertTrue(empty_coll_ref2 is empty_coll_ref)
|
||||||
|
# Add to the collection.
|
||||||
|
empty_coll_ref.append("something")
|
||||||
|
self.assertEqual(["something"], empty_coll_ref)
|
||||||
|
self.assertEqual(["something"], empty_coll_ref2)
|
||||||
|
self.assertEqual([], empty_coll)
|
||||||
|
self.assertEqual(["something"], g.get_collection("empty"))
|
||||||
|
empty_coll_ref3 = g.get_collection_ref("empty")
|
||||||
|
self.assertTrue(empty_coll_ref3 is empty_coll_ref)
|
||||||
|
|
||||||
def testDefaulGraph(self):
|
def testDefaulGraph(self):
|
||||||
with ops.Graph().as_default():
|
with ops.Graph().as_default():
|
||||||
ops.add_to_collection("key", 90)
|
ops.add_to_collection("key", 90)
|
||||||
|
@ -20,7 +20,6 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import math
|
import math
|
||||||
import sys
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||||
@ -1182,75 +1181,6 @@ class ControlFlowTest(tf.test.TestCase):
|
|||||||
self.assertEqual(0, value_x)
|
self.assertEqual(0, value_x)
|
||||||
self.assertEqual(73, value_x_grad)
|
self.assertEqual(73, value_x_grad)
|
||||||
|
|
||||||
def testFoldl_Simple(self):
|
|
||||||
with self.test_session():
|
|
||||||
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
|
|
||||||
|
|
||||||
r = control_flow_ops.foldl(
|
|
||||||
lambda a, x: tf.mul(tf.add(a, x), 2), elems)
|
|
||||||
self.assertAllEqual(208, r.eval())
|
|
||||||
|
|
||||||
r = control_flow_ops.foldl(
|
|
||||||
lambda a, x: tf.mul(tf.add(a, x), 2), elems, initializer=10)
|
|
||||||
self.assertAllEqual(880, r.eval())
|
|
||||||
|
|
||||||
def testFoldr_Simple(self):
|
|
||||||
with self.test_session():
|
|
||||||
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
|
|
||||||
|
|
||||||
r = control_flow_ops.foldr(
|
|
||||||
lambda a, x: tf.mul(tf.add(a, x), 2), elems)
|
|
||||||
self.assertAllEqual(450, r.eval())
|
|
||||||
|
|
||||||
r = control_flow_ops.foldr(
|
|
||||||
lambda a, x: tf.mul(tf.add(a, x), 2), elems, initializer=10)
|
|
||||||
self.assertAllEqual(1282, r.eval())
|
|
||||||
|
|
||||||
def testFold_Grad(self):
|
|
||||||
with self.test_session():
|
|
||||||
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
|
|
||||||
v = tf.constant(2.0, name="v")
|
|
||||||
|
|
||||||
r = control_flow_ops.foldl(
|
|
||||||
lambda a, x: tf.mul(a, x), elems, initializer=v)
|
|
||||||
r = tf.gradients(r, v)[0]
|
|
||||||
self.assertAllEqual(720.0, r.eval())
|
|
||||||
|
|
||||||
r = control_flow_ops.foldr(
|
|
||||||
lambda a, x: tf.mul(a, x), elems, initializer=v)
|
|
||||||
r = tf.gradients(r, v)[0]
|
|
||||||
self.assertAllEqual(720.0, r.eval())
|
|
||||||
|
|
||||||
def testMap_Simple(self):
|
|
||||||
with self.test_session():
|
|
||||||
nums = [1, 2, 3, 4, 5, 6]
|
|
||||||
elems = tf.constant(nums, name="data")
|
|
||||||
r = control_flow_ops.map_fn(
|
|
||||||
lambda x: tf.mul(tf.add(x, 3), 2), elems)
|
|
||||||
self.assertAllEqual(np.array([(x + 3) * 2 for x in nums]), r.eval())
|
|
||||||
|
|
||||||
def testScan_Simple(self):
|
|
||||||
with self.test_session():
|
|
||||||
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
|
|
||||||
v = tf.constant(2.0, name="v")
|
|
||||||
|
|
||||||
r = control_flow_ops.scan(lambda a, x: tf.mul(a, x), elems)
|
|
||||||
self.assertAllEqual([1., 2., 6., 24., 120., 720.], r.eval())
|
|
||||||
|
|
||||||
r = control_flow_ops.scan(
|
|
||||||
lambda a, x: tf.mul(a, x), elems, initializer=v)
|
|
||||||
self.assertAllEqual([2., 4., 12., 48., 240., 1440.], r.eval())
|
|
||||||
|
|
||||||
def testScan_Grad(self):
|
|
||||||
with self.test_session():
|
|
||||||
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
|
|
||||||
v = tf.constant(2.0, name="v")
|
|
||||||
|
|
||||||
r = control_flow_ops.scan(
|
|
||||||
lambda a, x: tf.mul(a, x), elems, initializer=v)
|
|
||||||
r = tf.gradients(r, v)[0]
|
|
||||||
self.assertAllEqual(873.0, r.eval())
|
|
||||||
|
|
||||||
def testOneValueCond(self):
|
def testOneValueCond(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
c = tf.placeholder(tf.int32, shape=[])
|
c = tf.placeholder(tf.int32, shape=[])
|
||||||
|
94
tensorflow/python/kernel_tests/functional_ops_test.py
Normal file
94
tensorflow/python/kernel_tests/functional_ops_test.py
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Tests for tensorflow.kernels.bcast_ops."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionalOpsTest(tf.test.TestCase):
|
||||||
|
|
||||||
|
def testFoldl_Simple(self):
|
||||||
|
with self.test_session():
|
||||||
|
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
|
||||||
|
|
||||||
|
r = tf.foldl(lambda a, x: tf.mul(tf.add(a, x), 2), elems)
|
||||||
|
self.assertAllEqual(208, r.eval())
|
||||||
|
|
||||||
|
r = tf.foldl(
|
||||||
|
lambda a, x: tf.mul(tf.add(a, x), 2), elems, initializer=10)
|
||||||
|
self.assertAllEqual(880, r.eval())
|
||||||
|
|
||||||
|
def testFoldr_Simple(self):
|
||||||
|
with self.test_session():
|
||||||
|
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
|
||||||
|
|
||||||
|
r = tf.foldr(lambda a, x: tf.mul(tf.add(a, x), 2), elems)
|
||||||
|
self.assertAllEqual(450, r.eval())
|
||||||
|
|
||||||
|
r = tf.foldr(
|
||||||
|
lambda a, x: tf.mul(tf.add(a, x), 2), elems, initializer=10)
|
||||||
|
self.assertAllEqual(1282, r.eval())
|
||||||
|
|
||||||
|
def testFold_Grad(self):
|
||||||
|
with self.test_session():
|
||||||
|
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
|
||||||
|
v = tf.constant(2.0, name="v")
|
||||||
|
|
||||||
|
r = tf.foldl(
|
||||||
|
lambda a, x: tf.mul(a, x), elems, initializer=v)
|
||||||
|
r = tf.gradients(r, v)[0]
|
||||||
|
self.assertAllEqual(720.0, r.eval())
|
||||||
|
|
||||||
|
r = tf.foldr(
|
||||||
|
lambda a, x: tf.mul(a, x), elems, initializer=v)
|
||||||
|
r = tf.gradients(r, v)[0]
|
||||||
|
self.assertAllEqual(720.0, r.eval())
|
||||||
|
|
||||||
|
def testMap_Simple(self):
|
||||||
|
with self.test_session():
|
||||||
|
nums = [1, 2, 3, 4, 5, 6]
|
||||||
|
elems = tf.constant(nums, name="data")
|
||||||
|
r = tf.map_fn(lambda x: tf.mul(tf.add(x, 3), 2), elems)
|
||||||
|
self.assertAllEqual(np.array([(x + 3) * 2 for x in nums]), r.eval())
|
||||||
|
|
||||||
|
def testScan_Simple(self):
|
||||||
|
with self.test_session():
|
||||||
|
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
|
||||||
|
v = tf.constant(2.0, name="v")
|
||||||
|
|
||||||
|
r = tf.scan(lambda a, x: tf.mul(a, x), elems)
|
||||||
|
self.assertAllEqual([1., 2., 6., 24., 120., 720.], r.eval())
|
||||||
|
|
||||||
|
r = tf.scan(
|
||||||
|
lambda a, x: tf.mul(a, x), elems, initializer=v)
|
||||||
|
self.assertAllEqual([2., 4., 12., 48., 240., 1440.], r.eval())
|
||||||
|
|
||||||
|
def testScan_Grad(self):
|
||||||
|
with self.test_session():
|
||||||
|
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
|
||||||
|
v = tf.constant(2.0, name="v")
|
||||||
|
|
||||||
|
r = tf.scan(lambda a, x: tf.mul(a, x), elems, initializer=v)
|
||||||
|
r = tf.gradients(r, v)[0]
|
||||||
|
self.assertAllEqual(873.0, r.eval())
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.test.main()
|
122
tensorflow/python/kernel_tests/gather_nd_op_test.py
Normal file
122
tensorflow/python/kernel_tests/gather_nd_op_test.py
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
"""Tests for tensorflow.ops.tf.gather_nd."""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
class GatherNdTest(tf.test.TestCase):
|
||||||
|
use_gpu = False
|
||||||
|
|
||||||
|
def _testSimpleDtype(self, dtype):
|
||||||
|
with self.test_session(use_gpu=self.use_gpu):
|
||||||
|
params = tf.constant(np.array([8, 1, 2, 3, 7, 5], dtype=dtype))
|
||||||
|
indices = tf.constant([[4], [4], [0]])
|
||||||
|
gather_nd_t = tf.gather_nd(params, indices)
|
||||||
|
gather_nd_val = gather_nd_t.eval()
|
||||||
|
|
||||||
|
self.assertAllEqual(np.array([7, 7, 8], dtype=dtype), gather_nd_val)
|
||||||
|
self.assertEqual([3], gather_nd_t.get_shape())
|
||||||
|
|
||||||
|
def testSimpleDtype(self):
|
||||||
|
self._testSimpleDtype(np.float32)
|
||||||
|
self._testSimpleDtype(np.float64)
|
||||||
|
self._testSimpleDtype(np.int32)
|
||||||
|
self._testSimpleDtype(np.int64)
|
||||||
|
self._testSimpleDtype(np.complex64)
|
||||||
|
self._testSimpleDtype("|S") # byte strings in python2 + 3
|
||||||
|
|
||||||
|
def testHigherRankParams(self):
|
||||||
|
with self.test_session(use_gpu=self.use_gpu):
|
||||||
|
shape = (10, 20, 5, 1, 17)
|
||||||
|
params = np.random.rand(*shape)
|
||||||
|
indices = np.vstack([
|
||||||
|
np.random.randint(0, s, size=2000) for s in shape]).T
|
||||||
|
gather_nd_t = tf.gather_nd(params, indices)
|
||||||
|
gather_nd_val = gather_nd_t.eval()
|
||||||
|
|
||||||
|
expected = params[tuple(indices.T)]
|
||||||
|
self.assertAllEqual(expected, gather_nd_val)
|
||||||
|
self.assertEqual([2000], gather_nd_t.get_shape())
|
||||||
|
|
||||||
|
def testHigherRankParamsAndIndices(self):
|
||||||
|
with self.test_session(use_gpu=self.use_gpu):
|
||||||
|
shape = (10, 20, 5, 1, 17)
|
||||||
|
params = np.random.rand(*shape)
|
||||||
|
indices = np.vstack([
|
||||||
|
np.random.randint(0, s, size=2000) for s in shape]).T
|
||||||
|
indices_reshaped = indices.reshape([10, 10, 20, 5])
|
||||||
|
gather_nd_t = tf.gather_nd(params, indices_reshaped)
|
||||||
|
gather_nd_val = gather_nd_t.eval()
|
||||||
|
|
||||||
|
expected = params[tuple(indices.T)]
|
||||||
|
self.assertAllEqual(expected.reshape([10, 10, 20]), gather_nd_val)
|
||||||
|
self.assertEqual([10, 10, 20], gather_nd_t.get_shape())
|
||||||
|
|
||||||
|
def testUnknownIndices(self):
|
||||||
|
params = tf.constant([[0, 1, 2]])
|
||||||
|
indices = tf.placeholder(tf.int32)
|
||||||
|
gather_nd_t = tf.gather_nd(params, indices)
|
||||||
|
shape = gather_nd_t.get_shape()
|
||||||
|
self.assertEqual(shape.ndims, None)
|
||||||
|
self.assertEqual(shape[0].value, None)
|
||||||
|
|
||||||
|
def testBadIndices(self):
|
||||||
|
with self.test_session(use_gpu=False):
|
||||||
|
params = [0, 1, 2]
|
||||||
|
indices = [[[0], [7]]] # Make this one higher rank
|
||||||
|
gather_nd = tf.gather_nd(params, indices)
|
||||||
|
with self.assertRaisesOpError(
|
||||||
|
r"flat indices\[1, :\] = \[7\] does not index into param "
|
||||||
|
r"\(shape: \[3\]\)"):
|
||||||
|
gather_nd.eval()
|
||||||
|
|
||||||
|
|
||||||
|
class GatherNdGpuTest(GatherNdTest):
|
||||||
|
use_gpu = True
|
||||||
|
|
||||||
|
|
||||||
|
class GatherNdOpBenchmark(tf.test.Benchmark):
|
||||||
|
|
||||||
|
def benchmark_gather_nd_op(self):
|
||||||
|
shape = (100, 47, 18, 170, 13)
|
||||||
|
np.random.seed(127)
|
||||||
|
params = np.random.rand(*shape)
|
||||||
|
indices = np.vstack([
|
||||||
|
np.random.randint(0, s, size=10000) for s in shape]).T
|
||||||
|
|
||||||
|
with tf.Session():
|
||||||
|
t_params = tf.Variable(params)
|
||||||
|
t_indices = tf.Variable(indices)
|
||||||
|
gather_op = tf.gather_nd(t_params, t_indices)
|
||||||
|
tf.initialize_all_variables().run()
|
||||||
|
for _ in range(10):
|
||||||
|
gather_op.eval()
|
||||||
|
t1 = time.time()
|
||||||
|
for _ in range(1000):
|
||||||
|
gather_op.eval()
|
||||||
|
t2 = time.time()
|
||||||
|
self.report_benchmark(iters=1000, wall_time=(t2-t1)/1000.0)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tf.test.main()
|
@ -193,6 +193,11 @@ def _GatherGrad(op, grad):
|
|||||||
return [ops.IndexedSlices(values, indices, dense_shape), None]
|
return [ops.IndexedSlices(values, indices, dense_shape), None]
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterGradient("GatherNd")
|
||||||
|
def _GatherNdGrad(unused_op, unused_grad):
|
||||||
|
raise NotImplementedError("Gradient for gather_nd is not implemented.")
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterGradient("Identity")
|
@ops.RegisterGradient("Identity")
|
||||||
def _IdGrad(_, grad):
|
def _IdGrad(_, grad):
|
||||||
return grad
|
return grad
|
||||||
|
@ -57,6 +57,7 @@ or join multiple tensors together.
|
|||||||
@@space_to_depth
|
@@space_to_depth
|
||||||
@@depth_to_space
|
@@depth_to_space
|
||||||
@@gather
|
@@gather
|
||||||
|
@@gather_nd
|
||||||
@@dynamic_partition
|
@@dynamic_partition
|
||||||
@@dynamic_stitch
|
@@dynamic_stitch
|
||||||
@@boolean_mask
|
@@boolean_mask
|
||||||
@ -879,6 +880,16 @@ def _GatherShape(op):
|
|||||||
return [indices_shape.concatenate(params_shape[1:])]
|
return [indices_shape.concatenate(params_shape[1:])]
|
||||||
|
|
||||||
|
|
||||||
|
@ops.RegisterShape("GatherNd")
|
||||||
|
def _GatherNdShape(op):
|
||||||
|
"""Shape function for array_ops.gather_nd."""
|
||||||
|
params_shape = op.inputs[0].get_shape()
|
||||||
|
indices_shape = op.inputs[1].get_shape().with_rank_at_least(2)
|
||||||
|
if indices_shape.ndims is not None:
|
||||||
|
indices_shape[-1].merge_with(params_shape.ndims)
|
||||||
|
return [indices_shape[:-1]]
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterShape("Unique")
|
@ops.RegisterShape("Unique")
|
||||||
def _UniqueShape(op):
|
def _UniqueShape(op):
|
||||||
"""Shape function for array_ops.Unique."""
|
"""Shape function for array_ops.Unique."""
|
||||||
|
@ -26,16 +26,6 @@ the execution of operations and add conditional dependencies to your graph.
|
|||||||
@@cond
|
@@cond
|
||||||
@@case
|
@@case
|
||||||
|
|
||||||
## Higher Order Operators
|
|
||||||
|
|
||||||
TensorFlow provides several higher order operators to simplify the common
|
|
||||||
map-reduce programming patterns.
|
|
||||||
|
|
||||||
@@map_fn
|
|
||||||
@@foldl
|
|
||||||
@@foldr
|
|
||||||
@@scan
|
|
||||||
|
|
||||||
## Logical Operators
|
## Logical Operators
|
||||||
|
|
||||||
TensorFlow provides several operations that you can use to add logical operators
|
TensorFlow provides several operations that you can use to add logical operators
|
||||||
@ -1785,269 +1775,6 @@ def tuple(tensors, name=None, control_inputs=None):
|
|||||||
return tpl
|
return tpl
|
||||||
|
|
||||||
|
|
||||||
# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
|
|
||||||
def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
|
||||||
swap_memory=False, name=None):
|
|
||||||
"""The foldl operator on the list of tensors resulted from unpacking `elems`
|
|
||||||
along the first dimension.
|
|
||||||
|
|
||||||
This foldl operator repeatedly applies the callable `fn` to a sequence
|
|
||||||
of elements from first to last. The elements are made of the tensors
|
|
||||||
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
|
|
||||||
arguments. The first argument is the accumulated value computed from the
|
|
||||||
preceding invocation of fn. If `initializer` is None, `elems` must contain
|
|
||||||
at least one element, and its first element is used as the initializer.
|
|
||||||
|
|
||||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
|
||||||
of the result tensor is fn(initializer, values[0]).shape`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fn: The callable to be performed.
|
|
||||||
elems: A tensor to be unpacked on dimension 0.
|
|
||||||
initializer: (optional) The initial value for the accumulator.
|
|
||||||
parallel_iterations: (optional) The number of iterations allowed to run
|
|
||||||
in parallel.
|
|
||||||
back_prop: (optional) True enables back propagation.
|
|
||||||
swap_memory: (optional) True enables GPU-CPU memory swapping.
|
|
||||||
name: (optional) Name prefix for the returned tensors.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tensor resulting from applying `fn` consecutively to the list of tensors
|
|
||||||
unpacked from `elems`, from first to last.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: if `fn` is not callable.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```python
|
|
||||||
elems = [1, 2, 3, 4, 5, 6]
|
|
||||||
sum = foldl(lambda a, x: a + x, elems)
|
|
||||||
# sum == 21
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
with ops.op_scope([elems], name, "foldl") as name:
|
|
||||||
if not callable(fn):
|
|
||||||
raise TypeError("fn must be callable.")
|
|
||||||
|
|
||||||
# Convert elems to tensor array.
|
|
||||||
n = array_ops.shape(elems)[0]
|
|
||||||
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
|
|
||||||
dynamic_size=False)
|
|
||||||
elems_ta = elems_ta.unpack(elems)
|
|
||||||
|
|
||||||
if initializer is None:
|
|
||||||
a = elems_ta.read(0)
|
|
||||||
i = constant_op.constant(1)
|
|
||||||
else:
|
|
||||||
a = ops.convert_to_tensor(initializer)
|
|
||||||
i = constant_op.constant(0)
|
|
||||||
|
|
||||||
def compute(i, a):
|
|
||||||
a = fn(a, elems_ta.read(i))
|
|
||||||
return [i + 1, a]
|
|
||||||
_, r_a = While(lambda i, a: i < n, compute, [i, a],
|
|
||||||
parallel_iterations=parallel_iterations,
|
|
||||||
back_prop=back_prop, swap_memory=swap_memory)
|
|
||||||
return r_a
|
|
||||||
|
|
||||||
|
|
||||||
def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
|
||||||
swap_memory=False, name=None):
|
|
||||||
"""The foldr operator on the list of tensors resulted from unpacking `elems`
|
|
||||||
along the first dimension.
|
|
||||||
|
|
||||||
This foldr operator repeatedly applies the callable `fn` to a sequence
|
|
||||||
of elements from last to first. The elements are made of the tensors
|
|
||||||
unpacked from `elems`. The callable fn takes two tensors as arguments.
|
|
||||||
The first argument is the accumulated value computed from the preceding
|
|
||||||
invocation of fn. If `initializer` is None, `elems` must contain at least
|
|
||||||
one element, and its first element is used as the initializer.
|
|
||||||
|
|
||||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
|
||||||
of the result tensor is `fn(initializer, values[0]).shape`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fn: The callable to be performed.
|
|
||||||
elems: A tensor that is unpacked into a sequence of tensors to apply `fn`.
|
|
||||||
initializer: (optional) The initial value for the accumulator.
|
|
||||||
parallel_iterations: (optional) The number of iterations allowed to run
|
|
||||||
in parallel.
|
|
||||||
back_prop: (optional) True enables back propagation.
|
|
||||||
swap_memory: (optional) True enables GPU-CPU memory swapping.
|
|
||||||
name: (optional) Name prefix for the returned tensors.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tensor resulting from applying `fn` consecutively to the list of tensors
|
|
||||||
unpacked from `elems`, from last to first.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: if `fn` is not callable.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```python
|
|
||||||
elems = [1, 2, 3, 4, 5, 6]
|
|
||||||
sum = foldr(lambda a, x: a + x, elems)
|
|
||||||
# sum == 21
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
with ops.op_scope([elems], name, "foldr") as name:
|
|
||||||
if not callable(fn):
|
|
||||||
raise TypeError("fn must be callable.")
|
|
||||||
|
|
||||||
# Convert elems to tensor array.
|
|
||||||
n = array_ops.shape(elems)[0]
|
|
||||||
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
|
|
||||||
dynamic_size=False)
|
|
||||||
elems_ta = elems_ta.unpack(elems)
|
|
||||||
|
|
||||||
if initializer is None:
|
|
||||||
i = n - 1
|
|
||||||
a = elems_ta.read(i)
|
|
||||||
else:
|
|
||||||
i = n
|
|
||||||
a = ops.convert_to_tensor(initializer)
|
|
||||||
def compute(i, a):
|
|
||||||
i -= 1
|
|
||||||
a = fn(a, elems_ta.read(i))
|
|
||||||
return [i, a]
|
|
||||||
_, r_a = While(lambda i, a: i > 0, compute, [i, a],
|
|
||||||
parallel_iterations=parallel_iterations,
|
|
||||||
back_prop=back_prop, swap_memory=swap_memory)
|
|
||||||
return r_a
|
|
||||||
|
|
||||||
|
|
||||||
def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
|
|
||||||
swap_memory=False, name=None):
|
|
||||||
"""The map operator on the list of tensors resulted from unpacking `elems`
|
|
||||||
along the first dimension.
|
|
||||||
|
|
||||||
This map operator repeatedly applies the callable `fn` to a sequence of
|
|
||||||
elements from first to last. The elements are made of the tensors unpacked
|
|
||||||
from `elems`. `dtype` is the data type of the return value of `fn`. Users
|
|
||||||
must provide `dtype` if it is different from the data type of `elems`.
|
|
||||||
|
|
||||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
|
||||||
of the result tensor is `[len(values)] + fn(values[0]).shape`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fn: The callable to be performed.
|
|
||||||
elems: A tensor to be unpacked to apply `fn`.
|
|
||||||
dtype: (optional) The output type of `fn`.
|
|
||||||
parallel_iterations: (optional) The number of iterations allowed to run
|
|
||||||
in parallel.
|
|
||||||
back_prop: (optional) True enables back propagation.
|
|
||||||
swap_memory: (optional) True enables GPU-CPU memory swapping.
|
|
||||||
name: (optional) Name prefix for the returned tensors.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tensor that packs the results of applying `fn` to the list of tensors
|
|
||||||
unpacked from `elems`, from first to last.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: if `fn` is not callable.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```python
|
|
||||||
elems = [1, 2, 3, 4, 5, 6]
|
|
||||||
squares = map_fn(lambda x: x * x, elems)
|
|
||||||
# squares == [1, 4, 9, 16, 25, 36]
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
with ops.op_scope([elems], name, "map") as name:
|
|
||||||
if not callable(fn):
|
|
||||||
raise TypeError("fn must be callable.")
|
|
||||||
dtype = dtype if dtype else elems.dtype
|
|
||||||
|
|
||||||
# Convert elems to tensor array.
|
|
||||||
n = array_ops.shape(elems)[0]
|
|
||||||
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
|
|
||||||
dynamic_size=False)
|
|
||||||
elems_ta = elems_ta.unpack(elems)
|
|
||||||
|
|
||||||
i = constant_op.constant(0)
|
|
||||||
acc_ta = tensor_array_ops.TensorArray(dtype=dtype, size=n,
|
|
||||||
dynamic_size=False)
|
|
||||||
def compute(i, ta):
|
|
||||||
ta = ta.write(i, fn(elems_ta.read(i)))
|
|
||||||
return [i + 1, ta]
|
|
||||||
_, r_a = While(lambda i, a: i < n, compute, [i, acc_ta],
|
|
||||||
parallel_iterations=parallel_iterations,
|
|
||||||
back_prop=back_prop, swap_memory=swap_memory)
|
|
||||||
return r_a.pack()
|
|
||||||
|
|
||||||
|
|
||||||
def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
|
||||||
swap_memory=False, name=None):
|
|
||||||
"""The scan operator on the list of tensors resulted from unpacking `elems`
|
|
||||||
along the first dimension.
|
|
||||||
|
|
||||||
This scan operator repeatedly applies the callable `fn` to a sequence
|
|
||||||
of elements from first to last. The elements are made of the tensors
|
|
||||||
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
|
|
||||||
arguments. The first argument is the accumulated value computed from the
|
|
||||||
preceding invocation of fn. If `initializer` is None, `elems` must contain
|
|
||||||
at least one element, and its first element is used as the initializer.
|
|
||||||
|
|
||||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
|
||||||
of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
fn: The callable to be performed.
|
|
||||||
elems: A tensor to be unpacked on dimension 0.
|
|
||||||
initializer: (optional) The initial value for the accumulator.
|
|
||||||
parallel_iterations: (optional) The number of iterations allowed to run
|
|
||||||
in parallel.
|
|
||||||
back_prop: (optional) True enables back propagation.
|
|
||||||
swap_memory: (optional) True enables GPU-CPU memory swapping.
|
|
||||||
name: (optional) Name prefix for the returned tensors.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A tensor that packs the results of applying `fn` to the list of tensors
|
|
||||||
unpacked from `elems`, from first to last.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: if `fn` is not callable.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```python
|
|
||||||
elems = [1, 2, 3, 4, 5, 6]
|
|
||||||
sum = scan(lambda a, x: a + x, elems)
|
|
||||||
# sum == [1, 3, 6, 10, 15, 21]
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
with ops.op_scope([elems], name, "scan") as name:
|
|
||||||
if not callable(fn):
|
|
||||||
raise TypeError("fn must be callable.")
|
|
||||||
|
|
||||||
# Convert elems to tensor array.
|
|
||||||
n = array_ops.shape(elems)[0]
|
|
||||||
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
|
|
||||||
dynamic_size=False)
|
|
||||||
elems_ta = elems_ta.unpack(elems)
|
|
||||||
|
|
||||||
if initializer is None:
|
|
||||||
a = elems_ta.read(0)
|
|
||||||
i = constant_op.constant(1)
|
|
||||||
else:
|
|
||||||
a = ops.convert_to_tensor(initializer)
|
|
||||||
i = constant_op.constant(0)
|
|
||||||
|
|
||||||
# Create a tensor array to store the intermediate values.
|
|
||||||
acc_ta = tensor_array_ops.TensorArray(dtype=a.dtype, size=n,
|
|
||||||
dynamic_size=False)
|
|
||||||
if initializer is None:
|
|
||||||
acc_ta = acc_ta.write(0, a)
|
|
||||||
|
|
||||||
def compute(i, a, ta):
|
|
||||||
a = fn(a, elems_ta.read(i))
|
|
||||||
ta = ta.write(i, a)
|
|
||||||
return [i + 1, a, ta]
|
|
||||||
_, _, r_a = While(lambda i, a, ta: i < n, compute, [i, a, acc_ta],
|
|
||||||
parallel_iterations=parallel_iterations,
|
|
||||||
back_prop=back_prop, swap_memory=swap_memory)
|
|
||||||
return r_a.pack()
|
|
||||||
|
|
||||||
|
|
||||||
def case(pred_fn_pairs, default, exclusive=False, name="case"):
|
def case(pred_fn_pairs, default, exclusive=False, name="case"):
|
||||||
"""Create a case operation.
|
"""Create a case operation.
|
||||||
|
|
||||||
|
@ -13,13 +13,28 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
"""Functional operations."""
|
"""Functional operations.
|
||||||
|
|
||||||
|
## Higher Order Operators
|
||||||
|
|
||||||
|
TensorFlow provides several higher order operators to simplify the common
|
||||||
|
map-reduce programming patterns.
|
||||||
|
|
||||||
|
@@map_fn
|
||||||
|
@@foldl
|
||||||
|
@@foldr
|
||||||
|
@@scan
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import constant_op
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import tensor_array_ops
|
||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import
|
||||||
from tensorflow.python.ops.gen_functional_ops import *
|
from tensorflow.python.ops.gen_functional_ops import *
|
||||||
# pylint: enable=wildcard-import
|
# pylint: enable=wildcard-import
|
||||||
@ -28,6 +43,269 @@ from tensorflow.python.ops.gen_functional_ops import _symbolic_gradient
|
|||||||
# pylint: enable=unused-import
|
# pylint: enable=unused-import
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
|
||||||
|
def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
||||||
|
swap_memory=False, name=None):
|
||||||
|
"""foldl on the list of tensors unpacked from `elems` on dimension 0.
|
||||||
|
|
||||||
|
This foldl operator repeatedly applies the callable `fn` to a sequence
|
||||||
|
of elements from first to last. The elements are made of the tensors
|
||||||
|
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
|
||||||
|
arguments. The first argument is the accumulated value computed from the
|
||||||
|
preceding invocation of fn. If `initializer` is None, `elems` must contain
|
||||||
|
at least one element, and its first element is used as the initializer.
|
||||||
|
|
||||||
|
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||||
|
of the result tensor is fn(initializer, values[0]).shape`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: The callable to be performed.
|
||||||
|
elems: A tensor to be unpacked on dimension 0.
|
||||||
|
initializer: (optional) The initial value for the accumulator.
|
||||||
|
parallel_iterations: (optional) The number of iterations allowed to run
|
||||||
|
in parallel.
|
||||||
|
back_prop: (optional) True enables back propagation.
|
||||||
|
swap_memory: (optional) True enables GPU-CPU memory swapping.
|
||||||
|
name: (optional) Name prefix for the returned tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor resulting from applying `fn` consecutively to the list of tensors
|
||||||
|
unpacked from `elems`, from first to last.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `fn` is not callable.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
elems = [1, 2, 3, 4, 5, 6]
|
||||||
|
sum = foldl(lambda a, x: a + x, elems)
|
||||||
|
# sum == 21
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
with ops.op_scope([elems], name, "foldl") as name:
|
||||||
|
if not callable(fn):
|
||||||
|
raise TypeError("fn must be callable.")
|
||||||
|
|
||||||
|
# Convert elems to tensor array.
|
||||||
|
n = array_ops.shape(elems)[0]
|
||||||
|
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
|
||||||
|
dynamic_size=False)
|
||||||
|
elems_ta = elems_ta.unpack(elems)
|
||||||
|
|
||||||
|
if initializer is None:
|
||||||
|
a = elems_ta.read(0)
|
||||||
|
i = constant_op.constant(1)
|
||||||
|
else:
|
||||||
|
a = ops.convert_to_tensor(initializer)
|
||||||
|
i = constant_op.constant(0)
|
||||||
|
|
||||||
|
def compute(i, a):
|
||||||
|
a = fn(a, elems_ta.read(i))
|
||||||
|
return [i + 1, a]
|
||||||
|
_, r_a = control_flow_ops.While(lambda i, a: i < n, compute, [i, a],
|
||||||
|
parallel_iterations=parallel_iterations,
|
||||||
|
back_prop=back_prop,
|
||||||
|
swap_memory=swap_memory)
|
||||||
|
return r_a
|
||||||
|
|
||||||
|
|
||||||
|
def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
||||||
|
swap_memory=False, name=None):
|
||||||
|
"""foldr on the list of tensors unpacked from `elems` on dimension 0.
|
||||||
|
|
||||||
|
This foldr operator repeatedly applies the callable `fn` to a sequence
|
||||||
|
of elements from last to first. The elements are made of the tensors
|
||||||
|
unpacked from `elems`. The callable fn takes two tensors as arguments.
|
||||||
|
The first argument is the accumulated value computed from the preceding
|
||||||
|
invocation of fn. If `initializer` is None, `elems` must contain at least
|
||||||
|
one element, and its first element is used as the initializer.
|
||||||
|
|
||||||
|
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||||
|
of the result tensor is `fn(initializer, values[0]).shape`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: The callable to be performed.
|
||||||
|
elems: A tensor that is unpacked into a sequence of tensors to apply `fn`.
|
||||||
|
initializer: (optional) The initial value for the accumulator.
|
||||||
|
parallel_iterations: (optional) The number of iterations allowed to run
|
||||||
|
in parallel.
|
||||||
|
back_prop: (optional) True enables back propagation.
|
||||||
|
swap_memory: (optional) True enables GPU-CPU memory swapping.
|
||||||
|
name: (optional) Name prefix for the returned tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor resulting from applying `fn` consecutively to the list of tensors
|
||||||
|
unpacked from `elems`, from last to first.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `fn` is not callable.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
elems = [1, 2, 3, 4, 5, 6]
|
||||||
|
sum = foldr(lambda a, x: a + x, elems)
|
||||||
|
# sum == 21
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
with ops.op_scope([elems], name, "foldr") as name:
|
||||||
|
if not callable(fn):
|
||||||
|
raise TypeError("fn must be callable.")
|
||||||
|
|
||||||
|
# Convert elems to tensor array.
|
||||||
|
n = array_ops.shape(elems)[0]
|
||||||
|
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
|
||||||
|
dynamic_size=False)
|
||||||
|
elems_ta = elems_ta.unpack(elems)
|
||||||
|
|
||||||
|
if initializer is None:
|
||||||
|
i = n - 1
|
||||||
|
a = elems_ta.read(i)
|
||||||
|
else:
|
||||||
|
i = n
|
||||||
|
a = ops.convert_to_tensor(initializer)
|
||||||
|
def compute(i, a):
|
||||||
|
i -= 1
|
||||||
|
a = fn(a, elems_ta.read(i))
|
||||||
|
return [i, a]
|
||||||
|
_, r_a = control_flow_ops.While(lambda i, a: i > 0, compute, [i, a],
|
||||||
|
parallel_iterations=parallel_iterations,
|
||||||
|
back_prop=back_prop,
|
||||||
|
swap_memory=swap_memory)
|
||||||
|
return r_a
|
||||||
|
|
||||||
|
|
||||||
|
def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
|
||||||
|
swap_memory=False, name=None):
|
||||||
|
"""map on the list of tensors unpacked from `elems` on dimension 0.
|
||||||
|
|
||||||
|
This map operator repeatedly applies the callable `fn` to a sequence of
|
||||||
|
elements from first to last. The elements are made of the tensors unpacked
|
||||||
|
from `elems`. `dtype` is the data type of the return value of `fn`. Users
|
||||||
|
must provide `dtype` if it is different from the data type of `elems`.
|
||||||
|
|
||||||
|
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||||
|
of the result tensor is `[len(values)] + fn(values[0]).shape`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: The callable to be performed.
|
||||||
|
elems: A tensor to be unpacked to apply `fn`.
|
||||||
|
dtype: (optional) The output type of `fn`.
|
||||||
|
parallel_iterations: (optional) The number of iterations allowed to run
|
||||||
|
in parallel.
|
||||||
|
back_prop: (optional) True enables back propagation.
|
||||||
|
swap_memory: (optional) True enables GPU-CPU memory swapping.
|
||||||
|
name: (optional) Name prefix for the returned tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor that packs the results of applying `fn` to the list of tensors
|
||||||
|
unpacked from `elems`, from first to last.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `fn` is not callable.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
elems = [1, 2, 3, 4, 5, 6]
|
||||||
|
squares = map_fn(lambda x: x * x, elems)
|
||||||
|
# squares == [1, 4, 9, 16, 25, 36]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
with ops.op_scope([elems], name, "map") as name:
|
||||||
|
if not callable(fn):
|
||||||
|
raise TypeError("fn must be callable.")
|
||||||
|
dtype = dtype if dtype else elems.dtype
|
||||||
|
|
||||||
|
# Convert elems to tensor array.
|
||||||
|
n = array_ops.shape(elems)[0]
|
||||||
|
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
|
||||||
|
dynamic_size=False)
|
||||||
|
elems_ta = elems_ta.unpack(elems)
|
||||||
|
|
||||||
|
i = constant_op.constant(0)
|
||||||
|
acc_ta = tensor_array_ops.TensorArray(dtype=dtype, size=n,
|
||||||
|
dynamic_size=False)
|
||||||
|
def compute(i, ta):
|
||||||
|
ta = ta.write(i, fn(elems_ta.read(i)))
|
||||||
|
return [i + 1, ta]
|
||||||
|
_, r_a = control_flow_ops.While(lambda i, a: i < n, compute, [i, acc_ta],
|
||||||
|
parallel_iterations=parallel_iterations,
|
||||||
|
back_prop=back_prop,
|
||||||
|
swap_memory=swap_memory)
|
||||||
|
return r_a.pack()
|
||||||
|
|
||||||
|
|
||||||
|
def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
||||||
|
swap_memory=False, name=None):
|
||||||
|
"""scan on the list of tensors unpacked from `elems` on dimension 0.
|
||||||
|
|
||||||
|
This scan operator repeatedly applies the callable `fn` to a sequence
|
||||||
|
of elements from first to last. The elements are made of the tensors
|
||||||
|
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
|
||||||
|
arguments. The first argument is the accumulated value computed from the
|
||||||
|
preceding invocation of fn. If `initializer` is None, `elems` must contain
|
||||||
|
at least one element, and its first element is used as the initializer.
|
||||||
|
|
||||||
|
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||||
|
of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: The callable to be performed.
|
||||||
|
elems: A tensor to be unpacked on dimension 0.
|
||||||
|
initializer: (optional) The initial value for the accumulator.
|
||||||
|
parallel_iterations: (optional) The number of iterations allowed to run
|
||||||
|
in parallel.
|
||||||
|
back_prop: (optional) True enables back propagation.
|
||||||
|
swap_memory: (optional) True enables GPU-CPU memory swapping.
|
||||||
|
name: (optional) Name prefix for the returned tensors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tensor that packs the results of applying `fn` to the list of tensors
|
||||||
|
unpacked from `elems`, from first to last.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: if `fn` is not callable.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
elems = [1, 2, 3, 4, 5, 6]
|
||||||
|
sum = scan(lambda a, x: a + x, elems)
|
||||||
|
# sum == [1, 3, 6, 10, 15, 21]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
with ops.op_scope([elems], name, "scan") as name:
|
||||||
|
if not callable(fn):
|
||||||
|
raise TypeError("fn must be callable.")
|
||||||
|
|
||||||
|
# Convert elems to tensor array.
|
||||||
|
n = array_ops.shape(elems)[0]
|
||||||
|
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
|
||||||
|
dynamic_size=False)
|
||||||
|
elems_ta = elems_ta.unpack(elems)
|
||||||
|
|
||||||
|
if initializer is None:
|
||||||
|
a = elems_ta.read(0)
|
||||||
|
i = constant_op.constant(1)
|
||||||
|
else:
|
||||||
|
a = ops.convert_to_tensor(initializer)
|
||||||
|
i = constant_op.constant(0)
|
||||||
|
|
||||||
|
# Create a tensor array to store the intermediate values.
|
||||||
|
acc_ta = tensor_array_ops.TensorArray(dtype=a.dtype, size=n,
|
||||||
|
dynamic_size=False)
|
||||||
|
if initializer is None:
|
||||||
|
acc_ta = acc_ta.write(0, a)
|
||||||
|
|
||||||
|
def compute(i, a, ta):
|
||||||
|
a = fn(a, elems_ta.read(i))
|
||||||
|
ta = ta.write(i, a)
|
||||||
|
return [i + 1, a, ta]
|
||||||
|
_, _, r_a = control_flow_ops.While(
|
||||||
|
lambda i, a, ta: i < n, compute, [i, a, acc_ta],
|
||||||
|
parallel_iterations=parallel_iterations,
|
||||||
|
back_prop=back_prop, swap_memory=swap_memory)
|
||||||
|
return r_a.pack()
|
||||||
|
|
||||||
|
|
||||||
@ops.RegisterShape("SymbolicGradient")
|
@ops.RegisterShape("SymbolicGradient")
|
||||||
def _symbolic_gradient_shape(op):
|
def _symbolic_gradient_shape(op):
|
||||||
# Say, (u, v) = f(x, y, z), _symbolic_gradient(f) is a function of
|
# Say, (u, v) = f(x, y, z), _symbolic_gradient(f) is a function of
|
||||||
|
@ -37,11 +37,8 @@ from tensorflow.python.ops.control_flow_ops import no_op
|
|||||||
from tensorflow.python.ops.control_flow_ops import tuple
|
from tensorflow.python.ops.control_flow_ops import tuple
|
||||||
from tensorflow.python.ops.control_flow_ops import cond
|
from tensorflow.python.ops.control_flow_ops import cond
|
||||||
from tensorflow.python.ops.control_flow_ops import case
|
from tensorflow.python.ops.control_flow_ops import case
|
||||||
from tensorflow.python.ops.control_flow_ops import foldl
|
|
||||||
from tensorflow.python.ops.control_flow_ops import foldr
|
|
||||||
from tensorflow.python.ops.control_flow_ops import map_fn
|
|
||||||
from tensorflow.python.ops.control_flow_ops import scan
|
|
||||||
from tensorflow.python.ops.data_flow_ops import *
|
from tensorflow.python.ops.data_flow_ops import *
|
||||||
|
from tensorflow.python.ops.functional_ops import *
|
||||||
from tensorflow.python.ops.gradients import *
|
from tensorflow.python.ops.gradients import *
|
||||||
from tensorflow.python.ops.histogram_ops import *
|
from tensorflow.python.ops.histogram_ops import *
|
||||||
from tensorflow.python.ops.init_ops import *
|
from tensorflow.python.ops.init_ops import *
|
||||||
|
@ -359,7 +359,8 @@ def _pure_variable_scope(name_or_scope, reuse=None, initializer=None,
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
get_variable_scope() # Ensure that a default exists, then get a pointer.
|
get_variable_scope() # Ensure that a default exists, then get a pointer.
|
||||||
default_varscope = ops.get_collection(_VARSCOPE_KEY)
|
# Get the reference to the collection as we want to modify it in place.
|
||||||
|
default_varscope = ops.get_collection_ref(_VARSCOPE_KEY)
|
||||||
try:
|
try:
|
||||||
old = default_varscope[0]
|
old = default_varscope[0]
|
||||||
reuse = reuse or old.reuse # Re-using is inherited by sub-scopes.
|
reuse = reuse or old.reuse # Re-using is inherited by sub-scopes.
|
||||||
|
@ -256,7 +256,7 @@ class SessionManager(object):
|
|||||||
self._safe_close(sess)
|
self._safe_close(sess)
|
||||||
logging.info("Waiting for model to be ready: %s", not_ready)
|
logging.info("Waiting for model to be ready: %s", not_ready)
|
||||||
time.sleep(self._recovery_wait_secs)
|
time.sleep(self._recovery_wait_secs)
|
||||||
sess = session.Session(master, graph=self._graph)
|
sess = session.Session(target, graph=self._graph, config=config)
|
||||||
|
|
||||||
return sess
|
return sess
|
||||||
|
|
||||||
|
@ -104,6 +104,13 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
|
|||||||
# Now you can call `minimize()` or `compute_gradients()` and
|
# Now you can call `minimize()` or `compute_gradients()` and
|
||||||
# `apply_gradients()` normally
|
# `apply_gradients()` normally
|
||||||
grads = opt.minimize(total_loss, global_step=self.global_step)
|
grads = opt.minimize(total_loss, global_step=self.global_step)
|
||||||
|
|
||||||
|
|
||||||
|
# You can now call get_init_tokens_op() and get_chief_queue_runner().
|
||||||
|
# Note that get_init_tokens_op() must be called before creating session
|
||||||
|
# because it modifies the graph.
|
||||||
|
init_token_op = opt.get_init_tokens_op()
|
||||||
|
chief_queue_runner = opt.get_chief_queue_runner()
|
||||||
```
|
```
|
||||||
|
|
||||||
In the training program, every worker will run the train_op as if not
|
In the training program, every worker will run the train_op as if not
|
||||||
@ -114,9 +121,9 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
|
|||||||
# After the session is created by the superviser and before the main while
|
# After the session is created by the superviser and before the main while
|
||||||
# loop:
|
# loop:
|
||||||
if is_chief and FLAGS.sync_replicas:
|
if is_chief and FLAGS.sync_replicas:
|
||||||
sv.start_queue_runners(sess, [opt.get_chief_queue_runner()])
|
sv.start_queue_runners(sess, [chief_queue_runner])
|
||||||
# Insert initial tokens to the queue.
|
# Insert initial tokens to the queue.
|
||||||
sess.run(opt.get_init_tokens_op())
|
sess.run(init_token_op)
|
||||||
```
|
```
|
||||||
|
|
||||||
@@__init__
|
@@__init__
|
||||||
|
@ -89,37 +89,41 @@ const DOMAIN_EDGE_WIDTH_SCALE = [1, 5E6];
|
|||||||
/**
|
/**
|
||||||
* Parameters that affect how the graph is rendered on the screen.
|
* Parameters that affect how the graph is rendered on the screen.
|
||||||
*/
|
*/
|
||||||
export interface RenderGraphParams {
|
const PARAMS = {
|
||||||
/**
|
/**
|
||||||
* Whether to extract high degree nodes from the core part of the graph.
|
* Whether to extract high degree nodes from the core part of the graph.
|
||||||
*/
|
*/
|
||||||
enableExtraction: boolean;
|
enableExtraction: true,
|
||||||
/**
|
/**
|
||||||
* Maximum in-degree that a node can have without being considered as
|
* Maximum in-degree that a node can have without being considered as
|
||||||
* high in-degree node.
|
* high in-degree node.
|
||||||
*/
|
*/
|
||||||
maxInDegree: number;
|
maxInDegree: 4,
|
||||||
/**
|
/**
|
||||||
* Maximum out-degree that a node can have without being considered as
|
* Maximum out-degree that a node can have without being considered as
|
||||||
* high out-degree node.
|
* high out-degree node.
|
||||||
*/
|
*/
|
||||||
maxOutDegree: number;
|
maxOutDegree: 4,
|
||||||
/**
|
/**
|
||||||
* Maximum number of control edges a node can have before they aren't
|
* Maximum number of control edges a node can have before they aren't
|
||||||
* displayed.
|
* displayed.
|
||||||
*/
|
*/
|
||||||
maxControlDegree: number;
|
maxControlDegree: 4,
|
||||||
/**
|
/**
|
||||||
* Types patterns for predefined out-extract nodes, which are
|
* Types patterns for predefined out-extract nodes, which are
|
||||||
* sink-like nodes that will be extracted from the main graph.
|
* sink-like nodes that will be extracted from the main graph.
|
||||||
*/
|
*/
|
||||||
outExtractTypes: string[];
|
outExtractTypes: [
|
||||||
|
"NoOp" // NoOps are sink-like used for managing control dependencies.
|
||||||
|
],
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Types patterns for predefined in-extract nodes, which are
|
* Types patterns for predefined in-extract nodes, which are
|
||||||
* source-like nodes that will be extracted from the main graph.
|
* source-like nodes that will be extracted from the main graph.
|
||||||
*/
|
*/
|
||||||
inExtractTypes: string[];
|
inExtractTypes: [
|
||||||
|
"Variable"
|
||||||
|
],
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* When removing edges from a high degree node, remove all of its edges if
|
* When removing edges from a high degree node, remove all of its edges if
|
||||||
@ -127,32 +131,33 @@ export interface RenderGraphParams {
|
|||||||
* the node has high in-degree, or all out-edges if the node has high
|
* the node has high in-degree, or all out-edges if the node has high
|
||||||
* out-degree.
|
* out-degree.
|
||||||
*/
|
*/
|
||||||
detachAllEdgesForHighDegree: boolean;
|
detachAllEdgesForHighDegree: true,
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* After extracting high in/out degree nodes and predefined
|
* After extracting high in/out degree nodes and predefined
|
||||||
* source-like/sink-like, extract isolated nodes to the side
|
* source-like/sink-like, extract isolated nodes to the side
|
||||||
* if this extractIsolatedNodesWithAnnotationsOnOneSide is true.
|
* if this extractIsolatedNodesWithAnnotationsOnOneSide is true.
|
||||||
*/
|
*/
|
||||||
extractIsolatedNodesWithAnnotationsOnOneSide: boolean;
|
extractIsolatedNodesWithAnnotationsOnOneSide: true,
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Whether to add bridge nodes and edges to the core when building the
|
* Whether to add bridge nodes and edges to the core when building the
|
||||||
* subhierarchy of an expanded metanode. See buildSubhierarchy().
|
* subhierarchy of an expanded metanode. See buildSubhierarchy().
|
||||||
*/
|
*/
|
||||||
enableBridgegraph: boolean;
|
enableBridgegraph: true,
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 2 colors, for the minimum and maximum value respectively, whenever we
|
* 2 colors, for the minimum and maximum value respectively, whenever we
|
||||||
* have a gradient scale.
|
* have a gradient scale.
|
||||||
*/
|
*/
|
||||||
minMaxColors: string[];
|
minMaxColors: ["#fff5f0", "#fb6a4a"],
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Maximum number of annotations to be displayed on a node.
|
* Maximum number of annotations to be displayed on a node before an
|
||||||
|
* ellipsis is used.
|
||||||
*/
|
*/
|
||||||
maxAnnotations: number;
|
maxAnnotations: 5
|
||||||
}
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Stores the rendering information, such as x and y coordinates,
|
* Stores the rendering information, such as x and y coordinates,
|
||||||
@ -161,7 +166,6 @@ export interface RenderGraphParams {
|
|||||||
export class RenderGraphInfo {
|
export class RenderGraphInfo {
|
||||||
private hierarchy: hierarchy.Hierarchy;
|
private hierarchy: hierarchy.Hierarchy;
|
||||||
private index: {[nodeName: string]: RenderNodeInfo};
|
private index: {[nodeName: string]: RenderNodeInfo};
|
||||||
private params: RenderGraphParams;
|
|
||||||
private deviceColorMap: d3.scale.Ordinal<string, string>;
|
private deviceColorMap: d3.scale.Ordinal<string, string>;
|
||||||
private memoryUsageScale: d3.scale.Linear<string, string>;
|
private memoryUsageScale: d3.scale.Linear<string, string>;
|
||||||
private computeTimeScale: d3.scale.Linear<string, string>;
|
private computeTimeScale: d3.scale.Linear<string, string>;
|
||||||
@ -173,7 +177,7 @@ export class RenderGraphInfo {
|
|||||||
private hasSubhierarchy: {[nodeName: string]: boolean};
|
private hasSubhierarchy: {[nodeName: string]: boolean};
|
||||||
root: RenderGroupNodeInfo;
|
root: RenderGroupNodeInfo;
|
||||||
|
|
||||||
constructor(hierarchy: hierarchy.Hierarchy, params: RenderGraphParams) {
|
constructor(hierarchy: hierarchy.Hierarchy) {
|
||||||
this.hierarchy = hierarchy;
|
this.hierarchy = hierarchy;
|
||||||
this.index = {};
|
this.index = {};
|
||||||
this.deviceColorMap = d3.scale.ordinal<string>()
|
this.deviceColorMap = d3.scale.ordinal<string>()
|
||||||
@ -199,7 +203,7 @@ export class RenderGraphInfo {
|
|||||||
});
|
});
|
||||||
this.memoryUsageScale = d3.scale.linear<string, string>()
|
this.memoryUsageScale = d3.scale.linear<string, string>()
|
||||||
.domain(memoryExtent)
|
.domain(memoryExtent)
|
||||||
.range(params.minMaxColors);
|
.range(PARAMS.minMaxColors);
|
||||||
|
|
||||||
// Find also the minimum and maximum compute time.
|
// Find also the minimum and maximum compute time.
|
||||||
let computeTimeExtent = d3.extent(topLevelGraph.nodes(),
|
let computeTimeExtent = d3.extent(topLevelGraph.nodes(),
|
||||||
@ -212,12 +216,11 @@ export class RenderGraphInfo {
|
|||||||
});
|
});
|
||||||
this.computeTimeScale = d3.scale.linear<string, string>()
|
this.computeTimeScale = d3.scale.linear<string, string>()
|
||||||
.domain(computeTimeExtent)
|
.domain(computeTimeExtent)
|
||||||
.range(params.minMaxColors);
|
.range(PARAMS.minMaxColors);
|
||||||
|
|
||||||
// Maps node name to whether the rendering hierarchy was already
|
// Maps node name to whether the rendering hierarchy was already
|
||||||
// constructed.
|
// constructed.
|
||||||
this.hasSubhierarchy = {};
|
this.hasSubhierarchy = {};
|
||||||
this.params = params;
|
|
||||||
this.root = new RenderGroupNodeInfo(hierarchy.root);
|
this.root = new RenderGroupNodeInfo(hierarchy.root);
|
||||||
this.index[hierarchy.root.name] = this.root;
|
this.index[hierarchy.root.name] = this.root;
|
||||||
this.buildSubhierarchy(hierarchy.root.name);
|
this.buildSubhierarchy(hierarchy.root.name);
|
||||||
@ -373,13 +376,13 @@ export class RenderGraphInfo {
|
|||||||
_.each((<OpNode>childNode).inEmbeddings, embedding => {
|
_.each((<OpNode>childNode).inEmbeddings, embedding => {
|
||||||
let renderMetaedgeInfo = new RenderMetaedgeInfo(null);
|
let renderMetaedgeInfo = new RenderMetaedgeInfo(null);
|
||||||
addInAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo,
|
addInAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo,
|
||||||
AnnotationType.CONSTANT, this.params);
|
AnnotationType.CONSTANT);
|
||||||
this.index[embedding.name] = new RenderNodeInfo(embedding);
|
this.index[embedding.name] = new RenderNodeInfo(embedding);
|
||||||
});
|
});
|
||||||
_.each((<OpNode>childNode).outEmbeddings, embedding => {
|
_.each((<OpNode>childNode).outEmbeddings, embedding => {
|
||||||
let renderMetaedgeInfo = new RenderMetaedgeInfo(null);
|
let renderMetaedgeInfo = new RenderMetaedgeInfo(null);
|
||||||
addOutAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo,
|
addOutAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo,
|
||||||
AnnotationType.SUMMARY, this.params);
|
AnnotationType.SUMMARY);
|
||||||
this.index[embedding.name] = new RenderNodeInfo(embedding);
|
this.index[embedding.name] = new RenderNodeInfo(embedding);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -393,9 +396,9 @@ export class RenderGraphInfo {
|
|||||||
coreGraph.setEdge(edgeObj.v, edgeObj.w, renderMetaedgeInfo);
|
coreGraph.setEdge(edgeObj.v, edgeObj.w, renderMetaedgeInfo);
|
||||||
});
|
});
|
||||||
|
|
||||||
if (this.params.enableExtraction &&
|
if (PARAMS.enableExtraction &&
|
||||||
renderGroupNodeInfo.node.type === NodeType.META) {
|
renderGroupNodeInfo.node.type === NodeType.META) {
|
||||||
extractHighDegrees(renderGroupNodeInfo, this.params);
|
extractHighDegrees(renderGroupNodeInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Record that we constructed the rendering hierarchy for this node, so we
|
// Record that we constructed the rendering hierarchy for this node, so we
|
||||||
@ -469,7 +472,7 @@ export class RenderGraphInfo {
|
|||||||
// either node is high-degree with respect to control edges. This will
|
// either node is high-degree with respect to control edges. This will
|
||||||
// be a signal to show it as an annotation instead of a bridge edge.
|
// be a signal to show it as an annotation instead of a bridge edge.
|
||||||
let isHighDegreeControlEdge = !bridgeMetaedge.numRegularEdges &&
|
let isHighDegreeControlEdge = !bridgeMetaedge.numRegularEdges &&
|
||||||
otherCounts.control[otherName] > this.params.maxControlDegree;
|
otherCounts.control[otherName] > PARAMS.maxControlDegree;
|
||||||
|
|
||||||
let [, childAnnotations] =
|
let [, childAnnotations] =
|
||||||
inbound ?
|
inbound ?
|
||||||
@ -478,8 +481,8 @@ export class RenderGraphInfo {
|
|||||||
|
|
||||||
let isOtherHighDegree =
|
let isOtherHighDegree =
|
||||||
inbound ?
|
inbound ?
|
||||||
otherCounts.out[otherName] > this.params.maxOutDegree :
|
otherCounts.out[otherName] > PARAMS.maxOutDegree :
|
||||||
otherCounts.in[otherName] > this.params.maxInDegree;
|
otherCounts.in[otherName] > PARAMS.maxInDegree;
|
||||||
|
|
||||||
// The adjoining render metaedge info from the parent's coreGraph, if any.
|
// The adjoining render metaedge info from the parent's coreGraph, if any.
|
||||||
// It will either be a Metaedge involving this node directly, if it
|
// It will either be a Metaedge involving this node directly, if it
|
||||||
@ -493,7 +496,7 @@ export class RenderGraphInfo {
|
|||||||
// - the child is in the core (not extracted for being high-degree), and
|
// - the child is in the core (not extracted for being high-degree), and
|
||||||
// - there's a path (in the traversal sense) between child and other.
|
// - there's a path (in the traversal sense) between child and other.
|
||||||
let canDrawBridgePath = false;
|
let canDrawBridgePath = false;
|
||||||
if (this.params.enableBridgegraph &&
|
if (PARAMS.enableBridgegraph &&
|
||||||
!isOtherHighDegree &&
|
!isOtherHighDegree &&
|
||||||
!isHighDegreeControlEdge &&
|
!isHighDegreeControlEdge &&
|
||||||
childRenderInfo.isInCore()) {
|
childRenderInfo.isInCore()) {
|
||||||
@ -581,7 +584,7 @@ export class RenderGraphInfo {
|
|||||||
otherRenderInfo,
|
otherRenderInfo,
|
||||||
new RenderMetaedgeInfo(bridgeMetaedge),
|
new RenderMetaedgeInfo(bridgeMetaedge),
|
||||||
AnnotationType.SHORTCUT,
|
AnnotationType.SHORTCUT,
|
||||||
inbound), this.params);
|
inbound));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -869,13 +872,13 @@ export class AnnotationList {
|
|||||||
* Append an annotation to the list, or a stand-in ellipsis annotation instead
|
* Append an annotation to the list, or a stand-in ellipsis annotation instead
|
||||||
* if this would make it too many.
|
* if this would make it too many.
|
||||||
*/
|
*/
|
||||||
push(annotation: Annotation, params: RenderGraphParams): void {
|
push(annotation: Annotation): void {
|
||||||
if (annotation.node.name in this.nodeNames) {
|
if (annotation.node.name in this.nodeNames) {
|
||||||
return; // Skip duplicate annotation.
|
return; // Skip duplicate annotation.
|
||||||
}
|
}
|
||||||
this.nodeNames[annotation.node.name] = true;
|
this.nodeNames[annotation.node.name] = true;
|
||||||
|
|
||||||
if (this.list.length < params.maxAnnotations) {
|
if (this.list.length < PARAMS.maxAnnotations) {
|
||||||
this.list.push(annotation);
|
this.list.push(annotation);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -1101,19 +1104,18 @@ export class RenderMetaedgeInfo {
|
|||||||
|
|
||||||
function addInAnnotation(node: RenderNodeInfo, predecessor: Node,
|
function addInAnnotation(node: RenderNodeInfo, predecessor: Node,
|
||||||
predecessorRenderInfo: RenderNodeInfo,
|
predecessorRenderInfo: RenderNodeInfo,
|
||||||
edge: RenderMetaedgeInfo, type: AnnotationType,
|
edge: RenderMetaedgeInfo, type: AnnotationType): void {
|
||||||
params: RenderGraphParams): void {
|
|
||||||
let annotation = new Annotation(predecessor, predecessorRenderInfo, edge,
|
let annotation = new Annotation(predecessor, predecessorRenderInfo, edge,
|
||||||
type, true);
|
type, true);
|
||||||
node.inAnnotations.push(annotation, params);
|
node.inAnnotations.push(annotation);
|
||||||
}
|
}
|
||||||
|
|
||||||
function addOutAnnotation(node: RenderNodeInfo, successor: Node,
|
function addOutAnnotation(node: RenderNodeInfo, successor: Node,
|
||||||
successorRenderInfo: RenderNodeInfo, edge: RenderMetaedgeInfo,
|
successorRenderInfo: RenderNodeInfo, edge: RenderMetaedgeInfo,
|
||||||
type: AnnotationType, params: RenderGraphParams): void {
|
type: AnnotationType): void {
|
||||||
let annotation = new Annotation(successor, successorRenderInfo, edge,
|
let annotation = new Annotation(successor, successorRenderInfo, edge,
|
||||||
type, false);
|
type, false);
|
||||||
node.outAnnotations.push(annotation, params);
|
node.outAnnotations.push(annotation);
|
||||||
}
|
}
|
||||||
|
|
||||||
function setGraphDepth(graph: graphlib.Graph<RenderNodeInfo, any>,
|
function setGraphDepth(graph: graphlib.Graph<RenderNodeInfo, any>,
|
||||||
@ -1181,7 +1183,7 @@ function setGroupNodeDepth(renderInfo: RenderGroupNodeInfo,
|
|||||||
*/
|
*/
|
||||||
function createShortcut(
|
function createShortcut(
|
||||||
graph: graphlib.Graph<RenderNodeInfo, RenderMetaedgeInfo>,
|
graph: graphlib.Graph<RenderNodeInfo, RenderMetaedgeInfo>,
|
||||||
v: string, w: string, params: RenderGraphParams) {
|
v: string, w: string) {
|
||||||
let src = graph.node(v);
|
let src = graph.node(v);
|
||||||
let sink = graph.node(w);
|
let sink = graph.node(w);
|
||||||
let edge = graph.edge(v, w);
|
let edge = graph.edge(v, w);
|
||||||
@ -1197,8 +1199,8 @@ function createShortcut(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Add each annotation.
|
// Add each annotation.
|
||||||
addOutAnnotation(src, sink.node, sink, edge, AnnotationType.SHORTCUT, params);
|
addOutAnnotation(src, sink.node, sink, edge, AnnotationType.SHORTCUT);
|
||||||
addInAnnotation(sink, src.node, src, edge, AnnotationType.SHORTCUT, params);
|
addInAnnotation(sink, src.node, src, edge, AnnotationType.SHORTCUT);
|
||||||
|
|
||||||
// Remove the edge from the core graph.
|
// Remove the edge from the core graph.
|
||||||
graph.removeEdge(v, w);
|
graph.removeEdge(v, w);
|
||||||
@ -1212,18 +1214,18 @@ function createShortcut(
|
|||||||
* edges. Otherwise, only extract all in-edges.
|
* edges. Otherwise, only extract all in-edges.
|
||||||
*/
|
*/
|
||||||
function makeOutExtract(renderNode: RenderGroupNodeInfo, n: string,
|
function makeOutExtract(renderNode: RenderGroupNodeInfo, n: string,
|
||||||
params: RenderGraphParams, forceDetach?: boolean) {
|
forceDetach?: boolean) {
|
||||||
let graph = renderNode.coreGraph;
|
let graph = renderNode.coreGraph;
|
||||||
let child = graph.node(n);
|
let child = graph.node(n);
|
||||||
child.isOutExtract = true;
|
child.isOutExtract = true;
|
||||||
|
|
||||||
_.each(graph.predecessors(n), (p, index) => {
|
_.each(graph.predecessors(n), (p, index) => {
|
||||||
createShortcut(graph, p, n, params);
|
createShortcut(graph, p, n);
|
||||||
});
|
});
|
||||||
|
|
||||||
if (params.detachAllEdgesForHighDegree || forceDetach) {
|
if (PARAMS.detachAllEdgesForHighDegree || forceDetach) {
|
||||||
_.each(graph.successors(n), (s, index) => {
|
_.each(graph.successors(n), (s, index) => {
|
||||||
createShortcut(graph, n, s, params);
|
createShortcut(graph, n, s);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1243,18 +1245,18 @@ function makeOutExtract(renderNode: RenderGroupNodeInfo, n: string,
|
|||||||
* edges. Otherwise, only remove all out-edges.
|
* edges. Otherwise, only remove all out-edges.
|
||||||
*/
|
*/
|
||||||
export function makeInExtract(renderNode: RenderGroupNodeInfo, n: string,
|
export function makeInExtract(renderNode: RenderGroupNodeInfo, n: string,
|
||||||
params: RenderGraphParams, forceDetach?: boolean) {
|
forceDetach?: boolean) {
|
||||||
let graph = renderNode.coreGraph;
|
let graph = renderNode.coreGraph;
|
||||||
let child = graph.node(n);
|
let child = graph.node(n);
|
||||||
child.isInExtract = true;
|
child.isInExtract = true;
|
||||||
|
|
||||||
_.each(graph.successors(n), (s, index) => {
|
_.each(graph.successors(n), (s, index) => {
|
||||||
createShortcut(graph, n, s, params);
|
createShortcut(graph, n, s);
|
||||||
});
|
});
|
||||||
|
|
||||||
if (params.detachAllEdgesForHighDegree || forceDetach) {
|
if (PARAMS.detachAllEdgesForHighDegree || forceDetach) {
|
||||||
_.each(graph.predecessors(n), (p, index) => {
|
_.each(graph.predecessors(n), (p, index) => {
|
||||||
createShortcut(graph, p, n, params);
|
createShortcut(graph, p, n);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1289,40 +1291,37 @@ function hasTypeIn(node: Node, types: string[]): boolean {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Move nodes that are specified to be excluded out of the core graph. */
|
/** Move nodes that are specified to be excluded out of the core graph. */
|
||||||
function extractSpecifiedNodes(renderNode: RenderGroupNodeInfo,
|
function extractSpecifiedNodes(renderNode: RenderGroupNodeInfo) {
|
||||||
params: RenderGraphParams) {
|
|
||||||
let graph = renderNode.coreGraph;
|
let graph = renderNode.coreGraph;
|
||||||
_.each(graph.nodes(), n => {
|
_.each(graph.nodes(), n => {
|
||||||
let renderInfo = graph.node(n);
|
let renderInfo = graph.node(n);
|
||||||
if (renderInfo.node.include === InclusionType.EXCLUDE) {
|
if (renderInfo.node.include === InclusionType.EXCLUDE) {
|
||||||
if (renderNode.coreGraph.outEdges(n).length >
|
if (renderNode.coreGraph.outEdges(n).length >
|
||||||
renderNode.coreGraph.inEdges(n).length) {
|
renderNode.coreGraph.inEdges(n).length) {
|
||||||
makeOutExtract(renderNode, n, params, true);
|
makeOutExtract(renderNode, n, true);
|
||||||
} else {
|
} else {
|
||||||
makeInExtract(renderNode, n, params, true);
|
makeInExtract(renderNode, n, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Remove edges from pre-defined out-extract patterns */
|
/** Remove edges from pre-defined out-extract patterns */
|
||||||
function extractPredefinedSink(renderNode: RenderGroupNodeInfo,
|
function extractPredefinedSink(renderNode: RenderGroupNodeInfo) {
|
||||||
params: RenderGraphParams) {
|
|
||||||
let graph = renderNode.coreGraph;
|
let graph = renderNode.coreGraph;
|
||||||
_.each(graph.nodes(), n => {
|
_.each(graph.nodes(), n => {
|
||||||
let renderInfo = graph.node(n);
|
let renderInfo = graph.node(n);
|
||||||
if (renderInfo.node.include !== InclusionType.UNSPECIFIED) {
|
if (renderInfo.node.include !== InclusionType.UNSPECIFIED) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (hasTypeIn(renderInfo.node, params.outExtractTypes)) {
|
if (hasTypeIn(renderInfo.node, PARAMS.outExtractTypes)) {
|
||||||
makeOutExtract(renderNode, n, params);
|
makeOutExtract(renderNode, n);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Remove edges from pre-defined in-extract patterns */
|
/** Remove edges from pre-defined in-extract patterns */
|
||||||
function extractPredefinedSource(renderNode: RenderGroupNodeInfo,
|
function extractPredefinedSource(renderNode: RenderGroupNodeInfo) {
|
||||||
params: RenderGraphParams) {
|
|
||||||
let graph = renderNode.coreGraph;
|
let graph = renderNode.coreGraph;
|
||||||
|
|
||||||
_.each(graph.nodes(), n => {
|
_.each(graph.nodes(), n => {
|
||||||
@ -1330,17 +1329,16 @@ function extractPredefinedSource(renderNode: RenderGroupNodeInfo,
|
|||||||
if (renderInfo.node.include !== InclusionType.UNSPECIFIED) {
|
if (renderInfo.node.include !== InclusionType.UNSPECIFIED) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (hasTypeIn(renderInfo.node, params.inExtractTypes)) {
|
if (hasTypeIn(renderInfo.node, PARAMS.inExtractTypes)) {
|
||||||
makeInExtract(renderNode, n, params);
|
makeInExtract(renderNode, n);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Extract from nodes with in-degree > maxInDegree */
|
/** Extract from nodes with in-degree > maxInDegree */
|
||||||
function extractHighInDegree(renderNode: RenderGroupNodeInfo,
|
function extractHighInDegree(renderNode: RenderGroupNodeInfo) {
|
||||||
params: RenderGraphParams) {
|
|
||||||
let graph = renderNode.coreGraph;
|
let graph = renderNode.coreGraph;
|
||||||
let maxInDegree = params.maxInDegree;
|
let maxInDegree = PARAMS.maxInDegree;
|
||||||
|
|
||||||
// detect first so degrees don't get affected by other removal
|
// detect first so degrees don't get affected by other removal
|
||||||
let highInDegreeNames = _.filter(graph.nodes(), n => {
|
let highInDegreeNames = _.filter(graph.nodes(), n => {
|
||||||
@ -1363,15 +1361,14 @@ function extractHighInDegree(renderNode: RenderGroupNodeInfo,
|
|||||||
});
|
});
|
||||||
|
|
||||||
_.each(highInDegreeNames, n => {
|
_.each(highInDegreeNames, n => {
|
||||||
makeOutExtract(renderNode, n, params);
|
makeOutExtract(renderNode, n);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Extract nodes with out-degree > maxOutDegree */
|
/** Extract nodes with out-degree > maxOutDegree */
|
||||||
function extractHighOutDegree(renderNode: RenderGroupNodeInfo,
|
function extractHighOutDegree(renderNode: RenderGroupNodeInfo) {
|
||||||
params: RenderGraphParams) {
|
|
||||||
let graph = renderNode.coreGraph;
|
let graph = renderNode.coreGraph;
|
||||||
let maxOutDegree = params.maxOutDegree;
|
let maxOutDegree = PARAMS.maxOutDegree;
|
||||||
|
|
||||||
// detect first so degrees don't get affected by other removal
|
// detect first so degrees don't get affected by other removal
|
||||||
let highOutDegreeNames = _.filter(graph.nodes(), n => {
|
let highOutDegreeNames = _.filter(graph.nodes(), n => {
|
||||||
@ -1394,13 +1391,12 @@ function extractHighOutDegree(renderNode: RenderGroupNodeInfo,
|
|||||||
});
|
});
|
||||||
|
|
||||||
_.each(highOutDegreeNames, n => {
|
_.each(highOutDegreeNames, n => {
|
||||||
makeInExtract(renderNode, n, params);
|
makeInExtract(renderNode, n);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Remove control edges from nodes that have too many control edges */
|
/** Remove control edges from nodes that have too many control edges */
|
||||||
function removeControlEdges(renderNode: RenderGroupNodeInfo,
|
function removeControlEdges(renderNode: RenderGroupNodeInfo) {
|
||||||
params: RenderGraphParams) {
|
|
||||||
let graph = renderNode.coreGraph;
|
let graph = renderNode.coreGraph;
|
||||||
|
|
||||||
// Collect control edges into a map by node name.
|
// Collect control edges into a map by node name.
|
||||||
@ -1414,8 +1410,8 @@ function removeControlEdges(renderNode: RenderGroupNodeInfo,
|
|||||||
|
|
||||||
// For each node with too many control edges, turn them into annotations.
|
// For each node with too many control edges, turn them into annotations.
|
||||||
_.each(map, (edges, nodeName) => {
|
_.each(map, (edges, nodeName) => {
|
||||||
if (edges.length > params.maxControlDegree) {
|
if (edges.length > PARAMS.maxControlDegree) {
|
||||||
_.each(edges, e => createShortcut(graph, e.v, e.w, params));
|
_.each(edges, e => createShortcut(graph, e.v, e.w));
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -1445,38 +1441,35 @@ export function mapIndexToHue(id: number): number {
|
|||||||
* screw up the graph layout.
|
* screw up the graph layout.
|
||||||
*
|
*
|
||||||
* @param {Render.Node} renderNode Node to manipulate.
|
* @param {Render.Node} renderNode Node to manipulate.
|
||||||
* @param {Object} params render Graph construction parameters. See
|
|
||||||
* <tf-graph-params>'s output
|
|
||||||
*/
|
*/
|
||||||
function extractHighDegrees(renderNode: RenderGroupNodeInfo,
|
function extractHighDegrees(renderNode: RenderGroupNodeInfo) {
|
||||||
params: RenderGraphParams) {
|
|
||||||
|
|
||||||
extractSpecifiedNodes(renderNode, params);
|
extractSpecifiedNodes(renderNode);
|
||||||
|
|
||||||
if (params.outExtractTypes) {
|
if (PARAMS.outExtractTypes) {
|
||||||
extractPredefinedSink(renderNode, params);
|
extractPredefinedSink(renderNode);
|
||||||
}
|
}
|
||||||
|
|
||||||
// This has to come before extract high in-degree to protect the core part
|
// This has to come before extract high in-degree to protect the core part
|
||||||
// that takes many variables.
|
// that takes many variables.
|
||||||
if (params.inExtractTypes) {
|
if (PARAMS.inExtractTypes) {
|
||||||
extractPredefinedSource(renderNode, params);
|
extractPredefinedSource(renderNode);
|
||||||
}
|
}
|
||||||
|
|
||||||
// This has to come before extract high out-degree to protect the core part
|
// This has to come before extract high out-degree to protect the core part
|
||||||
// that output to many places as there are more high-degree sinks than
|
// that output to many places as there are more high-degree sinks than
|
||||||
// sources.
|
// sources.
|
||||||
|
|
||||||
if (params.maxInDegree) {
|
if (PARAMS.maxInDegree) {
|
||||||
extractHighInDegree(renderNode, params);
|
extractHighInDegree(renderNode);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.maxOutDegree) {
|
if (PARAMS.maxOutDegree) {
|
||||||
extractHighOutDegree(renderNode, params);
|
extractHighOutDegree(renderNode);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.maxControlDegree) {
|
if (PARAMS.maxControlDegree) {
|
||||||
removeControlEdges(renderNode, params);
|
removeControlEdges(renderNode);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract isolated nodes, which can be
|
// Extract isolated nodes, which can be
|
||||||
@ -1513,7 +1506,7 @@ function extractHighDegrees(renderNode: RenderGroupNodeInfo,
|
|||||||
renderNode.isolatedOutExtract.push(child);
|
renderNode.isolatedOutExtract.push(child);
|
||||||
child.node.include = InclusionType.EXCLUDE;
|
child.node.include = InclusionType.EXCLUDE;
|
||||||
graph.removeNode(n);
|
graph.removeNode(n);
|
||||||
} else if (params.extractIsolatedNodesWithAnnotationsOnOneSide) {
|
} else if (PARAMS.extractIsolatedNodesWithAnnotationsOnOneSide) {
|
||||||
if (hasOutAnnotations && !hasInAnnotations) {
|
if (hasOutAnnotations && !hasInAnnotations) {
|
||||||
child.isInExtract = true; // for ones with high out-annotations
|
child.isInExtract = true; // for ones with high out-annotations
|
||||||
renderNode.isolatedInExtract.push(child);
|
renderNode.isolatedInExtract.push(child);
|
||||||
|
@ -1,113 +0,0 @@
|
|||||||
<link rel="import" href="../polymer/polymer.html">
|
|
||||||
<!--
|
|
||||||
Module for adjusting render graph building parameter.
|
|
||||||
-->
|
|
||||||
<dom-module id="tf-graph-params">
|
|
||||||
</dom-module>
|
|
||||||
<script>
|
|
||||||
Polymer({
|
|
||||||
|
|
||||||
is: 'tf-graph-params',
|
|
||||||
|
|
||||||
properties: {
|
|
||||||
// PARAMETERS
|
|
||||||
|
|
||||||
enableExtraction: {
|
|
||||||
type: Boolean,
|
|
||||||
value: true
|
|
||||||
},
|
|
||||||
|
|
||||||
/** Maximum in-degree that a node can have without being considered as
|
|
||||||
* high in-degree node. */
|
|
||||||
maxInDegree: {
|
|
||||||
type: Number,
|
|
||||||
value: 4
|
|
||||||
},
|
|
||||||
/** Maximum out-degree that a node can have without being considered as
|
|
||||||
* high out-degree node. */
|
|
||||||
maxOutDegree: {
|
|
||||||
type: Number,
|
|
||||||
value: 4
|
|
||||||
},
|
|
||||||
/** Maximum number of control edges a node can have before they aren't
|
|
||||||
* displayed. */
|
|
||||||
maxControlDegree: {
|
|
||||||
type: Number,
|
|
||||||
value: 4
|
|
||||||
},
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Types patterns for predefined out-extract nodes, which are
|
|
||||||
* sink-like nodes that will be extracted from the main graph.
|
|
||||||
*/
|
|
||||||
outExtractTypes: {
|
|
||||||
type: Array,
|
|
||||||
value: function() {
|
|
||||||
return [
|
|
||||||
'NoOp' // for "sgd", "momentum" group
|
|
||||||
];
|
|
||||||
}
|
|
||||||
},
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Types patterns for predefined in-extract nodes, which are
|
|
||||||
* source-like nodes that will be extracted from the main graph.
|
|
||||||
*/
|
|
||||||
inExtractTypes: {
|
|
||||||
type: Array,
|
|
||||||
value: function() {
|
|
||||||
return ['Variable'];
|
|
||||||
}
|
|
||||||
},
|
|
||||||
|
|
||||||
/**
|
|
||||||
* When removing edges from a high degree node, remove all of its edges if
|
|
||||||
* detachAllEdgesForHighDegree is true. Otherwise remove all in-edges if
|
|
||||||
* the node has high in-degree, or all out-edges if the node has high
|
|
||||||
* out-degree.
|
|
||||||
*/
|
|
||||||
detachAllEdgesForHighDegree: {
|
|
||||||
type: Boolean,
|
|
||||||
value: true
|
|
||||||
},
|
|
||||||
|
|
||||||
/**
|
|
||||||
* After extracting high in/out degree nodes and predefined
|
|
||||||
* source-like/sink-like, extract isolated nodes to the side
|
|
||||||
* if this extractIsolatedNodesWithAnnotationsOnOneSide is true.
|
|
||||||
*/
|
|
||||||
extractIsolatedNodesWithAnnotationsOnOneSide: {
|
|
||||||
type: Boolean,
|
|
||||||
value: true
|
|
||||||
},
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Whether to draw bridge paths inside of expanded group nodes.
|
|
||||||
*/
|
|
||||||
enableBridgegraph: {
|
|
||||||
type: Boolean,
|
|
||||||
value: true
|
|
||||||
},
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Colors for the minimum and maximum values whenever we have a gradient
|
|
||||||
* scale.
|
|
||||||
*/
|
|
||||||
minMaxColors: {
|
|
||||||
type: Array,
|
|
||||||
value: function() {
|
|
||||||
return ["#fff5f0", "#fb6a4a"];
|
|
||||||
}
|
|
||||||
},
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Maximum number of annotations to be displayed on a node before an
|
|
||||||
* ellipsis is used.
|
|
||||||
*/
|
|
||||||
maxAnnotations: {
|
|
||||||
type: Number,
|
|
||||||
value: 5
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
</script>
|
|
@ -6,7 +6,6 @@
|
|||||||
<link rel="import" href="../paper-toggle-button/paper-toggle-button.html">
|
<link rel="import" href="../paper-toggle-button/paper-toggle-button.html">
|
||||||
<link rel="import" href="../tf-graph-common/tf-graph-common.html">
|
<link rel="import" href="../tf-graph-common/tf-graph-common.html">
|
||||||
<link rel="import" href="tf-graph-scene.html">
|
<link rel="import" href="tf-graph-scene.html">
|
||||||
<link rel="import" href="tf-graph-params.html">
|
|
||||||
<dom-module id="tf-graph">
|
<dom-module id="tf-graph">
|
||||||
<template>
|
<template>
|
||||||
<style>
|
<style>
|
||||||
@ -37,7 +36,6 @@ paper-button {
|
|||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
<div class="container">
|
<div class="container">
|
||||||
<tf-graph-params id="graphParams"></tf-graph-params>
|
|
||||||
<div class="vertical">
|
<div class="vertical">
|
||||||
<h2>[[title]]</h2>
|
<h2>[[title]]</h2>
|
||||||
<tf-graph-scene id="scene" class="auto"
|
<tf-graph-scene id="scene" class="auto"
|
||||||
@ -91,13 +89,6 @@ Polymer({
|
|||||||
readOnly: true,
|
readOnly: true,
|
||||||
notify: true,
|
notify: true,
|
||||||
},
|
},
|
||||||
// internal properties
|
|
||||||
_graphParams: {
|
|
||||||
type: Object,
|
|
||||||
value: function() {
|
|
||||||
return this.$.graphParams;
|
|
||||||
}
|
|
||||||
},
|
|
||||||
_renderDepth: {
|
_renderDepth: {
|
||||||
type: Number,
|
type: Number,
|
||||||
value: 1
|
value: 1
|
||||||
@ -108,9 +99,9 @@ Polymer({
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
observers: [
|
observers: [
|
||||||
'_buildRenderHierarchy(graphHierarchy, _graphParams)'
|
'_buildRenderHierarchy(graphHierarchy)'
|
||||||
],
|
],
|
||||||
_buildRenderHierarchy: function(graphHierarchy, params) {
|
_buildRenderHierarchy: function(graphHierarchy) {
|
||||||
tf.time('new tf.graph.render.Hierarchy', function() {
|
tf.time('new tf.graph.render.Hierarchy', function() {
|
||||||
if (graphHierarchy.root.type !== tf.graph.NodeType.META) {
|
if (graphHierarchy.root.type !== tf.graph.NodeType.META) {
|
||||||
// root must be metanode but sometimes Polymer's dom-if has not
|
// root must be metanode but sometimes Polymer's dom-if has not
|
||||||
@ -118,8 +109,7 @@ Polymer({
|
|||||||
// and thus mistakenly pass non-metanode to this module.
|
// and thus mistakenly pass non-metanode to this module.
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
var renderGraph = new tf.graph.render.RenderGraphInfo(graphHierarchy,
|
var renderGraph = new tf.graph.render.RenderGraphInfo(graphHierarchy);
|
||||||
params);
|
|
||||||
// Producing the 'color by' parameters to be consumed
|
// Producing the 'color by' parameters to be consumed
|
||||||
// by the tf-graph-controls panel. It contains information about the
|
// by the tf-graph-controls panel. It contains information about the
|
||||||
// min and max values and their respective colors, as well as list
|
// min and max values and their respective colors, as well as list
|
||||||
@ -252,7 +242,7 @@ Polymer({
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Rebuild the render hierarchy.
|
// Rebuild the render hierarchy.
|
||||||
this._buildRenderHierarchy(this.graphHierarchy, this._graphParams);
|
this._buildRenderHierarchy(this.graphHierarchy);
|
||||||
},
|
},
|
||||||
_nodeToggleSeriesGroup: function(event) {
|
_nodeToggleSeriesGroup: function(event) {
|
||||||
// Toggle the group setting of the specified node appropriately.
|
// Toggle the group setting of the specified node appropriately.
|
||||||
@ -270,7 +260,7 @@ Polymer({
|
|||||||
tf.graph.hierarchy.build(this.basicGraph, this.hierarchyParams, hierarchyTracker)
|
tf.graph.hierarchy.build(this.basicGraph, this.hierarchyParams, hierarchyTracker)
|
||||||
.then(function(graphHierarchy) {
|
.then(function(graphHierarchy) {
|
||||||
this.set('graphHierarchy', graphHierarchy);
|
this.set('graphHierarchy', graphHierarchy);
|
||||||
this._buildRenderHierarchy(this.graphHierarchy, this._graphParams);
|
this._buildRenderHierarchy(this.graphHierarchy);
|
||||||
}.bind(this));
|
}.bind(this));
|
||||||
},
|
},
|
||||||
not: function(x) {
|
not: function(x) {
|
||||||
|
@ -13,8 +13,8 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
|
|||||||
|
|
||||||
native.new_http_archive(
|
native.new_http_archive(
|
||||||
name = "eigen_archive",
|
name = "eigen_archive",
|
||||||
url = "https://bitbucket.org/eigen/eigen/get/36b0586de49f.tar.gz",
|
url = "https://bitbucket.org/eigen/eigen/get/3d9f227afae2.tar.gz",
|
||||||
sha256 = "86da9dd97c91b6587a257add70d9478f4c463d6697d487564c5bfe83c4a0e8e0",
|
sha256 = "bf2638b7e1085de0b430b000c07e090dc71c83dd7f5b934a06f68b7db02676bf",
|
||||||
build_file = path_prefix + "eigen.BUILD",
|
build_file = path_prefix + "eigen.BUILD",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
2
third_party/eigen3/Eigen/Cholesky
vendored
2
third_party/eigen3/Eigen/Cholesky
vendored
@ -1 +1 @@
|
|||||||
#include "eigen-eigen-36b0586de49f/Eigen/Cholesky"
|
#include "eigen-eigen-3d9f227afae2/Eigen/Cholesky"
|
||||||
|
2
third_party/eigen3/Eigen/Core
vendored
2
third_party/eigen3/Eigen/Core
vendored
@ -1 +1 @@
|
|||||||
#include "eigen-eigen-36b0586de49f/Eigen/Core"
|
#include "eigen-eigen-3d9f227afae2/Eigen/Core"
|
||||||
|
2
third_party/eigen3/Eigen/Eigenvalues
vendored
2
third_party/eigen3/Eigen/Eigenvalues
vendored
@ -1 +1 @@
|
|||||||
#include "eigen-eigen-36b0586de49f/Eigen/Eigenvalues"
|
#include "eigen-eigen-3d9f227afae2/Eigen/Eigenvalues"
|
||||||
|
2
third_party/eigen3/Eigen/LU
vendored
2
third_party/eigen3/Eigen/LU
vendored
@ -1 +1 @@
|
|||||||
#include "eigen-eigen-36b0586de49f/Eigen/LU"
|
#include "eigen-eigen-3d9f227afae2/Eigen/LU"
|
||||||
|
2
third_party/eigen3/Eigen/QR
vendored
2
third_party/eigen3/Eigen/QR
vendored
@ -1 +1 @@
|
|||||||
#include "eigen-eigen-36b0586de49f/Eigen/QR"
|
#include "eigen-eigen-3d9f227afae2/Eigen/QR"
|
||||||
|
@ -1 +1 @@
|
|||||||
#include "eigen-eigen-36b0586de49f/unsupported/Eigen/CXX11/Tensor"
|
#include "eigen-eigen-3d9f227afae2/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
@ -3,6 +3,8 @@ build:cuda --define=using_cuda=true
|
|||||||
|
|
||||||
build --force_python=py$PYTHON_MAJOR_VERSION
|
build --force_python=py$PYTHON_MAJOR_VERSION
|
||||||
build --python$PYTHON_MAJOR_VERSION_path=$PYTHON_BINARY
|
build --python$PYTHON_MAJOR_VERSION_path=$PYTHON_BINARY
|
||||||
|
build --define=use_fast_cpp_protos=true
|
||||||
|
build --define=allow_oversize_protos=true
|
||||||
|
|
||||||
build --spawn_strategy=standalone
|
build --spawn_strategy=standalone
|
||||||
test --spawn_strategy=standalone
|
test --spawn_strategy=standalone
|
||||||
|
Loading…
Reference in New Issue
Block a user