Merge commit for internal changes
This commit is contained in:
commit
71320c0909
eigen.BUILD
tensorflow
core
BUILD
distributed_runtime/rpc
framework
kernels
BUILDbounds_check.hcwise_ops.heigen_backward_spatial_convolutions.hgather_nd_op.ccgather_nd_op.hgather_nd_op_gpu.cu.ccgather_nd_op_test.ccinitializable_lookup_table.cclookup_table_op.cc
ops
g3doc/api_docs
OWNERS
python
python
tensorboard/components
workspace.bzlthird_party/eigen3
tools
@ -1,6 +1,6 @@
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
archive_dir = "eigen-eigen-36b0586de49f"
|
||||
archive_dir = "eigen-eigen-3d9f227afae2"
|
||||
|
||||
cc_library(
|
||||
name = "eigen",
|
||||
|
@ -619,14 +619,19 @@ ANDROID_TF_COPTS = [
|
||||
"-std=c++11",
|
||||
"-DMIN_LOG_LEVEL=0",
|
||||
"-DTF_LEAN_BINARY",
|
||||
"-O2",
|
||||
]
|
||||
|
||||
# Native library support for Android applications.
|
||||
# Does not contain operators, use :android_tensorflow_lib if you want full
|
||||
# operator support.
|
||||
# Compiles to a trivial library on non-android to prevent irrelevant
|
||||
# build errors.
|
||||
#
|
||||
# Compiles to a trivial library on non-Android to prevent irrelevant
|
||||
# build errors. If not building this as part of an android_binary,
|
||||
# a command such as the following must be used:
|
||||
# bazel build -c opt tensorflow/core:android_tensorflow_lib \
|
||||
# --crosstool_top=//third_party/java/android/android_ndk_linux/toolchains:everything \
|
||||
# --cpu=armeabi-v7a \
|
||||
# --host_crosstool_top=@bazel_tools//tools/cpp:toolchain
|
||||
cc_library(
|
||||
name = "android_tensorflow_lib_lite",
|
||||
srcs = select({
|
||||
@ -643,7 +648,7 @@ cc_library(
|
||||
"public/session.h",
|
||||
],
|
||||
copts = select({
|
||||
":android": ANDROID_TF_COPTS,
|
||||
":android": ANDROID_TF_COPTS + ["-Os"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
tags = [
|
||||
@ -656,19 +661,23 @@ cc_library(
|
||||
":protos_cc",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
# Full Tensorflow library with operator support. Use this unless reducing
|
||||
# binary size (by packaging a reduced operator set) is a concern.
|
||||
cc_library(
|
||||
name = "android_tensorflow_lib",
|
||||
srcs = [
|
||||
":android_op_registrations_and_gradients",
|
||||
"//tensorflow/core/kernels:android_core_ops",
|
||||
"//tensorflow/core/kernels:android_extended_ops",
|
||||
],
|
||||
srcs = select({
|
||||
":android": [
|
||||
":android_op_registrations_and_gradients",
|
||||
"//tensorflow/core/kernels:android_core_ops",
|
||||
"//tensorflow/core/kernels:android_extended_ops",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
copts = select({
|
||||
":android": ANDROID_TF_COPTS,
|
||||
":android": ANDROID_TF_COPTS + ["-O2"],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
tags = [
|
||||
@ -682,6 +691,7 @@ cc_library(
|
||||
":protos_cc",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
filegroup(
|
||||
|
@ -124,7 +124,7 @@ class UntypedCall : public core::RefCounted {
|
||||
}
|
||||
|
||||
private:
|
||||
UntypedCall* call_; // `this` owns one reference.
|
||||
UntypedCall* const call_; // `this` owns one reference.
|
||||
Callback callback_;
|
||||
};
|
||||
};
|
||||
@ -149,30 +149,14 @@ class Call : public UntypedCall<Service> {
|
||||
Call<Service, GrpcService, RequestMessage, ResponseMessage>*);
|
||||
|
||||
Call(HandleRequestFunction handle_request_function)
|
||||
: handle_request_function_(handle_request_function),
|
||||
responder_(&ctx_),
|
||||
cancel_tag_(new typename UntypedCall<Service>::Tag(
|
||||
this, &UntypedCall<Service>::RequestCancelled)) {
|
||||
// The `ctx_` borrows the `cancel_tag_` until
|
||||
// `this->RequestReceived()` is called.
|
||||
ctx_.AsyncNotifyWhenDone(cancel_tag_.get());
|
||||
}
|
||||
: handle_request_function_(handle_request_function), responder_(&ctx_) {}
|
||||
|
||||
virtual ~Call() {}
|
||||
|
||||
void RequestReceived(Service* service, bool ok) override {
|
||||
if (ok) {
|
||||
// At this point, the `cancel_tag_` becomes owned by the
|
||||
// completion queue.
|
||||
cancel_tag_.release();
|
||||
|
||||
this->Ref();
|
||||
(service->*handle_request_function_)(this);
|
||||
} else {
|
||||
// `!ok` implies we never received a request for this call, and
|
||||
// the `cancel_tag_` will never be added to the completion
|
||||
// queue, so we free it here.
|
||||
cancel_tag_.reset();
|
||||
}
|
||||
}
|
||||
|
||||
@ -190,6 +174,9 @@ class Call : public UntypedCall<Service> {
|
||||
cancel_callback_();
|
||||
}
|
||||
}
|
||||
// NOTE(mrry): This can be called before or after RequestReceived, so we
|
||||
// release `cancel_tag_` (in order to allow the event loop to free it).
|
||||
cancel_tag_.release();
|
||||
}
|
||||
|
||||
// Registers `callback` as the function that should be called if and when this
|
||||
@ -213,9 +200,13 @@ class Call : public UntypedCall<Service> {
|
||||
static void EnqueueRequest(GrpcService* grpc_service,
|
||||
::grpc::ServerCompletionQueue* cq,
|
||||
EnqueueFunction enqueue_function,
|
||||
HandleRequestFunction handle_request_function) {
|
||||
HandleRequestFunction handle_request_function,
|
||||
bool supports_cancel) {
|
||||
auto call = new Call<Service, GrpcService, RequestMessage, ResponseMessage>(
|
||||
handle_request_function);
|
||||
if (supports_cancel) {
|
||||
call->RegisterCancellationHandler();
|
||||
}
|
||||
|
||||
(grpc_service->*enqueue_function)(
|
||||
&call->ctx_, &call->request, &call->responder_, cq, cq,
|
||||
@ -228,6 +219,15 @@ class Call : public UntypedCall<Service> {
|
||||
ResponseMessage response;
|
||||
|
||||
private:
|
||||
// Creates a completion queue tag for handling cancellation by the client.
|
||||
// NOTE: This method must be called before this call is enqueued on a
|
||||
// completion queue.
|
||||
void RegisterCancellationHandler() {
|
||||
cancel_tag_.reset(new typename UntypedCall<Service>::Tag(
|
||||
this, &UntypedCall<Service>::RequestCancelled));
|
||||
ctx_.AsyncNotifyWhenDone(cancel_tag_.get());
|
||||
}
|
||||
|
||||
HandleRequestFunction handle_request_function_;
|
||||
::grpc::ServerContext ctx_;
|
||||
::grpc::ServerAsyncResponseWriter<ResponseMessage> responder_;
|
||||
|
@ -88,7 +88,7 @@ class GrpcMasterService : public AsyncServiceInterface {
|
||||
// The implementation of the request handler for each RPC method
|
||||
// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
|
||||
// to keep accepting new requests.
|
||||
#define ENQUEUE_REQUEST(method) \
|
||||
#define ENQUEUE_REQUEST(method, supports_cancel) \
|
||||
do { \
|
||||
mutex_lock l(mu_); \
|
||||
if (!is_shutdown_) { \
|
||||
@ -96,19 +96,20 @@ class GrpcMasterService : public AsyncServiceInterface {
|
||||
method##Request, method##Response>:: \
|
||||
EnqueueRequest(&master_service_, cq_, \
|
||||
&grpc::MasterService::AsyncService::Request##method, \
|
||||
&GrpcMasterService::method##Handler); \
|
||||
&GrpcMasterService::method##Handler, \
|
||||
(supports_cancel)); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
void HandleRPCsLoop() {
|
||||
ENQUEUE_REQUEST(CreateSession);
|
||||
ENQUEUE_REQUEST(ExtendSession);
|
||||
ENQUEUE_REQUEST(CreateSession, true);
|
||||
ENQUEUE_REQUEST(ExtendSession, false);
|
||||
for (int i = 0; i < 100; ++i) {
|
||||
ENQUEUE_REQUEST(RunStep);
|
||||
ENQUEUE_REQUEST(RunStep, true);
|
||||
}
|
||||
ENQUEUE_REQUEST(CloseSession);
|
||||
ENQUEUE_REQUEST(ListDevices);
|
||||
ENQUEUE_REQUEST(Reset);
|
||||
ENQUEUE_REQUEST(CloseSession, false);
|
||||
ENQUEUE_REQUEST(ListDevices, false);
|
||||
ENQUEUE_REQUEST(Reset, false);
|
||||
|
||||
void* tag;
|
||||
bool ok;
|
||||
@ -146,7 +147,7 @@ class GrpcMasterService : public AsyncServiceInterface {
|
||||
[call](const Status& status) {
|
||||
call->SendResponse(ToGrpcStatus(status));
|
||||
});
|
||||
ENQUEUE_REQUEST(CreateSession);
|
||||
ENQUEUE_REQUEST(CreateSession, true);
|
||||
}
|
||||
|
||||
// RPC handler for extending a session.
|
||||
@ -156,7 +157,7 @@ class GrpcMasterService : public AsyncServiceInterface {
|
||||
[call](const Status& status) {
|
||||
call->SendResponse(ToGrpcStatus(status));
|
||||
});
|
||||
ENQUEUE_REQUEST(ExtendSession);
|
||||
ENQUEUE_REQUEST(ExtendSession, false);
|
||||
}
|
||||
|
||||
// RPC handler for running one step in a session.
|
||||
@ -169,7 +170,7 @@ class GrpcMasterService : public AsyncServiceInterface {
|
||||
delete call_opts;
|
||||
call->SendResponse(ToGrpcStatus(status));
|
||||
});
|
||||
ENQUEUE_REQUEST(RunStep);
|
||||
ENQUEUE_REQUEST(RunStep, true);
|
||||
}
|
||||
|
||||
// RPC handler for deleting a session.
|
||||
@ -179,7 +180,7 @@ class GrpcMasterService : public AsyncServiceInterface {
|
||||
[call](const Status& status) {
|
||||
call->SendResponse(ToGrpcStatus(status));
|
||||
});
|
||||
ENQUEUE_REQUEST(CloseSession);
|
||||
ENQUEUE_REQUEST(CloseSession, false);
|
||||
}
|
||||
|
||||
// RPC handler for listing devices.
|
||||
@ -189,7 +190,7 @@ class GrpcMasterService : public AsyncServiceInterface {
|
||||
[call](const Status& status) {
|
||||
call->SendResponse(ToGrpcStatus(status));
|
||||
});
|
||||
ENQUEUE_REQUEST(ListDevices);
|
||||
ENQUEUE_REQUEST(ListDevices, false);
|
||||
}
|
||||
|
||||
// RPC handler for resetting all sessions.
|
||||
@ -198,7 +199,7 @@ class GrpcMasterService : public AsyncServiceInterface {
|
||||
[call](const Status& status) {
|
||||
call->SendResponse(ToGrpcStatus(status));
|
||||
});
|
||||
ENQUEUE_REQUEST(Reset);
|
||||
ENQUEUE_REQUEST(Reset, false);
|
||||
}
|
||||
#undef ENQUEUE_REQUEST
|
||||
|
||||
|
@ -95,7 +95,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
// The implementation of the request handler for each RPC method
|
||||
// must ensure that it calls ENQUEUE_REQUEST() for that RPC method,
|
||||
// to keep accepting new requests.
|
||||
#define ENQUEUE_REQUEST(method) \
|
||||
#define ENQUEUE_REQUEST(method, supports_cancel) \
|
||||
do { \
|
||||
mutex_lock l(shutdown_mu_); \
|
||||
if (!is_shutdown_) { \
|
||||
@ -103,7 +103,8 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
method##Request, method##Response>:: \
|
||||
EnqueueRequest(&worker_service_, cq_, \
|
||||
&grpc::WorkerService::AsyncService::Request##method, \
|
||||
&GrpcWorkerService::method##Handler); \
|
||||
&GrpcWorkerService::method##Handler, \
|
||||
(supports_cancel)); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
@ -116,18 +117,18 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
// method, by re-enqueuing a request before the previous one
|
||||
// completes, and we may decide to bound some of the request
|
||||
// types.
|
||||
ENQUEUE_REQUEST(GetStatus);
|
||||
ENQUEUE_REQUEST(CleanupAll);
|
||||
ENQUEUE_REQUEST(RegisterGraph);
|
||||
ENQUEUE_REQUEST(DeregisterGraph);
|
||||
ENQUEUE_REQUEST(GetStatus, false);
|
||||
ENQUEUE_REQUEST(CleanupAll, false);
|
||||
ENQUEUE_REQUEST(RegisterGraph, false);
|
||||
ENQUEUE_REQUEST(DeregisterGraph, false);
|
||||
|
||||
// TODO(mrry): Consider enqueuing more of these request types.
|
||||
ENQUEUE_REQUEST(RecvTensor);
|
||||
ENQUEUE_REQUEST(RunGraph);
|
||||
ENQUEUE_REQUEST(RecvTensor, true);
|
||||
ENQUEUE_REQUEST(RunGraph, true);
|
||||
|
||||
ENQUEUE_REQUEST(CleanupGraph);
|
||||
ENQUEUE_REQUEST(Logging);
|
||||
ENQUEUE_REQUEST(Tracing);
|
||||
ENQUEUE_REQUEST(CleanupGraph, false);
|
||||
ENQUEUE_REQUEST(Logging, false);
|
||||
ENQUEUE_REQUEST(Tracing, false);
|
||||
|
||||
void* tag;
|
||||
bool ok;
|
||||
@ -181,7 +182,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
}
|
||||
call->SendResponse(::grpc::Status::OK);
|
||||
});
|
||||
ENQUEUE_REQUEST(GetStatus);
|
||||
ENQUEUE_REQUEST(GetStatus, false);
|
||||
}
|
||||
|
||||
void CleanupAllHandler(
|
||||
@ -192,7 +193,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
env_->device_mgr->ClearContainers(containers);
|
||||
call->SendResponse(::grpc::Status::OK);
|
||||
});
|
||||
ENQUEUE_REQUEST(CleanupAll);
|
||||
ENQUEUE_REQUEST(CleanupAll, false);
|
||||
}
|
||||
|
||||
void RegisterGraphHandler(
|
||||
@ -203,7 +204,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
call->request.graph_options(), call->response.mutable_graph_handle());
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
});
|
||||
ENQUEUE_REQUEST(RegisterGraph);
|
||||
ENQUEUE_REQUEST(RegisterGraph, false);
|
||||
}
|
||||
|
||||
void DeregisterGraphHandler(
|
||||
@ -212,18 +213,18 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
Status s = env_->graph_mgr->Deregister(call->request.graph_handle());
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
});
|
||||
ENQUEUE_REQUEST(DeregisterGraph);
|
||||
ENQUEUE_REQUEST(DeregisterGraph, false);
|
||||
}
|
||||
|
||||
void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
|
||||
env_->compute_pool->Schedule([this, call]() { DoRunGraph(call); });
|
||||
ENQUEUE_REQUEST(RunGraph);
|
||||
ENQUEUE_REQUEST(RunGraph, true);
|
||||
}
|
||||
|
||||
void RecvTensorHandler(
|
||||
WorkerCall<RecvTensorRequest, RecvTensorResponse>* call) {
|
||||
env_->compute_pool->Schedule([this, call]() { DoRecvTensor(call); });
|
||||
ENQUEUE_REQUEST(RecvTensor);
|
||||
ENQUEUE_REQUEST(RecvTensor, true);
|
||||
}
|
||||
|
||||
void CleanupGraphHandler(
|
||||
@ -233,7 +234,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
env_->rendezvous_mgr->Cleanup(step_id);
|
||||
call->SendResponse(::grpc::Status::OK);
|
||||
});
|
||||
ENQUEUE_REQUEST(CleanupGraph);
|
||||
ENQUEUE_REQUEST(CleanupGraph, false);
|
||||
}
|
||||
|
||||
void LoggingHandler(WorkerCall<LoggingRequest, LoggingResponse>* call) {
|
||||
@ -241,7 +242,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
Status s = DoLogging(call);
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
});
|
||||
ENQUEUE_REQUEST(Logging);
|
||||
ENQUEUE_REQUEST(Logging, false);
|
||||
}
|
||||
|
||||
void TracingHandler(WorkerCall<TracingRequest, TracingResponse>* call) {
|
||||
@ -249,7 +250,7 @@ class GrpcWorkerService : public AsyncServiceInterface {
|
||||
Status s = DoTracing(call);
|
||||
call->SendResponse(ToGrpcStatus(s));
|
||||
});
|
||||
ENQUEUE_REQUEST(Tracing);
|
||||
ENQUEUE_REQUEST(Tracing, false);
|
||||
}
|
||||
#undef ENQUEUE_REQUEST
|
||||
|
||||
|
@ -89,7 +89,8 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
|
||||
|
||||
// Look up the Kernel registered for this node def.
|
||||
const KernelDef* kdef = nullptr;
|
||||
status = FindKernelDef(device_type, ndef, &kdef);
|
||||
status =
|
||||
FindKernelDef(device_type, ndef, &kdef, nullptr /* kernel_class_name */);
|
||||
|
||||
if (!status.ok() || HasTypeList(*op_def)) {
|
||||
// When there is no kernel def for this op or the op's arg is a
|
||||
|
@ -594,10 +594,11 @@ Status OpKernelContext::MatchSignature(const DataTypeSlice expected_inputs,
|
||||
// OpKernel registration ------------------------------------------------------
|
||||
|
||||
struct KernelRegistration {
|
||||
KernelRegistration(const KernelDef& d,
|
||||
KernelRegistration(const KernelDef& d, StringPiece c,
|
||||
kernel_factory::OpKernelRegistrar::Factory f)
|
||||
: def(d), factory(f) {}
|
||||
: def(d), kernel_class_name(c.ToString()), factory(f) {}
|
||||
const KernelDef def;
|
||||
const string kernel_class_name;
|
||||
const kernel_factory::OpKernelRegistrar::Factory factory;
|
||||
};
|
||||
|
||||
@ -624,12 +625,13 @@ static string Key(StringPiece op_type, DeviceType device_type,
|
||||
namespace kernel_factory {
|
||||
|
||||
void OpKernelRegistrar::InitInternal(const KernelDef* kernel_def,
|
||||
StringPiece kernel_class_name,
|
||||
Factory factory) {
|
||||
const string key =
|
||||
Key(kernel_def->op(), DeviceType(kernel_def->device_type()),
|
||||
kernel_def->label());
|
||||
GlobalKernelRegistryTyped()->insert(
|
||||
std::make_pair(key, KernelRegistration(*kernel_def, factory)));
|
||||
GlobalKernelRegistryTyped()->insert(std::make_pair(
|
||||
key, KernelRegistration(*kernel_def, kernel_class_name, factory)));
|
||||
delete kernel_def;
|
||||
}
|
||||
|
||||
@ -724,7 +726,7 @@ Status FindKernelRegistration(DeviceType device_type, const NodeDef& node_def,
|
||||
} // namespace
|
||||
|
||||
Status FindKernelDef(DeviceType device_type, const NodeDef& node_def,
|
||||
const KernelDef** def) {
|
||||
const KernelDef** def, string* kernel_class_name) {
|
||||
const KernelRegistration* reg = nullptr;
|
||||
TF_RETURN_IF_ERROR(FindKernelRegistration(device_type, node_def, ®));
|
||||
if (reg == nullptr) {
|
||||
@ -733,7 +735,8 @@ Status FindKernelDef(DeviceType device_type, const NodeDef& node_def,
|
||||
" devices compatible with node ",
|
||||
SummarizeNodeDef(node_def));
|
||||
}
|
||||
*def = ®->def;
|
||||
if (def != nullptr) *def = ®->def;
|
||||
if (kernel_class_name != nullptr) *kernel_class_name = reg->kernel_class_name;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -1025,7 +1025,6 @@ namespace register_kernel {
|
||||
typedef ::tensorflow::KernelDefBuilder Name;
|
||||
} // namespace register_kernel
|
||||
|
||||
|
||||
#define REGISTER_KERNEL_BUILDER(kernel_builder, ...) \
|
||||
REGISTER_KERNEL_BUILDER_UNIQ_HELPER(__COUNTER__, kernel_builder, __VA_ARGS__)
|
||||
|
||||
@ -1035,18 +1034,20 @@ typedef ::tensorflow::KernelDefBuilder Name;
|
||||
#define REGISTER_KERNEL_BUILDER_UNIQ(ctr, kernel_builder, ...) \
|
||||
static ::tensorflow::kernel_factory::OpKernelRegistrar \
|
||||
registrar__body__##ctr##__object( \
|
||||
SHOULD_REGISTER_OP_KERNEL(__FILE__) \
|
||||
SHOULD_REGISTER_OP_KERNEL(#__VA_ARGS__) \
|
||||
? ::tensorflow::register_kernel::kernel_builder.Build() \
|
||||
: nullptr, \
|
||||
#__VA_ARGS__, \
|
||||
[](::tensorflow::OpKernelConstruction* context) \
|
||||
-> ::tensorflow::OpKernel* { return new __VA_ARGS__(context); })
|
||||
|
||||
void* GlobalKernelRegistry();
|
||||
|
||||
// If node_def has a corresponding kernel registered on device_type,
|
||||
// returns OK and fill in the kernel def.
|
||||
// returns OK and fill in the kernel def and kernel_class_name. <def> and
|
||||
// <kernel_class_name> may be null.
|
||||
Status FindKernelDef(DeviceType device_type, const NodeDef& node_def,
|
||||
const KernelDef** def);
|
||||
const KernelDef** def, string* kernel_class_name);
|
||||
|
||||
// Treats 'registry_ptr' as a pointer to KernelRegistry. For each kernel 'k'
|
||||
// registered with the current library's global kernel registry (obtained by
|
||||
@ -1058,16 +1059,19 @@ namespace kernel_factory {
|
||||
class OpKernelRegistrar {
|
||||
public:
|
||||
typedef OpKernel* (*Factory)(OpKernelConstruction*);
|
||||
OpKernelRegistrar(const KernelDef* kernel_def, Factory factory) {
|
||||
|
||||
OpKernelRegistrar(const KernelDef* kernel_def, StringPiece kernel_class_name,
|
||||
Factory factory) {
|
||||
// Perform the check in the header to allow compile-time optimization
|
||||
// to a no-op, allowing the linker to remove the kernel symbols.
|
||||
if (kernel_def != nullptr) {
|
||||
InitInternal(kernel_def, factory);
|
||||
InitInternal(kernel_def, kernel_class_name, factory);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void InitInternal(const KernelDef* kernel_def, Factory factory);
|
||||
void InitInternal(const KernelDef* kernel_def, StringPiece kernel_class_name,
|
||||
Factory factory);
|
||||
};
|
||||
|
||||
} // namespace kernel_factory
|
||||
|
@ -422,6 +422,27 @@ class OpKernelBuilderTest : public ::testing::Test {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
string GetKernelClassName(const string& op_type, DeviceType device_type,
|
||||
const std::vector<string>& attrs,
|
||||
DataTypeSlice input_types = {}) {
|
||||
NodeDef def = CreateNodeDef(op_type, attrs);
|
||||
for (size_t i = 0; i < input_types.size(); ++i) {
|
||||
def.add_input("a:0");
|
||||
}
|
||||
|
||||
const KernelDef* kernel_def = nullptr;
|
||||
string kernel_class_name;
|
||||
const Status status =
|
||||
FindKernelDef(device_type, def, &kernel_def, &kernel_class_name);
|
||||
if (status.ok()) {
|
||||
return kernel_class_name;
|
||||
} else if (errors::IsNotFound(status)) {
|
||||
return "not found";
|
||||
} else {
|
||||
return status.ToString();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_OP("BuildCPU");
|
||||
@ -429,7 +450,9 @@ REGISTER_KERNEL_BUILDER(Name("BuildCPU").Device(DEVICE_CPU), DummyKernel);
|
||||
|
||||
TEST_F(OpKernelBuilderTest, BuilderCPU) {
|
||||
ExpectSuccess("BuildCPU", DEVICE_CPU, {});
|
||||
EXPECT_EQ("DummyKernel", GetKernelClassName("BuildCPU", DEVICE_CPU, {}));
|
||||
ExpectFailure("BuildCPU", DEVICE_GPU, {}, error::NOT_FOUND);
|
||||
EXPECT_EQ("not found", GetKernelClassName("BuildCPU", DEVICE_GPU, {}));
|
||||
}
|
||||
|
||||
REGISTER_OP("BuildGPU");
|
||||
@ -472,12 +495,26 @@ REGISTER_KERNEL_BUILDER(Name("BuildTypeListAttr")
|
||||
|
||||
TEST_F(OpKernelBuilderTest, BuilderTypeListAttr) {
|
||||
ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[]"});
|
||||
EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
|
||||
{"T|list(type)|[]"}));
|
||||
|
||||
ExpectSuccess("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_BOOL]"});
|
||||
EXPECT_EQ("DummyKernel", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
|
||||
{"T|list(type)|[]"}));
|
||||
|
||||
ExpectSuccess("BuildTypeListAttr", DEVICE_CPU,
|
||||
{"T|list(type)|[DT_BOOL, DT_BOOL]"});
|
||||
|
||||
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|list(type)|[DT_FLOAT]"},
|
||||
error::NOT_FOUND);
|
||||
EXPECT_EQ("not found", GetKernelClassName("BuildTypeListAttr", DEVICE_CPU,
|
||||
{"T|list(type)|[DT_FLOAT]"}));
|
||||
|
||||
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {}, error::INVALID_ARGUMENT);
|
||||
EXPECT_TRUE(
|
||||
StringPiece(GetKernelClassName("BuildTypeListAttr", DEVICE_CPU, {}))
|
||||
.contains("Invalid argument: "));
|
||||
|
||||
ExpectFailure("BuildTypeListAttr", DEVICE_CPU, {"T|int|7"},
|
||||
error::INVALID_ARGUMENT);
|
||||
}
|
||||
@ -776,6 +813,9 @@ TEST_F(LabelTest, Default) {
|
||||
ExpectSuccess("LabeledKernel", DEVICE_CPU, {});
|
||||
auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
|
||||
EXPECT_EQ(0, get_labeled_kernel->Which());
|
||||
|
||||
EXPECT_EQ("LabeledKernel<0>",
|
||||
GetKernelClassName("LabeledKernel", DEVICE_CPU, {}));
|
||||
}
|
||||
|
||||
TEST_F(LabelTest, Specified) {
|
||||
@ -783,6 +823,8 @@ TEST_F(LabelTest, Specified) {
|
||||
ExpectSuccess("LabeledKernel", DEVICE_CPU, {"_kernel|string|'one'"});
|
||||
auto* get_labeled_kernel = static_cast<BaseKernel*>(op_kernel.get());
|
||||
EXPECT_EQ(1, get_labeled_kernel->Which());
|
||||
EXPECT_EQ("LabeledKernel<1>", GetKernelClassName("LabeledKernel", DEVICE_CPU,
|
||||
{"_kernel|string|'one'"}));
|
||||
}
|
||||
|
||||
TEST_F(LabelTest, Duplicate) {
|
||||
|
@ -34,13 +34,12 @@ limitations under the License.
|
||||
// out.
|
||||
#include "ops_to_register.h"
|
||||
|
||||
// Files which are not included in the whitelist provided by this
|
||||
// graph-specific header file will not be allowed to register their
|
||||
// operator kernels.
|
||||
#define SHOULD_REGISTER_OP_KERNEL(filename) \
|
||||
(strstr(kNecessaryOpFiles, filename) != nullptr)
|
||||
// Op kernel classes for which ShouldRegisterOpKernel returns false will not be
|
||||
// registered.
|
||||
#define SHOULD_REGISTER_OP_KERNEL(clz) \
|
||||
(strstr(kNecessaryOpKernelClasses, "," clz ",") != nullptr)
|
||||
|
||||
// Ops for which ShouldRegisterOp return false will no be registered.
|
||||
// Ops for which ShouldRegisterOp returns false will not be registered.
|
||||
#define SHOULD_REGISTER_OP(op) ShouldRegisterOp(op)
|
||||
|
||||
// If kRequiresSymbolicGradients is false, then no gradient ops are registered.
|
||||
|
@ -260,6 +260,7 @@ tf_kernel_libraries(
|
||||
"constant_op",
|
||||
"diag_op",
|
||||
"edit_distance_op",
|
||||
"gather_nd_op",
|
||||
"gather_op",
|
||||
"identity_op",
|
||||
"immutable_constant_op",
|
||||
|
@ -27,7 +27,8 @@ namespace tensorflow {
|
||||
// that 0 <= limit if Index is signed. Intended for use in performance
|
||||
// critical contexts where 0 <= index < limit is almost always true.
|
||||
template <typename Ta, typename Tb>
|
||||
EIGEN_ALWAYS_INLINE bool FastBoundsCheck(const Ta index, const Tb limit) {
|
||||
EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC bool FastBoundsCheck(const Ta index,
|
||||
const Tb limit) {
|
||||
static_assert(std::is_integral<Ta>::value && std::is_integral<Tb>::value,
|
||||
"FastBoundsCheck can only be used on integer types.");
|
||||
typedef typename std::make_unsigned<decltype(index + limit)>::type UIndex;
|
||||
|
@ -22,22 +22,27 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
|
||||
// The following functors (sign, tanh, sigmoid, etc.) are not defined
|
||||
// by Eigen. When their equivalent are added into the Eigen, we can
|
||||
// replace them using type aliases.
|
||||
|
||||
namespace Eigen {
|
||||
namespace internal {
|
||||
|
||||
// TODO(rmlarsen): Get rid of fmod2 once fmod is upstreamed to Eigen.
|
||||
template <typename T>
|
||||
struct scalar_fmod2_op {
|
||||
EIGEN_EMPTY_STRUCT_CTOR(scalar_fmod2_op)
|
||||
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a,
|
||||
const T& b) const {
|
||||
return fmod(a, b);
|
||||
return std::fmod(a, b);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct functor_traits<scalar_fmod2_op<T>> {
|
||||
enum {
|
||||
Cost = 13, // Reciprocal throughput of FPREM on Haswell.
|
||||
PacketAccess = false,
|
||||
};
|
||||
};
|
||||
|
||||
// scalar_left and scalar_right are template helpers to partially
|
||||
// apply a binary function.
|
||||
//
|
||||
@ -489,7 +494,7 @@ template <typename T>
|
||||
struct fmod : base<T, Eigen::internal::scalar_fmod2_op<T> > {};
|
||||
|
||||
template <typename T>
|
||||
struct mod : base<T, Eigen::internal::scalar_mod2_op<T> > {};
|
||||
struct mod : base<T, Eigen::internal::scalar_mod2_op<T>> {};
|
||||
|
||||
template <typename T>
|
||||
struct pow : base<T, Eigen::internal::scalar_binary_pow_op<T, T> > {};
|
||||
|
@ -36,7 +36,6 @@ namespace Eigen {
|
||||
* It is possible to swap the order of the width and height dimensions provided that the same order is used in the input, the kernel, and the output.
|
||||
*
|
||||
*/
|
||||
|
||||
template <typename OutputBackward, typename Kernel>
|
||||
EIGEN_ALWAYS_INLINE static const typename internal::conditional<
|
||||
internal::traits<OutputBackward>::Layout == ColMajor,
|
||||
@ -45,14 +44,18 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
|
||||
internal::traits<OutputBackward>::NumDimensions>,
|
||||
const TensorContractionOp<
|
||||
const array<
|
||||
IndexPair<typename internal::traits<OutputBackward>::Index>, 2>,
|
||||
IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
|
||||
const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
|
||||
const DSizes<typename internal::traits<OutputBackward>::Index,
|
||||
3>,
|
||||
const TensorReverseOp<const array<bool, 4>, const Kernel> > >,
|
||||
2>,
|
||||
const TensorShufflingOp<
|
||||
const array<
|
||||
typename internal::traits<OutputBackward>::Index, 4>,
|
||||
const TensorReverseOp<const array<bool, 4>,
|
||||
const Kernel> > > >,
|
||||
const TensorReshapingOp<
|
||||
const DSizes<typename internal::traits<OutputBackward>::Index,
|
||||
3>,
|
||||
2>,
|
||||
const TensorImagePatchOp<Dynamic, Dynamic,
|
||||
const OutputBackward> > > >,
|
||||
TensorReshapingOp<
|
||||
@ -60,17 +63,20 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
|
||||
internal::traits<OutputBackward>::NumDimensions>,
|
||||
const TensorContractionOp<
|
||||
const array<
|
||||
IndexPair<typename internal::traits<OutputBackward>::Index>, 2>,
|
||||
IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
|
||||
const TensorReshapingOp<
|
||||
const DSizes<typename internal::traits<OutputBackward>::Index,
|
||||
3>,
|
||||
2>,
|
||||
const TensorImagePatchOp<Dynamic, Dynamic,
|
||||
const OutputBackward> >,
|
||||
const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
|
||||
const DSizes<typename internal::traits<OutputBackward>::Index,
|
||||
3>,
|
||||
const TensorReverseOp<const array<bool, 4>,
|
||||
const Kernel> > > > > >::type
|
||||
2>,
|
||||
const TensorShufflingOp<
|
||||
const array<
|
||||
typename internal::traits<OutputBackward>::Index, 4>,
|
||||
const TensorReverseOp<const array<bool, 4>,
|
||||
const Kernel> > > > > > >::type
|
||||
SpatialConvolutionBackwardInput(
|
||||
const Kernel& kernel, const OutputBackward& output_backward,
|
||||
typename internal::traits<OutputBackward>::Index inputRows,
|
||||
@ -134,49 +140,57 @@ SpatialConvolutionBackwardInput(
|
||||
kernel_reverse[3] = false;
|
||||
}
|
||||
|
||||
DSizes<TensorIndex, 3> kernel_dims;
|
||||
// Reorder the dimensions to filters X patch_rows X patch_cols X channels
|
||||
array<TensorIndex, 4> kernel_shuffle;
|
||||
if (isColMajor) {
|
||||
kernel_dims[0] = kernelFilters;
|
||||
kernel_dims[1] = kernelChannels;
|
||||
kernel_dims[2] = kernelRows * kernelCols;
|
||||
kernel_shuffle[0] = 0;
|
||||
kernel_shuffle[1] = 2;
|
||||
kernel_shuffle[2] = 3;
|
||||
kernel_shuffle[3] = 1;
|
||||
} else {
|
||||
kernel_dims[0] = kernelRows * kernelCols;
|
||||
kernel_shuffle[0] = 2;
|
||||
kernel_shuffle[1] = 0;
|
||||
kernel_shuffle[2] = 1;
|
||||
kernel_shuffle[3] = 3;
|
||||
}
|
||||
|
||||
// Collapse the dims
|
||||
DSizes<TensorIndex, 2> kernel_dims;
|
||||
if (isColMajor) {
|
||||
kernel_dims[0] = kernelFilters * kernelRows * kernelCols;
|
||||
kernel_dims[1] = kernelChannels;
|
||||
kernel_dims[2] = kernelFilters;
|
||||
} else {
|
||||
kernel_dims[1] = kernelFilters * kernelRows * kernelCols;
|
||||
kernel_dims[0] = kernelChannels;
|
||||
}
|
||||
|
||||
// The output_backward has dimensions out_depth X out_rows X out_cols X OTHERS
|
||||
// When we extract the image patches from output_backward, it will have dimensions
|
||||
// out_depth X (patch_rows * patch_cols) X (input_rows * input_cols * OTHERS)
|
||||
DSizes<TensorIndex, 3> pre_contract_dims;
|
||||
DSizes<TensorIndex, 2> pre_contract_dims;
|
||||
if (isColMajor) {
|
||||
pre_contract_dims[0] = kernelFilters;
|
||||
pre_contract_dims[1] = kernelRows * kernelCols;
|
||||
pre_contract_dims[2] = inputRows * inputCols;
|
||||
pre_contract_dims[0] = kernelFilters * kernelRows * kernelCols;
|
||||
pre_contract_dims[1] = inputRows * inputCols;
|
||||
for (int i = 3; i < NumDims; ++i) {
|
||||
pre_contract_dims[2] *= out.dimension(i);
|
||||
pre_contract_dims[1] *= out.dimension(i);
|
||||
}
|
||||
} else {
|
||||
pre_contract_dims[2] = kernelFilters;
|
||||
pre_contract_dims[1] = kernelRows * kernelCols;
|
||||
pre_contract_dims[1] = kernelFilters * kernelRows * kernelCols;
|
||||
pre_contract_dims[0] = inputRows * inputCols;
|
||||
for (int i = 0; i < NumDims - 3; ++i) {
|
||||
pre_contract_dims[0] *= out.dimension(i);
|
||||
}
|
||||
}
|
||||
|
||||
// We will contract along dimensions (0, 2) in kernel and (0, 1) in
|
||||
// output_backward, if this is col-major, and
|
||||
// dimensions (0, 2) in kernel and (1, 2) in output_backward, if this row-major.
|
||||
array<IndexPair<TensorIndex>, 2> contract_dims;
|
||||
// We will contract along the fused dimension that contains the kernelFilters,
|
||||
// the kernelRows and the kernelCols.
|
||||
array<IndexPair<TensorIndex>, 1> contract_dims;
|
||||
if (isColMajor) {
|
||||
// col-major: kernel.contract(output.patches)
|
||||
contract_dims[0] = IndexPair<TensorIndex>(0, 0);
|
||||
contract_dims[1] = IndexPair<TensorIndex>(2, 1);
|
||||
} else {
|
||||
// row-major: output.patches.contract(kernel)
|
||||
contract_dims[0] = IndexPair<TensorIndex>(1, 0);
|
||||
contract_dims[1] = IndexPair<TensorIndex>(2, 2);
|
||||
contract_dims[0] = IndexPair<TensorIndex>(1, 1);
|
||||
}
|
||||
|
||||
// Post contraction, the dimensions of the input_backprop is
|
||||
@ -201,6 +215,7 @@ SpatialConvolutionBackwardInput(
|
||||
return choose(
|
||||
Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
|
||||
kernel.reverse(kernel_reverse)
|
||||
.shuffle(kernel_shuffle)
|
||||
.reshape(kernel_dims)
|
||||
.eval()
|
||||
.contract(output_backward
|
||||
@ -217,7 +232,10 @@ SpatialConvolutionBackwardInput(
|
||||
padding_bottom, padding_left, padding_right,
|
||||
OutScalar(0))
|
||||
.reshape(pre_contract_dims)
|
||||
.contract(kernel.reverse(kernel_reverse).reshape(kernel_dims).eval(),
|
||||
.contract(kernel.reverse(kernel_reverse)
|
||||
.shuffle(kernel_shuffle)
|
||||
.reshape(kernel_dims)
|
||||
.eval(),
|
||||
contract_dims)
|
||||
.reshape(post_contract_dims));
|
||||
}
|
||||
@ -239,15 +257,43 @@ SpatialConvolutionBackwardInput(
|
||||
* It is possible to swap the order of the width and height dimensions provided that the same order is used in the input, the kernel, and the output.
|
||||
*
|
||||
*/
|
||||
// TODO(gpapan): Resolve a bug in TensorContractionInputMapper at SpatialConvolutions.h that yangke circumvented by using .reshape().reshape().
|
||||
// This can significantly accelerate SpatialConvolutionBackwardKernel.
|
||||
|
||||
template <typename OutputBackward, typename Input>
|
||||
EIGEN_ALWAYS_INLINE
|
||||
static const typename internal::conditional<
|
||||
EIGEN_ALWAYS_INLINE static const typename internal::conditional<
|
||||
internal::traits<OutputBackward>::Layout == ColMajor,
|
||||
const TensorShufflingOp<const array<typename internal::traits<OutputBackward>::Index, 4>, const TensorReverseOp<const array<bool, 4>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorContractionOp<const array<IndexPair<typename internal::traits<Input>::Index>, 2>, const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 3>, const Input>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorImagePatchOp<Dynamic, Dynamic, const OutputBackward> > > > > > >,
|
||||
const TensorShufflingOp<const array<typename internal::traits<OutputBackward>::Index, 4>, const TensorReverseOp<const array<bool, 4>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorContractionOp<const array<IndexPair<typename internal::traits<Input>::Index>, 2>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorReshapingOp<const DSizes<typename internal::traits<OutputBackward>::Index, 4>, const TensorImagePatchOp<Dynamic, Dynamic, const OutputBackward> > >, const TensorReshapingOp<const DSizes<typename internal::traits<Input>::Index, 3>, const Input> > > > > >::type
|
||||
TensorReshapingOp<
|
||||
const DSizes<typename internal::traits<Input>::Index, 4>,
|
||||
const TensorContractionOp<
|
||||
const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
|
||||
const TensorReshapingOp<
|
||||
const DSizes<typename internal::traits<Input>::Index, 2>,
|
||||
const OutputBackward>,
|
||||
const TensorShufflingOp<
|
||||
const array<typename internal::traits<OutputBackward>::Index, 2>,
|
||||
const TensorReshapingOp<
|
||||
const DSizes<typename internal::traits<Input>::Index, 2>,
|
||||
const TensorImagePatchOp<Dynamic, Dynamic, const Input>
|
||||
>
|
||||
>
|
||||
>
|
||||
>,
|
||||
TensorReshapingOp<
|
||||
const DSizes<typename internal::traits<Input>::Index, 4>,
|
||||
const TensorContractionOp<
|
||||
const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
|
||||
const TensorShufflingOp<
|
||||
const array<typename internal::traits<OutputBackward>::Index, 2>,
|
||||
const TensorReshapingOp<
|
||||
const DSizes<typename internal::traits<Input>::Index, 2>,
|
||||
const TensorImagePatchOp<Dynamic, Dynamic, const Input>
|
||||
>
|
||||
>,
|
||||
const TensorReshapingOp<
|
||||
const DSizes<typename internal::traits<Input>::Index, 2>,
|
||||
const OutputBackward>
|
||||
>
|
||||
>
|
||||
>::type
|
||||
SpatialConvolutionBackwardKernel(const Input& input, const OutputBackward& output_backward, typename internal::traits<Input>::Index kernelRows, typename internal::traits<Input>::Index kernelCols, const DenseIndex stride = 1, const DenseIndex in_stride = 1) {
|
||||
|
||||
typedef typename internal::traits<Input>::Index TensorIndex;
|
||||
@ -283,127 +329,93 @@ SpatialConvolutionBackwardKernel(const Input& input, const OutputBackward& outpu
|
||||
const TensorIndex kernelColsEff = kernelCols + (kernelCols - 1) * (in_stride - 1);
|
||||
|
||||
// Computing the forward padding
|
||||
const TensorIndex forward_pad_top = ((outputRows - 1) * stride + kernelRowsEff - inputRows) / 2;
|
||||
const TensorIndex forward_pad_left = ((outputCols - 1) * stride + kernelColsEff - inputCols) / 2;
|
||||
const TensorIndex padRows = numext::maxi<Index>(
|
||||
0, (outputRows - 1) * stride + kernelRowsEff - inputRows);
|
||||
const TensorIndex padCols = numext::maxi<Index>(
|
||||
0, (outputCols - 1) * stride + kernelColsEff - inputCols);
|
||||
const TensorIndex padding_top = padRows / 2;
|
||||
const TensorIndex padding_bottom = padRows - padding_top;
|
||||
const TensorIndex padding_left = padCols / 2;
|
||||
const TensorIndex padding_right = padCols - padding_left;
|
||||
|
||||
// TODO: factor out the padding computation.
|
||||
const TensorIndex padding_top = kernelRowsEff - 1 - forward_pad_top;
|
||||
const TensorIndex padding_left = kernelColsEff - 1 - forward_pad_left;
|
||||
const TensorIndex padding_bottom = inputRows + kernelRowsEff - 1 - (outputRows - 1) * stride - 1 - padding_top;
|
||||
const TensorIndex padding_right = inputCols + kernelColsEff - 1 - (outputCols - 1) * stride - 1 - padding_left;
|
||||
|
||||
eigen_assert(padding_top >= 0);
|
||||
eigen_assert(padding_left >= 0);
|
||||
eigen_assert(padding_bottom >= 0);
|
||||
eigen_assert(padding_right >= 0);
|
||||
|
||||
// The output_backward has dimensions out_depth X out_rows X out_cols X OTHERS
|
||||
// When we extract the image patches from output_backward (with input as the
|
||||
// kernel), it will have dimensions
|
||||
// (out_depth) X (input_rows * input_cols) X (kernel_rows * kernel_cols) X OTHERS
|
||||
DSizes<TensorIndex, 4> pre_contract_dims;
|
||||
// Reshaped out
|
||||
DSizes<TensorIndex, 2> output_dims;
|
||||
if (isColMajor) {
|
||||
pre_contract_dims[0] = kernelFilters;
|
||||
pre_contract_dims[1] = inputRows * inputCols;
|
||||
pre_contract_dims[2] = kernelRows * kernelCols;
|
||||
pre_contract_dims[3] = 1;
|
||||
output_dims[0] = kernelFilters;
|
||||
output_dims[1] = outputRows * outputCols;
|
||||
for (int i = 3; i < NumDims; ++i) {
|
||||
pre_contract_dims[3] *= out.dimension(i);
|
||||
output_dims[1] *= out.dimension(i);
|
||||
}
|
||||
} else {
|
||||
pre_contract_dims[3] = kernelFilters;
|
||||
pre_contract_dims[2] = inputRows * inputCols;
|
||||
pre_contract_dims[1] = kernelRows * kernelCols;
|
||||
pre_contract_dims[0] = 1;
|
||||
output_dims[1] = kernelFilters;
|
||||
output_dims[0] = outputCols * outputRows;
|
||||
for (int i = 0; i < NumDims - 3; ++i) {
|
||||
pre_contract_dims[0] *= out.dimension(i);
|
||||
output_dims[0] *= out.dimension(i);
|
||||
}
|
||||
}
|
||||
|
||||
// The input has dimensions in_depth X (input_rows * input_cols) X OTHERS
|
||||
DSizes<TensorIndex, 3> input_dims;
|
||||
// Reshaped extract_image_patches(in)
|
||||
DSizes<TensorIndex, 2> pre_contract_dims;
|
||||
if (isColMajor) {
|
||||
input_dims[0] = kernelChannels;
|
||||
input_dims[1] = inputRows * inputCols;
|
||||
input_dims[2] = 1;
|
||||
pre_contract_dims[0] = kernelChannels * kernelRows * kernelCols;
|
||||
pre_contract_dims[1] = outputRows * outputCols;
|
||||
for (int i = 3; i < NumDims; ++i) {
|
||||
input_dims[2] *= in.dimension(i);
|
||||
pre_contract_dims[1] *= in.dimension(i);
|
||||
}
|
||||
eigen_assert(input_dims[2] == pre_contract_dims[3]);
|
||||
eigen_assert(output_dims[1] == pre_contract_dims[1]);
|
||||
} else {
|
||||
input_dims[2] = kernelChannels;
|
||||
input_dims[1] = inputRows * inputCols;
|
||||
input_dims[0] = 1;
|
||||
pre_contract_dims[1] = kernelCols * kernelRows * kernelChannels;
|
||||
pre_contract_dims[0] = outputRows * outputCols;
|
||||
for (int i = 0; i < NumDims - 3; ++i) {
|
||||
input_dims[0] *= in.dimension(i);
|
||||
pre_contract_dims[0] *= in.dimension(i);
|
||||
}
|
||||
eigen_assert(input_dims[0] == pre_contract_dims[0]);
|
||||
eigen_assert(output_dims[0] == pre_contract_dims[0]);
|
||||
}
|
||||
|
||||
// We will contract along dimensions (1, 2) in in and (1, 3) in out, if
|
||||
// this is col-major.
|
||||
// For row-major, it's dimensions (0, 1) in in and (0, 2) in out.
|
||||
array<IndexPair<TensorIndex>, 2> contract_dims;
|
||||
if (isColMajor) {
|
||||
// col-major: in.contract(output.patches)
|
||||
contract_dims[0] = IndexPair<TensorIndex>(1, 1);
|
||||
contract_dims[1] = IndexPair<TensorIndex>(2, 3);
|
||||
} else {
|
||||
// row-major: output.patches.contract(in)
|
||||
contract_dims[0] = IndexPair<TensorIndex>(0, 0);
|
||||
contract_dims[1] = IndexPair<TensorIndex>(2, 1);
|
||||
}
|
||||
array<TensorIndex, 2> shuffle_dims;
|
||||
shuffle_dims[0] = 1;
|
||||
shuffle_dims[1] = 0;
|
||||
|
||||
// After the contraction, the kernel will have dimension
|
||||
// in_depth X out_depth X kernel_rows X kernel_cols
|
||||
// We will need to shuffle the first two dimensions and reverse the latter
|
||||
// two dimensions.
|
||||
// The end shape is
|
||||
array<IndexPair<TensorIndex>, 1> contract_dims;
|
||||
contract_dims[0] = IndexPair<TensorIndex>(1, 0);
|
||||
|
||||
// After the contraction, the kernel will have the desired shape
|
||||
// out_depth X in_shape X kernel_rows X kernel_cols
|
||||
|
||||
// This is the shape of the kernel *before* the shuffling.
|
||||
DSizes<TensorIndex, 4> kernel_dims;
|
||||
if (isColMajor) {
|
||||
kernel_dims[0] = kernelChannels;
|
||||
kernel_dims[1] = kernelFilters;
|
||||
kernel_dims[0] = kernelFilters;
|
||||
kernel_dims[1] = kernelChannels;
|
||||
kernel_dims[2] = kernelRows;
|
||||
kernel_dims[3] = kernelCols;
|
||||
} else {
|
||||
kernel_dims[0] = kernelCols;
|
||||
kernel_dims[3] = kernelFilters;
|
||||
kernel_dims[2] = kernelChannels;
|
||||
kernel_dims[1] = kernelRows;
|
||||
kernel_dims[2] = kernelFilters;
|
||||
kernel_dims[3] = kernelChannels;
|
||||
kernel_dims[0] = kernelCols;
|
||||
}
|
||||
|
||||
array<TensorIndex, 4> kernel_shuffle;
|
||||
if (isColMajor) {
|
||||
kernel_shuffle[0] = 1;
|
||||
kernel_shuffle[1] = 0;
|
||||
kernel_shuffle[2] = 2;
|
||||
kernel_shuffle[3] = 3;
|
||||
} else {
|
||||
kernel_shuffle[0] = 0;
|
||||
kernel_shuffle[1] = 1;
|
||||
kernel_shuffle[2] = 3;
|
||||
kernel_shuffle[3] = 2;
|
||||
}
|
||||
|
||||
array<bool, 4> kernel_reverse;
|
||||
if (isColMajor) {
|
||||
kernel_reverse[0] = false;
|
||||
kernel_reverse[1] = false;
|
||||
kernel_reverse[2] = true;
|
||||
kernel_reverse[3] = true;
|
||||
} else {
|
||||
kernel_reverse[0] = true;
|
||||
kernel_reverse[1] = true;
|
||||
kernel_reverse[2] = false;
|
||||
kernel_reverse[3] = false;
|
||||
}
|
||||
|
||||
return choose(Cond<internal::traits<Input>::Layout == ColMajor>(),
|
||||
input.reshape(input_dims).contract(output_backward.extract_image_patches(inputRows, inputCols, in_stride, in_stride, 1, 1, stride, stride, padding_top, padding_bottom, padding_left, padding_right, OutScalar(0)).reshape(pre_contract_dims).reshape(pre_contract_dims), contract_dims).reshape(kernel_dims).reverse(kernel_reverse).shuffle(kernel_shuffle),
|
||||
output_backward.extract_image_patches(inputRows, inputCols, in_stride, in_stride, 1, 1, stride, stride, padding_top, padding_bottom, padding_left, padding_right, OutScalar(0)).reshape(pre_contract_dims).reshape(pre_contract_dims).contract(input.reshape(input_dims), contract_dims).reshape(kernel_dims).reverse(kernel_reverse).shuffle(kernel_shuffle));
|
||||
return choose(
|
||||
Cond<internal::traits<Input>::Layout == ColMajor>(),
|
||||
output_backward.reshape(output_dims)
|
||||
.contract(
|
||||
input.extract_image_patches(
|
||||
kernelRows, kernelCols, stride, stride,
|
||||
in_stride, in_stride, 1, 1, padding_top, padding_bottom,
|
||||
padding_left, padding_right, OutScalar(0))
|
||||
.reshape(pre_contract_dims)
|
||||
.shuffle(shuffle_dims),
|
||||
contract_dims)
|
||||
.reshape(kernel_dims),
|
||||
input.extract_image_patches(
|
||||
kernelRows, kernelCols, stride, stride,
|
||||
in_stride, in_stride, 1, 1, padding_top, padding_bottom,
|
||||
padding_left, padding_right, OutScalar(0))
|
||||
.reshape(pre_contract_dims)
|
||||
.shuffle(shuffle_dims)
|
||||
.contract(
|
||||
output_backward.reshape(output_dims),
|
||||
contract_dims)
|
||||
.reshape(kernel_dims));
|
||||
}
|
||||
|
||||
} // end namespace Eigen
|
||||
|
234
tensorflow/core/kernels/gather_nd_op.cc
Normal file
234
tensorflow/core/kernels/gather_nd_op.cc
Normal file
@ -0,0 +1,234 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/array_ops.cc.
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include <atomic>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/gather_nd_op.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/mem.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
template <typename Device, typename T, typename Index>
|
||||
class GatherNdOp : public OpKernel {
|
||||
public:
|
||||
explicit GatherNdOp(OpKernelConstruction* c) : OpKernel(c) {
|
||||
const DataType dt = DataTypeToEnum<T>::v();
|
||||
const DataType index_t = DataTypeToEnum<Index>::v();
|
||||
OP_REQUIRES_OK(c, c->MatchSignature({dt, index_t}, {dt}));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* c) override {
|
||||
const Tensor& params = c->input(0);
|
||||
const Tensor& indices = c->input(1);
|
||||
OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(params.shape()),
|
||||
errors::InvalidArgument("params must be at least a vector"));
|
||||
OP_REQUIRES(c, TensorShapeUtils::IsMatrixOrHigher(indices.shape()),
|
||||
errors::InvalidArgument("indices must be at least a matrix"));
|
||||
OP_REQUIRES(
|
||||
c, indices.dim_size(indices.dims() - 1) == params.dims(),
|
||||
errors::InvalidArgument(
|
||||
"index innermost dimension length must equal params rank; saw: ",
|
||||
indices.dim_size(indices.dims() - 1), " vs. ", params.dims()));
|
||||
|
||||
// Check that we have enough index space
|
||||
const int64 N_big = indices.NumElements() / params.dims();
|
||||
OP_REQUIRES(c, N_big <= std::numeric_limits<int>::max(),
|
||||
errors::InvalidArgument(
|
||||
"indices has too many elements for int indexing: ", N_big,
|
||||
" > ", std::numeric_limits<int>::max()));
|
||||
const int N = indices.NumElements() / params.dims();
|
||||
OP_REQUIRES(
|
||||
c, params.NumElements() <= std::numeric_limits<Index>::max(),
|
||||
errors::InvalidArgument("params.NumElements() too large for ",
|
||||
DataTypeString(DataTypeToEnum<Index>::v()),
|
||||
" indexing: ", params.NumElements(), " > ",
|
||||
std::numeric_limits<Index>::max()));
|
||||
|
||||
// The result shape is indices.shape[:-1]
|
||||
TensorShape result_shape(indices.shape());
|
||||
result_shape.RemoveDim(result_shape.dims() - 1);
|
||||
|
||||
Tensor* out = nullptr;
|
||||
OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, &out));
|
||||
if (N > 0) {
|
||||
auto indices_mat = indices.flat_inner_dims<Index>();
|
||||
auto out_flat = out->flat<T>();
|
||||
|
||||
Index bad_i = -1;
|
||||
switch (params.dims()) {
|
||||
#define PARAMS_CASE(NDIM) \
|
||||
case NDIM: { \
|
||||
functor::GatherNd<Device, T, Index, NDIM> functor; \
|
||||
auto params_tensor = params.tensor<T, NDIM>(); \
|
||||
bad_i = functor(c->eigen_device<Device>(), params_tensor, indices_mat, \
|
||||
out_flat); \
|
||||
} break
|
||||
|
||||
PARAMS_CASE(1);
|
||||
PARAMS_CASE(2);
|
||||
PARAMS_CASE(3);
|
||||
PARAMS_CASE(4);
|
||||
PARAMS_CASE(5);
|
||||
default:
|
||||
OP_REQUIRES(c, false,
|
||||
errors::InvalidArgument(
|
||||
"Only param tensors with ranks between 1 and 5 "
|
||||
"are currently supported. Tensor rank: ",
|
||||
params.dims()));
|
||||
}
|
||||
|
||||
OP_REQUIRES(c, bad_i < 0,
|
||||
errors::InvalidArgument(
|
||||
"flat indices[", bad_i, ", :] = [",
|
||||
str_util::Join(gtl::ArraySlice<Index>(
|
||||
&indices_mat(bad_i, 0), params.dims()),
|
||||
", "),
|
||||
"] does not index into param (shape: ",
|
||||
params.shape().DebugString(), ")."));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Specialization of GatherNd to CPU
|
||||
namespace generator {
|
||||
|
||||
template <typename T, typename Index, int NDIM>
|
||||
class GatherNdGenerator {
|
||||
public:
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||
GatherNdGenerator(typename TTypes<Index>::ConstMatrix Tindices,
|
||||
typename TTypes<T, NDIM>::ConstTensor Tparams,
|
||||
Index* error_loc)
|
||||
: Tindices_(Tindices), Tparams_(Tparams), error_loc_(*error_loc) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||
operator()(const Eigen::array<Eigen::DenseIndex, 1>& loc_array) const {
|
||||
Index loc = loc_array[0];
|
||||
Eigen::array<Eigen::DenseIndex, NDIM> ix;
|
||||
bool out_of_bounds = false;
|
||||
for (int i = 0; i < NDIM; ++i) {
|
||||
Index ix_i = Tindices_(loc, i);
|
||||
ix[i] = ix_i;
|
||||
out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i));
|
||||
}
|
||||
if (out_of_bounds) {
|
||||
error_loc_ = loc;
|
||||
return T(); // Return 0, 0.0, or '', etc.
|
||||
} else {
|
||||
return Tparams_(ix);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
typename TTypes<Index>::ConstMatrix Tindices_;
|
||||
typename TTypes<T, NDIM>::ConstTensor Tparams_;
|
||||
Index& error_loc_;
|
||||
};
|
||||
|
||||
} // namespace generator
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename T, typename Index, int NDIM>
|
||||
struct GatherNd<CPUDevice, T, Index, NDIM> {
|
||||
Index operator()(const CPUDevice& d,
|
||||
typename TTypes<T, NDIM>::ConstTensor Tparams,
|
||||
typename TTypes<Index>::ConstMatrix Tindices,
|
||||
typename TTypes<T>::Flat Tout) {
|
||||
Index error_loc(-1);
|
||||
generator::GatherNdGenerator<T, Index, NDIM> gather_nd_generator(Tindices,
|
||||
Tparams,
|
||||
&error_loc);
|
||||
Tout.device(d) = Tout.generate(gather_nd_generator);
|
||||
|
||||
// error_loc() returns -1 if there's no out-of-bounds index,
|
||||
// otherwise it returns the location of an OOB index in Tindices.
|
||||
return error_loc;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#define REGISTER_GATHER_ND_FULL(dev, type, index_type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("GatherNd") \
|
||||
.Device(DEVICE_##dev) \
|
||||
.TypeConstraint<type>("Tparams") \
|
||||
.TypeConstraint<index_type>("Tindices"), \
|
||||
GatherNdOp<dev##Device, type, index_type>)
|
||||
|
||||
#define REGISTER_GATHER_ND_ALL_INDICES(dev, type) \
|
||||
REGISTER_GATHER_ND_FULL(dev, type, int32); \
|
||||
REGISTER_GATHER_ND_FULL(dev, type, int64)
|
||||
|
||||
#define REGISTER_GATHER_ND_CPU(type) REGISTER_GATHER_ND_ALL_INDICES(CPU, type)
|
||||
|
||||
TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
|
||||
|
||||
#undef REGISTER_GATHER_ND_CPU
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
// Forward declarations of the functor specializations for GPU.
|
||||
namespace functor {
|
||||
#define DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, NDIM) \
|
||||
template <> \
|
||||
Index GatherNd<GPUDevice, T, Index, NDIM>::operator()( \
|
||||
const GPUDevice& d, typename TTypes<T, NDIM>::ConstTensor Tparams, \
|
||||
typename TTypes<Index>::ConstMatrix Tindices, \
|
||||
typename TTypes<T>::Flat Tout); \
|
||||
extern template struct GatherNd<GPUDevice, T, Index, NDIM>
|
||||
|
||||
#define DECLARE_GPU_SPECS_INDEX(T, Index) \
|
||||
DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 1); \
|
||||
DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 2); \
|
||||
DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 3); \
|
||||
DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 4); \
|
||||
DECLARE_GPU_SPECS_INDEX_NDIM(T, Index, 5)
|
||||
|
||||
#define DECLARE_GPU_SPECS(T) \
|
||||
DECLARE_GPU_SPECS_INDEX(T, int32); \
|
||||
DECLARE_GPU_SPECS_INDEX(T, int64)
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
|
||||
|
||||
#undef DECLARE_GPU_SPECS
|
||||
#undef DECLARE_GPU_SPECS_INDEX
|
||||
} // namespace functor
|
||||
|
||||
// Registration of the GPU implementations.
|
||||
#define REGISTER_GATHER_ND_GPU(type) REGISTER_GATHER_ND_ALL_INDICES(GPU, type)
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GATHER_ND_GPU);
|
||||
|
||||
#undef REGISTER_GATHER_ND_GPU
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#undef REGISTER_GATHER_ND_ALL_INDICES
|
||||
#undef REGISTER_GATHER_ND_FULL
|
||||
|
||||
} // namespace tensorflow
|
43
tensorflow/core/kernels/gather_nd_op.h
Normal file
43
tensorflow/core/kernels/gather_nd_op.h
Normal file
@ -0,0 +1,43 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_KERNELS_GATHER_ND_OP_H_
|
||||
#define TENSORFLOW_KERNELS_GATHER_ND_OP_H_
|
||||
// Functor definition for GatherOp, must be compilable by nvcc.
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class OpKernelContext;
|
||||
|
||||
namespace functor {
|
||||
template <typename Device, typename T, typename Index, int NDIM>
|
||||
struct GatherNd {
|
||||
// Performs gather op on (Tparams, Tindices), writing to Tout.
|
||||
// Returns an index to Tindices if the value at that index is out of range.
|
||||
// Returns -1 if all values of Tindices are in range.
|
||||
Index operator()(const Device& d,
|
||||
typename TTypes<T, NDIM>::ConstTensor Tparams,
|
||||
typename TTypes<Index>::ConstMatrix Tindices,
|
||||
typename TTypes<T>::Flat Tout);
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_KERNELS_GATHER_ND_OP_H_
|
105
tensorflow/core/kernels/gather_nd_op_gpu.cu.cc
Normal file
105
tensorflow/core/kernels/gather_nd_op_gpu.cu.cc
Normal file
@ -0,0 +1,105 @@
|
||||
/* Copyright 2016 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/gather_nd_op.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
namespace generator {
|
||||
|
||||
template <typename T, typename Index, int NDIM>
|
||||
class GatherNdGenerator {
|
||||
public:
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||
GatherNdGenerator(typename TTypes<const Index, 2>::Tensor32Bit Tindices,
|
||||
typename TTypes<const T, NDIM>::Tensor32Bit Tparams)
|
||||
: Tindices_(Tindices), Tparams_(Tparams) {}
|
||||
|
||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||
operator()(const Eigen::array<int, 1>& loc_array) const {
|
||||
int loc = loc_array[0];
|
||||
Eigen::array<int, NDIM> ix;
|
||||
bool out_of_bounds = false;
|
||||
for (int i = 0; i < NDIM; ++i) {
|
||||
int ix_i = Tindices_(loc, i);
|
||||
ix[i] = ix_i;
|
||||
out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i));
|
||||
}
|
||||
if (out_of_bounds) {
|
||||
return T(0); // TODO(ebrevdo): Pass error back to host.
|
||||
} else {
|
||||
return Tparams_(ix);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
typename TTypes<const Index, 2>::Tensor32Bit Tindices_;
|
||||
typename TTypes<const T, NDIM>::Tensor32Bit Tparams_;
|
||||
};
|
||||
|
||||
} // namespace generator
|
||||
|
||||
namespace functor {
|
||||
|
||||
template <typename T, typename Index, int NDIM>
|
||||
struct GatherNd<GPUDevice, T, Index, NDIM> {
|
||||
Index operator()(const GPUDevice& d,
|
||||
typename TTypes<T, NDIM>::ConstTensor Tparams,
|
||||
typename TTypes<Index>::ConstMatrix Tindices,
|
||||
typename TTypes<T>::Flat Tout) {
|
||||
generator::GatherNdGenerator<T, Index, NDIM> gather_nd_generator(
|
||||
To32Bit(Tindices), To32Bit(Tparams));
|
||||
To32Bit(Tout).device(d) = To32Bit(Tout).generate(gather_nd_generator);
|
||||
|
||||
// TODO(ebrevdo): enable indices validation on GPU.
|
||||
// Right now checking for indicies out of bound in the kernel would
|
||||
// require copying code between GPU/CPU, and is too slow.
|
||||
return -1;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#define DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, NDIM) \
|
||||
template struct functor::GatherNd<GPUDevice, T, Index, NDIM>;
|
||||
|
||||
#define DEFINE_GPU_SPECS_INDEX(T, Index) \
|
||||
DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 1); \
|
||||
DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 2); \
|
||||
DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 3); \
|
||||
DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 4); \
|
||||
DEFINE_GPU_SPECS_INDEX_NDIM(T, Index, 5);
|
||||
|
||||
#define DEFINE_GPU_SPECS(T) \
|
||||
DEFINE_GPU_SPECS_INDEX(T, int32); \
|
||||
DEFINE_GPU_SPECS_INDEX(T, int64);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
|
||||
|
||||
#undef DEFINE_GPU_SPECS
|
||||
#undef DEFINE_GPU_SPECS_INDEX
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
133
tensorflow/core/kernels/gather_nd_op_test.cc
Normal file
133
tensorflow/core/kernels/gather_nd_op_test.cc
Normal file
@ -0,0 +1,133 @@
|
||||
/* Copyright 2015 Google Inc. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/graph/testlib.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/random/simple_philox.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace test {
|
||||
namespace graph {
|
||||
|
||||
class Node* GatherNd(Graph* g, class Node* in0, class Node* in1) {
|
||||
class Node* ret;
|
||||
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "GatherNd")
|
||||
.Input(in0)
|
||||
.Input(in1)
|
||||
.Finalize(g, &ret));
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace graph
|
||||
} // namespace test
|
||||
|
||||
namespace {
|
||||
|
||||
class GatherNdOpTest : public OpsTestBase {
|
||||
protected:
|
||||
void MakeOp(DataType index_type) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("myop", "GatherNd")
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(index_type))
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(GatherNdOpTest, Simple) {
|
||||
MakeOp(DT_INT32);
|
||||
|
||||
// Feed and run
|
||||
AddInputFromArray<float>(TensorShape({5}), {0, 1, 2, 8, 4});
|
||||
AddInputFromArray<int32>(TensorShape({2, 1}), {3, 4});
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Check the output.
|
||||
Tensor expected(allocator(), DT_FLOAT, TensorShape({2}));
|
||||
test::FillValues<float>(&expected, {8, 4});
|
||||
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
||||
}
|
||||
|
||||
constexpr int kLookups = 2000;
|
||||
|
||||
template <typename Index>
|
||||
static Graph* GatherNd(int dim) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
// Always use a 512MB buffer.
|
||||
//const int kRows = ((512 << 20) / sizeof(float)) / dim;
|
||||
Tensor params(DT_FLOAT, TensorShape({dim, 8, 16, 32}));
|
||||
params.flat<float>().setRandom();
|
||||
|
||||
random::PhiloxRandom philox(301, 17);
|
||||
random::SimplePhilox rnd(&philox);
|
||||
Tensor indices(DataTypeToEnum<Index>::value, TensorShape({kLookups, 4}));
|
||||
auto indices_mat = indices.matrix<Index>();
|
||||
for (int i = 0; i < kLookups; i++) {
|
||||
indices_mat(i, 0) = rnd.Uniform(dim);
|
||||
indices_mat(i, 1) = rnd.Uniform(8);
|
||||
indices_mat(i, 2) = rnd.Uniform(16);
|
||||
indices_mat(i, 3) = rnd.Uniform(32);
|
||||
}
|
||||
|
||||
test::graph::GatherNd(g, test::graph::Constant(g, params),
|
||||
test::graph::Constant(g, indices));
|
||||
return g;
|
||||
}
|
||||
|
||||
#define BM_GATHER_ND(DEVICE, INDEX) \
|
||||
static void BM_##DEVICE##_gather_nd_##INDEX(int iters, int dim) { \
|
||||
const int64 tot = static_cast<int64>(iters) * kLookups * dim; \
|
||||
testing::ItemsProcessed(tot); \
|
||||
testing::BytesProcessed(tot * sizeof(float)); \
|
||||
testing::UseRealTime(); \
|
||||
test::Benchmark(#DEVICE, GatherNd<INDEX>(dim)).Run(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_##DEVICE##_gather_nd_##INDEX) \
|
||||
->Arg(1) \
|
||||
->Arg(10) \
|
||||
->Arg(20) \
|
||||
->Arg(64) \
|
||||
->Arg(100) \
|
||||
->Arg(200) \
|
||||
->Arg(1000)
|
||||
|
||||
BM_GATHER_ND(cpu, int32);
|
||||
BM_GATHER_ND(gpu, int32);
|
||||
BM_GATHER_ND(cpu, int64);
|
||||
BM_GATHER_ND(gpu, int64);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -25,6 +25,9 @@ Status InitializableLookupTable::Find(const Tensor& keys, Tensor* values,
|
||||
if (!is_initialized()) {
|
||||
return errors::FailedPrecondition("Table not initialized.");
|
||||
}
|
||||
// Do not let the use migrate before the check; table is used without
|
||||
// a lock by the readers.
|
||||
std::atomic_thread_fence(std::memory_order_acquire);
|
||||
TF_RETURN_IF_ERROR(CheckFindArguments(keys, *values, default_value));
|
||||
return DoFind(keys, values, default_value);
|
||||
}
|
||||
@ -48,6 +51,10 @@ Status InitializableLookupTable::Initialize(InitTableIterator& iter) {
|
||||
if (!errors::IsOutOfRange(iter.status())) {
|
||||
return iter.status();
|
||||
}
|
||||
|
||||
// Prevent compiler/memory reordering of is_initialized and
|
||||
// the initialization itself.
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
is_initialized_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -69,7 +69,11 @@ class HashTable : public InitializableLookupTable {
|
||||
public:
|
||||
size_t size() const override {
|
||||
// return the size of the table only if it's initialized, otherwise 0.
|
||||
return table_ && is_initialized_ ? table_->size() : 0;
|
||||
if (!is_initialized_) {
|
||||
return 0;
|
||||
}
|
||||
std::atomic_thread_fence(std::memory_order_acquire);
|
||||
return table_ ? table_->size() : 0;
|
||||
}
|
||||
|
||||
DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
|
||||
|
@ -407,6 +407,34 @@ this operation will permute `params` accordingly.
|
||||
</div>
|
||||
)doc");
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("GatherNd")
|
||||
.Input("params: Tparams")
|
||||
.Input("indices: Tindices")
|
||||
.Output("output: Tparams")
|
||||
.Attr("Tparams: type")
|
||||
.Attr("Tindices: {int32,int64}")
|
||||
.Doc(R"doc(
|
||||
Gather values from `params` according to `indices`.
|
||||
|
||||
`indices` must be integer tensor, containing indices into `params`.
|
||||
It must be shape `[d_0, ..., d_N, R]` where `R` is the rank of `params`.
|
||||
The innermost dimension of `indices` (with length `R`) corresponds to the
|
||||
indices of `params`.
|
||||
|
||||
Produces an output tensor with shape `[d_0, ..., d_{n-1}]` where:
|
||||
|
||||
output[i, j, k, ...] = params[indices[i, j, k, ..., :]]
|
||||
|
||||
e.g. for `indices` a matrix:
|
||||
|
||||
output[i] = params[indices[i, :]]
|
||||
|
||||
params: R-D. The tensor from which to gather values.
|
||||
indices: (N+1)-D. Index tensor having shape `[d_0, ..., d_N, R]`.
|
||||
output: N-D. Values from `params` gathered from indices given by `indices`.
|
||||
)doc");
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
REGISTER_OP("Identity")
|
||||
.Input("input: T")
|
||||
|
@ -6480,6 +6480,35 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "GatherNd"
|
||||
input_arg {
|
||||
name: "params"
|
||||
type_attr: "Tparams"
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "Tparams"
|
||||
}
|
||||
attr {
|
||||
name: "Tparams"
|
||||
type: "type"
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "Greater"
|
||||
input_arg {
|
||||
|
@ -4066,6 +4066,40 @@ op {
|
||||
summary: "Gather slices from `params` according to `indices`."
|
||||
description: "`indices` must be an integer tensor of any dimension (usually 0-D or 1-D).\nProduces an output tensor with shape `indices.shape + params.shape[1:]` where:\n\n # Scalar indices\n output[:, ..., :] = params[indices, :, ... :]\n\n # Vector indices\n output[i, :, ..., :] = params[indices[i], :, ... :]\n\n # Higher rank indices\n output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :]\n\nIf `indices` is a permutation and `len(indices) == params.shape[0]` then\nthis operation will permute `params` accordingly.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/Gather.png\" alt>\n</div>"
|
||||
}
|
||||
op {
|
||||
name: "GatherNd"
|
||||
input_arg {
|
||||
name: "params"
|
||||
description: "R-D. The tensor from which to gather values."
|
||||
type_attr: "Tparams"
|
||||
}
|
||||
input_arg {
|
||||
name: "indices"
|
||||
description: "(N+1)-D. Index tensor having shape `[d_0, ..., d_N, R]`."
|
||||
type_attr: "Tindices"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "N-D. Values from `params` gathered from indices given by `indices`."
|
||||
type_attr: "Tparams"
|
||||
}
|
||||
attr {
|
||||
name: "Tparams"
|
||||
type: "type"
|
||||
}
|
||||
attr {
|
||||
name: "Tindices"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Gather values from `params` according to `indices`."
|
||||
description: "`indices` must be integer tensor, containing indices into `params`.\nIt must be shape `[d_0, ..., d_N, R]` where `R` is the rank of `params`.\nThe innermost dimension of `indices` (with length `R`) corresponds to the\nindices of `params`.\n\nProduces an output tensor with shape `[d_0, ..., d_{n-1}]` where:\n\n output[i, j, k, ...] = params[indices[i, j, k, ..., :]]\n\ne.g. for `indices` a matrix:\n\n output[i] = params[indices[i, :]]"
|
||||
}
|
||||
op {
|
||||
name: "Greater"
|
||||
input_arg {
|
||||
|
1
tensorflow/g3doc/api_docs/OWNERS
Normal file
1
tensorflow/g3doc/api_docs/OWNERS
Normal file
@ -0,0 +1 @@
|
||||
tensorflow-git-owners
|
@ -1149,6 +1149,39 @@ this operation will permute `params` accordingly.
|
||||
A `Tensor`. Has the same type as `params`.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.gather_nd(params, indices, name=None)` {#gather_nd}
|
||||
|
||||
Gather values from `params` according to `indices`.
|
||||
|
||||
`indices` must be integer tensor, containing indices into `params`.
|
||||
It must be shape `[d_0, ..., d_N, R]` where `R` is the rank of `params`.
|
||||
The innermost dimension of `indices` (with length `R`) corresponds to the
|
||||
indices of `params`.
|
||||
|
||||
Produces an output tensor with shape `[d_0, ..., d_{n-1}]` where:
|
||||
|
||||
output[i, j, k, ...] = params[indices[i, j, k, ..., :]]
|
||||
|
||||
e.g. for `indices` a matrix:
|
||||
|
||||
output[i] = params[indices[i, :]]
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`params`</b>: A `Tensor`. R-D. The tensor from which to gather values.
|
||||
* <b>`indices`</b>: A `Tensor`. Must be one of the following types: `int32`, `int64`.
|
||||
(N+1)-D. Index tensor having shape `[d_0, ..., d_N, R]`.
|
||||
* <b>`name`</b>: A name for the operation (optional).
|
||||
|
||||
##### Returns:
|
||||
|
||||
A `Tensor`. Has the same type as `params`.
|
||||
N-D. Values from `params` gathered from indices given by `indices`.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.dynamic_partition(data, partitions, num_partitions, name=None)` {#dynamic_partition}
|
||||
|
@ -411,6 +411,30 @@ subtraction, it usually shouldn't hurt much either.
|
||||
* <b>`ValueError`</b>: If `regularizer` does not return a scalar output.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.contrib.layers.make_all(module_name, doc_string_modules=None)` {#make_all}
|
||||
|
||||
Generate `__all__` from the docstring of one or more modules.
|
||||
|
||||
Usage: `make_all(__name__)` or
|
||||
`make_all(__name__, [sys.modules(__name__), other_module])`. The doc string
|
||||
modules must each a docstring, and `__all__` will contain all symbols with
|
||||
`@@` references, where that symbol currently exists in the module named
|
||||
`module_name`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`module_name`</b>: The name of the module (usually `__name__`).
|
||||
* <b>`doc_string_modules`</b>: a list of modules from which to take docstring.
|
||||
If None, then a list containing only the module named `module_name` is used.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A list suitable for use as `__all__`.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.contrib.layers.optimize_loss(loss, global_step, learning_rate, optimizer, clip_gradients=None, moving_average_decay=0.9, learning_rate_decay_fn=None, variables=None)` {#optimize_loss}
|
||||
|
@ -260,202 +260,6 @@ Example 2:
|
||||
|
||||
|
||||
|
||||
## Higher Order Operators
|
||||
|
||||
TensorFlow provides several higher order operators to simplify the common
|
||||
map-reduce programming patterns.
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#map_fn}
|
||||
|
||||
The map operator on the list of tensors resulted from unpacking `elems`
|
||||
along the first dimension.
|
||||
|
||||
This map operator repeatedly applies the callable `fn` to a sequence of
|
||||
elements from first to last. The elements are made of the tensors unpacked
|
||||
from `elems`. `dtype` is the data type of the return value of `fn`. Users
|
||||
must provide `dtype` if it is different from the data type of `elems`.
|
||||
|
||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||
of the result tensor is `[len(values)] + fn(values[0]).shape`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`fn`</b>: The callable to be performed.
|
||||
* <b>`elems`</b>: A tensor to be unpacked to apply `fn`.
|
||||
* <b>`dtype`</b>: (optional) The output type of `fn`.
|
||||
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
|
||||
in parallel.
|
||||
* <b>`back_prop`</b>: (optional) True enables back propagation.
|
||||
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
|
||||
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A tensor that packs the results of applying `fn` to the list of tensors
|
||||
unpacked from `elems`, from first to last.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: if `fn` is not callable.
|
||||
|
||||
##### Example:
|
||||
|
||||
```python
|
||||
elems = [1, 2, 3, 4, 5, 6]
|
||||
squares = map_fn(lambda x: x * x, elems)
|
||||
# squares == [1, 4, 9, 16, 25, 36]
|
||||
```
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#foldl}
|
||||
|
||||
The foldl operator on the list of tensors resulted from unpacking `elems`
|
||||
along the first dimension.
|
||||
|
||||
This foldl operator repeatedly applies the callable `fn` to a sequence
|
||||
of elements from first to last. The elements are made of the tensors
|
||||
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
|
||||
arguments. The first argument is the accumulated value computed from the
|
||||
preceding invocation of fn. If `initializer` is None, `elems` must contain
|
||||
at least one element, and its first element is used as the initializer.
|
||||
|
||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||
of the result tensor is fn(initializer, values[0]).shape`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`fn`</b>: The callable to be performed.
|
||||
* <b>`elems`</b>: A tensor to be unpacked on dimension 0.
|
||||
* <b>`initializer`</b>: (optional) The initial value for the accumulator.
|
||||
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
|
||||
in parallel.
|
||||
* <b>`back_prop`</b>: (optional) True enables back propagation.
|
||||
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
|
||||
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A tensor resulting from applying `fn` consecutively to the list of tensors
|
||||
unpacked from `elems`, from first to last.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: if `fn` is not callable.
|
||||
|
||||
##### Example:
|
||||
|
||||
```python
|
||||
elems = [1, 2, 3, 4, 5, 6]
|
||||
sum = foldl(lambda a, x: a + x, elems)
|
||||
# sum == 21
|
||||
```
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#foldr}
|
||||
|
||||
The foldr operator on the list of tensors resulted from unpacking `elems`
|
||||
along the first dimension.
|
||||
|
||||
This foldr operator repeatedly applies the callable `fn` to a sequence
|
||||
of elements from last to first. The elements are made of the tensors
|
||||
unpacked from `elems`. The callable fn takes two tensors as arguments.
|
||||
The first argument is the accumulated value computed from the preceding
|
||||
invocation of fn. If `initializer` is None, `elems` must contain at least
|
||||
one element, and its first element is used as the initializer.
|
||||
|
||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||
of the result tensor is `fn(initializer, values[0]).shape`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`fn`</b>: The callable to be performed.
|
||||
* <b>`elems`</b>: A tensor that is unpacked into a sequence of tensors to apply `fn`.
|
||||
* <b>`initializer`</b>: (optional) The initial value for the accumulator.
|
||||
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
|
||||
in parallel.
|
||||
* <b>`back_prop`</b>: (optional) True enables back propagation.
|
||||
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
|
||||
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A tensor resulting from applying `fn` consecutively to the list of tensors
|
||||
unpacked from `elems`, from last to first.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: if `fn` is not callable.
|
||||
|
||||
##### Example:
|
||||
|
||||
```python
|
||||
elems = [1, 2, 3, 4, 5, 6]
|
||||
sum = foldr(lambda a, x: a + x, elems)
|
||||
# sum == 21
|
||||
```
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#scan}
|
||||
|
||||
The scan operator on the list of tensors resulted from unpacking `elems`
|
||||
along the first dimension.
|
||||
|
||||
This scan operator repeatedly applies the callable `fn` to a sequence
|
||||
of elements from first to last. The elements are made of the tensors
|
||||
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
|
||||
arguments. The first argument is the accumulated value computed from the
|
||||
preceding invocation of fn. If `initializer` is None, `elems` must contain
|
||||
at least one element, and its first element is used as the initializer.
|
||||
|
||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||
of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`fn`</b>: The callable to be performed.
|
||||
* <b>`elems`</b>: A tensor to be unpacked on dimension 0.
|
||||
* <b>`initializer`</b>: (optional) The initial value for the accumulator.
|
||||
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
|
||||
in parallel.
|
||||
* <b>`back_prop`</b>: (optional) True enables back propagation.
|
||||
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
|
||||
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A tensor that packs the results of applying `fn` to the list of tensors
|
||||
unpacked from `elems`, from first to last.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: if `fn` is not callable.
|
||||
|
||||
##### Example:
|
||||
|
||||
```python
|
||||
elems = [1, 2, 3, 4, 5, 6]
|
||||
sum = scan(lambda a, x: a + x, elems)
|
||||
# sum == [1, 3, 6, 10, 15, 21]
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Logical Operators
|
||||
|
||||
TensorFlow provides several operations that you can use to add logical operators
|
||||
|
@ -409,6 +409,31 @@ Returns a list of values in the collection with the given `name`.
|
||||
collected.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
#### `tf.Graph.get_collection_ref(name)` {#Graph.get_collection_ref}
|
||||
|
||||
Returns a list of values in the collection with the given `name`.
|
||||
|
||||
If the collection exists, this returns the list itself, which can
|
||||
be modified in place to change the collection. If the collection does
|
||||
not exist, it is created as an empty list and the list is returned.
|
||||
|
||||
This is different from `get_collection()` which always returns a copy of
|
||||
the collection list if it exists and never creates an empty collection.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`name`</b>: The key for the collection. For example, the `GraphKeys` class
|
||||
contains many standard names for collections.
|
||||
|
||||
##### Returns:
|
||||
|
||||
The list of values in the collection with the given `name`, or an empty
|
||||
list if no value has been added to that collection.
|
||||
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
@ -1752,6 +1777,29 @@ for more details.
|
||||
collected.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.get_collection_ref(key)` {#get_collection_ref}
|
||||
|
||||
Wrapper for `Graph.get_collection_ref()` using the default graph.
|
||||
|
||||
See [`Graph.get_collection_ref()`](../../api_docs/python/framework.md#Graph.get_collection_ref)
|
||||
for more details.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`key`</b>: The key for the collection. For example, the `GraphKeys` class
|
||||
contains many standard names for collections.
|
||||
|
||||
##### Returns:
|
||||
|
||||
The list of values in the collection with the given `name`, or an empty
|
||||
list if no value has been added to that collection. Note that this returns
|
||||
the collection list itself, which can be modified in place to change the
|
||||
collection.
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `class tf.GraphKeys` {#GraphKeys}
|
||||
|
202
tensorflow/g3doc/api_docs/python/functional_ops.md
Normal file
202
tensorflow/g3doc/api_docs/python/functional_ops.md
Normal file
@ -0,0 +1,202 @@
|
||||
<!-- This file is machine generated: DO NOT EDIT! -->
|
||||
|
||||
# Higher Order Functions
|
||||
|
||||
Note: Functions taking `Tensor` arguments can also take anything accepted by
|
||||
[`tf.convert_to_tensor`](framework.md#convert_to_tensor).
|
||||
|
||||
[TOC]
|
||||
|
||||
Functional operations.
|
||||
|
||||
## Higher Order Operators
|
||||
|
||||
TensorFlow provides several higher order operators to simplify the common
|
||||
map-reduce programming patterns.
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#map_fn}
|
||||
|
||||
map on the list of tensors unpacked from `elems` on dimension 0.
|
||||
|
||||
This map operator repeatedly applies the callable `fn` to a sequence of
|
||||
elements from first to last. The elements are made of the tensors unpacked
|
||||
from `elems`. `dtype` is the data type of the return value of `fn`. Users
|
||||
must provide `dtype` if it is different from the data type of `elems`.
|
||||
|
||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||
of the result tensor is `[len(values)] + fn(values[0]).shape`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`fn`</b>: The callable to be performed.
|
||||
* <b>`elems`</b>: A tensor to be unpacked to apply `fn`.
|
||||
* <b>`dtype`</b>: (optional) The output type of `fn`.
|
||||
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
|
||||
in parallel.
|
||||
* <b>`back_prop`</b>: (optional) True enables back propagation.
|
||||
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
|
||||
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A tensor that packs the results of applying `fn` to the list of tensors
|
||||
unpacked from `elems`, from first to last.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: if `fn` is not callable.
|
||||
|
||||
##### Example:
|
||||
|
||||
```python
|
||||
elems = [1, 2, 3, 4, 5, 6]
|
||||
squares = map_fn(lambda x: x * x, elems)
|
||||
# squares == [1, 4, 9, 16, 25, 36]
|
||||
```
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#foldl}
|
||||
|
||||
foldl on the list of tensors unpacked from `elems` on dimension 0.
|
||||
|
||||
This foldl operator repeatedly applies the callable `fn` to a sequence
|
||||
of elements from first to last. The elements are made of the tensors
|
||||
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
|
||||
arguments. The first argument is the accumulated value computed from the
|
||||
preceding invocation of fn. If `initializer` is None, `elems` must contain
|
||||
at least one element, and its first element is used as the initializer.
|
||||
|
||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||
of the result tensor is fn(initializer, values[0]).shape`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`fn`</b>: The callable to be performed.
|
||||
* <b>`elems`</b>: A tensor to be unpacked on dimension 0.
|
||||
* <b>`initializer`</b>: (optional) The initial value for the accumulator.
|
||||
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
|
||||
in parallel.
|
||||
* <b>`back_prop`</b>: (optional) True enables back propagation.
|
||||
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
|
||||
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A tensor resulting from applying `fn` consecutively to the list of tensors
|
||||
unpacked from `elems`, from first to last.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: if `fn` is not callable.
|
||||
|
||||
##### Example:
|
||||
|
||||
```python
|
||||
elems = [1, 2, 3, 4, 5, 6]
|
||||
sum = foldl(lambda a, x: a + x, elems)
|
||||
# sum == 21
|
||||
```
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#foldr}
|
||||
|
||||
foldr on the list of tensors unpacked from `elems` on dimension 0.
|
||||
|
||||
This foldr operator repeatedly applies the callable `fn` to a sequence
|
||||
of elements from last to first. The elements are made of the tensors
|
||||
unpacked from `elems`. The callable fn takes two tensors as arguments.
|
||||
The first argument is the accumulated value computed from the preceding
|
||||
invocation of fn. If `initializer` is None, `elems` must contain at least
|
||||
one element, and its first element is used as the initializer.
|
||||
|
||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||
of the result tensor is `fn(initializer, values[0]).shape`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`fn`</b>: The callable to be performed.
|
||||
* <b>`elems`</b>: A tensor that is unpacked into a sequence of tensors to apply `fn`.
|
||||
* <b>`initializer`</b>: (optional) The initial value for the accumulator.
|
||||
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
|
||||
in parallel.
|
||||
* <b>`back_prop`</b>: (optional) True enables back propagation.
|
||||
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
|
||||
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A tensor resulting from applying `fn` consecutively to the list of tensors
|
||||
unpacked from `elems`, from last to first.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: if `fn` is not callable.
|
||||
|
||||
##### Example:
|
||||
|
||||
```python
|
||||
elems = [1, 2, 3, 4, 5, 6]
|
||||
sum = foldr(lambda a, x: a + x, elems)
|
||||
# sum == 21
|
||||
```
|
||||
|
||||
|
||||
- - -
|
||||
|
||||
### `tf.scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True, swap_memory=False, name=None)` {#scan}
|
||||
|
||||
scan on the list of tensors unpacked from `elems` on dimension 0.
|
||||
|
||||
This scan operator repeatedly applies the callable `fn` to a sequence
|
||||
of elements from first to last. The elements are made of the tensors
|
||||
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
|
||||
arguments. The first argument is the accumulated value computed from the
|
||||
preceding invocation of fn. If `initializer` is None, `elems` must contain
|
||||
at least one element, and its first element is used as the initializer.
|
||||
|
||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||
of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`fn`</b>: The callable to be performed.
|
||||
* <b>`elems`</b>: A tensor to be unpacked on dimension 0.
|
||||
* <b>`initializer`</b>: (optional) The initial value for the accumulator.
|
||||
* <b>`parallel_iterations`</b>: (optional) The number of iterations allowed to run
|
||||
in parallel.
|
||||
* <b>`back_prop`</b>: (optional) True enables back propagation.
|
||||
* <b>`swap_memory`</b>: (optional) True enables GPU-CPU memory swapping.
|
||||
* <b>`name`</b>: (optional) Name prefix for the returned tensors.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A tensor that packs the results of applying `fn` to the list of tensors
|
||||
unpacked from `elems`, from first to last.
|
||||
|
||||
##### Raises:
|
||||
|
||||
|
||||
* <b>`TypeError`</b>: if `fn` is not callable.
|
||||
|
||||
##### Example:
|
||||
|
||||
```python
|
||||
elems = [1, 2, 3, 4, 5, 6]
|
||||
sum = scan(lambda a, x: a + x, elems)
|
||||
# sum == [1, 3, 6, 10, 15, 21]
|
||||
```
|
||||
|
||||
|
@ -1234,29 +1234,3 @@ false and no bounding boxes are supplied, an error is raised.
|
||||
Provide as input to `tf.image.draw_bounding_boxes`.
|
||||
|
||||
|
||||
|
||||
## Other Functions and Classes
|
||||
- - -
|
||||
|
||||
### `tf.contrib.layers.make_all(module_name, doc_string_modules=None)` {#make_all}
|
||||
|
||||
Generate `__all__` from the docstring of one or more modules.
|
||||
|
||||
Usage: `make_all(__name__)` or
|
||||
`make_all(__name__, [sys.modules(__name__), other_module])`. The doc string
|
||||
modules must each a docstring, and `__all__` will contain all symbols with
|
||||
`@@` references, where that symbol currently exists in the module named
|
||||
`module_name`.
|
||||
|
||||
##### Args:
|
||||
|
||||
|
||||
* <b>`module_name`</b>: The name of the module (usually `__name__`).
|
||||
* <b>`doc_string_modules`</b>: a list of modules from which to take docstring.
|
||||
If None, then a list containing only the module named `module_name` is used.
|
||||
|
||||
##### Returns:
|
||||
|
||||
A list suitable for use as `__all__`.
|
||||
|
||||
|
||||
|
@ -13,6 +13,7 @@
|
||||
* [`Dimension`](../../api_docs/python/framework.md#Dimension)
|
||||
* [`DType`](../../api_docs/python/framework.md#DType)
|
||||
* [`get_collection`](../../api_docs/python/framework.md#get_collection)
|
||||
* [`get_collection_ref`](../../api_docs/python/framework.md#get_collection_ref)
|
||||
* [`get_default_graph`](../../api_docs/python/framework.md#get_default_graph)
|
||||
* [`get_seed`](../../api_docs/python/framework.md#get_seed)
|
||||
* [`Graph`](../../api_docs/python/framework.md#Graph)
|
||||
@ -95,6 +96,7 @@
|
||||
* [`dynamic_stitch`](../../api_docs/python/array_ops.md#dynamic_stitch)
|
||||
* [`expand_dims`](../../api_docs/python/array_ops.md#expand_dims)
|
||||
* [`gather`](../../api_docs/python/array_ops.md#gather)
|
||||
* [`gather_nd`](../../api_docs/python/array_ops.md#gather_nd)
|
||||
* [`one_hot`](../../api_docs/python/array_ops.md#one_hot)
|
||||
* [`pack`](../../api_docs/python/array_ops.md#pack)
|
||||
* [`pad`](../../api_docs/python/array_ops.md#pad)
|
||||
@ -229,8 +231,6 @@
|
||||
* [`cond`](../../api_docs/python/control_flow_ops.md#cond)
|
||||
* [`count_up_to`](../../api_docs/python/control_flow_ops.md#count_up_to)
|
||||
* [`equal`](../../api_docs/python/control_flow_ops.md#equal)
|
||||
* [`foldl`](../../api_docs/python/control_flow_ops.md#foldl)
|
||||
* [`foldr`](../../api_docs/python/control_flow_ops.md#foldr)
|
||||
* [`greater`](../../api_docs/python/control_flow_ops.md#greater)
|
||||
* [`greater_equal`](../../api_docs/python/control_flow_ops.md#greater_equal)
|
||||
* [`group`](../../api_docs/python/control_flow_ops.md#group)
|
||||
@ -244,16 +244,20 @@
|
||||
* [`logical_not`](../../api_docs/python/control_flow_ops.md#logical_not)
|
||||
* [`logical_or`](../../api_docs/python/control_flow_ops.md#logical_or)
|
||||
* [`logical_xor`](../../api_docs/python/control_flow_ops.md#logical_xor)
|
||||
* [`map_fn`](../../api_docs/python/control_flow_ops.md#map_fn)
|
||||
* [`no_op`](../../api_docs/python/control_flow_ops.md#no_op)
|
||||
* [`not_equal`](../../api_docs/python/control_flow_ops.md#not_equal)
|
||||
* [`Print`](../../api_docs/python/control_flow_ops.md#Print)
|
||||
* [`scan`](../../api_docs/python/control_flow_ops.md#scan)
|
||||
* [`select`](../../api_docs/python/control_flow_ops.md#select)
|
||||
* [`tuple`](../../api_docs/python/control_flow_ops.md#tuple)
|
||||
* [`verify_tensor_all_finite`](../../api_docs/python/control_flow_ops.md#verify_tensor_all_finite)
|
||||
* [`where`](../../api_docs/python/control_flow_ops.md#where)
|
||||
|
||||
* **[Higher Order Functions](../../api_docs/python/functional_ops.md)**:
|
||||
* [`foldl`](../../api_docs/python/functional_ops.md#foldl)
|
||||
* [`foldr`](../../api_docs/python/functional_ops.md#foldr)
|
||||
* [`map_fn`](../../api_docs/python/functional_ops.md#map_fn)
|
||||
* [`scan`](../../api_docs/python/functional_ops.md#scan)
|
||||
|
||||
* **[Images](../../api_docs/python/image.md)**:
|
||||
* [`adjust_brightness`](../../api_docs/python/image.md#adjust_brightness)
|
||||
* [`adjust_contrast`](../../api_docs/python/image.md#adjust_contrast)
|
||||
@ -272,7 +276,6 @@
|
||||
* [`flip_up_down`](../../api_docs/python/image.md#flip_up_down)
|
||||
* [`grayscale_to_rgb`](../../api_docs/python/image.md#grayscale_to_rgb)
|
||||
* [`hsv_to_rgb`](../../api_docs/python/image.md#hsv_to_rgb)
|
||||
* [`make_all`](../../api_docs/python/image.md#make_all)
|
||||
* [`pad_to_bounding_box`](../../api_docs/python/image.md#pad_to_bounding_box)
|
||||
* [`per_image_whitening`](../../api_docs/python/image.md#per_image_whitening)
|
||||
* [`random_brightness`](../../api_docs/python/image.md#random_brightness)
|
||||
@ -466,6 +469,7 @@
|
||||
* [`fully_connected`](../../api_docs/python/contrib.layers.md#fully_connected)
|
||||
* [`l1_regularizer`](../../api_docs/python/contrib.layers.md#l1_regularizer)
|
||||
* [`l2_regularizer`](../../api_docs/python/contrib.layers.md#l2_regularizer)
|
||||
* [`make_all`](../../api_docs/python/contrib.layers.md#make_all)
|
||||
* [`optimize_loss`](../../api_docs/python/contrib.layers.md#optimize_loss)
|
||||
* [`sum_regularizer`](../../api_docs/python/contrib.layers.md#sum_regularizer)
|
||||
* [`summarize_activation`](../../api_docs/python/contrib.layers.md#summarize_activation)
|
||||
|
@ -101,6 +101,7 @@ from tensorflow.python.framework import framework_lib
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import constant_op
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import functional_ops
|
||||
from tensorflow.python.ops import histogram_ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -118,8 +119,8 @@ _whitelist = set([app, compat, contrib, errors, flags, gfile, image,
|
||||
# strings of other modules.
|
||||
__all__ = make_all(__name__,
|
||||
[framework_lib, array_ops, client_lib, constant_op,
|
||||
control_flow_ops, histogram_ops, io_ops, math_ops, nn,
|
||||
script_ops, sparse_ops, state_ops, train])
|
||||
control_flow_ops, functional_ops, histogram_ops, io_ops,
|
||||
math_ops, nn, script_ops, sparse_ops, state_ops, train])
|
||||
|
||||
# Symbols whitelisted for export without documentation.
|
||||
# TODO(cwhipkey): review these and move to contrib, expose through
|
||||
|
@ -484,6 +484,9 @@ class Library(Document):
|
||||
names = self._members.items()
|
||||
else:
|
||||
names = inspect.getmembers(self._module)
|
||||
all_names = getattr(self._module, "__all__", None)
|
||||
if all_names is not None:
|
||||
names = [(n, m) for n, m in names if n in all_names]
|
||||
leftovers = []
|
||||
for name, _ in names:
|
||||
if name in self._members and name not in self._documented:
|
||||
|
@ -43,6 +43,7 @@
|
||||
|
||||
@@add_to_collection
|
||||
@@get_collection
|
||||
@@get_collection_ref
|
||||
@@GraphKeys
|
||||
|
||||
## Defining new operations
|
||||
@ -82,6 +83,7 @@ from tensorflow.python.framework.ops import reset_default_graph
|
||||
from tensorflow.python.framework.ops import GraphKeys
|
||||
from tensorflow.python.framework.ops import add_to_collection
|
||||
from tensorflow.python.framework.ops import get_collection
|
||||
from tensorflow.python.framework.ops import get_collection_ref
|
||||
from tensorflow.python.framework.ops import convert_to_tensor
|
||||
from tensorflow.python.framework.ops import convert_to_tensor_or_indexed_slices
|
||||
from tensorflow.python.framework.random_seed import get_seed
|
||||
|
@ -83,6 +83,7 @@ def all_libraries(module_to_name, members, documented):
|
||||
prefix=PREFIX_TEXT),
|
||||
library("histogram_ops", "Histograms"),
|
||||
library("control_flow_ops", "Control Flow", prefix=PREFIX_TEXT),
|
||||
library("functional_ops", "Higher Order Functions", prefix=PREFIX_TEXT),
|
||||
library("image", "Images", tf.image, exclude_symbols=["ResizeMethod"],
|
||||
prefix=PREFIX_TEXT),
|
||||
library("sparse_ops", "Sparse Tensors",
|
||||
|
@ -1787,6 +1787,7 @@ class Graph(object):
|
||||
|
||||
@@add_to_collection
|
||||
@@get_collection
|
||||
@@get_collection_ref
|
||||
|
||||
@@as_graph_element
|
||||
@@get_operation_by_name
|
||||
@ -2396,6 +2397,30 @@ class Graph(object):
|
||||
for name in set(names):
|
||||
self.add_to_collection(name, value)
|
||||
|
||||
def get_collection_ref(self, name):
|
||||
"""Returns a list of values in the collection with the given `name`.
|
||||
|
||||
If the collection exists, this returns the list itself, which can
|
||||
be modified in place to change the collection. If the collection does
|
||||
not exist, it is created as an empty list and the list is returned.
|
||||
|
||||
This is different from `get_collection()` which always returns a copy of
|
||||
the collection list if it exists and never creates an empty collection.
|
||||
|
||||
Args:
|
||||
name: The key for the collection. For example, the `GraphKeys` class
|
||||
contains many standard names for collections.
|
||||
|
||||
Returns:
|
||||
The list of values in the collection with the given `name`, or an empty
|
||||
list if no value has been added to that collection.
|
||||
"""
|
||||
coll_list = self._collections.get(name, None)
|
||||
if coll_list is None:
|
||||
coll_list = []
|
||||
self._collections[name] = coll_list
|
||||
return coll_list
|
||||
|
||||
def get_collection(self, name, scope=None):
|
||||
"""Returns a list of values in the collection with the given `name`.
|
||||
|
||||
@ -2411,11 +2436,14 @@ class Graph(object):
|
||||
list contains the values in the order under which they were
|
||||
collected.
|
||||
"""
|
||||
coll_list = self._collections.get(name, None)
|
||||
if coll_list is None:
|
||||
return []
|
||||
if scope is None:
|
||||
return self._collections.get(name, list())
|
||||
return list(coll_list)
|
||||
else:
|
||||
c = []
|
||||
for item in self._collections.get(name, list()):
|
||||
for item in coll_list:
|
||||
if hasattr(item, "name") and item.name.startswith(scope):
|
||||
c.append(item)
|
||||
return c
|
||||
@ -3547,6 +3575,25 @@ def add_to_collections(names, value):
|
||||
get_default_graph().add_to_collections(names, value)
|
||||
|
||||
|
||||
def get_collection_ref(key):
|
||||
"""Wrapper for `Graph.get_collection_ref()` using the default graph.
|
||||
|
||||
See [`Graph.get_collection_ref()`](../../api_docs/python/framework.md#Graph.get_collection_ref)
|
||||
for more details.
|
||||
|
||||
Args:
|
||||
key: The key for the collection. For example, the `GraphKeys` class
|
||||
contains many standard names for collections.
|
||||
|
||||
Returns:
|
||||
The list of values in the collection with the given `name`, or an empty
|
||||
list if no value has been added to that collection. Note that this returns
|
||||
the collection list itself, which can be modified in place to change the
|
||||
collection.
|
||||
"""
|
||||
return get_default_graph().get_collection_ref(key)
|
||||
|
||||
|
||||
def get_collection(key, scope=None):
|
||||
"""Wrapper for `Graph.get_collection()` using the default graph.
|
||||
|
||||
|
@ -751,12 +751,41 @@ class CollectionTest(test_util.TensorFlowTestCase):
|
||||
blank2 = ObjectWithName("junk/foo")
|
||||
g.add_to_collection("blah", blank2)
|
||||
|
||||
self.assertEqual(["foo"], g.get_collection("other"))
|
||||
self.assertEqual([12, 34], g.get_collection("key"))
|
||||
self.assertEqual([], g.get_collection("nothing"))
|
||||
self.assertEqual([27, blank1, blank2], g.get_collection("blah"))
|
||||
self.assertEqual([blank1], g.get_collection("blah", "prefix"))
|
||||
|
||||
# Make sure that get_collection() returns a first-level
|
||||
# copy of the collection, while get_collection_ref() returns
|
||||
# the original list.
|
||||
other_collection_snapshot = g.get_collection("other")
|
||||
other_collection_ref = g.get_collection_ref("other")
|
||||
self.assertEqual(["foo"], other_collection_snapshot)
|
||||
self.assertEqual(["foo"], other_collection_ref)
|
||||
g.add_to_collection("other", "bar")
|
||||
self.assertEqual(["foo"], other_collection_snapshot)
|
||||
self.assertEqual(["foo", "bar"], other_collection_ref)
|
||||
self.assertEqual(["foo", "bar"], g.get_collection("other"))
|
||||
self.assertTrue(other_collection_ref is g.get_collection_ref("other"))
|
||||
|
||||
# Verify that getting an empty collection ref returns a modifiable list.
|
||||
empty_coll_ref = g.get_collection_ref("empty")
|
||||
self.assertEqual([], empty_coll_ref)
|
||||
empty_coll = g.get_collection("empty")
|
||||
self.assertEqual([], empty_coll)
|
||||
self.assertFalse(empty_coll is empty_coll_ref)
|
||||
empty_coll_ref2 = g.get_collection_ref("empty")
|
||||
self.assertTrue(empty_coll_ref2 is empty_coll_ref)
|
||||
# Add to the collection.
|
||||
empty_coll_ref.append("something")
|
||||
self.assertEqual(["something"], empty_coll_ref)
|
||||
self.assertEqual(["something"], empty_coll_ref2)
|
||||
self.assertEqual([], empty_coll)
|
||||
self.assertEqual(["something"], g.get_collection("empty"))
|
||||
empty_coll_ref3 = g.get_collection_ref("empty")
|
||||
self.assertTrue(empty_coll_ref3 is empty_coll_ref)
|
||||
|
||||
def testDefaulGraph(self):
|
||||
with ops.Graph().as_default():
|
||||
ops.add_to_collection("key", 90)
|
||||
|
@ -20,7 +20,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import sys
|
||||
|
||||
import numpy as np
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
@ -1182,75 +1181,6 @@ class ControlFlowTest(tf.test.TestCase):
|
||||
self.assertEqual(0, value_x)
|
||||
self.assertEqual(73, value_x_grad)
|
||||
|
||||
def testFoldl_Simple(self):
|
||||
with self.test_session():
|
||||
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
|
||||
|
||||
r = control_flow_ops.foldl(
|
||||
lambda a, x: tf.mul(tf.add(a, x), 2), elems)
|
||||
self.assertAllEqual(208, r.eval())
|
||||
|
||||
r = control_flow_ops.foldl(
|
||||
lambda a, x: tf.mul(tf.add(a, x), 2), elems, initializer=10)
|
||||
self.assertAllEqual(880, r.eval())
|
||||
|
||||
def testFoldr_Simple(self):
|
||||
with self.test_session():
|
||||
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
|
||||
|
||||
r = control_flow_ops.foldr(
|
||||
lambda a, x: tf.mul(tf.add(a, x), 2), elems)
|
||||
self.assertAllEqual(450, r.eval())
|
||||
|
||||
r = control_flow_ops.foldr(
|
||||
lambda a, x: tf.mul(tf.add(a, x), 2), elems, initializer=10)
|
||||
self.assertAllEqual(1282, r.eval())
|
||||
|
||||
def testFold_Grad(self):
|
||||
with self.test_session():
|
||||
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
|
||||
v = tf.constant(2.0, name="v")
|
||||
|
||||
r = control_flow_ops.foldl(
|
||||
lambda a, x: tf.mul(a, x), elems, initializer=v)
|
||||
r = tf.gradients(r, v)[0]
|
||||
self.assertAllEqual(720.0, r.eval())
|
||||
|
||||
r = control_flow_ops.foldr(
|
||||
lambda a, x: tf.mul(a, x), elems, initializer=v)
|
||||
r = tf.gradients(r, v)[0]
|
||||
self.assertAllEqual(720.0, r.eval())
|
||||
|
||||
def testMap_Simple(self):
|
||||
with self.test_session():
|
||||
nums = [1, 2, 3, 4, 5, 6]
|
||||
elems = tf.constant(nums, name="data")
|
||||
r = control_flow_ops.map_fn(
|
||||
lambda x: tf.mul(tf.add(x, 3), 2), elems)
|
||||
self.assertAllEqual(np.array([(x + 3) * 2 for x in nums]), r.eval())
|
||||
|
||||
def testScan_Simple(self):
|
||||
with self.test_session():
|
||||
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
|
||||
v = tf.constant(2.0, name="v")
|
||||
|
||||
r = control_flow_ops.scan(lambda a, x: tf.mul(a, x), elems)
|
||||
self.assertAllEqual([1., 2., 6., 24., 120., 720.], r.eval())
|
||||
|
||||
r = control_flow_ops.scan(
|
||||
lambda a, x: tf.mul(a, x), elems, initializer=v)
|
||||
self.assertAllEqual([2., 4., 12., 48., 240., 1440.], r.eval())
|
||||
|
||||
def testScan_Grad(self):
|
||||
with self.test_session():
|
||||
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
|
||||
v = tf.constant(2.0, name="v")
|
||||
|
||||
r = control_flow_ops.scan(
|
||||
lambda a, x: tf.mul(a, x), elems, initializer=v)
|
||||
r = tf.gradients(r, v)[0]
|
||||
self.assertAllEqual(873.0, r.eval())
|
||||
|
||||
def testOneValueCond(self):
|
||||
with self.test_session():
|
||||
c = tf.placeholder(tf.int32, shape=[])
|
||||
|
94
tensorflow/python/kernel_tests/functional_ops_test.py
Normal file
94
tensorflow/python/kernel_tests/functional_ops_test.py
Normal file
@ -0,0 +1,94 @@
|
||||
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for tensorflow.kernels.bcast_ops."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class FunctionalOpsTest(tf.test.TestCase):
|
||||
|
||||
def testFoldl_Simple(self):
|
||||
with self.test_session():
|
||||
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
|
||||
|
||||
r = tf.foldl(lambda a, x: tf.mul(tf.add(a, x), 2), elems)
|
||||
self.assertAllEqual(208, r.eval())
|
||||
|
||||
r = tf.foldl(
|
||||
lambda a, x: tf.mul(tf.add(a, x), 2), elems, initializer=10)
|
||||
self.assertAllEqual(880, r.eval())
|
||||
|
||||
def testFoldr_Simple(self):
|
||||
with self.test_session():
|
||||
elems = tf.constant([1, 2, 3, 4, 5, 6], name="data")
|
||||
|
||||
r = tf.foldr(lambda a, x: tf.mul(tf.add(a, x), 2), elems)
|
||||
self.assertAllEqual(450, r.eval())
|
||||
|
||||
r = tf.foldr(
|
||||
lambda a, x: tf.mul(tf.add(a, x), 2), elems, initializer=10)
|
||||
self.assertAllEqual(1282, r.eval())
|
||||
|
||||
def testFold_Grad(self):
|
||||
with self.test_session():
|
||||
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
|
||||
v = tf.constant(2.0, name="v")
|
||||
|
||||
r = tf.foldl(
|
||||
lambda a, x: tf.mul(a, x), elems, initializer=v)
|
||||
r = tf.gradients(r, v)[0]
|
||||
self.assertAllEqual(720.0, r.eval())
|
||||
|
||||
r = tf.foldr(
|
||||
lambda a, x: tf.mul(a, x), elems, initializer=v)
|
||||
r = tf.gradients(r, v)[0]
|
||||
self.assertAllEqual(720.0, r.eval())
|
||||
|
||||
def testMap_Simple(self):
|
||||
with self.test_session():
|
||||
nums = [1, 2, 3, 4, 5, 6]
|
||||
elems = tf.constant(nums, name="data")
|
||||
r = tf.map_fn(lambda x: tf.mul(tf.add(x, 3), 2), elems)
|
||||
self.assertAllEqual(np.array([(x + 3) * 2 for x in nums]), r.eval())
|
||||
|
||||
def testScan_Simple(self):
|
||||
with self.test_session():
|
||||
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
|
||||
v = tf.constant(2.0, name="v")
|
||||
|
||||
r = tf.scan(lambda a, x: tf.mul(a, x), elems)
|
||||
self.assertAllEqual([1., 2., 6., 24., 120., 720.], r.eval())
|
||||
|
||||
r = tf.scan(
|
||||
lambda a, x: tf.mul(a, x), elems, initializer=v)
|
||||
self.assertAllEqual([2., 4., 12., 48., 240., 1440.], r.eval())
|
||||
|
||||
def testScan_Grad(self):
|
||||
with self.test_session():
|
||||
elems = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name="data")
|
||||
v = tf.constant(2.0, name="v")
|
||||
|
||||
r = tf.scan(lambda a, x: tf.mul(a, x), elems, initializer=v)
|
||||
r = tf.gradients(r, v)[0]
|
||||
self.assertAllEqual(873.0, r.eval())
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
122
tensorflow/python/kernel_tests/gather_nd_op_test.py
Normal file
122
tensorflow/python/kernel_tests/gather_nd_op_test.py
Normal file
@ -0,0 +1,122 @@
|
||||
# Copyright 2015 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Tests for tensorflow.ops.tf.gather_nd."""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
class GatherNdTest(tf.test.TestCase):
|
||||
use_gpu = False
|
||||
|
||||
def _testSimpleDtype(self, dtype):
|
||||
with self.test_session(use_gpu=self.use_gpu):
|
||||
params = tf.constant(np.array([8, 1, 2, 3, 7, 5], dtype=dtype))
|
||||
indices = tf.constant([[4], [4], [0]])
|
||||
gather_nd_t = tf.gather_nd(params, indices)
|
||||
gather_nd_val = gather_nd_t.eval()
|
||||
|
||||
self.assertAllEqual(np.array([7, 7, 8], dtype=dtype), gather_nd_val)
|
||||
self.assertEqual([3], gather_nd_t.get_shape())
|
||||
|
||||
def testSimpleDtype(self):
|
||||
self._testSimpleDtype(np.float32)
|
||||
self._testSimpleDtype(np.float64)
|
||||
self._testSimpleDtype(np.int32)
|
||||
self._testSimpleDtype(np.int64)
|
||||
self._testSimpleDtype(np.complex64)
|
||||
self._testSimpleDtype("|S") # byte strings in python2 + 3
|
||||
|
||||
def testHigherRankParams(self):
|
||||
with self.test_session(use_gpu=self.use_gpu):
|
||||
shape = (10, 20, 5, 1, 17)
|
||||
params = np.random.rand(*shape)
|
||||
indices = np.vstack([
|
||||
np.random.randint(0, s, size=2000) for s in shape]).T
|
||||
gather_nd_t = tf.gather_nd(params, indices)
|
||||
gather_nd_val = gather_nd_t.eval()
|
||||
|
||||
expected = params[tuple(indices.T)]
|
||||
self.assertAllEqual(expected, gather_nd_val)
|
||||
self.assertEqual([2000], gather_nd_t.get_shape())
|
||||
|
||||
def testHigherRankParamsAndIndices(self):
|
||||
with self.test_session(use_gpu=self.use_gpu):
|
||||
shape = (10, 20, 5, 1, 17)
|
||||
params = np.random.rand(*shape)
|
||||
indices = np.vstack([
|
||||
np.random.randint(0, s, size=2000) for s in shape]).T
|
||||
indices_reshaped = indices.reshape([10, 10, 20, 5])
|
||||
gather_nd_t = tf.gather_nd(params, indices_reshaped)
|
||||
gather_nd_val = gather_nd_t.eval()
|
||||
|
||||
expected = params[tuple(indices.T)]
|
||||
self.assertAllEqual(expected.reshape([10, 10, 20]), gather_nd_val)
|
||||
self.assertEqual([10, 10, 20], gather_nd_t.get_shape())
|
||||
|
||||
def testUnknownIndices(self):
|
||||
params = tf.constant([[0, 1, 2]])
|
||||
indices = tf.placeholder(tf.int32)
|
||||
gather_nd_t = tf.gather_nd(params, indices)
|
||||
shape = gather_nd_t.get_shape()
|
||||
self.assertEqual(shape.ndims, None)
|
||||
self.assertEqual(shape[0].value, None)
|
||||
|
||||
def testBadIndices(self):
|
||||
with self.test_session(use_gpu=False):
|
||||
params = [0, 1, 2]
|
||||
indices = [[[0], [7]]] # Make this one higher rank
|
||||
gather_nd = tf.gather_nd(params, indices)
|
||||
with self.assertRaisesOpError(
|
||||
r"flat indices\[1, :\] = \[7\] does not index into param "
|
||||
r"\(shape: \[3\]\)"):
|
||||
gather_nd.eval()
|
||||
|
||||
|
||||
class GatherNdGpuTest(GatherNdTest):
|
||||
use_gpu = True
|
||||
|
||||
|
||||
class GatherNdOpBenchmark(tf.test.Benchmark):
|
||||
|
||||
def benchmark_gather_nd_op(self):
|
||||
shape = (100, 47, 18, 170, 13)
|
||||
np.random.seed(127)
|
||||
params = np.random.rand(*shape)
|
||||
indices = np.vstack([
|
||||
np.random.randint(0, s, size=10000) for s in shape]).T
|
||||
|
||||
with tf.Session():
|
||||
t_params = tf.Variable(params)
|
||||
t_indices = tf.Variable(indices)
|
||||
gather_op = tf.gather_nd(t_params, t_indices)
|
||||
tf.initialize_all_variables().run()
|
||||
for _ in range(10):
|
||||
gather_op.eval()
|
||||
t1 = time.time()
|
||||
for _ in range(1000):
|
||||
gather_op.eval()
|
||||
t2 = time.time()
|
||||
self.report_benchmark(iters=1000, wall_time=(t2-t1)/1000.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
@ -193,6 +193,11 @@ def _GatherGrad(op, grad):
|
||||
return [ops.IndexedSlices(values, indices, dense_shape), None]
|
||||
|
||||
|
||||
@ops.RegisterGradient("GatherNd")
|
||||
def _GatherNdGrad(unused_op, unused_grad):
|
||||
raise NotImplementedError("Gradient for gather_nd is not implemented.")
|
||||
|
||||
|
||||
@ops.RegisterGradient("Identity")
|
||||
def _IdGrad(_, grad):
|
||||
return grad
|
||||
|
@ -57,6 +57,7 @@ or join multiple tensors together.
|
||||
@@space_to_depth
|
||||
@@depth_to_space
|
||||
@@gather
|
||||
@@gather_nd
|
||||
@@dynamic_partition
|
||||
@@dynamic_stitch
|
||||
@@boolean_mask
|
||||
@ -879,6 +880,16 @@ def _GatherShape(op):
|
||||
return [indices_shape.concatenate(params_shape[1:])]
|
||||
|
||||
|
||||
@ops.RegisterShape("GatherNd")
|
||||
def _GatherNdShape(op):
|
||||
"""Shape function for array_ops.gather_nd."""
|
||||
params_shape = op.inputs[0].get_shape()
|
||||
indices_shape = op.inputs[1].get_shape().with_rank_at_least(2)
|
||||
if indices_shape.ndims is not None:
|
||||
indices_shape[-1].merge_with(params_shape.ndims)
|
||||
return [indices_shape[:-1]]
|
||||
|
||||
|
||||
@ops.RegisterShape("Unique")
|
||||
def _UniqueShape(op):
|
||||
"""Shape function for array_ops.Unique."""
|
||||
|
@ -26,16 +26,6 @@ the execution of operations and add conditional dependencies to your graph.
|
||||
@@cond
|
||||
@@case
|
||||
|
||||
## Higher Order Operators
|
||||
|
||||
TensorFlow provides several higher order operators to simplify the common
|
||||
map-reduce programming patterns.
|
||||
|
||||
@@map_fn
|
||||
@@foldl
|
||||
@@foldr
|
||||
@@scan
|
||||
|
||||
## Logical Operators
|
||||
|
||||
TensorFlow provides several operations that you can use to add logical operators
|
||||
@ -1785,269 +1775,6 @@ def tuple(tensors, name=None, control_inputs=None):
|
||||
return tpl
|
||||
|
||||
|
||||
# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
|
||||
def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
||||
swap_memory=False, name=None):
|
||||
"""The foldl operator on the list of tensors resulted from unpacking `elems`
|
||||
along the first dimension.
|
||||
|
||||
This foldl operator repeatedly applies the callable `fn` to a sequence
|
||||
of elements from first to last. The elements are made of the tensors
|
||||
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
|
||||
arguments. The first argument is the accumulated value computed from the
|
||||
preceding invocation of fn. If `initializer` is None, `elems` must contain
|
||||
at least one element, and its first element is used as the initializer.
|
||||
|
||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||
of the result tensor is fn(initializer, values[0]).shape`.
|
||||
|
||||
Args:
|
||||
fn: The callable to be performed.
|
||||
elems: A tensor to be unpacked on dimension 0.
|
||||
initializer: (optional) The initial value for the accumulator.
|
||||
parallel_iterations: (optional) The number of iterations allowed to run
|
||||
in parallel.
|
||||
back_prop: (optional) True enables back propagation.
|
||||
swap_memory: (optional) True enables GPU-CPU memory swapping.
|
||||
name: (optional) Name prefix for the returned tensors.
|
||||
|
||||
Returns:
|
||||
A tensor resulting from applying `fn` consecutively to the list of tensors
|
||||
unpacked from `elems`, from first to last.
|
||||
|
||||
Raises:
|
||||
TypeError: if `fn` is not callable.
|
||||
|
||||
Example:
|
||||
```python
|
||||
elems = [1, 2, 3, 4, 5, 6]
|
||||
sum = foldl(lambda a, x: a + x, elems)
|
||||
# sum == 21
|
||||
```
|
||||
"""
|
||||
with ops.op_scope([elems], name, "foldl") as name:
|
||||
if not callable(fn):
|
||||
raise TypeError("fn must be callable.")
|
||||
|
||||
# Convert elems to tensor array.
|
||||
n = array_ops.shape(elems)[0]
|
||||
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
|
||||
dynamic_size=False)
|
||||
elems_ta = elems_ta.unpack(elems)
|
||||
|
||||
if initializer is None:
|
||||
a = elems_ta.read(0)
|
||||
i = constant_op.constant(1)
|
||||
else:
|
||||
a = ops.convert_to_tensor(initializer)
|
||||
i = constant_op.constant(0)
|
||||
|
||||
def compute(i, a):
|
||||
a = fn(a, elems_ta.read(i))
|
||||
return [i + 1, a]
|
||||
_, r_a = While(lambda i, a: i < n, compute, [i, a],
|
||||
parallel_iterations=parallel_iterations,
|
||||
back_prop=back_prop, swap_memory=swap_memory)
|
||||
return r_a
|
||||
|
||||
|
||||
def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
||||
swap_memory=False, name=None):
|
||||
"""The foldr operator on the list of tensors resulted from unpacking `elems`
|
||||
along the first dimension.
|
||||
|
||||
This foldr operator repeatedly applies the callable `fn` to a sequence
|
||||
of elements from last to first. The elements are made of the tensors
|
||||
unpacked from `elems`. The callable fn takes two tensors as arguments.
|
||||
The first argument is the accumulated value computed from the preceding
|
||||
invocation of fn. If `initializer` is None, `elems` must contain at least
|
||||
one element, and its first element is used as the initializer.
|
||||
|
||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||
of the result tensor is `fn(initializer, values[0]).shape`.
|
||||
|
||||
Args:
|
||||
fn: The callable to be performed.
|
||||
elems: A tensor that is unpacked into a sequence of tensors to apply `fn`.
|
||||
initializer: (optional) The initial value for the accumulator.
|
||||
parallel_iterations: (optional) The number of iterations allowed to run
|
||||
in parallel.
|
||||
back_prop: (optional) True enables back propagation.
|
||||
swap_memory: (optional) True enables GPU-CPU memory swapping.
|
||||
name: (optional) Name prefix for the returned tensors.
|
||||
|
||||
Returns:
|
||||
A tensor resulting from applying `fn` consecutively to the list of tensors
|
||||
unpacked from `elems`, from last to first.
|
||||
|
||||
Raises:
|
||||
TypeError: if `fn` is not callable.
|
||||
|
||||
Example:
|
||||
```python
|
||||
elems = [1, 2, 3, 4, 5, 6]
|
||||
sum = foldr(lambda a, x: a + x, elems)
|
||||
# sum == 21
|
||||
```
|
||||
"""
|
||||
with ops.op_scope([elems], name, "foldr") as name:
|
||||
if not callable(fn):
|
||||
raise TypeError("fn must be callable.")
|
||||
|
||||
# Convert elems to tensor array.
|
||||
n = array_ops.shape(elems)[0]
|
||||
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
|
||||
dynamic_size=False)
|
||||
elems_ta = elems_ta.unpack(elems)
|
||||
|
||||
if initializer is None:
|
||||
i = n - 1
|
||||
a = elems_ta.read(i)
|
||||
else:
|
||||
i = n
|
||||
a = ops.convert_to_tensor(initializer)
|
||||
def compute(i, a):
|
||||
i -= 1
|
||||
a = fn(a, elems_ta.read(i))
|
||||
return [i, a]
|
||||
_, r_a = While(lambda i, a: i > 0, compute, [i, a],
|
||||
parallel_iterations=parallel_iterations,
|
||||
back_prop=back_prop, swap_memory=swap_memory)
|
||||
return r_a
|
||||
|
||||
|
||||
def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
|
||||
swap_memory=False, name=None):
|
||||
"""The map operator on the list of tensors resulted from unpacking `elems`
|
||||
along the first dimension.
|
||||
|
||||
This map operator repeatedly applies the callable `fn` to a sequence of
|
||||
elements from first to last. The elements are made of the tensors unpacked
|
||||
from `elems`. `dtype` is the data type of the return value of `fn`. Users
|
||||
must provide `dtype` if it is different from the data type of `elems`.
|
||||
|
||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||
of the result tensor is `[len(values)] + fn(values[0]).shape`.
|
||||
|
||||
Args:
|
||||
fn: The callable to be performed.
|
||||
elems: A tensor to be unpacked to apply `fn`.
|
||||
dtype: (optional) The output type of `fn`.
|
||||
parallel_iterations: (optional) The number of iterations allowed to run
|
||||
in parallel.
|
||||
back_prop: (optional) True enables back propagation.
|
||||
swap_memory: (optional) True enables GPU-CPU memory swapping.
|
||||
name: (optional) Name prefix for the returned tensors.
|
||||
|
||||
Returns:
|
||||
A tensor that packs the results of applying `fn` to the list of tensors
|
||||
unpacked from `elems`, from first to last.
|
||||
|
||||
Raises:
|
||||
TypeError: if `fn` is not callable.
|
||||
|
||||
Example:
|
||||
```python
|
||||
elems = [1, 2, 3, 4, 5, 6]
|
||||
squares = map_fn(lambda x: x * x, elems)
|
||||
# squares == [1, 4, 9, 16, 25, 36]
|
||||
```
|
||||
"""
|
||||
with ops.op_scope([elems], name, "map") as name:
|
||||
if not callable(fn):
|
||||
raise TypeError("fn must be callable.")
|
||||
dtype = dtype if dtype else elems.dtype
|
||||
|
||||
# Convert elems to tensor array.
|
||||
n = array_ops.shape(elems)[0]
|
||||
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
|
||||
dynamic_size=False)
|
||||
elems_ta = elems_ta.unpack(elems)
|
||||
|
||||
i = constant_op.constant(0)
|
||||
acc_ta = tensor_array_ops.TensorArray(dtype=dtype, size=n,
|
||||
dynamic_size=False)
|
||||
def compute(i, ta):
|
||||
ta = ta.write(i, fn(elems_ta.read(i)))
|
||||
return [i + 1, ta]
|
||||
_, r_a = While(lambda i, a: i < n, compute, [i, acc_ta],
|
||||
parallel_iterations=parallel_iterations,
|
||||
back_prop=back_prop, swap_memory=swap_memory)
|
||||
return r_a.pack()
|
||||
|
||||
|
||||
def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
||||
swap_memory=False, name=None):
|
||||
"""The scan operator on the list of tensors resulted from unpacking `elems`
|
||||
along the first dimension.
|
||||
|
||||
This scan operator repeatedly applies the callable `fn` to a sequence
|
||||
of elements from first to last. The elements are made of the tensors
|
||||
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
|
||||
arguments. The first argument is the accumulated value computed from the
|
||||
preceding invocation of fn. If `initializer` is None, `elems` must contain
|
||||
at least one element, and its first element is used as the initializer.
|
||||
|
||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||
of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
|
||||
|
||||
Args:
|
||||
fn: The callable to be performed.
|
||||
elems: A tensor to be unpacked on dimension 0.
|
||||
initializer: (optional) The initial value for the accumulator.
|
||||
parallel_iterations: (optional) The number of iterations allowed to run
|
||||
in parallel.
|
||||
back_prop: (optional) True enables back propagation.
|
||||
swap_memory: (optional) True enables GPU-CPU memory swapping.
|
||||
name: (optional) Name prefix for the returned tensors.
|
||||
|
||||
Returns:
|
||||
A tensor that packs the results of applying `fn` to the list of tensors
|
||||
unpacked from `elems`, from first to last.
|
||||
|
||||
Raises:
|
||||
TypeError: if `fn` is not callable.
|
||||
|
||||
Example:
|
||||
```python
|
||||
elems = [1, 2, 3, 4, 5, 6]
|
||||
sum = scan(lambda a, x: a + x, elems)
|
||||
# sum == [1, 3, 6, 10, 15, 21]
|
||||
```
|
||||
"""
|
||||
with ops.op_scope([elems], name, "scan") as name:
|
||||
if not callable(fn):
|
||||
raise TypeError("fn must be callable.")
|
||||
|
||||
# Convert elems to tensor array.
|
||||
n = array_ops.shape(elems)[0]
|
||||
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
|
||||
dynamic_size=False)
|
||||
elems_ta = elems_ta.unpack(elems)
|
||||
|
||||
if initializer is None:
|
||||
a = elems_ta.read(0)
|
||||
i = constant_op.constant(1)
|
||||
else:
|
||||
a = ops.convert_to_tensor(initializer)
|
||||
i = constant_op.constant(0)
|
||||
|
||||
# Create a tensor array to store the intermediate values.
|
||||
acc_ta = tensor_array_ops.TensorArray(dtype=a.dtype, size=n,
|
||||
dynamic_size=False)
|
||||
if initializer is None:
|
||||
acc_ta = acc_ta.write(0, a)
|
||||
|
||||
def compute(i, a, ta):
|
||||
a = fn(a, elems_ta.read(i))
|
||||
ta = ta.write(i, a)
|
||||
return [i + 1, a, ta]
|
||||
_, _, r_a = While(lambda i, a, ta: i < n, compute, [i, a, acc_ta],
|
||||
parallel_iterations=parallel_iterations,
|
||||
back_prop=back_prop, swap_memory=swap_memory)
|
||||
return r_a.pack()
|
||||
|
||||
|
||||
def case(pred_fn_pairs, default, exclusive=False, name="case"):
|
||||
"""Create a case operation.
|
||||
|
||||
|
@ -13,13 +13,28 @@
|
||||
# limitations under the License.
|
||||
# =============================================================================
|
||||
|
||||
"""Functional operations."""
|
||||
"""Functional operations.
|
||||
|
||||
## Higher Order Operators
|
||||
|
||||
TensorFlow provides several higher order operators to simplify the common
|
||||
map-reduce programming patterns.
|
||||
|
||||
@@map_fn
|
||||
@@foldl
|
||||
@@foldr
|
||||
@@scan
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import constant_op
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.python.ops.gen_functional_ops import *
|
||||
# pylint: enable=wildcard-import
|
||||
@ -28,6 +43,269 @@ from tensorflow.python.ops.gen_functional_ops import _symbolic_gradient
|
||||
# pylint: enable=unused-import
|
||||
|
||||
|
||||
# TODO(yuanbyu, mrry): Handle stride to support sliding windows.
|
||||
def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
||||
swap_memory=False, name=None):
|
||||
"""foldl on the list of tensors unpacked from `elems` on dimension 0.
|
||||
|
||||
This foldl operator repeatedly applies the callable `fn` to a sequence
|
||||
of elements from first to last. The elements are made of the tensors
|
||||
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
|
||||
arguments. The first argument is the accumulated value computed from the
|
||||
preceding invocation of fn. If `initializer` is None, `elems` must contain
|
||||
at least one element, and its first element is used as the initializer.
|
||||
|
||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||
of the result tensor is fn(initializer, values[0]).shape`.
|
||||
|
||||
Args:
|
||||
fn: The callable to be performed.
|
||||
elems: A tensor to be unpacked on dimension 0.
|
||||
initializer: (optional) The initial value for the accumulator.
|
||||
parallel_iterations: (optional) The number of iterations allowed to run
|
||||
in parallel.
|
||||
back_prop: (optional) True enables back propagation.
|
||||
swap_memory: (optional) True enables GPU-CPU memory swapping.
|
||||
name: (optional) Name prefix for the returned tensors.
|
||||
|
||||
Returns:
|
||||
A tensor resulting from applying `fn` consecutively to the list of tensors
|
||||
unpacked from `elems`, from first to last.
|
||||
|
||||
Raises:
|
||||
TypeError: if `fn` is not callable.
|
||||
|
||||
Example:
|
||||
```python
|
||||
elems = [1, 2, 3, 4, 5, 6]
|
||||
sum = foldl(lambda a, x: a + x, elems)
|
||||
# sum == 21
|
||||
```
|
||||
"""
|
||||
with ops.op_scope([elems], name, "foldl") as name:
|
||||
if not callable(fn):
|
||||
raise TypeError("fn must be callable.")
|
||||
|
||||
# Convert elems to tensor array.
|
||||
n = array_ops.shape(elems)[0]
|
||||
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
|
||||
dynamic_size=False)
|
||||
elems_ta = elems_ta.unpack(elems)
|
||||
|
||||
if initializer is None:
|
||||
a = elems_ta.read(0)
|
||||
i = constant_op.constant(1)
|
||||
else:
|
||||
a = ops.convert_to_tensor(initializer)
|
||||
i = constant_op.constant(0)
|
||||
|
||||
def compute(i, a):
|
||||
a = fn(a, elems_ta.read(i))
|
||||
return [i + 1, a]
|
||||
_, r_a = control_flow_ops.While(lambda i, a: i < n, compute, [i, a],
|
||||
parallel_iterations=parallel_iterations,
|
||||
back_prop=back_prop,
|
||||
swap_memory=swap_memory)
|
||||
return r_a
|
||||
|
||||
|
||||
def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
||||
swap_memory=False, name=None):
|
||||
"""foldr on the list of tensors unpacked from `elems` on dimension 0.
|
||||
|
||||
This foldr operator repeatedly applies the callable `fn` to a sequence
|
||||
of elements from last to first. The elements are made of the tensors
|
||||
unpacked from `elems`. The callable fn takes two tensors as arguments.
|
||||
The first argument is the accumulated value computed from the preceding
|
||||
invocation of fn. If `initializer` is None, `elems` must contain at least
|
||||
one element, and its first element is used as the initializer.
|
||||
|
||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||
of the result tensor is `fn(initializer, values[0]).shape`.
|
||||
|
||||
Args:
|
||||
fn: The callable to be performed.
|
||||
elems: A tensor that is unpacked into a sequence of tensors to apply `fn`.
|
||||
initializer: (optional) The initial value for the accumulator.
|
||||
parallel_iterations: (optional) The number of iterations allowed to run
|
||||
in parallel.
|
||||
back_prop: (optional) True enables back propagation.
|
||||
swap_memory: (optional) True enables GPU-CPU memory swapping.
|
||||
name: (optional) Name prefix for the returned tensors.
|
||||
|
||||
Returns:
|
||||
A tensor resulting from applying `fn` consecutively to the list of tensors
|
||||
unpacked from `elems`, from last to first.
|
||||
|
||||
Raises:
|
||||
TypeError: if `fn` is not callable.
|
||||
|
||||
Example:
|
||||
```python
|
||||
elems = [1, 2, 3, 4, 5, 6]
|
||||
sum = foldr(lambda a, x: a + x, elems)
|
||||
# sum == 21
|
||||
```
|
||||
"""
|
||||
with ops.op_scope([elems], name, "foldr") as name:
|
||||
if not callable(fn):
|
||||
raise TypeError("fn must be callable.")
|
||||
|
||||
# Convert elems to tensor array.
|
||||
n = array_ops.shape(elems)[0]
|
||||
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
|
||||
dynamic_size=False)
|
||||
elems_ta = elems_ta.unpack(elems)
|
||||
|
||||
if initializer is None:
|
||||
i = n - 1
|
||||
a = elems_ta.read(i)
|
||||
else:
|
||||
i = n
|
||||
a = ops.convert_to_tensor(initializer)
|
||||
def compute(i, a):
|
||||
i -= 1
|
||||
a = fn(a, elems_ta.read(i))
|
||||
return [i, a]
|
||||
_, r_a = control_flow_ops.While(lambda i, a: i > 0, compute, [i, a],
|
||||
parallel_iterations=parallel_iterations,
|
||||
back_prop=back_prop,
|
||||
swap_memory=swap_memory)
|
||||
return r_a
|
||||
|
||||
|
||||
def map_fn(fn, elems, dtype=None, parallel_iterations=10, back_prop=True,
|
||||
swap_memory=False, name=None):
|
||||
"""map on the list of tensors unpacked from `elems` on dimension 0.
|
||||
|
||||
This map operator repeatedly applies the callable `fn` to a sequence of
|
||||
elements from first to last. The elements are made of the tensors unpacked
|
||||
from `elems`. `dtype` is the data type of the return value of `fn`. Users
|
||||
must provide `dtype` if it is different from the data type of `elems`.
|
||||
|
||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||
of the result tensor is `[len(values)] + fn(values[0]).shape`.
|
||||
|
||||
Args:
|
||||
fn: The callable to be performed.
|
||||
elems: A tensor to be unpacked to apply `fn`.
|
||||
dtype: (optional) The output type of `fn`.
|
||||
parallel_iterations: (optional) The number of iterations allowed to run
|
||||
in parallel.
|
||||
back_prop: (optional) True enables back propagation.
|
||||
swap_memory: (optional) True enables GPU-CPU memory swapping.
|
||||
name: (optional) Name prefix for the returned tensors.
|
||||
|
||||
Returns:
|
||||
A tensor that packs the results of applying `fn` to the list of tensors
|
||||
unpacked from `elems`, from first to last.
|
||||
|
||||
Raises:
|
||||
TypeError: if `fn` is not callable.
|
||||
|
||||
Example:
|
||||
```python
|
||||
elems = [1, 2, 3, 4, 5, 6]
|
||||
squares = map_fn(lambda x: x * x, elems)
|
||||
# squares == [1, 4, 9, 16, 25, 36]
|
||||
```
|
||||
"""
|
||||
with ops.op_scope([elems], name, "map") as name:
|
||||
if not callable(fn):
|
||||
raise TypeError("fn must be callable.")
|
||||
dtype = dtype if dtype else elems.dtype
|
||||
|
||||
# Convert elems to tensor array.
|
||||
n = array_ops.shape(elems)[0]
|
||||
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
|
||||
dynamic_size=False)
|
||||
elems_ta = elems_ta.unpack(elems)
|
||||
|
||||
i = constant_op.constant(0)
|
||||
acc_ta = tensor_array_ops.TensorArray(dtype=dtype, size=n,
|
||||
dynamic_size=False)
|
||||
def compute(i, ta):
|
||||
ta = ta.write(i, fn(elems_ta.read(i)))
|
||||
return [i + 1, ta]
|
||||
_, r_a = control_flow_ops.While(lambda i, a: i < n, compute, [i, acc_ta],
|
||||
parallel_iterations=parallel_iterations,
|
||||
back_prop=back_prop,
|
||||
swap_memory=swap_memory)
|
||||
return r_a.pack()
|
||||
|
||||
|
||||
def scan(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
|
||||
swap_memory=False, name=None):
|
||||
"""scan on the list of tensors unpacked from `elems` on dimension 0.
|
||||
|
||||
This scan operator repeatedly applies the callable `fn` to a sequence
|
||||
of elements from first to last. The elements are made of the tensors
|
||||
unpacked from `elems` on dimension 0. The callable fn takes two tensors as
|
||||
arguments. The first argument is the accumulated value computed from the
|
||||
preceding invocation of fn. If `initializer` is None, `elems` must contain
|
||||
at least one element, and its first element is used as the initializer.
|
||||
|
||||
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
|
||||
of the result tensor is `[len(values)] + fn(initializer, values[0]).shape`.
|
||||
|
||||
Args:
|
||||
fn: The callable to be performed.
|
||||
elems: A tensor to be unpacked on dimension 0.
|
||||
initializer: (optional) The initial value for the accumulator.
|
||||
parallel_iterations: (optional) The number of iterations allowed to run
|
||||
in parallel.
|
||||
back_prop: (optional) True enables back propagation.
|
||||
swap_memory: (optional) True enables GPU-CPU memory swapping.
|
||||
name: (optional) Name prefix for the returned tensors.
|
||||
|
||||
Returns:
|
||||
A tensor that packs the results of applying `fn` to the list of tensors
|
||||
unpacked from `elems`, from first to last.
|
||||
|
||||
Raises:
|
||||
TypeError: if `fn` is not callable.
|
||||
|
||||
Example:
|
||||
```python
|
||||
elems = [1, 2, 3, 4, 5, 6]
|
||||
sum = scan(lambda a, x: a + x, elems)
|
||||
# sum == [1, 3, 6, 10, 15, 21]
|
||||
```
|
||||
"""
|
||||
with ops.op_scope([elems], name, "scan") as name:
|
||||
if not callable(fn):
|
||||
raise TypeError("fn must be callable.")
|
||||
|
||||
# Convert elems to tensor array.
|
||||
n = array_ops.shape(elems)[0]
|
||||
elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
|
||||
dynamic_size=False)
|
||||
elems_ta = elems_ta.unpack(elems)
|
||||
|
||||
if initializer is None:
|
||||
a = elems_ta.read(0)
|
||||
i = constant_op.constant(1)
|
||||
else:
|
||||
a = ops.convert_to_tensor(initializer)
|
||||
i = constant_op.constant(0)
|
||||
|
||||
# Create a tensor array to store the intermediate values.
|
||||
acc_ta = tensor_array_ops.TensorArray(dtype=a.dtype, size=n,
|
||||
dynamic_size=False)
|
||||
if initializer is None:
|
||||
acc_ta = acc_ta.write(0, a)
|
||||
|
||||
def compute(i, a, ta):
|
||||
a = fn(a, elems_ta.read(i))
|
||||
ta = ta.write(i, a)
|
||||
return [i + 1, a, ta]
|
||||
_, _, r_a = control_flow_ops.While(
|
||||
lambda i, a, ta: i < n, compute, [i, a, acc_ta],
|
||||
parallel_iterations=parallel_iterations,
|
||||
back_prop=back_prop, swap_memory=swap_memory)
|
||||
return r_a.pack()
|
||||
|
||||
|
||||
@ops.RegisterShape("SymbolicGradient")
|
||||
def _symbolic_gradient_shape(op):
|
||||
# Say, (u, v) = f(x, y, z), _symbolic_gradient(f) is a function of
|
||||
|
@ -37,11 +37,8 @@ from tensorflow.python.ops.control_flow_ops import no_op
|
||||
from tensorflow.python.ops.control_flow_ops import tuple
|
||||
from tensorflow.python.ops.control_flow_ops import cond
|
||||
from tensorflow.python.ops.control_flow_ops import case
|
||||
from tensorflow.python.ops.control_flow_ops import foldl
|
||||
from tensorflow.python.ops.control_flow_ops import foldr
|
||||
from tensorflow.python.ops.control_flow_ops import map_fn
|
||||
from tensorflow.python.ops.control_flow_ops import scan
|
||||
from tensorflow.python.ops.data_flow_ops import *
|
||||
from tensorflow.python.ops.functional_ops import *
|
||||
from tensorflow.python.ops.gradients import *
|
||||
from tensorflow.python.ops.histogram_ops import *
|
||||
from tensorflow.python.ops.init_ops import *
|
||||
|
@ -359,7 +359,8 @@ def _pure_variable_scope(name_or_scope, reuse=None, initializer=None,
|
||||
|
||||
"""
|
||||
get_variable_scope() # Ensure that a default exists, then get a pointer.
|
||||
default_varscope = ops.get_collection(_VARSCOPE_KEY)
|
||||
# Get the reference to the collection as we want to modify it in place.
|
||||
default_varscope = ops.get_collection_ref(_VARSCOPE_KEY)
|
||||
try:
|
||||
old = default_varscope[0]
|
||||
reuse = reuse or old.reuse # Re-using is inherited by sub-scopes.
|
||||
|
@ -256,7 +256,7 @@ class SessionManager(object):
|
||||
self._safe_close(sess)
|
||||
logging.info("Waiting for model to be ready: %s", not_ready)
|
||||
time.sleep(self._recovery_wait_secs)
|
||||
sess = session.Session(master, graph=self._graph)
|
||||
sess = session.Session(target, graph=self._graph, config=config)
|
||||
|
||||
return sess
|
||||
|
||||
|
@ -104,6 +104,13 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
|
||||
# Now you can call `minimize()` or `compute_gradients()` and
|
||||
# `apply_gradients()` normally
|
||||
grads = opt.minimize(total_loss, global_step=self.global_step)
|
||||
|
||||
|
||||
# You can now call get_init_tokens_op() and get_chief_queue_runner().
|
||||
# Note that get_init_tokens_op() must be called before creating session
|
||||
# because it modifies the graph.
|
||||
init_token_op = opt.get_init_tokens_op()
|
||||
chief_queue_runner = opt.get_chief_queue_runner()
|
||||
```
|
||||
|
||||
In the training program, every worker will run the train_op as if not
|
||||
@ -114,9 +121,9 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
|
||||
# After the session is created by the superviser and before the main while
|
||||
# loop:
|
||||
if is_chief and FLAGS.sync_replicas:
|
||||
sv.start_queue_runners(sess, [opt.get_chief_queue_runner()])
|
||||
sv.start_queue_runners(sess, [chief_queue_runner])
|
||||
# Insert initial tokens to the queue.
|
||||
sess.run(opt.get_init_tokens_op())
|
||||
sess.run(init_token_op)
|
||||
```
|
||||
|
||||
@@__init__
|
||||
|
@ -89,37 +89,41 @@ const DOMAIN_EDGE_WIDTH_SCALE = [1, 5E6];
|
||||
/**
|
||||
* Parameters that affect how the graph is rendered on the screen.
|
||||
*/
|
||||
export interface RenderGraphParams {
|
||||
const PARAMS = {
|
||||
/**
|
||||
* Whether to extract high degree nodes from the core part of the graph.
|
||||
*/
|
||||
enableExtraction: boolean;
|
||||
enableExtraction: true,
|
||||
/**
|
||||
* Maximum in-degree that a node can have without being considered as
|
||||
* high in-degree node.
|
||||
*/
|
||||
maxInDegree: number;
|
||||
maxInDegree: 4,
|
||||
/**
|
||||
* Maximum out-degree that a node can have without being considered as
|
||||
* high out-degree node.
|
||||
*/
|
||||
maxOutDegree: number;
|
||||
maxOutDegree: 4,
|
||||
/**
|
||||
* Maximum number of control edges a node can have before they aren't
|
||||
* displayed.
|
||||
*/
|
||||
maxControlDegree: number;
|
||||
maxControlDegree: 4,
|
||||
/**
|
||||
* Types patterns for predefined out-extract nodes, which are
|
||||
* sink-like nodes that will be extracted from the main graph.
|
||||
*/
|
||||
outExtractTypes: string[];
|
||||
outExtractTypes: [
|
||||
"NoOp" // NoOps are sink-like used for managing control dependencies.
|
||||
],
|
||||
|
||||
/**
|
||||
* Types patterns for predefined in-extract nodes, which are
|
||||
* source-like nodes that will be extracted from the main graph.
|
||||
*/
|
||||
inExtractTypes: string[];
|
||||
inExtractTypes: [
|
||||
"Variable"
|
||||
],
|
||||
|
||||
/**
|
||||
* When removing edges from a high degree node, remove all of its edges if
|
||||
@ -127,32 +131,33 @@ export interface RenderGraphParams {
|
||||
* the node has high in-degree, or all out-edges if the node has high
|
||||
* out-degree.
|
||||
*/
|
||||
detachAllEdgesForHighDegree: boolean;
|
||||
detachAllEdgesForHighDegree: true,
|
||||
|
||||
/**
|
||||
* After extracting high in/out degree nodes and predefined
|
||||
* source-like/sink-like, extract isolated nodes to the side
|
||||
* if this extractIsolatedNodesWithAnnotationsOnOneSide is true.
|
||||
*/
|
||||
extractIsolatedNodesWithAnnotationsOnOneSide: boolean;
|
||||
extractIsolatedNodesWithAnnotationsOnOneSide: true,
|
||||
|
||||
/**
|
||||
* Whether to add bridge nodes and edges to the core when building the
|
||||
* subhierarchy of an expanded metanode. See buildSubhierarchy().
|
||||
*/
|
||||
enableBridgegraph: boolean;
|
||||
enableBridgegraph: true,
|
||||
|
||||
/**
|
||||
* 2 colors, for the minimum and maximum value respectively, whenever we
|
||||
* have a gradient scale.
|
||||
*/
|
||||
minMaxColors: string[];
|
||||
minMaxColors: ["#fff5f0", "#fb6a4a"],
|
||||
|
||||
/**
|
||||
* Maximum number of annotations to be displayed on a node.
|
||||
* Maximum number of annotations to be displayed on a node before an
|
||||
* ellipsis is used.
|
||||
*/
|
||||
maxAnnotations: number;
|
||||
}
|
||||
maxAnnotations: 5
|
||||
};
|
||||
|
||||
/**
|
||||
* Stores the rendering information, such as x and y coordinates,
|
||||
@ -161,7 +166,6 @@ export interface RenderGraphParams {
|
||||
export class RenderGraphInfo {
|
||||
private hierarchy: hierarchy.Hierarchy;
|
||||
private index: {[nodeName: string]: RenderNodeInfo};
|
||||
private params: RenderGraphParams;
|
||||
private deviceColorMap: d3.scale.Ordinal<string, string>;
|
||||
private memoryUsageScale: d3.scale.Linear<string, string>;
|
||||
private computeTimeScale: d3.scale.Linear<string, string>;
|
||||
@ -173,7 +177,7 @@ export class RenderGraphInfo {
|
||||
private hasSubhierarchy: {[nodeName: string]: boolean};
|
||||
root: RenderGroupNodeInfo;
|
||||
|
||||
constructor(hierarchy: hierarchy.Hierarchy, params: RenderGraphParams) {
|
||||
constructor(hierarchy: hierarchy.Hierarchy) {
|
||||
this.hierarchy = hierarchy;
|
||||
this.index = {};
|
||||
this.deviceColorMap = d3.scale.ordinal<string>()
|
||||
@ -199,7 +203,7 @@ export class RenderGraphInfo {
|
||||
});
|
||||
this.memoryUsageScale = d3.scale.linear<string, string>()
|
||||
.domain(memoryExtent)
|
||||
.range(params.minMaxColors);
|
||||
.range(PARAMS.minMaxColors);
|
||||
|
||||
// Find also the minimum and maximum compute time.
|
||||
let computeTimeExtent = d3.extent(topLevelGraph.nodes(),
|
||||
@ -212,12 +216,11 @@ export class RenderGraphInfo {
|
||||
});
|
||||
this.computeTimeScale = d3.scale.linear<string, string>()
|
||||
.domain(computeTimeExtent)
|
||||
.range(params.minMaxColors);
|
||||
.range(PARAMS.minMaxColors);
|
||||
|
||||
// Maps node name to whether the rendering hierarchy was already
|
||||
// constructed.
|
||||
this.hasSubhierarchy = {};
|
||||
this.params = params;
|
||||
this.root = new RenderGroupNodeInfo(hierarchy.root);
|
||||
this.index[hierarchy.root.name] = this.root;
|
||||
this.buildSubhierarchy(hierarchy.root.name);
|
||||
@ -373,13 +376,13 @@ export class RenderGraphInfo {
|
||||
_.each((<OpNode>childNode).inEmbeddings, embedding => {
|
||||
let renderMetaedgeInfo = new RenderMetaedgeInfo(null);
|
||||
addInAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo,
|
||||
AnnotationType.CONSTANT, this.params);
|
||||
AnnotationType.CONSTANT);
|
||||
this.index[embedding.name] = new RenderNodeInfo(embedding);
|
||||
});
|
||||
_.each((<OpNode>childNode).outEmbeddings, embedding => {
|
||||
let renderMetaedgeInfo = new RenderMetaedgeInfo(null);
|
||||
addOutAnnotation(childRenderInfo, embedding, null, renderMetaedgeInfo,
|
||||
AnnotationType.SUMMARY, this.params);
|
||||
AnnotationType.SUMMARY);
|
||||
this.index[embedding.name] = new RenderNodeInfo(embedding);
|
||||
});
|
||||
}
|
||||
@ -393,9 +396,9 @@ export class RenderGraphInfo {
|
||||
coreGraph.setEdge(edgeObj.v, edgeObj.w, renderMetaedgeInfo);
|
||||
});
|
||||
|
||||
if (this.params.enableExtraction &&
|
||||
if (PARAMS.enableExtraction &&
|
||||
renderGroupNodeInfo.node.type === NodeType.META) {
|
||||
extractHighDegrees(renderGroupNodeInfo, this.params);
|
||||
extractHighDegrees(renderGroupNodeInfo);
|
||||
}
|
||||
|
||||
// Record that we constructed the rendering hierarchy for this node, so we
|
||||
@ -469,7 +472,7 @@ export class RenderGraphInfo {
|
||||
// either node is high-degree with respect to control edges. This will
|
||||
// be a signal to show it as an annotation instead of a bridge edge.
|
||||
let isHighDegreeControlEdge = !bridgeMetaedge.numRegularEdges &&
|
||||
otherCounts.control[otherName] > this.params.maxControlDegree;
|
||||
otherCounts.control[otherName] > PARAMS.maxControlDegree;
|
||||
|
||||
let [, childAnnotations] =
|
||||
inbound ?
|
||||
@ -478,8 +481,8 @@ export class RenderGraphInfo {
|
||||
|
||||
let isOtherHighDegree =
|
||||
inbound ?
|
||||
otherCounts.out[otherName] > this.params.maxOutDegree :
|
||||
otherCounts.in[otherName] > this.params.maxInDegree;
|
||||
otherCounts.out[otherName] > PARAMS.maxOutDegree :
|
||||
otherCounts.in[otherName] > PARAMS.maxInDegree;
|
||||
|
||||
// The adjoining render metaedge info from the parent's coreGraph, if any.
|
||||
// It will either be a Metaedge involving this node directly, if it
|
||||
@ -493,7 +496,7 @@ export class RenderGraphInfo {
|
||||
// - the child is in the core (not extracted for being high-degree), and
|
||||
// - there's a path (in the traversal sense) between child and other.
|
||||
let canDrawBridgePath = false;
|
||||
if (this.params.enableBridgegraph &&
|
||||
if (PARAMS.enableBridgegraph &&
|
||||
!isOtherHighDegree &&
|
||||
!isHighDegreeControlEdge &&
|
||||
childRenderInfo.isInCore()) {
|
||||
@ -581,7 +584,7 @@ export class RenderGraphInfo {
|
||||
otherRenderInfo,
|
||||
new RenderMetaedgeInfo(bridgeMetaedge),
|
||||
AnnotationType.SHORTCUT,
|
||||
inbound), this.params);
|
||||
inbound));
|
||||
return;
|
||||
}
|
||||
|
||||
@ -869,13 +872,13 @@ export class AnnotationList {
|
||||
* Append an annotation to the list, or a stand-in ellipsis annotation instead
|
||||
* if this would make it too many.
|
||||
*/
|
||||
push(annotation: Annotation, params: RenderGraphParams): void {
|
||||
push(annotation: Annotation): void {
|
||||
if (annotation.node.name in this.nodeNames) {
|
||||
return; // Skip duplicate annotation.
|
||||
}
|
||||
this.nodeNames[annotation.node.name] = true;
|
||||
|
||||
if (this.list.length < params.maxAnnotations) {
|
||||
if (this.list.length < PARAMS.maxAnnotations) {
|
||||
this.list.push(annotation);
|
||||
return;
|
||||
}
|
||||
@ -1101,19 +1104,18 @@ export class RenderMetaedgeInfo {
|
||||
|
||||
function addInAnnotation(node: RenderNodeInfo, predecessor: Node,
|
||||
predecessorRenderInfo: RenderNodeInfo,
|
||||
edge: RenderMetaedgeInfo, type: AnnotationType,
|
||||
params: RenderGraphParams): void {
|
||||
edge: RenderMetaedgeInfo, type: AnnotationType): void {
|
||||
let annotation = new Annotation(predecessor, predecessorRenderInfo, edge,
|
||||
type, true);
|
||||
node.inAnnotations.push(annotation, params);
|
||||
node.inAnnotations.push(annotation);
|
||||
}
|
||||
|
||||
function addOutAnnotation(node: RenderNodeInfo, successor: Node,
|
||||
successorRenderInfo: RenderNodeInfo, edge: RenderMetaedgeInfo,
|
||||
type: AnnotationType, params: RenderGraphParams): void {
|
||||
type: AnnotationType): void {
|
||||
let annotation = new Annotation(successor, successorRenderInfo, edge,
|
||||
type, false);
|
||||
node.outAnnotations.push(annotation, params);
|
||||
node.outAnnotations.push(annotation);
|
||||
}
|
||||
|
||||
function setGraphDepth(graph: graphlib.Graph<RenderNodeInfo, any>,
|
||||
@ -1181,7 +1183,7 @@ function setGroupNodeDepth(renderInfo: RenderGroupNodeInfo,
|
||||
*/
|
||||
function createShortcut(
|
||||
graph: graphlib.Graph<RenderNodeInfo, RenderMetaedgeInfo>,
|
||||
v: string, w: string, params: RenderGraphParams) {
|
||||
v: string, w: string) {
|
||||
let src = graph.node(v);
|
||||
let sink = graph.node(w);
|
||||
let edge = graph.edge(v, w);
|
||||
@ -1197,8 +1199,8 @@ function createShortcut(
|
||||
}
|
||||
|
||||
// Add each annotation.
|
||||
addOutAnnotation(src, sink.node, sink, edge, AnnotationType.SHORTCUT, params);
|
||||
addInAnnotation(sink, src.node, src, edge, AnnotationType.SHORTCUT, params);
|
||||
addOutAnnotation(src, sink.node, sink, edge, AnnotationType.SHORTCUT);
|
||||
addInAnnotation(sink, src.node, src, edge, AnnotationType.SHORTCUT);
|
||||
|
||||
// Remove the edge from the core graph.
|
||||
graph.removeEdge(v, w);
|
||||
@ -1212,18 +1214,18 @@ function createShortcut(
|
||||
* edges. Otherwise, only extract all in-edges.
|
||||
*/
|
||||
function makeOutExtract(renderNode: RenderGroupNodeInfo, n: string,
|
||||
params: RenderGraphParams, forceDetach?: boolean) {
|
||||
forceDetach?: boolean) {
|
||||
let graph = renderNode.coreGraph;
|
||||
let child = graph.node(n);
|
||||
child.isOutExtract = true;
|
||||
|
||||
_.each(graph.predecessors(n), (p, index) => {
|
||||
createShortcut(graph, p, n, params);
|
||||
createShortcut(graph, p, n);
|
||||
});
|
||||
|
||||
if (params.detachAllEdgesForHighDegree || forceDetach) {
|
||||
if (PARAMS.detachAllEdgesForHighDegree || forceDetach) {
|
||||
_.each(graph.successors(n), (s, index) => {
|
||||
createShortcut(graph, n, s, params);
|
||||
createShortcut(graph, n, s);
|
||||
});
|
||||
}
|
||||
|
||||
@ -1243,18 +1245,18 @@ function makeOutExtract(renderNode: RenderGroupNodeInfo, n: string,
|
||||
* edges. Otherwise, only remove all out-edges.
|
||||
*/
|
||||
export function makeInExtract(renderNode: RenderGroupNodeInfo, n: string,
|
||||
params: RenderGraphParams, forceDetach?: boolean) {
|
||||
forceDetach?: boolean) {
|
||||
let graph = renderNode.coreGraph;
|
||||
let child = graph.node(n);
|
||||
child.isInExtract = true;
|
||||
|
||||
_.each(graph.successors(n), (s, index) => {
|
||||
createShortcut(graph, n, s, params);
|
||||
createShortcut(graph, n, s);
|
||||
});
|
||||
|
||||
if (params.detachAllEdgesForHighDegree || forceDetach) {
|
||||
if (PARAMS.detachAllEdgesForHighDegree || forceDetach) {
|
||||
_.each(graph.predecessors(n), (p, index) => {
|
||||
createShortcut(graph, p, n, params);
|
||||
createShortcut(graph, p, n);
|
||||
});
|
||||
}
|
||||
|
||||
@ -1289,40 +1291,37 @@ function hasTypeIn(node: Node, types: string[]): boolean {
|
||||
}
|
||||
|
||||
/** Move nodes that are specified to be excluded out of the core graph. */
|
||||
function extractSpecifiedNodes(renderNode: RenderGroupNodeInfo,
|
||||
params: RenderGraphParams) {
|
||||
function extractSpecifiedNodes(renderNode: RenderGroupNodeInfo) {
|
||||
let graph = renderNode.coreGraph;
|
||||
_.each(graph.nodes(), n => {
|
||||
let renderInfo = graph.node(n);
|
||||
if (renderInfo.node.include === InclusionType.EXCLUDE) {
|
||||
if (renderNode.coreGraph.outEdges(n).length >
|
||||
renderNode.coreGraph.inEdges(n).length) {
|
||||
makeOutExtract(renderNode, n, params, true);
|
||||
makeOutExtract(renderNode, n, true);
|
||||
} else {
|
||||
makeInExtract(renderNode, n, params, true);
|
||||
makeInExtract(renderNode, n, true);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/** Remove edges from pre-defined out-extract patterns */
|
||||
function extractPredefinedSink(renderNode: RenderGroupNodeInfo,
|
||||
params: RenderGraphParams) {
|
||||
function extractPredefinedSink(renderNode: RenderGroupNodeInfo) {
|
||||
let graph = renderNode.coreGraph;
|
||||
_.each(graph.nodes(), n => {
|
||||
let renderInfo = graph.node(n);
|
||||
if (renderInfo.node.include !== InclusionType.UNSPECIFIED) {
|
||||
return;
|
||||
}
|
||||
if (hasTypeIn(renderInfo.node, params.outExtractTypes)) {
|
||||
makeOutExtract(renderNode, n, params);
|
||||
if (hasTypeIn(renderInfo.node, PARAMS.outExtractTypes)) {
|
||||
makeOutExtract(renderNode, n);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/** Remove edges from pre-defined in-extract patterns */
|
||||
function extractPredefinedSource(renderNode: RenderGroupNodeInfo,
|
||||
params: RenderGraphParams) {
|
||||
function extractPredefinedSource(renderNode: RenderGroupNodeInfo) {
|
||||
let graph = renderNode.coreGraph;
|
||||
|
||||
_.each(graph.nodes(), n => {
|
||||
@ -1330,17 +1329,16 @@ function extractPredefinedSource(renderNode: RenderGroupNodeInfo,
|
||||
if (renderInfo.node.include !== InclusionType.UNSPECIFIED) {
|
||||
return;
|
||||
}
|
||||
if (hasTypeIn(renderInfo.node, params.inExtractTypes)) {
|
||||
makeInExtract(renderNode, n, params);
|
||||
if (hasTypeIn(renderInfo.node, PARAMS.inExtractTypes)) {
|
||||
makeInExtract(renderNode, n);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/** Extract from nodes with in-degree > maxInDegree */
|
||||
function extractHighInDegree(renderNode: RenderGroupNodeInfo,
|
||||
params: RenderGraphParams) {
|
||||
function extractHighInDegree(renderNode: RenderGroupNodeInfo) {
|
||||
let graph = renderNode.coreGraph;
|
||||
let maxInDegree = params.maxInDegree;
|
||||
let maxInDegree = PARAMS.maxInDegree;
|
||||
|
||||
// detect first so degrees don't get affected by other removal
|
||||
let highInDegreeNames = _.filter(graph.nodes(), n => {
|
||||
@ -1363,15 +1361,14 @@ function extractHighInDegree(renderNode: RenderGroupNodeInfo,
|
||||
});
|
||||
|
||||
_.each(highInDegreeNames, n => {
|
||||
makeOutExtract(renderNode, n, params);
|
||||
makeOutExtract(renderNode, n);
|
||||
});
|
||||
}
|
||||
|
||||
/** Extract nodes with out-degree > maxOutDegree */
|
||||
function extractHighOutDegree(renderNode: RenderGroupNodeInfo,
|
||||
params: RenderGraphParams) {
|
||||
function extractHighOutDegree(renderNode: RenderGroupNodeInfo) {
|
||||
let graph = renderNode.coreGraph;
|
||||
let maxOutDegree = params.maxOutDegree;
|
||||
let maxOutDegree = PARAMS.maxOutDegree;
|
||||
|
||||
// detect first so degrees don't get affected by other removal
|
||||
let highOutDegreeNames = _.filter(graph.nodes(), n => {
|
||||
@ -1394,13 +1391,12 @@ function extractHighOutDegree(renderNode: RenderGroupNodeInfo,
|
||||
});
|
||||
|
||||
_.each(highOutDegreeNames, n => {
|
||||
makeInExtract(renderNode, n, params);
|
||||
makeInExtract(renderNode, n);
|
||||
});
|
||||
}
|
||||
|
||||
/** Remove control edges from nodes that have too many control edges */
|
||||
function removeControlEdges(renderNode: RenderGroupNodeInfo,
|
||||
params: RenderGraphParams) {
|
||||
function removeControlEdges(renderNode: RenderGroupNodeInfo) {
|
||||
let graph = renderNode.coreGraph;
|
||||
|
||||
// Collect control edges into a map by node name.
|
||||
@ -1414,8 +1410,8 @@ function removeControlEdges(renderNode: RenderGroupNodeInfo,
|
||||
|
||||
// For each node with too many control edges, turn them into annotations.
|
||||
_.each(map, (edges, nodeName) => {
|
||||
if (edges.length > params.maxControlDegree) {
|
||||
_.each(edges, e => createShortcut(graph, e.v, e.w, params));
|
||||
if (edges.length > PARAMS.maxControlDegree) {
|
||||
_.each(edges, e => createShortcut(graph, e.v, e.w));
|
||||
}
|
||||
});
|
||||
}
|
||||
@ -1445,38 +1441,35 @@ export function mapIndexToHue(id: number): number {
|
||||
* screw up the graph layout.
|
||||
*
|
||||
* @param {Render.Node} renderNode Node to manipulate.
|
||||
* @param {Object} params render Graph construction parameters. See
|
||||
* <tf-graph-params>'s output
|
||||
*/
|
||||
function extractHighDegrees(renderNode: RenderGroupNodeInfo,
|
||||
params: RenderGraphParams) {
|
||||
function extractHighDegrees(renderNode: RenderGroupNodeInfo) {
|
||||
|
||||
extractSpecifiedNodes(renderNode, params);
|
||||
extractSpecifiedNodes(renderNode);
|
||||
|
||||
if (params.outExtractTypes) {
|
||||
extractPredefinedSink(renderNode, params);
|
||||
if (PARAMS.outExtractTypes) {
|
||||
extractPredefinedSink(renderNode);
|
||||
}
|
||||
|
||||
// This has to come before extract high in-degree to protect the core part
|
||||
// that takes many variables.
|
||||
if (params.inExtractTypes) {
|
||||
extractPredefinedSource(renderNode, params);
|
||||
if (PARAMS.inExtractTypes) {
|
||||
extractPredefinedSource(renderNode);
|
||||
}
|
||||
|
||||
// This has to come before extract high out-degree to protect the core part
|
||||
// that output to many places as there are more high-degree sinks than
|
||||
// sources.
|
||||
|
||||
if (params.maxInDegree) {
|
||||
extractHighInDegree(renderNode, params);
|
||||
if (PARAMS.maxInDegree) {
|
||||
extractHighInDegree(renderNode);
|
||||
}
|
||||
|
||||
if (params.maxOutDegree) {
|
||||
extractHighOutDegree(renderNode, params);
|
||||
if (PARAMS.maxOutDegree) {
|
||||
extractHighOutDegree(renderNode);
|
||||
}
|
||||
|
||||
if (params.maxControlDegree) {
|
||||
removeControlEdges(renderNode, params);
|
||||
if (PARAMS.maxControlDegree) {
|
||||
removeControlEdges(renderNode);
|
||||
}
|
||||
|
||||
// Extract isolated nodes, which can be
|
||||
@ -1513,7 +1506,7 @@ function extractHighDegrees(renderNode: RenderGroupNodeInfo,
|
||||
renderNode.isolatedOutExtract.push(child);
|
||||
child.node.include = InclusionType.EXCLUDE;
|
||||
graph.removeNode(n);
|
||||
} else if (params.extractIsolatedNodesWithAnnotationsOnOneSide) {
|
||||
} else if (PARAMS.extractIsolatedNodesWithAnnotationsOnOneSide) {
|
||||
if (hasOutAnnotations && !hasInAnnotations) {
|
||||
child.isInExtract = true; // for ones with high out-annotations
|
||||
renderNode.isolatedInExtract.push(child);
|
||||
|
@ -1,113 +0,0 @@
|
||||
<link rel="import" href="../polymer/polymer.html">
|
||||
<!--
|
||||
Module for adjusting render graph building parameter.
|
||||
-->
|
||||
<dom-module id="tf-graph-params">
|
||||
</dom-module>
|
||||
<script>
|
||||
Polymer({
|
||||
|
||||
is: 'tf-graph-params',
|
||||
|
||||
properties: {
|
||||
// PARAMETERS
|
||||
|
||||
enableExtraction: {
|
||||
type: Boolean,
|
||||
value: true
|
||||
},
|
||||
|
||||
/** Maximum in-degree that a node can have without being considered as
|
||||
* high in-degree node. */
|
||||
maxInDegree: {
|
||||
type: Number,
|
||||
value: 4
|
||||
},
|
||||
/** Maximum out-degree that a node can have without being considered as
|
||||
* high out-degree node. */
|
||||
maxOutDegree: {
|
||||
type: Number,
|
||||
value: 4
|
||||
},
|
||||
/** Maximum number of control edges a node can have before they aren't
|
||||
* displayed. */
|
||||
maxControlDegree: {
|
||||
type: Number,
|
||||
value: 4
|
||||
},
|
||||
|
||||
/**
|
||||
* Types patterns for predefined out-extract nodes, which are
|
||||
* sink-like nodes that will be extracted from the main graph.
|
||||
*/
|
||||
outExtractTypes: {
|
||||
type: Array,
|
||||
value: function() {
|
||||
return [
|
||||
'NoOp' // for "sgd", "momentum" group
|
||||
];
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* Types patterns for predefined in-extract nodes, which are
|
||||
* source-like nodes that will be extracted from the main graph.
|
||||
*/
|
||||
inExtractTypes: {
|
||||
type: Array,
|
||||
value: function() {
|
||||
return ['Variable'];
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* When removing edges from a high degree node, remove all of its edges if
|
||||
* detachAllEdgesForHighDegree is true. Otherwise remove all in-edges if
|
||||
* the node has high in-degree, or all out-edges if the node has high
|
||||
* out-degree.
|
||||
*/
|
||||
detachAllEdgesForHighDegree: {
|
||||
type: Boolean,
|
||||
value: true
|
||||
},
|
||||
|
||||
/**
|
||||
* After extracting high in/out degree nodes and predefined
|
||||
* source-like/sink-like, extract isolated nodes to the side
|
||||
* if this extractIsolatedNodesWithAnnotationsOnOneSide is true.
|
||||
*/
|
||||
extractIsolatedNodesWithAnnotationsOnOneSide: {
|
||||
type: Boolean,
|
||||
value: true
|
||||
},
|
||||
|
||||
/**
|
||||
* Whether to draw bridge paths inside of expanded group nodes.
|
||||
*/
|
||||
enableBridgegraph: {
|
||||
type: Boolean,
|
||||
value: true
|
||||
},
|
||||
|
||||
/**
|
||||
* Colors for the minimum and maximum values whenever we have a gradient
|
||||
* scale.
|
||||
*/
|
||||
minMaxColors: {
|
||||
type: Array,
|
||||
value: function() {
|
||||
return ["#fff5f0", "#fb6a4a"];
|
||||
}
|
||||
},
|
||||
|
||||
/**
|
||||
* Maximum number of annotations to be displayed on a node before an
|
||||
* ellipsis is used.
|
||||
*/
|
||||
maxAnnotations: {
|
||||
type: Number,
|
||||
value: 5
|
||||
}
|
||||
}
|
||||
});
|
||||
</script>
|
@ -6,7 +6,6 @@
|
||||
<link rel="import" href="../paper-toggle-button/paper-toggle-button.html">
|
||||
<link rel="import" href="../tf-graph-common/tf-graph-common.html">
|
||||
<link rel="import" href="tf-graph-scene.html">
|
||||
<link rel="import" href="tf-graph-params.html">
|
||||
<dom-module id="tf-graph">
|
||||
<template>
|
||||
<style>
|
||||
@ -37,7 +36,6 @@ paper-button {
|
||||
}
|
||||
</style>
|
||||
<div class="container">
|
||||
<tf-graph-params id="graphParams"></tf-graph-params>
|
||||
<div class="vertical">
|
||||
<h2>[[title]]</h2>
|
||||
<tf-graph-scene id="scene" class="auto"
|
||||
@ -91,13 +89,6 @@ Polymer({
|
||||
readOnly: true,
|
||||
notify: true,
|
||||
},
|
||||
// internal properties
|
||||
_graphParams: {
|
||||
type: Object,
|
||||
value: function() {
|
||||
return this.$.graphParams;
|
||||
}
|
||||
},
|
||||
_renderDepth: {
|
||||
type: Number,
|
||||
value: 1
|
||||
@ -108,9 +99,9 @@ Polymer({
|
||||
},
|
||||
},
|
||||
observers: [
|
||||
'_buildRenderHierarchy(graphHierarchy, _graphParams)'
|
||||
'_buildRenderHierarchy(graphHierarchy)'
|
||||
],
|
||||
_buildRenderHierarchy: function(graphHierarchy, params) {
|
||||
_buildRenderHierarchy: function(graphHierarchy) {
|
||||
tf.time('new tf.graph.render.Hierarchy', function() {
|
||||
if (graphHierarchy.root.type !== tf.graph.NodeType.META) {
|
||||
// root must be metanode but sometimes Polymer's dom-if has not
|
||||
@ -118,8 +109,7 @@ Polymer({
|
||||
// and thus mistakenly pass non-metanode to this module.
|
||||
return;
|
||||
}
|
||||
var renderGraph = new tf.graph.render.RenderGraphInfo(graphHierarchy,
|
||||
params);
|
||||
var renderGraph = new tf.graph.render.RenderGraphInfo(graphHierarchy);
|
||||
// Producing the 'color by' parameters to be consumed
|
||||
// by the tf-graph-controls panel. It contains information about the
|
||||
// min and max values and their respective colors, as well as list
|
||||
@ -252,7 +242,7 @@ Polymer({
|
||||
}
|
||||
|
||||
// Rebuild the render hierarchy.
|
||||
this._buildRenderHierarchy(this.graphHierarchy, this._graphParams);
|
||||
this._buildRenderHierarchy(this.graphHierarchy);
|
||||
},
|
||||
_nodeToggleSeriesGroup: function(event) {
|
||||
// Toggle the group setting of the specified node appropriately.
|
||||
@ -270,7 +260,7 @@ Polymer({
|
||||
tf.graph.hierarchy.build(this.basicGraph, this.hierarchyParams, hierarchyTracker)
|
||||
.then(function(graphHierarchy) {
|
||||
this.set('graphHierarchy', graphHierarchy);
|
||||
this._buildRenderHierarchy(this.graphHierarchy, this._graphParams);
|
||||
this._buildRenderHierarchy(this.graphHierarchy);
|
||||
}.bind(this));
|
||||
},
|
||||
not: function(x) {
|
||||
|
@ -13,8 +13,8 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
|
||||
|
||||
native.new_http_archive(
|
||||
name = "eigen_archive",
|
||||
url = "https://bitbucket.org/eigen/eigen/get/36b0586de49f.tar.gz",
|
||||
sha256 = "86da9dd97c91b6587a257add70d9478f4c463d6697d487564c5bfe83c4a0e8e0",
|
||||
url = "https://bitbucket.org/eigen/eigen/get/3d9f227afae2.tar.gz",
|
||||
sha256 = "bf2638b7e1085de0b430b000c07e090dc71c83dd7f5b934a06f68b7db02676bf",
|
||||
build_file = path_prefix + "eigen.BUILD",
|
||||
)
|
||||
|
||||
|
2
third_party/eigen3/Eigen/Cholesky
vendored
2
third_party/eigen3/Eigen/Cholesky
vendored
@ -1 +1 @@
|
||||
#include "eigen-eigen-36b0586de49f/Eigen/Cholesky"
|
||||
#include "eigen-eigen-3d9f227afae2/Eigen/Cholesky"
|
||||
|
2
third_party/eigen3/Eigen/Core
vendored
2
third_party/eigen3/Eigen/Core
vendored
@ -1 +1 @@
|
||||
#include "eigen-eigen-36b0586de49f/Eigen/Core"
|
||||
#include "eigen-eigen-3d9f227afae2/Eigen/Core"
|
||||
|
2
third_party/eigen3/Eigen/Eigenvalues
vendored
2
third_party/eigen3/Eigen/Eigenvalues
vendored
@ -1 +1 @@
|
||||
#include "eigen-eigen-36b0586de49f/Eigen/Eigenvalues"
|
||||
#include "eigen-eigen-3d9f227afae2/Eigen/Eigenvalues"
|
||||
|
2
third_party/eigen3/Eigen/LU
vendored
2
third_party/eigen3/Eigen/LU
vendored
@ -1 +1 @@
|
||||
#include "eigen-eigen-36b0586de49f/Eigen/LU"
|
||||
#include "eigen-eigen-3d9f227afae2/Eigen/LU"
|
||||
|
2
third_party/eigen3/Eigen/QR
vendored
2
third_party/eigen3/Eigen/QR
vendored
@ -1 +1 @@
|
||||
#include "eigen-eigen-36b0586de49f/Eigen/QR"
|
||||
#include "eigen-eigen-3d9f227afae2/Eigen/QR"
|
||||
|
@ -1 +1 @@
|
||||
#include "eigen-eigen-36b0586de49f/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "eigen-eigen-3d9f227afae2/unsupported/Eigen/CXX11/Tensor"
|
||||
|
@ -3,6 +3,8 @@ build:cuda --define=using_cuda=true
|
||||
|
||||
build --force_python=py$PYTHON_MAJOR_VERSION
|
||||
build --python$PYTHON_MAJOR_VERSION_path=$PYTHON_BINARY
|
||||
build --define=use_fast_cpp_protos=true
|
||||
build --define=allow_oversize_protos=true
|
||||
|
||||
build --spawn_strategy=standalone
|
||||
test --spawn_strategy=standalone
|
||||
|
Loading…
Reference in New Issue
Block a user