Merge commit for internal changes

This commit is contained in:
Vijay Vasudevan 2017-10-20 13:31:37 -07:00
commit cf336b3365
128 changed files with 6825 additions and 2227 deletions

View File

@ -5,7 +5,7 @@ http_archive(
sha256 = "110fe68753413777944b473c25eed6368c4a0487cee23a7bac1b13cc49d3e257", sha256 = "110fe68753413777944b473c25eed6368c4a0487cee23a7bac1b13cc49d3e257",
strip_prefix = "rules_closure-4af89ef1db659eb41f110df189b67d4cf14073e1", strip_prefix = "rules_closure-4af89ef1db659eb41f110df189b67d4cf14073e1",
urls = [ urls = [
"http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", # 2017-08-28 "https://github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", # 2017-08-28
], ],
) )

View File

@ -348,6 +348,7 @@ filegroup(
"//tensorflow/compiler/xla/service/llvm_ir:all_files", "//tensorflow/compiler/xla/service/llvm_ir:all_files",
"//tensorflow/compiler/xla/tests:all_files", "//tensorflow/compiler/xla/tests:all_files",
"//tensorflow/compiler/xla/tools:all_files", "//tensorflow/compiler/xla/tools:all_files",
"//tensorflow/compiler/xla/tools/parser:all_files",
"//tensorflow/contrib:all_files", "//tensorflow/contrib:all_files",
"//tensorflow/contrib/all_reduce:all_files", "//tensorflow/contrib/all_reduce:all_files",
"//tensorflow/contrib/android:all_files", "//tensorflow/contrib/android:all_files",
@ -421,7 +422,6 @@ filegroup(
"//tensorflow/contrib/remote_fused_graph/pylib:all_files", "//tensorflow/contrib/remote_fused_graph/pylib:all_files",
"//tensorflow/contrib/resampler:all_files", "//tensorflow/contrib/resampler:all_files",
"//tensorflow/contrib/rnn:all_files", "//tensorflow/contrib/rnn:all_files",
"//tensorflow/contrib/s3:all_files",
"//tensorflow/contrib/saved_model:all_files", "//tensorflow/contrib/saved_model:all_files",
"//tensorflow/contrib/saved_model/cc/saved_model:all_files", "//tensorflow/contrib/saved_model/cc/saved_model:all_files",
"//tensorflow/contrib/seq2seq:all_files", "//tensorflow/contrib/seq2seq:all_files",
@ -475,6 +475,7 @@ filegroup(
"//tensorflow/core/platform/cloud:all_files", "//tensorflow/core/platform/cloud:all_files",
"//tensorflow/core/platform/default/build_config:all_files", "//tensorflow/core/platform/default/build_config:all_files",
"//tensorflow/core/platform/hadoop:all_files", "//tensorflow/core/platform/hadoop:all_files",
"//tensorflow/core/platform/s3:all_files",
"//tensorflow/core/profiler:all_files", "//tensorflow/core/profiler:all_files",
"//tensorflow/core/profiler/internal:all_files", "//tensorflow/core/profiler/internal:all_files",
"//tensorflow/core/profiler/internal/advisor:all_files", "//tensorflow/core/profiler/internal/advisor:all_files",

View File

@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0
load( load(
"//tensorflow:tensorflow.bzl", "//tensorflow:tensorflow.bzl",
"tf_cuda_cc_test",
"tf_cc_test", "tf_cc_test",
"tf_copts", "tf_copts",
"tf_cuda_library", "tf_cuda_library",
@ -50,7 +51,7 @@ tf_cuda_library(
], ],
) )
tf_cc_test( tf_cuda_cc_test(
name = "c_api_test", name = "c_api_test",
srcs = ["c_api_test.cc"], srcs = ["c_api_test.cc"],
deps = [ deps = [

View File

@ -54,9 +54,23 @@ string DeviceName(tensorflow::Device* d) {
extern "C" { extern "C" {
TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) { TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
size_t proto_len, TF_Status* status) {
TF_SetConfig(&options->session_options, proto, proto_len, status);
}
void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
options->policy = policy;
}
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
TF_Graph* graph = TF_NewGraph(); TF_Graph* graph = TF_NewGraph();
TF_Session* session = TF_NewSession(graph, opts, status); TF_Session* session = TF_NewSession(graph, &opts->session_options, status);
if (status->status.ok()) { if (status->status.ok()) {
if (session->device_mgr == nullptr || session->devices.empty()) { if (session->device_mgr == nullptr || session->devices.empty()) {
status->status = tensorflow::errors::InvalidArgument( status->status = tensorflow::errors::InvalidArgument(
@ -71,9 +85,10 @@ TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) {
} }
TFE_Context* ret = new TFE_Context(session); TFE_Context* ret = new TFE_Context(session);
ret->policy = opts->policy;
ret->pflr.reset(new tensorflow::ProcessFunctionLibraryRuntime( ret->pflr.reset(new tensorflow::ProcessFunctionLibraryRuntime(
ret->session->device_mgr, opts->options.env, TF_GRAPH_DEF_VERSION, ret->session->device_mgr, opts->session_options.options.env,
&ret->func_lib_def, {})); TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {}));
ret->rendezvous = ret->rendezvous =
new tensorflow::IntraProcessRendezvous(ret->session->device_mgr); new tensorflow::IntraProcessRendezvous(ret->session->device_mgr);
@ -408,8 +423,10 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
namespace { namespace {
tensorflow::Status ValidateInputTypeAndPlacement( tensorflow::Status ValidateInputTypeAndPlacement(
tensorflow::Device* host_device, tensorflow::Device* op_device, TFE_Op* op, TFE_Context* ctx, tensorflow::Device* host_device,
const tensorflow::OpKernel* kernel) { tensorflow::Device* op_device, TFE_Op* op,
const tensorflow::OpKernel* kernel,
std::vector<TFE_TensorHandle*>* copied_tensors) {
const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types(); const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types();
if (memtypes.size() != op->inputs.size()) { if (memtypes.size() != op->inputs.size()) {
return tensorflow::errors::InvalidArgument( return tensorflow::errors::InvalidArgument(
@ -421,11 +438,42 @@ tensorflow::Status ValidateInputTypeAndPlacement(
const tensorflow::Device* actual_device = const tensorflow::Device* actual_device =
op->input_devices[i] == nullptr ? host_device : op->input_devices[i]; op->input_devices[i] == nullptr ? host_device : op->input_devices[i];
if (expected_device != actual_device) { if (expected_device != actual_device) {
return tensorflow::errors::InvalidArgument( switch (ctx->policy) {
"cannot compute ", op->name, " as input #", i, case TFE_DEVICE_PLACEMENT_EXPLICIT:
" was expected to be on ", expected_device->name(), return tensorflow::errors::InvalidArgument(
" but is actually on ", actual_device->name(), "cannot compute ", op->name, " as input #", i,
" (operation running on ", op_device->name(), ")"); " was expected to be on ", expected_device->name(),
" but is actually on ", actual_device->name(),
" (operation running on ", op_device->name(), ")");
case TFE_DEVICE_PLACEMENT_WARN:
LOG(WARNING) << "before computing " << op->name << " input #" << i
<< " was expected to be on " << expected_device->name()
<< " but is actually on " << actual_device->name()
<< " (operation running on " << op_device->name()
<< "). This triggers a copy which can be a performance "
"bottleneck.";
break;
case TFE_DEVICE_PLACEMENT_SILENT: // Do nothing.
break;
}
// We are only here if the policy is warn or silent copies, so we should
// trigger a copy.
TFE_TensorHandle original{op->inputs[i], op->input_devices[i]};
TF_Status* s = TF_NewStatus();
TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice(
&original, ctx, expected_device->name().c_str(), s);
if (!s->status.ok()) {
tensorflow::Status status = s->status;
delete s;
return tensorflow::errors::Internal(
"Failed copying input tensor from ", actual_device->name(), " to ",
expected_device->name(), " in order to run ", op->name, ": ",
status.error_message());
}
op->inputs[i] = copied_tensor->t;
copied_tensors->push_back(copied_tensor);
op->input_devices[i] = copied_tensor->d;
delete s;
} }
if (op->inputs[i].dtype() != kernel->input_type(i)) { if (op->inputs[i].dtype() != kernel->input_type(i)) {
return tensorflow::errors::InvalidArgument( return tensorflow::errors::InvalidArgument(
@ -468,10 +516,14 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
} }
tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
} }
status->status = ValidateInputTypeAndPlacement(ctx->devices()[0], device, op, std::vector<TFE_TensorHandle*> copied_tensors;
kernel->kernel()); status->status = ValidateInputTypeAndPlacement(
ctx, ctx->devices()[0], device, op, kernel->kernel(), &copied_tensors);
output_memory_types = &kernel->kernel()->output_memory_types(); output_memory_types = &kernel->kernel()->output_memory_types();
if (!status->status.ok()) { if (!status->status.ok()) {
for (auto* t : copied_tensors) {
TFE_DeleteTensorHandle(t);
}
return; return;
} }
// WARNING: kernel->Run utilizes the FunctionLibraryRuntime // WARNING: kernel->Run utilizes the FunctionLibraryRuntime
@ -483,6 +535,9 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
// sense for FunctionLibraryRuntime to ensure thread-safe access to // sense for FunctionLibraryRuntime to ensure thread-safe access to
// FunctionLibraryDefinition?). // FunctionLibraryDefinition?).
status->status = kernel->Run(&op->inputs, &outputs); status->status = kernel->Run(&op->inputs, &outputs);
for (auto* t : copied_tensors) {
TFE_DeleteTensorHandle(t);
}
if (!status->status.ok()) return; if (!status->status.ok()) return;
*num_retvals = std::min<int>(*num_retvals, outputs.size()); *num_retvals = std::min<int>(*num_retvals, outputs.size());
for (int i = 0; i < *num_retvals; ++i) { for (int i = 0; i < *num_retvals; ++i) {

View File

@ -43,14 +43,46 @@ limitations under the License.
extern "C" { extern "C" {
#endif #endif
typedef struct TFE_ContextOptions TFE_ContextOptions;
// Return a new options object.
TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions();
// Set the config in TF_ContextOptions.options.
// config should be a serialized tensorflow.ConfigProto proto.
// If config was not parsed successfully as a ConfigProto, record the
// error information in *status.
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetConfig(
TFE_ContextOptions* options, const void* proto, size_t proto_len,
TF_Status* status);
// Controls how to act when we try to run an operation on a given device but
// some input tensors are not on that device.
typedef enum TFE_ContextDevicePlacementPolicy {
// The default: running operations with input tensors on the wrong device will
// fail.
TFE_DEVICE_PLACEMENT_EXPLICIT = 0,
// Copy the tensor to the right device but log a warning.
TFE_DEVICE_PLACEMENT_WARN = 1,
// Silently copy the tensor, which has a performance cost since the
// operation will be blocked till the copy completes.
TFE_DEVICE_PLACEMENT_SILENT = 2,
} TFE_ContextDevicePlacementPolicy;
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy);
// Destroy an options object.
TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*);
// "Context" under which operations/functions are executed. It encapsulates // "Context" under which operations/functions are executed. It encapsulates
// things like the available devices, resource manager etc. // things like the available devices, resource manager etc.
// //
// TODO(ashankar): Merge with TF_Session? // TODO(ashankar): Merge with TF_Session?
typedef struct TFE_Context TFE_Context; typedef struct TFE_Context TFE_Context;
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext(
TF_Status* status); const TFE_ContextOptions* opts, TF_Status* status);
TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status); TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status);
TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
TF_Status* status); TF_Status* status);

View File

@ -35,9 +35,16 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/thread_annotations.h"
struct TFE_ContextOptions {
TF_SessionOptions session_options;
TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_EXPLICIT};
};
struct TFE_Context { struct TFE_Context {
explicit TFE_Context(TF_Session* s) : session(s) {} explicit TFE_Context(TF_Session* s) : session(s) {}
TFE_ContextDevicePlacementPolicy policy;
// TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph. // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph.
TF_Session* session; TF_Session* session;
tensorflow::Rendezvous* rendezvous; tensorflow::Rendezvous* rendezvous;

View File

@ -62,10 +62,10 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
void BM_InitOp(int iters) { void BM_InitOp(int iters) {
tensorflow::testing::StopTiming(); tensorflow::testing::StopTiming();
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions(); TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status); TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts); TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_TensorHandle* m = TestMatrixTensorHandle();
tensorflow::testing::StartTiming(); tensorflow::testing::StartTiming();
@ -84,10 +84,10 @@ BENCHMARK(BM_InitOp);
void BM_Execute(int iters) { void BM_Execute(int iters) {
tensorflow::testing::StopTiming(); tensorflow::testing::StopTiming();
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions(); TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status); TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts); TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul = MatMulOp(ctx, m, m); TFE_Op* matmul = MatMulOp(ctx, m, m);
@ -109,9 +109,9 @@ BENCHMARK(BM_Execute);
TEST(CAPI, Context) { TEST(CAPI, Context) {
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions(); TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status); TFE_Context* ctx = TFE_NewContext(opts, status);
TF_DeleteSessionOptions(opts); TFE_DeleteContextOptions(opts);
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status); TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@ -150,9 +150,9 @@ TEST(CAPI, TensorHandle) {
TEST(CAPI, TensorHandleCopyBetweenDevices) { TEST(CAPI, TensorHandleCopyBetweenDevices) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
TF_SessionOptions* opts = TF_NewSessionOptions(); TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get()); TFE_Context* ctx = TFE_NewContext(opts, status.get());
TF_DeleteSessionOptions(opts); TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
@ -216,12 +216,58 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) {
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
} }
TEST(CAPI, TensorHandleSilentCopy) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
const int num_devices = TF_DeviceListCount(devices);
// Disable the test if no GPU is present.
if (num_devices > 1) {
const int device_to_use = 1;
const string name(TF_DeviceListName(devices, device_to_use, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hgpu =
TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
TFE_OpSetDevice(matmul, name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(hgpu);
}
TF_DeleteDeviceList(devices);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
TFE_DeleteContext(ctx, status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
}
TEST(CAPI, Execute) { TEST(CAPI, Execute) {
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions(); TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status); TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts); TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul = MatMulOp(ctx, m, m); TFE_Op* matmul = MatMulOp(ctx, m, m);
@ -285,10 +331,10 @@ string MatMulFunction() {
TEST(CAPI, FunctionDefAndExecute) { TEST(CAPI, FunctionDefAndExecute) {
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions(); TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status); TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts); TFE_DeleteContextOptions(opts);
string function_def = MatMulFunction(); string function_def = MatMulFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
@ -326,10 +372,10 @@ TEST(CAPI, FunctionDefAndExecute) {
void BM_ExecuteFunction(int iters) { void BM_ExecuteFunction(int iters) {
tensorflow::testing::StopTiming(); tensorflow::testing::StopTiming();
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions(); TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status); TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts); TFE_DeleteContextOptions(opts);
string function_def = MatMulFunction(); string function_def = MatMulFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
@ -406,10 +452,10 @@ TEST(CAPI, Variables) {
// Variables use resource handles, so this is really a test for resource // Variables use resource handles, so this is really a test for resource
// tensor handling. // tensor handling.
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions(); TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status); TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts); TFE_DeleteContextOptions(opts);
TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status); TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@ -446,10 +492,10 @@ TEST(CAPI, Variables) {
void BM_ReadVariable(int iters) { void BM_ReadVariable(int iters) {
tensorflow::testing::StopTiming(); tensorflow::testing::StopTiming();
TF_Status* status = TF_NewStatus(); TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions(); TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status); TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts); TFE_DeleteContextOptions(opts);
TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status); TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);

View File

@ -138,6 +138,11 @@ class ComputationBuilder {
ComputationDataHandle ConstantR2( ComputationDataHandle ConstantR2(
std::initializer_list<std::initializer_list<NativeT>> values); std::initializer_list<std::initializer_list<NativeT>> values);
template <typename NativeT> template <typename NativeT>
ComputationDataHandle ConstantFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout);
template <typename NativeT>
ComputationDataHandle ConstantFromArray(const Array<NativeT>& values);
template <typename NativeT>
ComputationDataHandle ConstantR2FromArray2DWithLayout( ComputationDataHandle ConstantR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout); const Array2D<NativeT>& values, const Layout& layout);
template <typename NativeT> template <typename NativeT>
@ -909,49 +914,55 @@ ComputationDataHandle ComputationBuilder::ConstantR2(
[&values](Literal* literal) { literal->PopulateR2(values); }); [&values](Literal* literal) { literal->PopulateR2(values); });
} }
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout) {
return ConstantOp([&values, &layout](Literal* literal) {
literal->PopulateFromArrayWithLayout(values, layout);
});
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantFromArray(
const Array<NativeT>& values) {
return ConstantOp(
[&values](Literal* literal) { literal->PopulateFromArray(values); });
}
template <typename NativeT> template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) { const Array2D<NativeT>& values, const Layout& layout) {
return ConstantOp([&values, &layout](Literal* literal) { return ConstantFromArrayWithLayout(values, layout);
literal->PopulateR2FromArray2DWithLayout(values, layout);
});
} }
template <typename NativeT> template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D( ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D(
const Array2D<NativeT>& values) { const Array2D<NativeT>& values) {
return ConstantOp( return ConstantFromArray(values);
[&values](Literal* literal) { literal->PopulateR2FromArray2D(values); });
} }
template <typename NativeT> template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout( ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) { const Array3D<NativeT>& values, const Layout& layout) {
return ConstantOp([&values, &layout](Literal* literal) { return ConstantFromArrayWithLayout(values, layout);
literal->PopulateR3FromArray3DWithLayout(values, layout);
});
} }
template <typename NativeT> template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D( ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D(
const Array3D<NativeT>& values) { const Array3D<NativeT>& values) {
return ConstantOp( return ConstantFromArray(values);
[&values](Literal* literal) { literal->PopulateR3FromArray3D(values); });
} }
template <typename NativeT> template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout( ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout(
const Array4D<NativeT>& values, const Layout& layout) { const Array4D<NativeT>& values, const Layout& layout) {
return ConstantOp([&values, &layout](Literal* literal) { return ConstantFromArrayWithLayout(values, layout);
literal->PopulateR4FromArray4DWithLayout(values, layout);
});
} }
template <typename NativeT> template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D( ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D(
const Array4D<NativeT>& values) { const Array4D<NativeT>& values) {
return ConstantOp( return ConstantFromArray(values);
[&values](Literal* literal) { literal->PopulateR4FromArray4D(values); });
} }
} // namespace xla } // namespace xla

View File

@ -83,6 +83,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
return CreateDefaultLayoutForRank(shape.dimensions_size()); return CreateDefaultLayoutForRank(shape.dimensions_size());
} }
/* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64 rank) {
return CreateDefaultLayoutForRank(rank);
}
/* static */ Layout LayoutUtil::GetDefaultLayoutForR2() { /* static */ Layout LayoutUtil::GetDefaultLayoutForR2() {
return CreateDefaultLayoutForRank(2); return CreateDefaultLayoutForRank(2);
} }

View File

@ -40,6 +40,7 @@ class LayoutUtil {
static Layout GetDefaultLayoutForShape(const Shape& shape); static Layout GetDefaultLayoutForShape(const Shape& shape);
// Helper functions that create default layouts for various ranks. // Helper functions that create default layouts for various ranks.
static Layout GetDefaultLayoutForRank(int64 rank);
static Layout GetDefaultLayoutForR2(); static Layout GetDefaultLayoutForR2();
static Layout GetDefaultLayoutForR3(); static Layout GetDefaultLayoutForR3();
static Layout GetDefaultLayoutForR4(); static Layout GetDefaultLayoutForR4();

View File

@ -206,9 +206,9 @@ void AllocateFlags() {
flag_values->xla_gpu_disable_multi_streaming(), flag_values->xla_gpu_disable_multi_streaming(),
"If true, multi-streaming in the GPU backend is disabled."), "If true, multi-streaming in the GPU backend is disabled."),
tensorflow::Flag( tensorflow::Flag(
"xla_dump_debug_json_to", "xla_dump_hlo_proto_to",
flag_values->mutable_xla_dump_debug_json_to(), flag_values->mutable_xla_dump_hlo_proto_to(),
"Dump compilation artifacts as JSON into this directory."), "Dump compilation artifacts as proto binary into this directory."),
tensorflow::Flag( tensorflow::Flag(
"xla_test_all_output_layouts", "xla_test_all_output_layouts",
bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts),

View File

@ -334,6 +334,11 @@ class Literal {
// WithLayout use the default XLA layout for the literal's linear // WithLayout use the default XLA layout for the literal's linear
// representation in memory. // representation in memory.
template <typename NativeT> template <typename NativeT>
static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values);
template <typename NativeT>
static std::unique_ptr<Literal> CreateFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout);
template <typename NativeT>
static std::unique_ptr<Literal> CreateR2FromArray2D( static std::unique_ptr<Literal> CreateR2FromArray2D(
const Array2D<NativeT>& values); const Array2D<NativeT>& values);
template <typename NativeT> template <typename NativeT>
@ -481,6 +486,11 @@ class Literal {
std::initializer_list<std::initializer_list<NativeT>> values, std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout); const Layout& layout);
template <typename NativeT> template <typename NativeT>
void PopulateFromArray(const Array<NativeT>& values);
template <typename NativeT>
void PopulateFromArrayWithLayout(const Array<NativeT>& values,
const Layout& layout);
template <typename NativeT>
void PopulateR2FromArray2D(const Array2D<NativeT>& values); void PopulateR2FromArray2D(const Array2D<NativeT>& values);
template <typename NativeT> template <typename NativeT>
void PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values, void PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
@ -815,34 +825,43 @@ template <typename NativeT>
return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
} }
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout) {
auto literal = MakeUnique<Literal>();
literal->PopulateFromArrayWithLayout(values, layout);
return literal;
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateFromArray(
const Array<NativeT>& values) {
return CreateFromArrayWithLayout(
values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
}
template <typename NativeT> template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2DWithLayout( /* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) { const Array2D<NativeT>& values, const Layout& layout) {
auto literal = MakeUnique<Literal>(); return CreateFromArrayWithLayout(values, layout);
literal->PopulateR2FromArray2DWithLayout(values, layout);
return literal;
} }
template <typename NativeT> template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2D( /* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2D(
const Array2D<NativeT>& values) { const Array2D<NativeT>& values) {
return CreateR2FromArray2DWithLayout(values, return CreateFromArray(values);
LayoutUtil::GetDefaultLayoutForR2());
} }
template <typename NativeT> template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3DWithLayout( /* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) { const Array3D<NativeT>& values, const Layout& layout) {
auto literal = MakeUnique<Literal>(); return CreateFromArrayWithLayout(values, layout);
literal->PopulateR3FromArray3DWithLayout(values, layout);
return literal;
} }
template <typename NativeT> template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3D( /* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3D(
const Array3D<NativeT>& values) { const Array3D<NativeT>& values) {
return CreateR3FromArray3DWithLayout(values, return CreateFromArray(values);
LayoutUtil::GetDefaultLayoutForR3());
} }
template <typename NativeT> template <typename NativeT>
@ -901,16 +920,13 @@ template <typename NativeT>
template <typename NativeT> template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4D( /* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4D(
const Array4D<NativeT>& values) { const Array4D<NativeT>& values) {
return CreateR4FromArray4DWithLayout(values, return CreateFromArray(values);
LayoutUtil::GetDefaultLayoutForR4());
} }
template <typename NativeT> template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4DWithLayout( /* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4DWithLayout(
const Array4D<NativeT>& values, const Layout& layout) { const Array4D<NativeT>& values, const Layout& layout) {
auto literal = MakeUnique<Literal>(); return CreateFromArrayWithLayout(values, layout);
literal->PopulateR4FromArray4DWithLayout(values, layout);
return literal;
} }
template <typename NativeT> template <typename NativeT>
@ -1069,83 +1085,54 @@ void Literal::PopulateR2(
PopulateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); PopulateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
} }
template <typename NativeT>
void Literal::PopulateFromArrayWithLayout(const Array<NativeT>& values,
const Layout& layout) {
*mutable_shape() = ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
AsInt64Slice(layout.minor_to_major()));
Reserve(values.num_elements());
values.Each([this](tensorflow::gtl::ArraySlice<int64> indices,
NativeT value) { this->Set(indices, value); });
}
template <typename NativeT>
void Literal::PopulateFromArray(const Array<NativeT>& values) {
PopulateFromArrayWithLayout(
values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
}
template <typename NativeT> template <typename NativeT>
void Literal::PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values, void Literal::PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
const Layout& layout) { const Layout& layout) {
*mutable_shape() = ShapeUtil::MakeShapeWithLayout( PopulateFromArrayWithLayout(values, layout);
primitive_util::NativeToPrimitiveType<NativeT>(),
{values.height(), values.width()}, AsInt64Slice(layout.minor_to_major()));
const int64 dim1_size = values.width();
const int64 dim0_size = values.height();
CHECK_EQ(dim0_size, shape().dimensions(0));
CHECK_EQ(dim1_size, shape().dimensions(1));
Reserve(dim1_size * dim0_size);
for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) {
Set({dim0, dim1}, values(dim0, dim1));
}
}
} }
template <typename NativeT> template <typename NativeT>
void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) { void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
PopulateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); PopulateFromArray(values);
} }
template <typename NativeT> template <typename NativeT>
void Literal::PopulateR3FromArray3DWithLayout(const Array3D<NativeT>& values, void Literal::PopulateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
const Layout& layout) { const Layout& layout) {
*mutable_shape() = ShapeUtil::MakeShapeWithLayout( PopulateFromArrayWithLayout(values, layout);
primitive_util::NativeToPrimitiveType<NativeT>(),
{values.n1(), values.n2(), values.n3()},
AsInt64Slice(layout.minor_to_major()));
CHECK_EQ(values.n1(), shape().dimensions(0));
CHECK_EQ(values.n2(), shape().dimensions(1));
CHECK_EQ(values.n3(), shape().dimensions(2));
Reserve(values.n1() * values.n2() * values.n3());
for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) {
for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) {
for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) {
Set({dim0, dim1, dim2}, values(dim0, dim1, dim2));
}
}
}
} }
template <typename NativeT> template <typename NativeT>
void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) { void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
PopulateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); PopulateFromArray(values);
} }
template <typename NativeT> template <typename NativeT>
void Literal::PopulateR4FromArray4DWithLayout(const Array4D<NativeT>& values, void Literal::PopulateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
const Layout& layout) { const Layout& layout) {
*mutable_shape() = ShapeUtil::MakeShapeWithLayout( PopulateFromArrayWithLayout(values, layout);
primitive_util::NativeToPrimitiveType<NativeT>(),
{values.planes(), values.depth(), values.height(), values.width()},
AsInt64Slice(layout.minor_to_major()));
CHECK_EQ(values.n1(), shape().dimensions(0));
CHECK_EQ(values.n2(), shape().dimensions(1));
CHECK_EQ(values.n3(), shape().dimensions(2));
CHECK_EQ(values.n4(), shape().dimensions(3));
Reserve(values.n1() * values.n2() * values.n3() * values.n4());
for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) {
for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) {
for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) {
for (int64 dim3 = 0; dim3 < values.n4(); ++dim3) {
Set({dim0, dim1, dim2, dim3}, values(dim0, dim1, dim2, dim3));
}
}
}
}
} }
template <typename NativeT> template <typename NativeT>
void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) { void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); PopulateFromArray(values);
} }
template <typename NativeT, typename FnType> template <typename NativeT, typename FnType>

View File

@ -37,20 +37,6 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
return (serialized1 == serialized2); return (serialized1 == serialized2);
} }
StatusOr<string> ToJson(const tensorflow::protobuf::Message& message) {
string json_output;
tensorflow::protobuf::util::JsonPrintOptions json_options;
json_options.add_whitespace = true;
json_options.always_print_primitive_fields = true;
auto status = tensorflow::protobuf::util::MessageToJsonString(
message, &json_output, json_options);
if (!status.ok()) {
return InternalError("MessageToJsonString failed: %s",
status.error_message().data());
}
return json_output;
}
namespace { namespace {
string SanitizeFilename(const string& file_name) { string SanitizeFilename(const string& file_name) {
@ -65,17 +51,6 @@ string SanitizeFilename(const string& file_name) {
} // namespace } // namespace
Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name) {
TF_ASSIGN_OR_RETURN(const string json_output, ToJson(message));
tensorflow::Env* env = tensorflow::Env::Default();
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory));
string safe_file_name = SanitizeFileName(file_name) + ".json";
const string path = tensorflow::io::JoinPath(directory, safe_file_name);
return tensorflow::WriteStringToFile(env, path, json_output);
}
Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name) { const string& directory, const string& file_name) {
tensorflow::Env* env = tensorflow::Env::Default(); tensorflow::Env* env = tensorflow::Env::Default();

View File

@ -32,17 +32,12 @@ namespace protobuf_util {
extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1, extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
const tensorflow::protobuf::Message& m2); const tensorflow::protobuf::Message& m2);
// Returns 'message' as a JSON string. // Writes the given message in binary proto to the path formed by joining
StatusOr<string> ToJson(const tensorflow::protobuf::Message& message); // 'directory/file_name.pb'. The 'directory' is recursively created if it
// doesn't already exist, and the 'file_name' is sanitized by replacing
// Writes the given message in binary proto or JSON format to the path formed by // illegal characters with underscore '_'.
// joining 'directory/file_name.pb' (or file_name.json). The 'directory' is
// recursively created if it doesn't already exist, and the 'file_name' is
// sanitized by replacing illegal characters with underscore '_'.
Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name); const string& directory, const string& file_name);
Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name);
} // namespace protobuf_util } // namespace protobuf_util
} // namespace xla } // namespace xla

View File

@ -2064,6 +2064,29 @@ tf_cc_test(
], ],
) )
cc_library(
name = "hlo_runner",
srcs = ["hlo_runner.cc"],
hdrs = ["hlo_runner.h"],
deps = [
":executable",
":hlo",
":transfer_manager",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//third_party/eigen3",
],
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
filegroup( filegroup(

View File

@ -475,8 +475,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
// ownership is std::moved. // ownership is std::moved.
const bool embed_ir_in_executable = const bool embed_ir_in_executable =
module->config().debug_options().xla_embed_ir_in_executable(); module->config().debug_options().xla_embed_ir_in_executable();
const string dump_debug_json_to = const string xla_dump_hlo_proto_to =
module->config().debug_options().xla_dump_debug_json_to(); module->config().debug_options().xla_dump_hlo_proto_to();
if (options::CpuParallelBackendRequested(module->config())) { if (options::CpuParallelBackendRequested(module->config())) {
VLOG(1) << "Using parallel cpu backend"; VLOG(1) << "Using parallel cpu backend";
@ -496,10 +496,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
// print one ourselves. // print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString()); XLA_VLOG_LINES(2, assignment->ToString());
if (!dump_debug_json_to.empty()) { if (!xla_dump_hlo_proto_to.empty()) {
HloProto proto = MakeHloProto(*module, *assignment); HloProto proto = MakeHloProto(*module, *assignment);
TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
proto, dump_debug_json_to, module->name())); proto, xla_dump_hlo_proto_to, module->name()));
} }
// If we are using the parallel CPU backend, we need to create map from // If we are using the parallel CPU backend, we need to create map from
@ -603,12 +603,11 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
// print one ourselves. // print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString()); XLA_VLOG_LINES(2, assignment->ToString());
if (!dump_debug_json_to.empty()) { if (!xla_dump_hlo_proto_to.empty()) {
HloProto proto = MakeHloProto(*module, *assignment); HloProto proto = MakeHloProto(*module, *assignment);
TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
proto, dump_debug_json_to, module->name())); proto, xla_dump_hlo_proto_to, module->name()));
} }
// Each computation is a single function. Emit all embedded computations // Each computation is a single function. Emit all embedded computations
// before the entry computation. The order of computations returned from // before the entry computation. The order of computations returned from
// GetEmbeddedComputations guarantees that a called computation occurs // GetEmbeddedComputations guarantees that a called computation occurs
@ -775,12 +774,12 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
// print one ourselves. // print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString()); XLA_VLOG_LINES(2, assignment->ToString());
const string dump_debug_json_to = const string xla_dump_hlo_proto_to =
module->config().debug_options().xla_dump_debug_json_to(); module->config().debug_options().xla_dump_hlo_proto_to();
if (!dump_debug_json_to.empty()) { if (!xla_dump_hlo_proto_to.empty()) {
HloProto proto = MakeHloProto(*module, *assignment); HloProto proto = MakeHloProto(*module, *assignment);
TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
proto, dump_debug_json_to, module->name())); proto, xla_dump_hlo_proto_to, module->name()));
} }
IrEmitter ir_emitter(*module, *assignment, &llvm_module, IrEmitter ir_emitter(*module, *assignment, &llvm_module,

View File

@ -136,6 +136,8 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
instruction->opcode() == HloOpcode::kCall || instruction->opcode() == HloOpcode::kCall ||
instruction->opcode() == HloOpcode::kCustomCall || instruction->opcode() == HloOpcode::kCustomCall ||
instruction->opcode() == HloOpcode::kSelectAndScatter || instruction->opcode() == HloOpcode::kSelectAndScatter ||
instruction->opcode() == HloOpcode::kGetTupleElement ||
instruction->opcode() == HloOpcode::kBitcast ||
(instruction->opcode() == HloOpcode::kConvolution && (instruction->opcode() == HloOpcode::kConvolution &&
PotentiallyImplementedAsEigenConvolution(*instruction)) || PotentiallyImplementedAsEigenConvolution(*instruction)) ||
PotentiallyImplementedAsEigenDot(*instruction) || PotentiallyImplementedAsEigenDot(*instruction) ||

View File

@ -318,12 +318,12 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
// print one ourselves. // print one ourselves.
XLA_VLOG_LINES(2, buffer_assignment->ToString()); XLA_VLOG_LINES(2, buffer_assignment->ToString());
const string dump_debug_json_to = const string xla_dump_hlo_proto_to =
module->config().debug_options().xla_dump_debug_json_to(); module->config().debug_options().xla_dump_hlo_proto_to();
if (!dump_debug_json_to.empty()) { if (!xla_dump_hlo_proto_to.empty()) {
HloProto proto = MakeHloProto(*module, *buffer_assignment); HloProto proto = MakeHloProto(*module, *buffer_assignment);
TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
proto, dump_debug_json_to, module->name())); proto, xla_dump_hlo_proto_to, module->name()));
} }
IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(), IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(),

View File

@ -373,8 +373,8 @@ string HloComputation::ToString(int nested_level) const {
for (int i = 0; i < nested_level; i++) { for (int i = 0; i < nested_level; i++) {
s << " "; s << " ";
} }
s << name() << " " << ShapeUtil::HumanString(ComputeProgramShape()) s << "%" << name() << " " << ShapeUtil::HumanString(ComputeProgramShape())
<< " { \n"; << " {\n";
for (const HloInstruction* instruction : MakeInstructionPostOrder()) { for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
for (int i = 0; i < nested_level; i++) { for (int i = 0; i < nested_level; i++) {
s << " "; s << " ";

View File

@ -0,0 +1,199 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_runner.h"
#include <set>
#include <string>
#include <utility>
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace se = ::perftools::gputools;
namespace xla {
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunner::ReadModuleFromHloProtoFile(const char* filename,
const DebugOptions& debug_options) {
HloProto proto;
TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
filename, &proto));
HloModuleConfig config;
config.set_debug_options(debug_options);
TF_ASSIGN_OR_RETURN(auto module, HloModule::CreateFromProto(
proto.hlo_module(),
VersionedComputationHandle(), config));
return std::move(module);
}
// Define this in .cc file to avoid having to include eigen or forward declare
// these types in the header.
struct HloRunner::EigenThreadPoolWrapper {
std::unique_ptr<EigenThreadPoolWrapper> pool;
std::unique_ptr<Eigen::ThreadPoolDevice> device;
};
HloRunner::HloRunner() {}
HloRunner::HloRunner(se::Platform* platform) {
BackendOptions backend_options;
backend_options.set_platform(platform);
backend_ = Backend::CreateBackend(backend_options).ConsumeValueOrDie();
VLOG(1) << "Created HloRunner for platform: " << platform->Name();
}
HloRunner::~HloRunner() {
// Deallocate all the memory allocated during the tests.
for (auto& allocation : allocations_) {
backend().default_stream_executor()->Deallocate(&allocation);
}
}
StatusOr<se::DeviceMemoryBase> HloRunner::Execute(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
Shape* result_shape) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
backend().compiler()->Compile(std::move(module),
backend().default_stream_executor()));
se::Stream stream(backend().default_stream_executor());
stream.Init();
ExecutableRunOptions run_options;
run_options.set_stream(&stream);
run_options.set_allocator(backend().memory_allocator());
run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool());
run_options.set_intra_op_thread_pool(
backend().eigen_intra_op_thread_pool_device());
HloExecutionProfile hlo_execution_profile;
ServiceExecutableRunOptions service_run_options(
run_options, backend().StreamBorrower(),
backend().inter_op_thread_pool());
TF_ASSIGN_OR_RETURN(
se::DeviceMemoryBase result,
executable->ExecuteOnStream(&service_run_options, arguments,
&hlo_execution_profile));
TF_RET_CHECK(stream.BlockHostUntilDone());
allocations_.push_back(result);
*result_shape = executable->result_shape();
if (ShapeUtil::IsTuple(*result_shape)) {
// We must record element buffers of tuples as well to avoid leaks.
DCHECK(!ShapeUtil::IsNestedTuple(*result_shape));
TF_ASSIGN_OR_RETURN(
std::vector<se::DeviceMemoryBase> element_buffers,
backend().transfer_manager()->ShallowCopyTupleFromDevice(
backend().default_stream_executor(), result, *result_shape));
// A tuple may contain the same buffer in more than one element. Keep track
// of the buffers already added to avoid duplicates in allocations_.
std::set<void*> added_opaques;
for (auto element_buffer : element_buffers) {
if (added_opaques.count(element_buffer.opaque()) == 0) {
CHECK(element_buffer.opaque() != nullptr);
added_opaques.insert(element_buffer.opaque());
allocations_.push_back(element_buffer);
}
}
}
return result;
}
se::DeviceMemoryBase HloRunner::TransferToDevice(const Literal& literal) {
// Allocate memory on the device using the stream executor.
int64 allocation_size =
backend().transfer_manager()->GetByteSizeRequirement(literal.shape());
se::DeviceMemoryBase allocation =
backend().default_stream_executor()->AllocateArray<uint8>(
allocation_size);
allocations_.push_back(allocation);
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToDevice(
backend().default_stream_executor(), literal, &allocation));
return allocation;
}
std::unique_ptr<Literal> HloRunner::TransferFromDevice(
const Shape& shape, se::DeviceMemoryBase device_base) {
auto literal = MakeUnique<Literal>();
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromDevice(
backend().default_stream_executor(), device_base, shape, shape,
literal.get()));
return literal;
}
std::unique_ptr<Literal> HloRunner::ExecuteAndTransfer(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
Shape result_shape;
se::DeviceMemoryBase device_base =
Execute(std::move(module), arguments, &result_shape).ValueOrDie();
return TransferFromDevice(result_shape, device_base);
}
template <>
std::unique_ptr<Literal> HloRunner::Execute(
std::unique_ptr<HloModule> module,
const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>>& literals) {
std::vector<se::DeviceMemoryBase> arguments;
for (const auto& literal : literals) {
arguments.push_back(TransferToDevice(*literal));
}
return ExecuteAndTransfer(std::move(module), arguments);
}
template <>
std::unique_ptr<Literal> HloRunner::Execute(
std::unique_ptr<HloModule> module,
const tensorflow::gtl::ArraySlice<Literal*>& literals) {
std::vector<se::DeviceMemoryBase> arguments;
for (const auto& literal : literals) {
arguments.push_back(TransferToDevice(*literal));
}
return ExecuteAndTransfer(std::move(module), arguments);
}
Backend& HloRunner::backend() {
if (!backend_) {
backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie();
VLOG(1) << "executing on platform " << backend().platform()->Name();
}
return *backend_;
}
} // namespace xla

View File

@ -0,0 +1,100 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
// A base class for running an HloModule. This executes the given HloModule on a
// certain backend directly without using the client interface. HloModule can be
// explicitly built, or loaded from a serialization file (e.g., hlo proto file).
class HloRunner {
public:
HloRunner();
HloRunner(::perftools::gputools::Platform* platform);
~HloRunner();
// Reads the binary proto file in xla.HloProto format, creates and returns the
// HloModule.
static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloProtoFile(
const char* filename, const DebugOptions& debug_options);
// Executes the given module with given literals as input and returns the
// result as a Literal. The LiteralPtr type accepts Literal* or
// std::unique_ptr<Literal>.
template <typename LiteralPtr>
std::unique_ptr<Literal> Execute(
std::unique_ptr<HloModule> module,
const tensorflow::gtl::ArraySlice<LiteralPtr>& literals);
// Executes the given module and returns a global data handle.
StatusOr<perftools::gputools::DeviceMemoryBase> Execute(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments,
Shape* result_shape);
// Transfers the given literal to the device and returns the data handle.
perftools::gputools::DeviceMemoryBase TransferToDevice(
const Literal& literal);
// Transfers the array referred to by the given handle from the device and
// returns as a Literal.
std::unique_ptr<Literal> TransferFromDevice(
const Shape& shape, perftools::gputools::DeviceMemoryBase device_base);
// Executes the given module and return the result as a Literal.
std::unique_ptr<Literal> ExecuteAndTransfer(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments);
// If backend is not created in the constructor, creates and returns the
// default backend. If creation fails, crashes the program.
//
// This creates the backend lazily so it's possible to instantiate an
// HloRunner in a program without any backends linked in.
Backend& backend();
private:
struct EigenThreadPoolWrapper;
std::vector<perftools::gputools::DeviceMemoryBase> allocations_;
std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_;
std::unique_ptr<Backend> backend_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_

View File

@ -58,14 +58,32 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution(
return {}; return {};
} }
// We only support folding the RHS. const ConvolutionDimensionNumbers& dnums =
const int64 kRhsOperandIndex = 1; convolution.convolution_dimension_numbers();
auto& operand = *convolution.operand(kRhsOperandIndex);
if (operand.opcode() == HloOpcode::kTranspose && operand.user_count() == 1) { TransposeFolding::OperandIndices operand_set;
return transposable_conv_operands(convolution, {kRhsOperandIndex}); for (int64 i = 0; i < convolution.operand_count(); ++i) {
auto& operand = *convolution.operand(i);
if (operand.opcode() == HloOpcode::kTranspose &&
operand.user_count() == 1) {
const auto& transpose_dimensions = operand.dimensions();
// We can transpose the LHS so long as it doesn't move around spatial
// dimensions because ConvolutionDimensionNumbers doesn't have different
// fields for input and output spatial dimensions.
if (i == 0 &&
std::any_of(dnums.spatial_dimensions().begin(),
dnums.spatial_dimensions().end(),
[&](const int64 spatial_dimension) {
return transpose_dimensions[spatial_dimension] !=
spatial_dimension;
})) {
continue;
}
operand_set.push_back(i);
}
} }
return {}; return transposable_conv_operands(convolution, operand_set);
} }
using InstructionOperandsPair = using InstructionOperandsPair =
@ -98,40 +116,61 @@ bool FoldTransposeIntoDot(InstructionOperandsPair pair) {
// Returns whether the module is changed. // Returns whether the module is changed.
bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
auto& convolution = *pair.first; auto& convolution = *pair.first;
auto& operand_indices = pair.second;
// We only support fusing the RHS transpose into convolution.
//
// ConvolutionDimensionNumbers doesn't make enough of a distinction between
// the output and the activations.
//
// TODO(b/37125184): Support transposing the LHS too.
if (pair.second.size() != 1 || pair.second.front() != 1) {
return false;
}
const ConvolutionDimensionNumbers& dnums = const ConvolutionDimensionNumbers& dnums =
convolution.convolution_dimension_numbers(); convolution.convolution_dimension_numbers();
HloInstruction& transpose = *convolution.mutable_operand(1);
CHECK_EQ(transpose.opcode(), HloOpcode::kTranspose);
const auto& transpose_dimensions = transpose.dimensions();
HloInstruction& transpose_operand = *transpose.mutable_operand(0);
// Everything remains the same except for the kernel dimension numbers. We
// need to apply the transpose permutation to the original shape to figure out
// what the new logical dimensions are.
ConvolutionDimensionNumbers new_dnums = dnums; ConvolutionDimensionNumbers new_dnums = dnums;
new_dnums.set_kernel_input_feature_dimension(
transpose_dimensions[dnums.kernel_input_feature_dimension()]); HloInstruction* new_lhs;
new_dnums.set_kernel_output_feature_dimension( const int64 kLhsIdx = 0;
transpose_dimensions[dnums.kernel_output_feature_dimension()]); if (std::find(operand_indices.begin(), operand_indices.end(), kLhsIdx) !=
for (auto& kernel_spatial_dimension : operand_indices.end()) {
*new_dnums.mutable_kernel_spatial_dimensions()) { HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx);
kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension]; const auto& transpose_dimensions = transpose.dimensions();
HloInstruction& transpose_operand = *transpose.mutable_operand(0);
// Everything remains the same except for the input/output dimension
// numbers. We need to apply the transpose permutation to the original shape
// to figure out what the new logical dimensions are.
new_dnums.set_input_batch_dimension(
transpose_dimensions[dnums.input_batch_dimension()]);
new_dnums.set_input_feature_dimension(
transpose_dimensions[dnums.input_feature_dimension()]);
for (const auto& spatial_dimension : dnums.spatial_dimensions()) {
CHECK_EQ(spatial_dimension, transpose_dimensions[spatial_dimension]);
}
new_lhs = &transpose_operand;
} else {
new_lhs = convolution.mutable_operand(kLhsIdx);
}
HloInstruction* new_rhs;
const int64 kRhsIdx = 1;
if (std::find(operand_indices.begin(), operand_indices.end(), kRhsIdx) !=
operand_indices.end()) {
HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx);
const auto& transpose_dimensions = transpose.dimensions();
HloInstruction& transpose_operand = *transpose.mutable_operand(0);
// Everything remains the same except for the kernel dimension numbers. We
// need to apply the transpose permutation to the original shape to figure
// out what the new logical dimensions are.
new_dnums.set_kernel_input_feature_dimension(
transpose_dimensions[dnums.kernel_input_feature_dimension()]);
new_dnums.set_kernel_output_feature_dimension(
transpose_dimensions[dnums.kernel_output_feature_dimension()]);
for (auto& kernel_spatial_dimension :
*new_dnums.mutable_kernel_spatial_dimensions()) {
kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension];
}
new_rhs = &transpose_operand;
} else {
new_rhs = convolution.mutable_operand(kRhsIdx);
} }
auto new_conv = HloInstruction::CreateConvolve( auto new_conv = HloInstruction::CreateConvolve(
convolution.shape(), convolution.mutable_operand(0), &transpose_operand, convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums);
convolution.window(), new_dnums);
TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction(
&convolution, std::move(new_conv))); &convolution, std::move(new_conv)));

View File

@ -313,8 +313,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) {
new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(1)); new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(1));
} }
// Test that a transpose of the activations does not get folded into // Test that a transpose of the activations gets folded into convolution.
// convolution.
TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
auto builder = HloComputation::Builder("entry_computation"); auto builder = HloComputation::Builder("entry_computation");
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
@ -348,18 +347,25 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
module.AddEntryComputation(builder.Build(conv)); module.AddEntryComputation(builder.Build(conv));
FoldTranspose(&module); FoldTranspose(&module);
// Instructions after folding: transpose_x, y, and the convolution. // Instructions after folding: x, y, and the convolution.
std::unordered_set<HloInstruction*> instruction_set( std::unordered_set<HloInstruction*> instruction_set(
entry_computation->instructions().begin(), entry_computation->instructions().begin(),
entry_computation->instructions().end()); entry_computation->instructions().end());
CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
CHECK_EQ(1, instruction_set.erase(transpose_x)) EXPECT_EQ(1, instruction_set.size())
<< "transpose_x is not in entry_computation."; << "entry_computation should contain exactly 3 instructions.";
CHECK_EQ(1, instruction_set.erase(conv)) HloInstruction* new_conv = *instruction_set.begin();
<< "transpose_x is not in entry_computation."; EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
CHECK_EQ(0, instruction_set.size()) EXPECT_EQ(dnums.input_feature_dimension(),
<< "entry_computation should contain exactly 4 instructions."; new_conv->convolution_dimension_numbers().input_batch_dimension());
EXPECT_EQ(
dnums.input_batch_dimension(),
new_conv->convolution_dimension_numbers().input_feature_dimension());
EXPECT_EQ(dnums.spatial_dimensions(0),
new_conv->convolution_dimension_numbers().spatial_dimensions(0));
EXPECT_EQ(dnums.spatial_dimensions(1),
new_conv->convolution_dimension_numbers().spatial_dimensions(1));
} }
} // namespace } // namespace

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <stack> #include <stack>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/literal_util.h"
@ -1843,10 +1844,17 @@ UserComputation::GetEmbeddedComputations(
XLA_VLOG_LINES(3, session_computation_.DebugString()); XLA_VLOG_LINES(3, session_computation_.DebugString());
std::vector<VersionedComputationHandle> computations; std::vector<VersionedComputationHandle> computations;
std::vector<int64> sorted_handles;
for (const auto& handle_request : session_computation_.requests()) { for (const auto& handle_request : session_computation_.requests()) {
int64 handle_value = handle_request.first; sorted_handles.push_back(handle_request.first);
}
std::sort(sorted_handles.begin(), sorted_handles.end());
for (int64 handle : sorted_handles) {
const auto& handle_request = session_computation_.requests().find(handle);
CHECK(handle_request != session_computation_.requests().end());
int64 handle_value = handle_request->first;
if (handle_value <= version) { if (handle_value <= version) {
const OperationRequest& request = handle_request.second; const OperationRequest& request = handle_request->second;
switch (request.request().op_case()) { switch (request.request().op_case()) {
case OpRequest::kCallRequest: { case OpRequest::kCallRequest: {
CHECK_EQ(1, request.embedded_computation_versions_size()); CHECK_EQ(1, request.embedded_computation_versions_size());

View File

@ -102,6 +102,32 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
return true; return true;
} }
// Constructs and returns the new shape with the given minor_to_major order in
// its Layout.
StatusOr<Shape> MakeShapeWithLayoutInternal(
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> minor_to_major) {
if (dimensions.size() != minor_to_major.size()) {
return InvalidArgument("Dimensions size is %ld, but layout size is %ld.",
dimensions.size(), minor_to_major.size());
}
if (element_type == OPAQUE || element_type == TUPLE) {
return InvalidArgument("Unsupported element type: %s",
PrimitiveType_Name(element_type).c_str());
}
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
auto min2maj = shape.mutable_layout()->mutable_minor_to_major();
min2maj->Clear();
for (int64 value : minor_to_major) {
min2maj->Add(value);
}
if (!shape.has_layout()) {
return InvalidArgument("Shape has no layout.");
}
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
return shape;
}
} // namespace } // namespace
/* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) { /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
@ -152,16 +178,8 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
/* static */ Shape ShapeUtil::MakeShapeWithLayout( /* static */ Shape ShapeUtil::MakeShapeWithLayout(
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions, PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> minor_to_major) { tensorflow::gtl::ArraySlice<int64> minor_to_major) {
CHECK_EQ(dimensions.size(), minor_to_major.size()); return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major)
Shape shape = MakeShape(element_type, dimensions); .ValueOrDie();
auto min2maj = shape.mutable_layout()->mutable_minor_to_major();
min2maj->Clear();
for (int64 value : minor_to_major) {
min2maj->Add(value);
}
DCHECK(shape.has_layout());
TF_DCHECK_OK(ValidateShape(shape));
return shape;
} }
/* static */ Shape ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( /* static */ Shape ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
@ -499,11 +517,10 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
// Extract the layout minor-to-major and set it. // Extract the layout minor-to-major and set it.
TF_ASSIGN_OR_RETURN(std::vector<int64> min2maj, TF_ASSIGN_OR_RETURN(std::vector<int64> min2maj,
comma_list_to_int64s(layout_string)); comma_list_to_int64s(layout_string));
TF_RET_CHECK(dimensions.size() == min2maj.size()); TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal(
result = primitive_type, dimensions, min2maj));
ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, min2maj);
} }
TF_DCHECK_OK(ShapeUtil::ValidateShape(result)); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result));
return std::move(result); return std::move(result);
} }

View File

@ -102,28 +102,18 @@ cc_library(
deps = [ deps = [
":literal_test_util", ":literal_test_util",
"//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_execution_profile", "//tensorflow/compiler/xla/service:hlo_runner",
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test", "//tensorflow/core:test",
"//third_party/eigen3",
], ],
) )

View File

@ -19,24 +19,9 @@ limitations under the License.
#include <string> #include <string>
#include <utility> #include <utility>
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
@ -45,22 +30,6 @@ namespace se = ::perftools::gputools;
namespace xla { namespace xla {
// Define this in .cc file to avoid having to include eigen or forward declare
// these types in the header.
struct HloTestBase::EigenThreadPoolWrapper {
std::unique_ptr<EigenThreadPoolWrapper> pool;
std::unique_ptr<Eigen::ThreadPoolDevice> device;
};
HloTestBase::HloTestBase() {}
HloTestBase::~HloTestBase() {
// Deallocate all the memory allocated during the tests.
for (auto& allocation : allocations_) {
backend().default_stream_executor()->Deallocate(&allocation);
}
}
/* static */ /* static */
std::unique_ptr<HloModule> HloTestBase::CreateNewModule() { std::unique_ptr<HloModule> HloTestBase::CreateNewModule() {
HloModuleConfig config; HloModuleConfig config;
@ -80,98 +49,25 @@ StatusOr<perftools::gputools::DeviceMemoryBase> HloTestBase::Execute(
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase> tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments, arguments,
Shape* result_shape) { Shape* result_shape) {
TF_ASSIGN_OR_RETURN( return runner_.Execute(std::move(module), arguments, result_shape);
std::unique_ptr<Executable> executable,
backend().compiler()->Compile(std::move(module),
backend().default_stream_executor()));
se::Stream stream(backend().default_stream_executor());
stream.Init();
ExecutableRunOptions run_options;
run_options.set_stream(&stream);
run_options.set_allocator(backend().memory_allocator());
run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool());
run_options.set_intra_op_thread_pool(
backend().eigen_intra_op_thread_pool_device());
HloExecutionProfile hlo_execution_profile;
ServiceExecutableRunOptions service_run_options(
run_options, backend().StreamBorrower(),
backend().inter_op_thread_pool());
TF_ASSIGN_OR_RETURN(
se::DeviceMemoryBase result,
executable->ExecuteOnStream(&service_run_options, arguments,
&hlo_execution_profile));
TF_RET_CHECK(stream.BlockHostUntilDone());
allocations_.push_back(result);
*result_shape = executable->result_shape();
if (ShapeUtil::IsTuple(*result_shape)) {
// We must record element buffers of tuples as well to avoid leaks.
DCHECK(!ShapeUtil::IsNestedTuple(*result_shape));
TF_ASSIGN_OR_RETURN(
std::vector<se::DeviceMemoryBase> element_buffers,
backend().transfer_manager()->ShallowCopyTupleFromDevice(
backend().default_stream_executor(), result, *result_shape));
// A tuple may contain the same buffer in more than one element. Keep track
// of the buffers already added to avoid duplicates in allocations_.
std::set<void*> added_opaques;
for (auto element_buffer : element_buffers) {
if (added_opaques.count(element_buffer.opaque()) == 0) {
CHECK(element_buffer.opaque() != nullptr);
added_opaques.insert(element_buffer.opaque());
allocations_.push_back(element_buffer);
}
}
}
return result;
} }
se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) { se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) {
// Allocate memory on the device using the stream executor. return runner_.TransferToDevice(literal);
int64 allocation_size =
backend().transfer_manager()->GetByteSizeRequirement(literal.shape());
se::DeviceMemoryBase allocation =
backend().default_stream_executor()->AllocateArray<uint8>(
allocation_size);
allocations_.push_back(allocation);
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToDevice(
backend().default_stream_executor(), literal, &allocation));
return allocation;
} }
std::unique_ptr<Literal> HloTestBase::TransferFromDevice( std::unique_ptr<Literal> HloTestBase::TransferFromDevice(
const Shape& shape, se::DeviceMemoryBase device_base) { const Shape& shape, se::DeviceMemoryBase device_base) {
auto literal = MakeUnique<Literal>(); return runner_.TransferFromDevice(shape, device_base);
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromDevice(
backend().default_stream_executor(), device_base, shape, shape,
literal.get()));
return literal;
} }
std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer( std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
std::unique_ptr<HloModule> module, std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) { tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
Shape result_shape; return runner_.ExecuteAndTransfer(std::move(module), arguments);
se::DeviceMemoryBase device_base =
Execute(std::move(module), arguments, &result_shape).ValueOrDie();
return TransferFromDevice(result_shape, device_base);
} }
Backend& HloTestBase::backend() { Backend& HloTestBase::backend() { return runner_.backend(); }
if (!backend_) {
backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie();
VLOG(1) << "executing on platform " << backend().platform()->Name();
}
return *backend_;
}
/* static */ /* static */
string HloTestBase::TestName() { string HloTestBase::TestName() {

View File

@ -21,12 +21,12 @@ limitations under the License.
#include <vector> #include <vector>
#include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_runner.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h"
@ -39,10 +39,9 @@ namespace xla {
// building a graph of HLO instructions to run. // building a graph of HLO instructions to run.
class HloTestBase : public ::testing::Test { class HloTestBase : public ::testing::Test {
protected: protected:
struct EigenThreadPoolWrapper; HloTestBase() {}
HloTestBase();
~HloTestBase() override; ~HloTestBase() override {}
// Creates a new HLO module for a test. The module created will have // Creates a new HLO module for a test. The module created will have
// TestName() for its name; it will also automatically populate its debug // TestName() for its name; it will also automatically populate its debug
@ -102,23 +101,12 @@ class HloTestBase : public ::testing::Test {
static string TestName(); static string TestName();
// Creates (if necessary) and returns the default backend. If creation fails, // Returns the backend owned by the HloRunner.
// crashes the program.
//
// This creates the backend lazily so it's possible to instantiate an
// HloTestBase in a program without any backends linked in.
Backend& backend(); Backend& backend();
// This vector contains handles of all the device memory allocations performed HloRunner runner_;
// by the test. These are deallocated on destruction of the test object.
std::vector<perftools::gputools::DeviceMemoryBase> allocations_;
ErrorSpec error_spec_{0.0001}; ErrorSpec error_spec_{0.0001};
std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_;
private:
std::unique_ptr<Backend> backend_; // Lazily populated. Access via backend().
}; };
} // namespace xla } // namespace xla

View File

@ -210,6 +210,18 @@ tf_cc_binary(
], ],
) )
tf_cc_binary(
name = "hlo_proto_to_json",
srcs = ["hlo_proto_to_json.cc"],
deps = [
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
# ----------------------------------------------------------------------------- # -----------------------------------------------------------------------------
filegroup( filegroup(

View File

@ -0,0 +1,91 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Usage:
// hlo_proto_to_json --input_file=some_binary_proto
// --output_file=path_to_dump_output
//
// Reads one serilized Hlo module, convert it into JSON format and dump into
// some output directory. some_binaray_proto is obtained by serializing Hlo
// module to disk using --xla_dump_hlo_proto_to debug optoin.
#include <stdio.h>
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"
using tensorflow::Env;
using xla::string;
namespace xla {
namespace tools {
StatusOr<string> ToJson(const tensorflow::protobuf::Message& message) {
string json_output;
tensorflow::protobuf::util::JsonPrintOptions json_options;
json_options.add_whitespace = true;
json_options.always_print_primitive_fields = true;
auto status = tensorflow::protobuf::util::MessageToJsonString(
message, &json_output, json_options);
if (!status.ok()) {
return InternalError("MessageToJsonString failed: %s",
status.error_message().data());
}
return json_output;
}
void RealMain(const string& input, const string& output) {
HloProto hlo_proto;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), input,
&hlo_proto))
<< "Can't open, read, or parse input file " << input;
auto statusor = ToJson(hlo_proto);
QCHECK(statusor.ok()) << "Error converting " << input << " to JSON."
<< statusor.status();
TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(), output,
statusor.ValueOrDie()));
}
} // namespace tools
} // namespace xla
int main(int argc, char** argv) {
string input_file, output_file;
const std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("input_file", &input_file, "file to convert."),
tensorflow::Flag("output_file", &output_file, "converted file"),
};
const string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
QCHECK(parse_ok && argc == 1) << "\n" << usage;
QCHECK(!input_file.empty()) << "--input_file is required";
QCHECK(!output_file.empty()) << "--output_file is required";
xla::tools::RealMain(input_file, output_file);
return 0;
}

View File

@ -0,0 +1,84 @@
# Build file for the Hlo parser.
licenses(["notice"]) # Apache 2.0
package(
default_visibility = [":friends"],
)
package_group(
name = "friends",
includes = [
"//tensorflow/compiler/xla:friends",
],
)
# Filegroup used to collect source files for dependency checking.
filegroup(
name = "c_srcs",
data = glob([
"**/*.cc",
"**/*.h",
]),
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
cc_library(
name = "hlo_lexer",
srcs = ["hlo_lexer.cc"],
hdrs = [
"hlo_lexer.h",
"hlo_token.h",
],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
],
)
cc_library(
name = "hlo_parser",
srcs = ["hlo_parser.cc"],
hdrs = ["hlo_parser.h"],
deps = [
":hlo_lexer",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_cc_test(
name = "hlo_parser_test",
size = "small",
srcs = ["hlo_parser_test.cc"],
deps = [
":hlo_parser",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
# -----------------------------------------------------------------------------
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,69 @@
# HloModule string syntax
TODO: Support subcomputations (for fusion, reduce, while, ...).
TODO: Support ops that require extra attributes, e.g. dimensions, strides.
```yacc
hlo_module
: 'HloModule' name computation
;
computation
: 'ENTRY' name param_list '->' shape instruction_list
;
instruction_list
: '{' instruction_list1 '}'
;
instruction_list1
: instruction
| instruction_list1 instruction
;
instruction
: name '=' shape opcode operands
;
operands
: '(' operands1 ')'
;
operands1
: /*empty*/
| operand
| operands1 ',' operand
;
operand
: shape name
;
param_list
: '(' param_list1 ')'
;
param_list1
: /*empty*/
| param
| param_list1 ',' param
;
param
: name shape
;
shape
: shape_val_
| '(' tuple_elements ')'
;
tuple_elements
: /*empty*/
| shape (',' shape)*
;
name
: identifier ':'
| '%' identifier
;
identifier
: [a-zA-Z_][a-zA-Z0-9_.-]*
;
```

View File

@ -0,0 +1,270 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h"
#include <unordered_map>
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/regexp.h"
namespace xla {
namespace tools {
using tensorflow::StringPiece;
namespace {
constexpr int kEOF = -1;
constexpr int kError = -2;
// [a-zA-Z0-9_.-]
bool IsIdentifierChar(char c) {
return isalnum(static_cast<unsigned char>(c)) || c == '-' || c == '.' ||
c == '_';
}
} // namespace
int HloLexer::GetNextChar() {
int current_char = PeekCurrentChar();
if (current_char != kEOF && current_char != kError) {
current_ptr_++;
}
return current_char;
}
int HloLexer::PeekCurrentChar() const {
if (current_ptr_ == buf_.end()) {
return kEOF;
}
char current_char = *current_ptr_;
if (current_char == 0) {
// '\0' should not appear in the middle of the string.
return kError;
}
return static_cast<unsigned char>(current_char);
}
bool HloLexer::CanDereference(const char* ptr) const {
return ptr < buf_.end() && ptr >= buf_.begin();
}
StringPiece HloLexer::StringPieceFromPointers(const char* begin,
const char* end) const {
CHECK(begin <= end);
CHECK(begin == buf_.end() || CanDereference(begin));
CHECK(end == buf_.end() || CanDereference(end));
return StringPiece(begin, end - begin);
}
tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers(
const char* begin, const char* end) const {
CHECK(begin <= end);
CHECK(begin == buf_.end() || CanDereference(begin));
CHECK(end == buf_.end() || CanDereference(end));
return tensorflow::RegexpStringPiece(begin, end - begin);
}
TokKind HloLexer::LexToken() {
while (true) {
token_start_ = current_ptr_;
int current_char = GetNextChar();
switch (current_char) {
default:
// [a-zA-Z_]
if (isalpha(static_cast<unsigned char>(current_char)) ||
current_char == '_') {
return LexIdentifier();
}
return TokKind::kError;
case kEOF:
// Hit the end of the input buffer.
return TokKind::kEof;
case kError:
// Hit an invalid character in the input buffer.
return TokKind::kError;
case ' ':
case '\t':
case '\n':
case '\r':
// Ignore whitespace.
continue;
case '0':
case '1':
case '2':
case '3':
case '4':
case '5':
case '6':
case '7':
case '8':
case '9':
case '-':
if (current_char == '-' && PeekCurrentChar() == '>') {
current_ptr_++;
return TokKind::kArrow;
}
return LexDigitOrNegative();
case '=':
return TokKind::kEqual;
case ',':
return TokKind::kComma;
case '%':
return LexPercent();
case ':':
return TokKind::kColon;
case '[':
return TokKind::kLsquare;
case ']':
return TokKind::kRsquare;
case '{':
return TokKind::kLbrace;
case '}':
return TokKind::kRbrace;
case '(':
return TokKind::kLparen;
case ')':
return TokKind::kRparen;
}
}
}
// Lex a shape, name, keyword, or opcode.
// shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})?
// name ::= [a-zA-Z_][a-zA-Z0-9_.-]*:
// keyword ::= HloModule, ENTRY, ...
// opcode ::= add, greater-than, ...
TokKind HloLexer::LexIdentifier() {
{
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
// 'consumable' will be advanced iff its prefix matches the pattern.
static LazyRE2 shape_pattern = {
R"(^(\w*\d*)\[([\d,]*)\](?:\s*{([\d,]*)})?)"};
if (RE2::Consume(&consumable, *shape_pattern)) {
auto status_or_shape = ShapeUtil::ParseShapeString(
StringPieceFromPointers(token_start_, consumable.begin()));
if (status_or_shape.ok()) {
// This is a shape string.
shape_val_ = status_or_shape.ValueOrDie();
current_ptr_ = consumable.begin();
return TokKind::kShape;
}
}
}
while (IsIdentifierChar(PeekCurrentChar())) {
current_ptr_++;
}
// If followed by ':', it's a name.
if (PeekCurrentChar() == ':') {
str_val_.assign(token_start_, current_ptr_);
current_ptr_++; // skip ':'
return TokKind::kName;
}
StringPiece identifier = StringPieceFromPointers(token_start_, current_ptr_);
// See if this is a keyword.
#define KEYWORD(STR) \
do { \
if (identifier == #STR) { \
return TokKind::kw_##STR; \
} \
} while (false)
KEYWORD(true);
KEYWORD(false);
KEYWORD(HloModule);
KEYWORD(ENTRY);
#undef KEYWORD
// See if this is an opcode.
auto opcode = StringToHloOpcode(identifier.ToString());
if (opcode.ok()) {
opcode_val_ = opcode.ValueOrDie();
return TokKind::kOpcode;
}
current_ptr_ = token_start_ + 1;
return TokKind::kError;
}
// Lex names after a % character.
// name ::= [a-zA-Z_][a-zA-Z0-9_.-]*
TokKind HloLexer::LexPercent() {
const char* name_start = current_ptr_;
if (isalpha(static_cast<unsigned char>(PeekCurrentChar())) ||
PeekCurrentChar() == '_') {
current_ptr_++;
while (IsIdentifierChar(PeekCurrentChar())) {
current_ptr_++;
}
str_val_.assign(name_start, current_ptr_);
return TokKind::kName;
}
return TokKind::kError;
}
// Lex integer and floating-point values.
// int [-]?[0-9]+
// fp with exp [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
// fp without exp [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
TokKind HloLexer::LexDigitOrNegative() {
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
static LazyRE2 float_pattern = {
R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|(\d+[.]\d*|\d*[.]\d+))"};
if (RE2::Consume(&consumable, *float_pattern)) {
current_ptr_ = consumable.begin();
tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(),
&decimal_val_);
return TokKind::kDecimal;
}
static LazyRE2 int_pattern = {R"([-]?\d+)"};
if (RE2::Consume(&consumable, *int_pattern)) {
current_ptr_ = consumable.begin();
tensorflow::strings::safe_strto64(
StringPieceFromPointers(token_start_, current_ptr_), &int64_val_);
return TokKind::kInt;
}
return TokKind::kError;
}
StringPiece HloLexer::GetCurrentLine() const {
const char* start = token_start_;
const char* end = current_ptr_;
if (!CanDereference(start) || !CanDereference(end)) {
return "LINE OUT OF RANGE";
}
while (start > buf_.begin() && *start != '\n') {
start--;
}
while (end < buf_.end() && *end != '\n') {
end++;
}
return StringPieceFromPointers(start, end);
}
} // namespace tools
} // namespace xla

View File

@ -0,0 +1,108 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_
#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_
#include <string>
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_token.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace tools {
// Lexer for the HloModule::ToString() format text.
class HloLexer {
public:
explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) {
current_ptr_ = buf_.begin();
}
TokKind Lex() { return current_kind_ = LexToken(); }
TokKind GetKind() const { return current_kind_; }
string GetStrVal() const {
CHECK(GetKind() == TokKind::kName);
return str_val_;
}
Shape GetShapeVal() const {
CHECK(GetKind() == TokKind::kShape);
return shape_val_;
}
HloOpcode GetOpcodeVal() const {
CHECK(GetKind() == TokKind::kOpcode);
return opcode_val_;
}
int64 GetInt64Val() const {
CHECK(GetKind() == TokKind::kInt);
return int64_val_;
}
double GetDecimalVal() const {
CHECK(GetKind() == TokKind::kDecimal);
return decimal_val_;
}
// Returns the line of text that is currently being lexed.
tensorflow::StringPiece GetCurrentLine() const;
private:
// Returns the current character. If it's neither the end of input buffer nor
// an invalid character, moves the pointer forward.
int GetNextChar();
// Returns the current character.
int PeekCurrentChar() const;
// Creates StringPiece with the given begin and end. Exits if the begin > end,
// or it's out of the range of the current buffer.
tensorflow::StringPiece StringPieceFromPointers(const char* begin,
const char* end) const;
tensorflow::RegexpStringPiece RegexpStringPieceFromPointers(
const char* begin, const char* end) const;
// Returns true if the given ptr is dereferenceable within the range of the
// current buffer.
bool CanDereference(const char* ptr) const;
TokKind LexToken();
TokKind LexIdentifier();
TokKind LexPercent();
TokKind LexShape();
TokKind LexConstant();
TokKind LexDigitOrNegative();
const tensorflow::StringPiece buf_;
const char* current_ptr_;
// Information about the current token.
const char* token_start_;
TokKind current_kind_;
string str_val_;
Shape shape_val_;
HloOpcode opcode_val_;
int64 int64_val_;
double decimal_val_;
};
} // namespace tools
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_

View File

@ -0,0 +1,502 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace tools {
namespace {
using tensorflow::StringPiece;
using tensorflow::strings::StrCat;
// Parser for the HloModule::ToString() format text.
class HloParser {
public:
explicit HloParser(StringPiece str) : lexer_(str) {}
// Runs the parser. Returns false if an error occurred.
bool Run();
// Returns the parsed HloModule.
std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); }
// Returns the error information.
string GetError() const { return tensorflow::str_util::Join(error_, "\n"); }
private:
// ParseXXX returns false if an error occurred.
bool ParseHloModule();
bool ParseComputation();
bool ParseInstructionList(HloComputation::Builder* builder);
bool ParseInstruction(HloComputation::Builder* builder);
bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
bool ParseOperands(std::vector<HloInstruction*>* operands,
const int expected_size);
bool ParseParamList();
bool ParseName(string* result);
bool ParseShape(Shape* result);
bool ParseOpcode(HloOpcode* result);
bool ParseInt64(int64* result);
bool ParseDecimal(double* result);
bool ParseBool(bool* result);
bool ParseToken(TokKind kind, const string& msg);
// Logs the current parsing line and the given message. Always returns false.
bool TokenError(StringPiece msg);
// If the current token is 'kind', eats it (i.e. lexes the next token) and
// returns true.
bool EatIfPresent(TokKind kind);
// Adds the instruction to the pool. Returns false and emits an error if the
// instruction already exists.
bool AddInstruction(const string& name, HloInstruction* instruction);
// The map from the instruction name to the instruction. This does not own the
// instructions.
std::unordered_map<string, HloInstruction*> instruction_pool_;
HloLexer lexer_;
std::unique_ptr<HloModule> module_;
std::vector<string> error_;
};
bool HloParser::TokenError(StringPiece msg) {
error_.push_back(
StrCat("was parsing \"", lexer_.GetCurrentLine(), "\"; ", msg));
return false;
}
bool HloParser::Run() {
lexer_.Lex();
return ParseHloModule();
}
// ::= 'HloModule' name computation
bool HloParser::ParseHloModule() {
if (lexer_.GetKind() != TokKind::kw_HloModule) {
return TokenError("expects HloModule");
}
// Eat 'HloModule'
lexer_.Lex();
string name;
if (!ParseName(&name)) {
return false;
}
module_ = MakeUnique<HloModule>(name);
return ParseComputation();
}
// computation ::= 'ENTRY' name param_list '->' shape instruction_list
bool HloParser::ParseComputation() {
string name;
if (!ParseToken(TokKind::kw_ENTRY, "expects 'ENTRY'") || !ParseName(&name)) {
return false;
}
auto builder = MakeUnique<HloComputation::Builder>(name);
Shape shape;
if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'") ||
!ParseShape(&shape) || !ParseInstructionList(builder.get())) {
return false;
}
module_->AddEntryComputation(builder->Build());
return true;
}
// instruction_list ::= '{' instruction_list1 '}'
// instruction_list1 ::= (instruction)+
bool HloParser::ParseInstructionList(HloComputation::Builder* builder) {
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of instruction list.")) {
return false;
}
do {
if (!ParseInstruction(builder)) {
return false;
}
} while (lexer_.GetKind() != TokKind::kRbrace);
return ParseToken(TokKind::kRbrace,
"expects '}' at the end of instruction list.");
}
// instruction ::= name '=' shape opcode operands
bool HloParser::ParseInstruction(HloComputation::Builder* builder) {
string name;
Shape shape;
HloOpcode opcode;
std::vector<HloInstruction*> operands;
if (!ParseName(&name) ||
!ParseToken(TokKind::kEqual, "expects '=' in instruction") ||
!ParseShape(&shape) || !ParseOpcode(&opcode)) {
return false;
}
switch (opcode) {
case HloOpcode::kParameter: {
int64 parameter_number;
return ParseToken(TokKind::kLparen,
"expects '(' before parameter number") &&
ParseInt64(&parameter_number) &&
ParseToken(TokKind::kRparen,
"expects ')' after parameter number") &&
AddInstruction(
name, builder->AddInstruction(HloInstruction::CreateParameter(
parameter_number, shape, name)));
}
case HloOpcode::kConstant: {
std::unique_ptr<Literal> literal;
return ParseToken(TokKind::kLparen,
"expects '(' before parameter number") &&
ParseLiteral(&literal, shape) &&
ParseToken(TokKind::kRparen,
"expects ')' after parameter number") &&
AddInstruction(
name, builder->AddInstruction(
HloInstruction::CreateConstant(std::move(literal))));
}
// Unary ops.
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kBitcast:
case HloOpcode::kCeil:
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kIsFinite:
case HloOpcode::kFloor:
case HloOpcode::kLog:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSort:
case HloOpcode::kTanh: {
return ParseOperands(&operands, /*expected_size=*/1) &&
AddInstruction(name,
builder->AddInstruction(HloInstruction::CreateUnary(
shape, opcode, operands[0])));
}
// Binary ops.
case HloOpcode::kAdd:
case HloOpcode::kDivide:
case HloOpcode::kMultiply:
case HloOpcode::kSubtract:
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kLe:
case HloOpcode::kLt:
case HloOpcode::kNe:
case HloOpcode::kDot:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical: {
return ParseOperands(&operands, /*expected_size=*/2) &&
AddInstruction(
name, builder->AddInstruction(HloInstruction::CreateBinary(
shape, opcode, operands[0], operands[1])));
}
// Ternary ops.
case HloOpcode::kClamp:
case HloOpcode::kSelect: {
return ParseOperands(&operands, /*expected_size=*/3) &&
AddInstruction(
name,
builder->AddInstruction(HloInstruction::CreateTernary(
shape, opcode, operands[0], operands[1], operands[2])));
}
// Other supported ops.
case HloOpcode::kConvert: {
return ParseOperands(&operands, /*expected_size=*/1) &&
AddInstruction(
name, builder->AddInstruction(
HloInstruction::CreateConvert(shape, operands[0])));
}
case HloOpcode::kCrossReplicaSum: {
return ParseOperands(&operands, /*expected_size=*/1) &&
AddInstruction(name, builder->AddInstruction(
HloInstruction::CreateCrossReplicaSum(
shape, operands[0])));
}
case HloOpcode::kReshape: {
return ParseOperands(&operands, /*expected_size=*/1) &&
AddInstruction(
name, builder->AddInstruction(
HloInstruction::CreateReshape(shape, operands[0])));
}
case HloOpcode::kBroadcast:
case HloOpcode::kCall:
case HloOpcode::kCustomCall:
case HloOpcode::kConcatenate:
case HloOpcode::kReducePrecision:
case HloOpcode::kConvolution:
case HloOpcode::kGetTupleElement:
case HloOpcode::kMap:
case HloOpcode::kPad:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kReverse:
case HloOpcode::kRng:
case HloOpcode::kSlice:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kTranspose:
case HloOpcode::kTuple:
case HloOpcode::kWhile:
case HloOpcode::kFusion:
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kRecv:
case HloOpcode::kSend:
case HloOpcode::kUpdate:
case HloOpcode::kIndex:
case HloOpcode::kTrace:
return TokenError(StrCat("parsing not yet implemented for op: ",
HloOpcodeString(opcode)));
}
}
bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
const Shape& shape) {
switch (shape.element_type()) {
case PRED:
bool b;
if (!ParseBool(&b)) {
return false;
}
*literal = Literal::CreateR0<bool>(b);
return true;
case S32:
int64 i;
if (!ParseInt64(&i)) {
return false;
}
*literal = Literal::CreateR0<int32>(i);
return true;
case F32:
double d;
if (!ParseDecimal(&d)) {
return false;
}
*literal = Literal::CreateR0<float>(d);
return true;
default:
return TokenError(StrCat("unsupported constant in shape: ",
ShapeUtil::HumanString(shape)));
}
}
// operands ::= '(' operands1 ')'
// operands1
// ::= /*empty*/
// ::= operand (, operand)*
// operand ::= shape name
bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands,
const int expected_size) {
if (!ParseToken(TokKind::kLparen,
"expects '(' at the beginning of operands")) {
return false;
}
if (lexer_.GetKind() == TokKind::kRparen) {
// empty
} else {
do {
Shape shape;
string name;
if (!ParseShape(&shape) || !ParseName(&name)) {
return false;
}
HloInstruction* instruction =
tensorflow::gtl::FindPtrOrNull(instruction_pool_, name);
if (!instruction) {
return TokenError(StrCat("instruction does not exist: ", name));
}
operands->push_back(instruction);
} while (EatIfPresent(TokKind::kComma));
}
if (expected_size != operands->size()) {
return TokenError(StrCat("expects ", expected_size, " operands, but has ",
operands->size(), " operands"));
}
return ParseToken(TokKind::kRparen, "expects ')' at the end of operands");
}
// param_list ::= '(' param_list1 ')'
// param_list1
// ::= /*empty*/
// ::= param (',' param)*
// param ::= name shape
bool HloParser::ParseParamList() {
if (!ParseToken(TokKind::kLparen,
"expects '(' at the beginning of param list")) {
return false;
}
if (lexer_.GetKind() == TokKind::kRparen) {
// empty
} else {
do {
Shape shape;
if (!ParseToken(TokKind::kName, "expects name in parameter") ||
!ParseShape(&shape)) {
return false;
}
} while (EatIfPresent(TokKind::kComma));
}
return ParseToken(TokKind::kRparen, "expects ')' at the end of param list");
}
// shape ::= shape_val_
// shape ::= '(' tuple_elements ')'
// tuple_elements
// ::= /*empty*/
// ::= shape (',' shape)*
bool HloParser::ParseShape(Shape* result) {
if (EatIfPresent(TokKind::kLparen)) { // Tuple
std::vector<Shape> shapes;
if (lexer_.GetKind() == TokKind::kRparen) {
/*empty*/
} else {
// shape (',' shape)*
do {
shapes.emplace_back();
if (!ParseShape(&shapes.back())) {
return false;
}
} while (EatIfPresent(TokKind::kComma));
}
*result = ShapeUtil::MakeTupleShape(shapes);
return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple.");
}
if (lexer_.GetKind() != TokKind::kShape) {
return TokenError("expects shape");
}
*result = lexer_.GetShapeVal();
lexer_.Lex();
return true;
}
bool HloParser::ParseName(string* result) {
VLOG(1) << "ParseName";
if (lexer_.GetKind() != TokKind::kName) {
return TokenError("expects name");
}
*result = lexer_.GetStrVal();
lexer_.Lex();
return true;
}
bool HloParser::ParseOpcode(HloOpcode* result) {
VLOG(1) << "ParseOpcode";
if (lexer_.GetKind() != TokKind::kOpcode) {
return TokenError("expects opcode");
}
*result = lexer_.GetOpcodeVal();
lexer_.Lex();
return true;
}
bool HloParser::ParseInt64(int64* result) {
VLOG(1) << "ParseInt64";
if (lexer_.GetKind() != TokKind::kInt) {
return TokenError("expects integer");
}
*result = lexer_.GetInt64Val();
lexer_.Lex();
return true;
}
bool HloParser::ParseDecimal(double* result) {
switch (lexer_.GetKind()) {
case TokKind::kDecimal:
*result = lexer_.GetDecimalVal();
break;
case TokKind::kInt:
*result = static_cast<double>(lexer_.GetInt64Val());
break;
default:
return TokenError("expects decimal or integer");
}
lexer_.Lex();
return true;
}
bool HloParser::ParseBool(bool* result) {
if (lexer_.GetKind() != TokKind::kw_true &&
lexer_.GetKind() != TokKind::kw_false) {
return TokenError("expects true or false");
}
*result = lexer_.GetKind() == TokKind::kw_true;
lexer_.Lex();
return true;
}
bool HloParser::ParseToken(TokKind kind, const string& msg) {
if (lexer_.GetKind() != kind) {
return TokenError(msg);
}
lexer_.Lex();
return true;
}
bool HloParser::EatIfPresent(TokKind kind) {
if (lexer_.GetKind() != kind) {
return false;
}
lexer_.Lex();
return true;
}
bool HloParser::AddInstruction(const string& name,
HloInstruction* instruction) {
auto result = instruction_pool_.insert({name, instruction});
if (!result.second) {
return TokenError(StrCat("instruction already exists: ", name));
}
return true;
}
} // namespace
StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str) {
HloParser parser(str);
if (!parser.Run()) {
return InvalidArgument("Syntax error: %s", parser.GetError().c_str());
}
return parser.ConsumeHloModule();
}
} // namespace tools
} // namespace xla

View File

@ -0,0 +1,37 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_
#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace tools {
// The api of the hlo parser. Given a string in the HloModule::ToString()
// format, returns the parsed HloModule.
StatusOr<std::unique_ptr<HloModule>> Parse(tensorflow::StringPiece str);
} // namespace tools
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_

View File

@ -0,0 +1,240 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include <string>
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace tools {
namespace {
struct TestData {
string test_name;
string module_string;
};
string TestDataToString(const ::testing::TestParamInfo<TestData>& data) {
return data.param.test_name;
}
std::vector<TestData> CreateTestCases() {
// clang-format off
return std::vector<TestData>({
// ax + y
{
"AxpyParam",
R"(HloModule axpy_module:
ENTRY %axpy.v5 (alpha: f32[2,4], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
%alpha = f32[2,4]{1,0} parameter(0)
%x = f32[2,4]{1,0} parameter(1)
%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %alpha, f32[2,4]{1,0} %x)
%y = f32[2,4]{1,0} parameter(2)
%add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
}
)"
},
// pred constant
{
"ConstantPred",
R"(HloModule constant_pred_module:
ENTRY %constant_pred () -> pred[] {
%constant = pred[] constant(true)
}
)"
},
// s32 constant
{
"ConstantS32",
R"(HloModule constant_s32_module:
ENTRY %constant_s32 () -> s32[] {
%constant = s32[] constant(-42)
}
)"
},
// f32 constant, but the value is not a decimal
{
"ConstantF32", R"(HloModule ConstantF32_module:
ENTRY %ConstantF32.v4 () -> f32[] {
%constant = f32[] constant(42)
}
)"
},
// constant + constant
{
"AddConstants",
R"(HloModule add_constants_module:
ENTRY %add_constants () -> f32[] {
%constant = f32[] constant(3.14)
%add = f32[] add(f32[] %constant, f32[] %constant)
}
)"
},
// v1 > v2 ? v1 : v2
{
"SelectR1F32",
R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module:
ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
%v1 = f32[4]{0} parameter(0)
%v2 = f32[4]{0} parameter(1)
%greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2)
%select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2)
}
)"
}
});
// clang-format on
}
class HloParserTest : public ::testing::Test,
public ::testing::WithParamInterface<TestData> {
protected:
void ExpectSuccess() {
const string& original = GetParam().module_string;
auto result = Parse(original);
TF_EXPECT_OK(result.status());
EXPECT_EQ(original, result.ValueOrDie()->ToString());
}
};
TEST_P(HloParserTest, Run) { ExpectSuccess(); }
INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest,
::testing::ValuesIn(CreateTestCases()),
TestDataToString);
TEST_F(HloParserTest, Empty) {
const string original = "";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
TEST_F(HloParserTest, Garbage) {
const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
TEST_F(HloParserTest, WrongOpcode) {
const string original = R"(HloModule wrong_opcode:
ENTRY %blabla (x: f32[], y: f32[]) -> f32[] {
%x = f32[]{} parameter(0)
%y = f32[]{} parameter(1)
%le = pred[]{} le(f32[]{} %x, f32[]{} %y)
}
)";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
TEST_F(HloParserTest, WrongShape) {
const string original = R"(HloModule wrong_opcode:
ENTRY %blabla (x: g32[]) -> g32[] {
%x = g32[]{} parameter(0)
}
)";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
TEST_F(HloParserTest, WrongOperandsSize) {
const string original = R"(HloModule wrong_opcode:
ENTRY %blabla (x: f32[]) -> pred[] {
%x = f32[]{} parameter(0)
%eq = pred[]{} equal-to(f32[]{} %x)
}
)";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
TEST_F(HloParserTest, OperandNotFound) {
const string original = R"(HloModule operand_not_found:
ENTRY %blabla (x: f32[]) -> pred[] {
%x = f32[]{} parameter(0)
%eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y)
}
)";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
TEST_F(HloParserTest, MoreConstants) {
const string original = R"(HloModule SelectScalarS32True_module:
ENTRY %SelectScalarS32True.v4 () -> s32[] {
%constant.2 = pred[] constant(true)
%constant.1 = s32[] constant(-42)
%constant = s32[] constant(42)
%select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
}
)";
auto result = Parse(original);
TF_EXPECT_OK(result.status());
// Constant instructions have no name. The string will be parsed successfully
// but the constant names will not be exactly the same.
}
TEST_F(HloParserTest, ConstantWithExp) {
const string original = R"(HloModule ConstantWithExp_module:
ENTRY %ConstantWithExp.v4 () -> f32[] {
%constant.1 = f32[] constant(3e+2)
}
)";
auto result = Parse(original);
TF_EXPECT_OK(result.status());
// The string will be parsed successfully but the output strings are not
// exactly the same, because "3e2" is parsed into value 300 and will be
// printed as "300".
}
TEST_F(HloParserTest, Tuple) {
const string original = R"(HloModule EmptyTupleCreate_module:
ENTRY %EmptyTupleCreate.v1 () -> () {
%tuple = () tuple()
}
)";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
} // namespace
} // namespace tools
} // namespace xla

View File

@ -0,0 +1,58 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_
#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_
namespace xla {
namespace tools {
// Defines different kinds of tokens in a hlo module string.
enum class TokKind {
// Markers
kEof,
kError,
// Tokens with no info.
kEqual, // =
kComma, // ,
kColon, // :
kLsquare,
kRsquare, // [ ]
kLbrace,
kRbrace, // { }
kLparen,
kRparen, // ( )
kArrow, // ->
// Keywords
kw_HloModule,
kw_ENTRY,
kw_true,
kw_false,
// Typed tokens.
kName, // %foo
kShape, // f32[2,3]{1,0}
kOpcode, // add
kInt, // 42
kDecimal, // 4.2
};
} // namespace tools
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_

View File

@ -82,8 +82,8 @@ message DebugOptions {
// Dump all HLO modules as text into the provided directory path. // Dump all HLO modules as text into the provided directory path.
string xla_generate_hlo_text_to = 7; string xla_generate_hlo_text_to = 7;
// Dump compilation artifacts as JSON into this directory. // Dump compilation artifacts in binary proto into this directory.
string xla_dump_debug_json_to = 8; string xla_dump_hlo_proto_to = 8;
// Instrument the computation to collect per-HLO cycle counts. // Instrument the computation to collect per-HLO cycle counts.
bool xla_hlo_profile = 9; bool xla_hlo_profile = 9;

View File

@ -69,6 +69,28 @@ tf_cc_test(
], ],
) )
cc_library(
name = "adaptive_shared_batch_scheduler",
hdrs = ["adaptive_shared_batch_scheduler.h"],
deps = [
":batch_scheduler",
"//tensorflow/contrib/batching/util:periodic_function_dynamic",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "adaptive_shared_batch_scheduler_test",
srcs = ["adaptive_shared_batch_scheduler_test.cc"],
deps = [
":adaptive_shared_batch_scheduler",
"//tensorflow/contrib/batching/test_util:fake_clock_env",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library( cc_library(
name = "basic_batch_scheduler", name = "basic_batch_scheduler",
hdrs = ["basic_batch_scheduler.h"], hdrs = ["basic_batch_scheduler.h"],

View File

@ -0,0 +1,463 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
#include <functional>
#include <memory>
#include <queue>
#include <unordered_map>
#include <vector>
#include "tensorflow/contrib/batching/batch_scheduler.h"
#include "tensorflow/contrib/batching/util/periodic_function.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace serving {
namespace internal {
template <typename TaskType>
class ASBSBatch;
template <typename TaskType>
class ASBSQueue;
} // namespace internal
// Shared batch scheduler designed to minimize latency. The scheduler keeps
// track of a number of queues (one per model or model version) which are
// continuously enqueuing requests. The scheduler groups the requests into
// batches which it periodically sends off for processing (see
// shared_batch_scheduler.h for more details). The AdaptiveSharedBatchScheduler
// prioritizes batches by age (i.e. the batch's oldest request) irrespective of
// queue. The scheduler will process the oldest batch at an adjustable rate,
// regardless of batch size. The user can provide feedback to help set this rate
// to achieve some goal (i.e. minimize overall latency, limit cpu usage, etc).
//
// The rate (or rather, the corresponding period) is adjusted each time a batch
// is processed, using an exponentially weighted moving average to smooth
// potentially noisy feedback:
// ewma_feedback = ((N - 1) * ewma_feedback + feedback()) / N
// period *= (1 + K * emwa_feedback)
//
// Some potential use cases:
// Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing
// involves serial processing by a device, from a latency perspective it is
// desirable to keep the device evenly loaded, avoiding the need to wait for
// the device to process prior batches.
// feedback = num_pending_on_device() - desired_pending.
// CPU utilization - If the batch processing is cpu dominated, you can reap
// latency gains when underutilized by increasing the processing rate, but
// back the rate off when the load increases to avoid overload.
// feedback = cpu_rate() - desired_cpu_rate.
template <typename TaskType>
class AdaptiveSharedBatchScheduler
: public std::enable_shared_from_this<
AdaptiveSharedBatchScheduler<TaskType>> {
public:
struct Options {
// The name to use for the pool of batch threads.
string thread_pool_name = {"batch_threads"};
// Number of batch processing threads; equivalently the maximum number of
// concurrently running batches.
int64 num_batch_threads = port::NumSchedulableCPUs();
// The environment to use (typically only overridden by test code).
Env* env = Env::Default();
// Initial batch scheduling period in microseconds. Will be altered for
// non-zero rate_feedback.
double initial_scheduling_period_micros = 500;
// Minimum batch scheduling period in microseconds. Recommend setting this
// value greater than 0, otherwise it may take a while to recover from a
// sustained time of negative scheduling_period_feedback (which may occur
// under low load).
double min_scheduling_period_micros = 100;
// Maximum batch scheduling period in microseconds.
double max_scheduling_period_micros = 10000;
// Feedback function used to modify the scheduling period each time a batch
// is scheduled. Should return values roughly O(1), with positive values
// resulting in an increased period.
std::function<double()> scheduling_period_feedback = [] { return 0.; };
// To handle potentially noisy scheduling_period_feedback, the period is
// adjusted using an exponentially weighted moving average over the previous
// feedback_smoothing_batches batches. Must be greater than 0.
int64 feedback_smoothing_batches = 10;
};
// Ownership is shared between the caller of Create() and any queues created
// via AddQueue().
static Status Create(
const Options& options,
std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler);
struct QueueOptions {
// Maximum size of each batch.
int max_batch_size = 1000;
// Maximum number of enqueued (i.e. non-scheduled) batches.
int max_enqueued_batches = 10;
};
using BatchProcessor = std::function<void(std::unique_ptr<Batch<TaskType>>)>;
// Adds queue (and its callback) to be managed by this scheduler.
Status AddQueue(const QueueOptions& options,
BatchProcessor process_batch_callback,
std::unique_ptr<BatchScheduler<TaskType>>* queue);
private:
// access to AddBatch, RemoveQueue, GetEnv.
friend class internal::ASBSQueue<TaskType>;
explicit AdaptiveSharedBatchScheduler(const Options& options);
// Batch scheduling function which runs every scheduling_period_ microseconds.
void ProcessOneBatch();
// Notifies scheduler of non-empty batch which is eligible for processing.
void AddBatch(internal::ASBSBatch<TaskType>*);
// Removes queue from scheduler.
void RemoveQueue(const internal::ASBSQueue<TaskType>* queue);
Env* GetEnv() const { return options_.env; }
const Options options_;
struct BatchCompare {
bool operator()(const internal::ASBSBatch<TaskType>* a,
const internal::ASBSBatch<TaskType>* b);
};
// Collection of batches added by AddBatch, ordered by age. Owned by scheduler
// until they are released for processing.
std::priority_queue<const internal::ASBSBatch<TaskType>*,
std::vector<internal::ASBSBatch<TaskType>*>, BatchCompare>
batches_ GUARDED_BY(mu_);
// Unowned queues and callbacks added by AddQueue.
std::unordered_map<const internal::ASBSQueue<TaskType>*, BatchProcessor>
queues_and_callbacks_ GUARDED_BY(mu_);
mutex mu_;
// Responsible for running ProcessOneBatch. PeriodicFunction was used in order
// to check for deletion so that the thread can be shut down.
std::unique_ptr<PeriodicFunction> scheduling_thread_;
// Responsible for running the batch processing callbacks.
std::unique_ptr<thread::ThreadPool> batch_thread_pool_;
// Time interval in microseconds between successive ProcessOneBatch calls.
double scheduling_period_;
// Exponentially weighted moving average of
// options_.scheduling_period_feedback() evaluated in each ProcessOneBatch
// call.
double ewma_feedback_ = 0;
TF_DISALLOW_COPY_AND_ASSIGN(AdaptiveSharedBatchScheduler);
};
//////////////////////////////////////////////////////////
// Implementation details follow. API users need not read.
namespace internal {
// Consolidates tasks into batches, passing them off to the
// AdaptiveSharedBatchScheduler for processing.
template <typename TaskType>
class ASBSQueue : public BatchScheduler<TaskType> {
public:
using QueueOptions =
typename AdaptiveSharedBatchScheduler<TaskType>::QueueOptions;
ASBSQueue(std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
const QueueOptions& options);
~ASBSQueue() override;
// Adds task to current batch. Fails if the task size is larger than the batch
// size or if the current batch is full and this queue's number of outstanding
// batches is at its maximum.
Status Schedule(std::unique_ptr<TaskType>* task) override;
// Number of tasks waiting to be scheduled.
size_t NumEnqueuedTasks() const override;
// Number of size 1 tasks which could currently be scheduled without failing.
size_t SchedulingCapacity() const override;
// Notifies queue that a batch is about to be scheduled; the queue should not
// place any more tasks in this batch.
void ReleaseBatch(const ASBSBatch<TaskType>* batch);
private:
std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler_;
const QueueOptions options_;
// Owned by scheduler_.
ASBSBatch<TaskType>* current_batch_ GUARDED_BY(mu_) = nullptr;
int64 num_enqueued_batches_ GUARDED_BY(mu_) = 0;
int64 num_enqueued_tasks_ GUARDED_BY(mu_) = 0;
mutable mutex mu_;
TF_DISALLOW_COPY_AND_ASSIGN(ASBSQueue);
};
// Batch which remembers when and by whom it was created.
template <typename TaskType>
class ASBSBatch : public Batch<TaskType> {
public:
ASBSBatch(ASBSQueue<TaskType>* queue, int64 creation_time_micros)
: queue_(queue), creation_time_micros_(creation_time_micros) {}
~ASBSBatch() override {}
ASBSQueue<TaskType>* queue() const { return queue_; }
int64 creation_time_micros() const { return creation_time_micros_; }
private:
ASBSQueue<TaskType>* queue_;
const int64 creation_time_micros_;
TF_DISALLOW_COPY_AND_ASSIGN(ASBSBatch);
};
} // namespace internal
// ---------------- AdaptiveSharedBatchScheduler ----------------
template <typename TaskType>
Status AdaptiveSharedBatchScheduler<TaskType>::Create(
const Options& options,
std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler) {
if (options.num_batch_threads < 1) {
return errors::InvalidArgument("num_batch_threads must be positive; was ",
options.num_batch_threads);
}
if (options.min_scheduling_period_micros < 0) {
return errors::InvalidArgument(
"min_scheduling_period_micros must be >= 0; was ",
options.min_scheduling_period_micros);
}
if (options.min_scheduling_period_micros >
options.initial_scheduling_period_micros) {
return errors::InvalidArgument(
"initial_scheduling_period_micros (",
options.initial_scheduling_period_micros,
") must be >= min_scheduling_period_micros (",
options.min_scheduling_period_micros, ")");
}
if (options.initial_scheduling_period_micros >
options.max_scheduling_period_micros) {
return errors::InvalidArgument(
"initial_scheduling_period_micros (",
options.initial_scheduling_period_micros,
") must be <= max_scheduling_period_micros (",
options.max_scheduling_period_micros, ")");
}
if (options.feedback_smoothing_batches < 1) {
return errors::InvalidArgument(
"feedback_smoothing_batches must be positive; was ",
options.feedback_smoothing_batches);
}
scheduler->reset(new AdaptiveSharedBatchScheduler<TaskType>(options));
return Status::OK();
}
template <typename TaskType>
AdaptiveSharedBatchScheduler<TaskType>::AdaptiveSharedBatchScheduler(
const Options& options)
: options_(options),
scheduling_period_(options.initial_scheduling_period_micros) {
PeriodicFunction::Options opts;
opts.thread_name_prefix = "scheduling_thread";
opts.env = GetEnv();
scheduling_thread_.reset(
new PeriodicFunction([this] { ProcessOneBatch(); }, 0, opts));
batch_thread_pool_.reset(new thread::ThreadPool(
GetEnv(), options.thread_pool_name, options.num_batch_threads));
}
template <typename TaskType>
Status AdaptiveSharedBatchScheduler<TaskType>::AddQueue(
const QueueOptions& options, BatchProcessor process_batch_callback,
std::unique_ptr<BatchScheduler<TaskType>>* queue) {
if (options.max_batch_size <= 0) {
return errors::InvalidArgument("max_batch_size must be positive; was ",
options.max_batch_size);
}
if (options.max_enqueued_batches <= 0) {
return errors::InvalidArgument(
"max_enqueued_batches must be positive; was ",
options.max_enqueued_batches);
}
internal::ASBSQueue<TaskType>* asbs_queue_raw;
queue->reset(asbs_queue_raw = new internal::ASBSQueue<TaskType>(
this->shared_from_this(), options));
mutex_lock l(mu_);
queues_and_callbacks_[asbs_queue_raw] = process_batch_callback;
return Status::OK();
}
template <typename TaskType>
void AdaptiveSharedBatchScheduler<TaskType>::AddBatch(
internal::ASBSBatch<TaskType>* batch) {
mutex_lock l(mu_);
batches_.push(batch);
}
template <typename TaskType>
void AdaptiveSharedBatchScheduler<TaskType>::RemoveQueue(
const internal::ASBSQueue<TaskType>* queue) {
mutex_lock l(mu_);
queues_and_callbacks_.erase(queue);
}
template <typename TaskType>
void AdaptiveSharedBatchScheduler<TaskType>::ProcessOneBatch() {
static const double kFeedbackMultiplier = .001;
internal::ASBSBatch<TaskType>* batch = nullptr;
BatchProcessor callback;
const int64 start_time_micros = GetEnv()->NowMicros();
{
mutex_lock l(mu_);
if (!batches_.empty()) {
batch = batches_.top();
batches_.pop();
callback = queues_and_callbacks_[batch->queue()];
}
}
if (batch != nullptr) {
double feedback = options_.scheduling_period_feedback();
const int64 N = options_.feedback_smoothing_batches;
ewma_feedback_ = ((N - 1) * ewma_feedback_ + feedback) / N;
scheduling_period_ *= (1 + kFeedbackMultiplier * ewma_feedback_);
if (scheduling_period_ < options_.min_scheduling_period_micros) {
scheduling_period_ = options_.min_scheduling_period_micros;
} else if (scheduling_period_ > options_.max_scheduling_period_micros) {
scheduling_period_ = options_.max_scheduling_period_micros;
}
// Queue may destroy itself after ReleaseBatch is called.
batch->queue()->ReleaseBatch(batch);
batch_thread_pool_->Schedule([callback, batch] {
callback(std::unique_ptr<Batch<TaskType>>(batch));
});
}
const int64 sleep_time =
scheduling_period_ - (GetEnv()->NowMicros() - start_time_micros);
if (sleep_time > 0) {
GetEnv()->SleepForMicroseconds(sleep_time);
}
}
template <typename TaskType>
bool AdaptiveSharedBatchScheduler<TaskType>::BatchCompare::operator()(
const internal::ASBSBatch<TaskType>* a,
const internal::ASBSBatch<TaskType>* b) {
return a->creation_time_micros() > b->creation_time_micros();
}
// ---------------- ASBSQueue ----------------
namespace internal {
template <typename TaskType>
ASBSQueue<TaskType>::ASBSQueue(
std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
const QueueOptions& options)
: scheduler_(scheduler), options_(options) {}
template <typename TaskType>
ASBSQueue<TaskType>::~ASBSQueue() {
// Wait until last batch has been scheduled.
const int kSleepMicros = 1000;
for (;;) {
{
mutex_lock l(mu_);
if (num_enqueued_batches_ == 0) {
break;
}
}
scheduler_->GetEnv()->SleepForMicroseconds(kSleepMicros);
}
scheduler_->RemoveQueue(this);
}
template <typename TaskType>
Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
bool added_new_batch = false;
size_t size = (*task)->size();
if (size > options_.max_batch_size) {
return errors::InvalidArgument("Task size ", size,
" is larger than maximum batch size ",
options_.max_batch_size);
}
{
mutex_lock l(mu_);
// Current batch is full, create another if allowed.
if (current_batch_ &&
current_batch_->size() + size > options_.max_batch_size) {
if (num_enqueued_batches_ >= options_.max_enqueued_batches) {
return errors::Unavailable("The batch scheduling queue is full");
}
current_batch_->Close();
current_batch_ = nullptr;
}
if (!current_batch_) {
added_new_batch = true;
num_enqueued_batches_++;
current_batch_ =
new ASBSBatch<TaskType>(this, scheduler_->GetEnv()->NowMicros());
}
current_batch_->AddTask(std::move(*task));
num_enqueued_tasks_++;
}
if (added_new_batch) scheduler_->AddBatch(current_batch_);
return Status::OK();
}
template <typename TaskType>
void ASBSQueue<TaskType>::ReleaseBatch(const ASBSBatch<TaskType>* batch) {
mutex_lock l(mu_);
num_enqueued_batches_--;
num_enqueued_tasks_ -= batch->num_tasks();
if (batch == current_batch_) {
current_batch_->Close();
current_batch_ = nullptr;
}
}
template <typename TaskType>
size_t ASBSQueue<TaskType>::NumEnqueuedTasks() const {
mutex_lock l(mu_);
return num_enqueued_tasks_;
}
template <typename TaskType>
size_t ASBSQueue<TaskType>::SchedulingCapacity() const {
mutex_lock l(mu_);
const int current_batch_capacity =
current_batch_ ? options_.max_batch_size - current_batch_->size() : 0;
const int spare_batches =
options_.max_enqueued_batches - num_enqueued_batches_;
return spare_batches * options_.max_batch_size + current_batch_capacity;
}
} // namespace internal
} // namespace serving
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_

View File

@ -0,0 +1,438 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h"
#include "tensorflow/contrib/batching/test_util/fake_clock_env.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace serving {
namespace anonymous {
class FakeTask : public BatchTask {
public:
explicit FakeTask(size_t size) : size_(size) {}
~FakeTask() override = default;
size_t size() const override { return size_; }
private:
const size_t size_;
TF_DISALLOW_COPY_AND_ASSIGN(FakeTask);
};
// Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' on
// that task. Returns the resulting status.
Status ScheduleTask(size_t task_size, BatchScheduler<FakeTask>* scheduler) {
std::unique_ptr<FakeTask> task(new FakeTask(task_size));
Status status = scheduler->Schedule(&task);
// Schedule() should have consumed 'task' iff it returned Status::OK.
CHECK_EQ(status.ok(), task == nullptr);
return status;
}
// Creates a thread that waits on 'start' and then advances the fake clock in
// 'env' in a loop until 'stop' is notified. Useful for allowing objects that
// use the clock to be destroyed.
std::unique_ptr<Thread> CreateFakeClockAdvancerThread(
test_util::FakeClockEnv* env, Notification* start, Notification* stop) {
return std::unique_ptr<Thread>(Env::Default()->StartThread(
{}, "FakeClockAdvancerThread", [env, start, stop] {
start->WaitForNotification();
while (!stop->HasBeenNotified()) {
env->AdvanceByMicroseconds(10);
Env::Default()->SleepForMicroseconds(10);
}
}));
}
TEST(AdaptiveSharedBatchSchedulerTest, Basic) {
for (const bool delete_scheduler_early : {false, true}) {
for (const bool delete_queue_1_early : {false, true}) {
int queue_0_tasks = 0;
auto queue_0_callback =
[&queue_0_tasks](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
for (int i = 0; i < batch->num_tasks(); i++) {
queue_0_tasks += batch->task(i).size();
}
};
int queue_1_tasks = 0;
auto queue_1_callback =
[&queue_1_tasks](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
for (int i = 0; i < batch->num_tasks(); i++) {
queue_1_tasks += batch->task(i).size();
}
};
{
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
AdaptiveSharedBatchScheduler<FakeTask>::Create({}, &scheduler));
// Create two queues.
std::unique_ptr<BatchScheduler<FakeTask>> queue_0;
TF_ASSERT_OK(scheduler->AddQueue({}, queue_0_callback, &queue_0));
std::unique_ptr<BatchScheduler<FakeTask>> queue_1;
TF_ASSERT_OK(scheduler->AddQueue({}, queue_1_callback, &queue_1));
if (delete_scheduler_early) {
// Delete our copy of the scheduler. The queues should keep it alive
// under the covers.
scheduler = nullptr;
}
// Submit tasks to the two queues, and (optionally) remove the queues.
TF_ASSERT_OK(ScheduleTask(1, queue_0.get()));
TF_ASSERT_OK(ScheduleTask(2, queue_1.get()));
TF_ASSERT_OK(ScheduleTask(3, queue_0.get()));
TF_ASSERT_OK(ScheduleTask(4, queue_1.get()));
if (delete_queue_1_early) {
queue_1 = nullptr;
}
TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
}
EXPECT_EQ(queue_0_tasks, 9);
EXPECT_EQ(queue_1_tasks, 6);
}
}
}
TEST(AdaptiveSharedBatchSchedulerTest, BadOptions) {
using Scheduler = AdaptiveSharedBatchScheduler<FakeTask>;
std::shared_ptr<Scheduler> scheduler;
Scheduler::Options options;
options.num_batch_threads = 0;
EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
options = Scheduler::Options();
options.min_scheduling_period_micros = 50;
options.max_scheduling_period_micros = 100;
options.initial_scheduling_period_micros = 1;
EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
options = Scheduler::Options();
options.min_scheduling_period_micros = 50;
options.max_scheduling_period_micros = 100;
options.initial_scheduling_period_micros = 1000;
EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
options = Scheduler::Options();
options.min_scheduling_period_micros = 100;
options.max_scheduling_period_micros = 50;
options.initial_scheduling_period_micros = 75;
EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
options = Scheduler::Options();
options.feedback_smoothing_batches = 0;
EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
}
TEST(AdaptiveSharedBatchSchedulerTest, ObeysQueueOptions) {
test_util::FakeClockEnv env(Env::Default());
Notification start_teardown, stop_teardown;
std::unique_ptr<Thread> teardown_thread =
CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
{
AdaptiveSharedBatchScheduler<FakeTask>::Options options;
options.initial_scheduling_period_micros = 1000;
options.env = &env;
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
std::unique_ptr<BatchScheduler<FakeTask>> queue_0;
std::unique_ptr<BatchScheduler<FakeTask>> queue_1;
int queue_0_tasks = 0;
int queue_1_tasks = 0;
auto queue_0_callback = [&queue_0_tasks,
&env](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
for (int i = 0; i < batch->num_tasks(); i++) {
queue_0_tasks += batch->task(i).size();
}
env.SleepForMicroseconds(1);
};
auto queue_1_callback = [&queue_1_tasks,
&env](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
for (int i = 0; i < batch->num_tasks(); i++) {
queue_1_tasks += batch->task(i).size();
}
env.SleepForMicroseconds(1);
};
AdaptiveSharedBatchScheduler<FakeTask>::QueueOptions queue_options;
queue_options.max_batch_size = 10;
queue_options.max_enqueued_batches = 0;
// Queue must have max_enqueued_batchs > 1.
EXPECT_FALSE(
scheduler->AddQueue(queue_options, queue_0_callback, &queue_0).ok());
queue_options.max_enqueued_batches = 2;
TF_ASSERT_OK(
scheduler->AddQueue(queue_options, queue_0_callback, &queue_0));
queue_options.max_batch_size = 0;
// Queue must have max_batch_size > 0.
EXPECT_FALSE(
scheduler->AddQueue(queue_options, queue_1_callback, &queue_1).ok());
queue_options.max_batch_size = 2;
queue_options.max_enqueued_batches = 1;
TF_ASSERT_OK(
scheduler->AddQueue(queue_options, queue_1_callback, &queue_1));
// Wait for scheduling_thread to sleep.
env.BlockUntilThreadsAsleep(1);
// Task larger than max_batch_size shouldn't schedule.
EXPECT_FALSE(ScheduleTask(15, queue_0.get()).ok());
TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
env.AdvanceByMicroseconds(1);
// Task larger than max_batch_size shouldn't schedule.
EXPECT_FALSE(ScheduleTask(3, queue_1.get()).ok());
TF_ASSERT_OK(ScheduleTask(1, queue_1.get()));
TF_ASSERT_OK(ScheduleTask(1, queue_1.get()));
env.AdvanceByMicroseconds(1);
// Exceeds max_enqueued_batches, shouldn't schedule.
EXPECT_FALSE(ScheduleTask(1, queue_1.get()).ok());
TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
// Exceeds max_enqueued_batches, shouldn't schedule.
EXPECT_FALSE(ScheduleTask(6, queue_0.get()).ok());
TF_ASSERT_OK(ScheduleTask(4, queue_0.get()));
// Batches should be processed in order from oldest to newest.
env.AdvanceByMicroseconds(1000);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(queue_0_tasks, 10);
EXPECT_EQ(queue_1_tasks, 0);
env.AdvanceByMicroseconds(1000);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(queue_0_tasks, 10);
EXPECT_EQ(queue_1_tasks, 2);
env.AdvanceByMicroseconds(1000);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(queue_0_tasks, 19);
EXPECT_EQ(queue_1_tasks, 2);
start_teardown.Notify();
}
stop_teardown.Notify();
}
TEST(AdaptiveSharedBatchSchedulerTest, RateFeedback) {
test_util::FakeClockEnv env(Env::Default());
Notification start_teardown, stop_teardown;
std::unique_ptr<Thread> teardown_thread =
CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
{
double feedback = 0;
AdaptiveSharedBatchScheduler<FakeTask>::Options options;
options.initial_scheduling_period_micros = 1000;
options.min_scheduling_period_micros = 200;
options.max_scheduling_period_micros = 2000;
options.env = &env;
options.scheduling_period_feedback = [&feedback] { return feedback; };
options.feedback_smoothing_batches = 1;
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
std::unique_ptr<BatchScheduler<FakeTask>> queue;
int scheduled_items = 0;
auto queue_callback = [&scheduled_items,
&env](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
scheduled_items = 0;
for (int i = 0; i < batch->num_tasks(); i++) {
scheduled_items += batch->task(i).size();
}
env.SleepForMicroseconds(1);
};
TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
// Wait for scheduling_thread to sleep.
env.BlockUntilThreadsAsleep(1);
// Enqueue 6 batches.
for (int i = 0; i < 6; i++) {
TF_ASSERT_OK(ScheduleTask(900 + i, queue.get()));
env.AdvanceByMicroseconds(1);
}
feedback = -500;
env.AdvanceByMicroseconds(994);
env.BlockUntilThreadsAsleep(2); // scheduling period = 500 usec.
EXPECT_EQ(scheduled_items, 900);
env.AdvanceByMicroseconds(500);
env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec.
EXPECT_EQ(scheduled_items, 901);
feedback = 0;
env.AdvanceByMicroseconds(250);
env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec.
EXPECT_EQ(scheduled_items, 902);
feedback = 10000; // large feedback should hit max_scheduling_period.
env.AdvanceByMicroseconds(250);
env.BlockUntilThreadsAsleep(2); // scheduling period = 2000 usec.
EXPECT_EQ(scheduled_items, 903);
feedback = -10000; // large feedback should hit min_scheduling_period.
env.AdvanceByMicroseconds(1999);
// No callback scheduled, only scheduling thread sleeping.
env.BlockUntilThreadsAsleep(1);
EXPECT_EQ(scheduled_items, 903);
env.AdvanceByMicroseconds(1);
env.BlockUntilThreadsAsleep(2); // scheduling period = 200 usec.
EXPECT_EQ(scheduled_items, 904);
env.AdvanceByMicroseconds(200);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(scheduled_items, 905);
start_teardown.Notify();
}
stop_teardown.Notify();
}
TEST(AdaptiveSharedBatchSchedulerTest, FeedbackSmoothing) {
test_util::FakeClockEnv env(Env::Default());
Notification start_teardown, stop_teardown;
std::unique_ptr<Thread> teardown_thread =
CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
{
double feedback = 0;
AdaptiveSharedBatchScheduler<FakeTask>::Options options;
options.initial_scheduling_period_micros = 1000;
options.env = &env;
options.scheduling_period_feedback = [&feedback] { return feedback; };
options.feedback_smoothing_batches = 3;
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
std::unique_ptr<BatchScheduler<FakeTask>> queue;
int scheduled_items = 0;
auto queue_callback = [&scheduled_items,
&env](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
scheduled_items = 0;
for (int i = 0; i < batch->num_tasks(); i++) {
scheduled_items += batch->task(i).size();
}
env.SleepForMicroseconds(1);
};
TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
// Wait for scheduling_thread to sleep.
env.BlockUntilThreadsAsleep(1);
// Enqueue 4 batches.
for (int i = 0; i < 4; i++) {
TF_ASSERT_OK(ScheduleTask(900 + i, queue.get()));
env.AdvanceByMicroseconds(1);
}
feedback = -300;
env.AdvanceByMicroseconds(996);
env.BlockUntilThreadsAsleep(2);
// ewma_feedback = 100, scheduling_period = 900.
EXPECT_EQ(scheduled_items, 900);
env.AdvanceByMicroseconds(899);
// No callback scheduled, only scheduling thread sleeping.
env.BlockUntilThreadsAsleep(1);
EXPECT_EQ(scheduled_items, 900);
env.AdvanceByMicroseconds(1);
env.BlockUntilThreadsAsleep(2);
// ewma_feedback = 167, scheduling_period = 750.
EXPECT_EQ(scheduled_items, 901);
env.AdvanceByMicroseconds(749);
// No callback scheduled, only scheduling thread sleeping.
env.BlockUntilThreadsAsleep(1);
EXPECT_EQ(scheduled_items, 901);
feedback = 1000 / 3.;
env.AdvanceByMicroseconds(1);
env.BlockUntilThreadsAsleep(2);
// emwa_feedback = 0, scheduling_period = 750.
EXPECT_EQ(scheduled_items, 902);
env.AdvanceByMicroseconds(749);
// No callback scheduled, only scheduling thread sleeping.
env.BlockUntilThreadsAsleep(1);
EXPECT_EQ(scheduled_items, 902);
env.AdvanceByMicroseconds(1);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(scheduled_items, 903);
start_teardown.Notify();
}
stop_teardown.Notify();
}
TEST(AdaptiveSharedBatchSchedulerTest, QueueCapacityInfo) {
test_util::FakeClockEnv env(Env::Default());
Notification start_teardown, stop_teardown;
std::unique_ptr<Thread> teardown_thread =
CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
{
AdaptiveSharedBatchScheduler<FakeTask>::Options options;
options.initial_scheduling_period_micros = 1000;
options.env = &env;
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
std::unique_ptr<BatchScheduler<FakeTask>> queue;
int scheduled_items = 0;
auto queue_callback = [&scheduled_items,
&env](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
scheduled_items = 0;
for (int i = 0; i < batch->num_tasks(); i++) {
scheduled_items += batch->task(i).size();
}
env.SleepForMicroseconds(1);
};
AdaptiveSharedBatchScheduler<FakeTask>::QueueOptions queue_options;
queue_options.max_batch_size = 10;
queue_options.max_enqueued_batches = 10;
TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue));
// Wait for scheduling_thread to sleep.
env.BlockUntilThreadsAsleep(1);
// Enqueue 3 tasks.
EXPECT_EQ(queue->NumEnqueuedTasks(), 0);
EXPECT_EQ(queue->SchedulingCapacity(), 100);
TF_ASSERT_OK(ScheduleTask(5, queue.get()));
EXPECT_EQ(queue->NumEnqueuedTasks(), 1);
EXPECT_EQ(queue->SchedulingCapacity(), 95);
env.AdvanceByMicroseconds(1);
TF_ASSERT_OK(ScheduleTask(6, queue.get()));
EXPECT_EQ(queue->NumEnqueuedTasks(), 2);
EXPECT_EQ(queue->SchedulingCapacity(), 84);
env.AdvanceByMicroseconds(1);
TF_ASSERT_OK(ScheduleTask(1, queue.get()));
EXPECT_EQ(queue->NumEnqueuedTasks(), 3);
EXPECT_EQ(queue->SchedulingCapacity(), 83);
env.AdvanceByMicroseconds(998);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(scheduled_items, 5);
env.AdvanceByMicroseconds(1000);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(scheduled_items, 7);
start_teardown.Notify();
}
stop_teardown.Notify();
}
} // namespace anonymous
} // namespace serving
} // namespace tensorflow

View File

@ -78,7 +78,7 @@ template <typename TaskType>
class Batch { class Batch {
public: public:
Batch() = default; Batch() = default;
~Batch(); // Blocks until the batch is closed. virtual ~Batch(); // Blocks until the batch is closed.
// Appends 'task' to the batch. After calling AddTask(), the newly-added task // Appends 'task' to the batch. After calling AddTask(), the newly-added task
// can be accessed via task(num_tasks()-1) or mutable_task(num_tasks()-1). // can be accessed via task(num_tasks()-1) or mutable_task(num_tasks()-1).

View File

@ -14,7 +14,7 @@
# ============================================================================== # ==============================================================================
include (ExternalProject) include (ExternalProject)
set(cub_URL https://github.com/NVlabs/cub/archive/1.7.4.zip) set(cub_URL https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip)
set(cub_HASH SHA256=20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31) set(cub_HASH SHA256=20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31)
set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub)
set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub)

View File

@ -15,7 +15,7 @@
include (ExternalProject) include (ExternalProject)
set(gif_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/gif_archive/giflib-5.1.4/) set(gif_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/gif_archive/giflib-5.1.4/)
set(gif_URL http://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz) set(gif_URL https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz)
set(gif_HASH SHA256=34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1) set(gif_HASH SHA256=34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1)
set(gif_INSTALL ${CMAKE_BINARY_DIR}/gif/install) set(gif_INSTALL ${CMAKE_BINARY_DIR}/gif/install)
set(gif_BUILD ${CMAKE_BINARY_DIR}/gif/src/gif) set(gif_BUILD ${CMAKE_BINARY_DIR}/gif/src/gif)

View File

@ -15,7 +15,7 @@
include (ExternalProject) include (ExternalProject)
set(jpeg_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/jpeg_archive) set(jpeg_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/jpeg_archive)
set(jpeg_URL http://mirror.bazel.build/www.ijg.org/files/jpegsrc.v9a.tar.gz) set(jpeg_URL https://mirror.bazel.build/www.ijg.org/files/jpegsrc.v9a.tar.gz)
set(jpeg_HASH SHA256=3a753ea48d917945dd54a2d97de388aa06ca2eb1066cbfdc6652036349fe05a7) set(jpeg_HASH SHA256=3a753ea48d917945dd54a2d97de388aa06ca2eb1066cbfdc6652036349fe05a7)
set(jpeg_BUILD ${CMAKE_CURRENT_BINARY_DIR}/jpeg/src/jpeg) set(jpeg_BUILD ${CMAKE_CURRENT_BINARY_DIR}/jpeg/src/jpeg)
set(jpeg_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/jpeg/install) set(jpeg_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/jpeg/install)

View File

@ -15,7 +15,7 @@
include (ExternalProject) include (ExternalProject)
set(lmdb_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/lmdb) set(lmdb_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/lmdb)
set(lmdb_URL http://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz) set(lmdb_URL https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz)
set(lmdb_HASH SHA256=108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d05326) set(lmdb_HASH SHA256=108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d05326)
set(lmdb_BUILD ${CMAKE_BINARY_DIR}/lmdb/src/lmdb) set(lmdb_BUILD ${CMAKE_BINARY_DIR}/lmdb/src/lmdb)
set(lmdb_INSTALL ${CMAKE_BINARY_DIR}/lmdb/install) set(lmdb_INSTALL ${CMAKE_BINARY_DIR}/lmdb/install)

View File

@ -47,4 +47,4 @@ ExternalProject_Add(snappy
) )
# actually enables snappy in the source code # actually enables snappy in the source code
add_definitions(-DSNAPPY) add_definitions(-DTF_USE_SNAPPY)

View File

@ -86,7 +86,7 @@ cuda_py_test(
"//tensorflow/python:client", "//tensorflow/python:client",
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python/eager:graph_callable", "//tensorflow/python/eager:graph_callable",
"//tensorflow/python:platform_test", "//tensorflow/python/eager:test",
"//tensorflow/python:variables", "//tensorflow/python:variables",
], ],
) )
@ -132,11 +132,12 @@ py_library(
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:dtypes", "//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:layers_base", "//tensorflow/python:layers_base",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:util", "//tensorflow/python:util",
"//tensorflow/python:variable_scope", "//tensorflow/python:variable_scope",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:function",
], ],
) )
@ -146,6 +147,10 @@ py_test(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":metrics", ":metrics",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test", "//tensorflow/python/eager:test",
], ],
) )
@ -160,6 +165,8 @@ py_library(
deps = [ deps = [
":datasets", ":datasets",
":metrics", ":metrics",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:function",
], ],
) )

View File

@ -86,7 +86,7 @@ class EvaluatorTest(test.TestCase):
for v in e.metric_variables: for v in e.metric_variables:
p = v.name.split("/")[0] p = v.name.split("/")[0]
prefix_count[p] = prefix_count.get(p, 0) + 1 prefix_count[p] = prefix_count.get(p, 0) + 1
self.assertEqual({"outer-mean": 2, "mean": 2}, prefix_count) self.assertEqual({"outer_mean": 2, "mean": 2}, prefix_count)
def testDataset(self): def testDataset(self):
e = SimpleEvaluator(IdentityModel()) e = SimpleEvaluator(IdentityModel())

View File

@ -18,6 +18,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import re
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
@ -25,55 +29,69 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
_to_replace = re.compile("[^A-Za-z0-9.]")
class Metric(object): class Metric(object):
"""A metric holds state for aggregating statistics over an evaluation run. """A metric holds state for aggregating statistics over an evaluation run.
Users will use Evaluator.add_metric() to add Metric objects to their Users will use Evaluator.add_metric() to add Metric objects to their
evaluation, call them in each step, and then use evaluation, call them in each step (treating the object as a callable),
Evaluator.all_metric_results() at the end. and then use Evaluator.all_metric_results() at the end.
Descendants will implement: Descendants will implement:
* call(): Should follow this pattern: * `build()`: All variables should be created in this method, by calling
if not self.built: `self.add_variable()` as in: `self.var = self.add_variable(...)`
self.var = self.add_variable(...) build() will be called in the first invocation of `__call__()`, with
self.add_update(self.var.assign_add(...)) the same arguments passed `call()`.
* aggregate(): Adds in the state from a list of metrics of the same type * `call()`: Has all updates to variables, as in:
as `self`. (Default of summing all the variables will be fine for most self.var.assign_add(...)
descendants.) * `result()`: Computes and returns a final value for the metric
* result(): Computes and returns a final value for the metric
from the variables in `self`. from the variables in `self`.
Decendants may override, but usually won't need to:
* `aggregate()`: Adds in the state from a list of metrics of the same type
as `self`. (Default is to sum all the variables.)
* `reset()`: Reset all variables to their initial state. (Default is to
zero all the variables.)
Note that users should not call `aggregate()` or `reset()`, they are for
use by TensorFlow infrastructure.
""" """
def __init__(self, name=None): def __init__(self, name=None):
self.built = False self._built = False
self._vars = [] self._vars = []
self._updates = [] self._updates = []
self._name = name or self.__class__.__name__ name = name or self.__class__.__name__
# TODO(josh11b): Need some way to make sure two Metrics in the same # Replace things like spaces in name to create a valid scope name.
# Network have distinct names. Maybe we can get a unique name from scope_name = _to_replace.sub("_", name)
# a name/variable scope? # We create the variable scope now to get the unique name that will
# TODO(josh11b): self._in_graph_mode = context.in_graph_mode() # be used as a variable prefix when build() calls add_variable().
with variable_scope.variable_scope(
None, default_name=scope_name, use_resource=True, reuse=False) as scope:
pos = scope.name.rfind(scope_name)
self._name = name + scope.name[pos + len(scope_name):]
self._scope = scope
if context.in_graph_mode():
# We make self.call() into a graph callable here, so that we can
# return a single op that performs all of the variable updates.
self.call = function.defun(self.call)
# ---- API for users ---- # ---- API for users ----
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
# TODO(josh11b): If self._in_graph_mode is true, make self.call() into a """Returns op to execute to update this metric for these inputs.
# graph callable here, so that variable updates happen without requiring
# a separate fetch. Returns None if eager execution is enabled.
# TODO(josh11b): Do we need a separate build() method to separate
# initialization from each update? If so, how do we get the arguments Args:
# to it? We *could* just pass in *args and **kwargs... *args:
if not self.built: **kwargs: A mini-batch of inputs to the Metric, passed on to `call()`.
# TODO(ashankar): Set up container isolation so there is no chance """
# distinct metrics objects accidentally share variables. if not self._built:
# TODO(josh11b): Replace things like spaces in self._name to create with variable_scope.variable_scope(self._scope):
# a valid scope name. self.build(*args, **kwargs)
with variable_scope.variable_scope( self._built = True
self._name, use_resource=True, reuse=False): return self.call(*args, **kwargs)
ret = self.call(*args, **kwargs)
self.built = True
else:
ret = self.call(*args, **kwargs)
return ret
@property @property
def name(self): def name(self):
@ -84,10 +102,43 @@ class Metric(object):
return self._vars return self._vars
# ---- To be implemented by descendants --- # ---- To be implemented by descendants ---
def build(self, *args, **kwargs):
"""Method to create variables.
Called by `__call__()` before `call()` for the first time.
Args:
*args:
**kwargs: The arguments to the first invocation of `__call__()`.
`build()` may use the shape and/or dtype of these arguments
when deciding how to create variables.
"""
raise NotImplementedError("Metrics must define a build() member function")
def call(self, *args, **kwargs): def call(self, *args, **kwargs):
"""Accumulates statistics for the metric.""" """Accumulates statistics for the metric. Users should use __call__ instead.
Note: This function is executed as a graph function in graph mode.
This means:
a) Operations on the same resource are executed in textual order.
This should make it easier to do things like add the updated
value of a variable to another, for example.
b) You don't need to worry about collecting the update ops to execute.
All update ops added to the graph by this function will be executed.
As a result, code should generally work the same way with graph or
eager execution.
Args:
*args:
**kwargs: A mini-batch of inputs to the Metric, as passed to
`__call__()`.
"""
raise NotImplementedError("Metrics must define a call() member function") raise NotImplementedError("Metrics must define a call() member function")
def result(self): # TODO(josh11b): Add an optional summary_writer parameter.
"""Computes and returns a final value for the metric."""
raise NotImplementedError("Metrics must define a result() member function")
# We can support two different strategies of for doing data-parallel # We can support two different strategies of for doing data-parallel
# distributed metric computations: # distributed metric computations:
# * Put metric variables on the first device and rely on small # * Put metric variables on the first device and rely on small
@ -123,16 +174,19 @@ class Metric(object):
self._vars[i].assign_add(math_ops.add_n([m._vars[i] for m in metrics])) self._vars[i].assign_add(math_ops.add_n([m._vars[i] for m in metrics]))
# pylint: enable=protected-access # pylint: enable=protected-access
def result(self): # TODO(josh11b): Add an optional summary_writer parameter. def reset(self):
"""Computes and returns a final value for the metric.""" """Reset this metric to a freshly initialized state.
raise NotImplementedError("Metrics must define a result() member function")
Default implementation zeros all the metric variables.
"""
for v in self._vars:
v.assign(math_ops.zeros_like(v))
# ---- For use by descendants --- # ---- For use by descendants ---
def add_variable(self, name, shape=None, dtype=None, initializer=None): def add_variable(self, name, shape=None, dtype=None, initializer=None):
"""***Only for use by descendants of Metric***.""" """***Only for use by descendants of Metric***."""
if self.built: if self._built:
raise RuntimeError("Can't call add_variable() after a Metric has been " raise RuntimeError("Can't call add_variable() except in build().")
"built in the first call().")
v = variable_scope.get_variable(name, shape, dtype, initializer, v = variable_scope.get_variable(name, shape, dtype, initializer,
trainable=False, use_resource=True) trainable=False, use_resource=True)
self._vars.append(v) self._vars.append(v)
@ -144,6 +198,15 @@ class Mean(Metric):
# TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64? # TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64?
# Or defaults to type of the input if it is tf.float32, else tf.float64? # Or defaults to type of the input if it is tf.float32, else tf.float64?
def build(self, values, weights=None):
del values, weights # build() does not use call's arguments
self.numer = self.add_variable(name="numer", shape=(),
dtype=dtypes.float64,
initializer=init_ops.zeros_initializer)
self.denom = self.add_variable(name="denom", shape=(),
dtype=dtypes.float64,
initializer=init_ops.zeros_initializer)
def call(self, values, weights=None): def call(self, values, weights=None):
"""Accumulate statistics for computing the mean. """Accumulate statistics for computing the mean.
@ -154,13 +217,6 @@ class Mean(Metric):
values: Tensor with the per-example value. values: Tensor with the per-example value.
weights: Optional weighting of each example. Defaults to 1. weights: Optional weighting of each example. Defaults to 1.
""" """
if not self.built: # False only in the first call().
self.numer = self.add_variable(name="numer", shape=(),
dtype=dtypes.float64,
initializer=init_ops.zeros_initializer)
self.denom = self.add_variable(name="denom", shape=(),
dtype=dtypes.float64,
initializer=init_ops.zeros_initializer)
if weights is None: if weights is None:
self.denom.assign_add( self.denom.assign_add(
math_ops.cast(array_ops.size(values), dtypes.float64)) math_ops.cast(array_ops.size(values), dtypes.float64))
@ -179,6 +235,10 @@ class Mean(Metric):
class Accuracy(Mean): class Accuracy(Mean):
"""Calculates how often `predictions` matches `labels`.""" """Calculates how often `predictions` matches `labels`."""
def build(self, labels, predictions, weights=None):
del labels, predictions, weights
super(Accuracy, self).build(None) # Arguments are unused
def call(self, labels, predictions, weights=None): def call(self, labels, predictions, weights=None):
"""Accumulate accuracy statistics. """Accumulate accuracy statistics.

View File

@ -19,7 +19,11 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.eager.python import metrics from tensorflow.contrib.eager.python import metrics
from tensorflow.python.eager import context
from tensorflow.python.eager import test from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
class MetricsTest(test.TestCase): class MetricsTest(test.TestCase):
@ -56,6 +60,53 @@ class MetricsTest(test.TestCase):
m([7], [2]) # 0 correct, weight 1 m([7], [2]) # 0 correct, weight 1
self.assertEqual(2.5/5, m.result().numpy()) self.assertEqual(2.5/5, m.result().numpy())
def testTwoMeans(self):
# Verify two metrics with the same class and name don't
# accidentally share state.
m1 = metrics.Mean()
m2 = metrics.Mean()
m1(0)
m2(2)
self.assertEqual(0, m1.result().numpy())
self.assertEqual(2, m2.result().numpy())
self.assertNotEqual(m1.name, m2.name)
def testNamesWithSpaces(self):
# Verify two metrics with the same class and name don't
# accidentally share state.
m1 = metrics.Mean("has space")
m2 = metrics.Mean("has space")
m2(2)
m1(0)
self.assertEqual(m1.name, "has space")
self.assertEqual(m1.numer.name, "has_space/numer:0")
self.assertEqual(m2.name, "has space_1")
self.assertEqual(m2.numer.name, "has_space_1/numer:0")
def testGraph(self):
with context.graph_mode(), self.test_session() as sess:
m = metrics.Mean()
p = array_ops.placeholder(dtypes.float32)
accumulate = m(p)
variables.global_variables_initializer().run()
sess.run(accumulate, feed_dict={p: [1, 10, 100]})
sess.run(accumulate, feed_dict={p: 1000})
sess.run(accumulate, feed_dict={p: [10000, 100000]})
self.assertAllEqual(m.result().eval(), 111111.0/6)
def testTwoMeansGraph(self):
# Verify two metrics with the same class and name don't
# accidentally share state.
with context.graph_mode(), self.test_session() as sess:
m1 = metrics.Mean()
m2 = metrics.Mean()
accumulate1 = m1(0)
accumulate2 = m2(2)
variables.global_variables_initializer().run()
sess.run([accumulate1, accumulate2])
self.assertEqual(0, m1.result().eval())
self.assertEqual(2, m2.result().eval())
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()

View File

@ -22,6 +22,7 @@ import os
from tensorflow.contrib.eager.python import saver as _saver from tensorflow.contrib.eager.python import saver as _saver
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import graph_callable from tensorflow.python.eager import graph_callable
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
@ -29,7 +30,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
class SaverTest(test.TestCase): class SaverTest(test.TestCase):
@ -38,7 +38,7 @@ class SaverTest(test.TestCase):
return '/device:GPU:0' if context.num_gpus() else '/device:CPU:0' return '/device:GPU:0' if context.num_gpus() else '/device:CPU:0'
def testBasics(self): def testBasics(self):
with context.eager_mode(), ops.device(self._dev()): with ops.device(self._dev()):
v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
def model(): def model():
return array_ops.constant(2.0) * v1 return array_ops.constant(2.0) * v1
@ -54,8 +54,42 @@ class SaverTest(test.TestCase):
saver.restore(ckpt_prefix) saver.restore(ckpt_prefix)
self.assertEqual(v1.read_value().numpy(), 1.0) self.assertEqual(v1.read_value().numpy(), 1.0)
def testRestoreOnCreate(self): def testSameNameNoClobbering(self):
with context.eager_mode(), ops.device(self._dev()): with context.eager_mode(), ops.device(self._dev()):
# Note that this test purposefully uses Graphs rather than
# IsolateTest. Users are more likely to accidentally create the same
# variable name this way.
first_graph = ops.Graph()
with first_graph.as_default():
v1_first_graph = resource_variable_ops.ResourceVariable(1.0, name='v1')
with ops.Graph().as_default():
v1_second_graph = resource_variable_ops.ResourceVariable(2.0, name='v1')
saver = _saver.Saver([v1_first_graph, v1_second_graph])
ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
with self.assertRaisesRegexp(ValueError, 'v1'):
saver.save(ckpt_prefix)
def testDifferentGraphError(self):
with context.eager_mode(), ops.device(self._dev()):
with ops.Graph().as_default():
v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
with ops.Graph().as_default():
saver = _saver.Saver([v1])
ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
with self.assertRaisesRegexp(ValueError, 'Graph'):
saver.save(ckpt_prefix)
def testSameObjectOK(self):
with context.eager_mode(), ops.device(self._dev()):
v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
# While different objects with the same shared_name are not good, passing
# in the same object multiple times is fine.
saver = _saver.Saver([v1, v1])
ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
saver.save(ckpt_prefix)
def testRestoreOnCreate(self):
with ops.device(self._dev()):
def model(init_val): def model(init_val):
v1 = resource_variable_ops.ResourceVariable(init_val, name='v1') v1 = resource_variable_ops.ResourceVariable(init_val, name='v1')
return array_ops.constant(1.0) * v1, v1 return array_ops.constant(1.0) * v1, v1
@ -71,12 +105,9 @@ class SaverTest(test.TestCase):
# Value is from checkpoint, but not from argument. # Value is from checkpoint, but not from argument.
ret, _ = model(2.0) ret, _ = model(2.0)
self.assertEqual(ret.numpy(), 1.0) self.assertEqual(ret.numpy(), 1.0)
# Create it a second time won't re-assign the checkpoint value.
v1_2 = resource_variable_ops.ResourceVariable(3.0, name='v1')
self.assertEqual(v1_2.read_value().numpy(), 3.0)
def testRestoreNotFound(self): def testRestoreNotFound(self):
with context.eager_mode(), ops.device(self._dev()): with ops.device(self._dev()):
def model(v): def model(v):
return array_ops.constant(1.0) * v return array_ops.constant(1.0) * v
@ -92,7 +123,7 @@ class SaverTest(test.TestCase):
_ = model(resource_variable_ops.ResourceVariable(1.0, name='v2')) _ = model(resource_variable_ops.ResourceVariable(1.0, name='v2'))
def testSaveRestoreGraphCallable(self): def testSaveRestoreGraphCallable(self):
with context.eager_mode(), ops.device(self._dev()): with ops.device(self._dev()):
@graph_callable.graph_callable( @graph_callable.graph_callable(
[graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
def model(x): def model(x):

View File

@ -53,6 +53,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@in_eager_mode @@in_eager_mode
@@in_graph_mode @@in_graph_mode
@@IsolateTest
@@run_test_in_graph_and_eager_modes @@run_test_in_graph_and_eager_modes
""" """
@ -84,6 +85,7 @@ from tensorflow.python.eager.execution_callbacks import nan_callback
from tensorflow.python.eager.execution_callbacks import seterr from tensorflow.python.eager.execution_callbacks import seterr
from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.framework.ops import enable_eager_execution
from tensorflow.python.framework.ops import eager_run as run from tensorflow.python.framework.ops import eager_run as run
from tensorflow.python.framework.test_util import IsolateTest
from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes
from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable
from tensorflow.python.util.all_util import remove_undocumented from tensorflow.python.util.all_util import remove_undocumented

View File

@ -24,7 +24,11 @@ the full-batch version.
approach for computing the initial cluster assignments that is expensive but is approach for computing the initial cluster assignments that is expensive but is
typically less prone to getting stuck in bad local minima. typically less prone to getting stuck in bad local minima.
We provide distributed implementations of both full-batch and mini-batch **[k-MC2](https://www.aaai.org/ocs/index.php/AAAI/AAAI16/paper/view/12147/11759)**
K-Means algorithm. Both K-Means++ and random initialization are supported. provides a very fast seeding method that provides high quality centers
The user can also choose between **Cosine** and **Squared Euclidean** distance comparable to K-Means++ seeding. k-MC2 works particularly well if it is combined
metrics. with Mini-batch K-Means.
We provide distributed implementations of both full-batch and mini-batch K-Means
algorithm. K-Means++, k-MC2 and random initialization are supported. The user
can also choose between **Cosine** and **Squared Euclidean** distance metrics.

View File

@ -224,6 +224,58 @@ class KmeansPlusPlusInitializationOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("KmeansPlusPlusInitialization").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("KmeansPlusPlusInitialization").Device(DEVICE_CPU),
KmeansPlusPlusInitializationOp); KmeansPlusPlusInitializationOp);
// Implementation of one single Markov Chain for the k-MC^2 algorithm
class KMC2ChainInitializationOp : public OpKernel {
public:
explicit KMC2ChainInitializationOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context,
context->MatchSignature({DT_FLOAT, DT_INT64}, {DT_INT64}));
}
void Compute(OpKernelContext* context) override {
const Tensor& distances_tensor = context->input(0);
const Tensor& seed_tensor = context->input(1);
OP_REQUIRES(context, TensorShapeUtils::IsVector(distances_tensor.shape()),
InvalidArgument("Input distances should be a vector."));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()),
InvalidArgument("Input seed should be a scalar."));
const int64 num_points = distances_tensor.dim_size(0);
const int64 seed = seed_tensor.scalar<int64>()();
OP_REQUIRES(context, num_points > 0,
InvalidArgument("Expected distances_tensor.size() > 0."));
random::PhiloxRandom random(seed);
random::SimplePhilox rng(&random);
auto distances = distances_tensor.flat<float>();
// Set the initial state of the Markov chain to be the first candidate.
int64 selected_index = 0;
float selected_distance = distances(selected_index);
// Build a Markov chain of length num_points.
for (int64 i = 1; i < num_points; ++i) {
const float candidate_distance = distances(i);
// Set the next state of the Markov chain to be the candidate with
// probability min(1, candidate_distance/selected_distance).
if (candidate_distance > rng.RandFloat() * selected_distance) {
selected_index = i;
selected_distance = candidate_distance;
}
}
Tensor* output_sampled_index_tensor;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({}),
&output_sampled_index_tensor));
auto output = output_sampled_index_tensor->scalar<int64>();
// Return the last state of the Markov chain as the new center.
output() = selected_index;
}
};
REGISTER_KERNEL_BUILDER(Name("KMC2ChainInitialization").Device(DEVICE_CPU),
KMC2ChainInitializationOp);
// Operator for computing the nearest neighbors for a set of points. // Operator for computing the nearest neighbors for a set of points.
class NearestNeighborsOp : public OpKernel { class NearestNeighborsOp : public OpKernel {
public: public:

View File

@ -116,6 +116,62 @@ RUN_BM_KmeansPlusPlusInitialization(k3RetriesPerSample);
#undef RUN_BM_KmeansPlusPlusInitialization #undef RUN_BM_KmeansPlusPlusInitialization
#undef BENCHMARK_KMEANS_PLUS_PLUS #undef BENCHMARK_KMEANS_PLUS_PLUS
Graph* SetUpKMC2Initialization(int num_points) {
Graph* g = new Graph(OpRegistry::Global());
Tensor distances(DT_FLOAT, TensorShape({num_points}));
Tensor seed(DT_INT64, TensorShape({}));
distances.flat<float>().setRandom();
seed.flat<int64>().setConstant(12345);
TF_CHECK_OK(
NodeBuilder("KMC2ChainInitializationOp", "KMC2ChainInitialization")
.Input(test::graph::Constant(g, distances))
.Input(test::graph::Constant(g, seed))
.Finalize(g, nullptr /* node */));
return g;
}
template <int num_points, int num_to_sample, int num_dims>
void BM_KMC2Initialization(int iters) {
testing::StopTiming();
testing::ItemsProcessed(static_cast<int64>(iters) * num_points * num_dims *
num_to_sample);
testing::UseRealTime();
Graph* g = SetUpKMC2Initialization(num_points);
testing::StartTiming();
test::Benchmark("cpu", g).Run(iters);
}
#define BENCHMARK_KMC2(p, c, d) \
void BM_KMC2Initialization_##p##_##c##_##d(int iters) { \
BM_KMC2Initialization<p, c, d>(iters); \
} \
BENCHMARK(BM_KMC2Initialization_##p##_##c##_##d);
#define RUN_BM_KMC2Initialization \
BENCHMARK_KMC2(k10Points, k2Centers, k100Dim); \
BENCHMARK_KMC2(k10Points, k5Centers, k100Dim); \
BENCHMARK_KMC2(k10Points, k10Centers, k100Dim); \
BENCHMARK_KMC2(k100Points, k10Centers, k100Dim); \
BENCHMARK_KMC2(k100Points, k20Centers, k100Dim); \
BENCHMARK_KMC2(k100Points, k50Centers, k100Dim); \
BENCHMARK_KMC2(k100Points, k100Centers, k100Dim); \
BENCHMARK_KMC2(k1kPoints, k100Centers, k100Dim); \
BENCHMARK_KMC2(k1kPoints, k200Centers, k100Dim); \
BENCHMARK_KMC2(k1kPoints, k500Centers, k100Dim); \
BENCHMARK_KMC2(k1kPoints, k1kCenters, k100Dim); \
BENCHMARK_KMC2(k10kPoints, k100Centers, k100Dim); \
BENCHMARK_KMC2(k10kPoints, k200Centers, k100Dim); \
BENCHMARK_KMC2(k10kPoints, k500Centers, k100Dim); \
BENCHMARK_KMC2(k10kPoints, k1kCenters, k100Dim); \
BENCHMARK_KMC2(k1MPoints, k100Centers, k100Dim); \
BENCHMARK_KMC2(k1MPoints, k200Centers, k100Dim); \
BENCHMARK_KMC2(k1MPoints, k500Centers, k100Dim); \
BENCHMARK_KMC2(k1MPoints, k1kCenters, k100Dim)
RUN_BM_KMC2Initialization;
#undef RUN_BM_KMC2Initialization
#undef BENCHMARK_KMC2
Graph* SetUpNearestNeighbors(int num_dims, int num_points, int num_centers, Graph* SetUpNearestNeighbors(int num_dims, int num_points, int num_centers,
int k) { int k) {
Graph* g = new Graph(OpRegistry::Global()); Graph* g = new Graph(OpRegistry::Global());

View File

@ -44,6 +44,25 @@ num_retries_per_sample: Scalar. For each row that is sampled, this parameter
samples: Matrix of shape (num_to_sample, d). The sampled rows. samples: Matrix of shape (num_to_sample, d). The sampled rows.
)"); )");
REGISTER_OP("KMC2ChainInitialization")
.Input("distances: float32")
.Input("seed: int64")
.Output("index: int64")
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"(
Returns the index of a data point that should be added to the seed set.
Entries in distances are assumed to be squared distances of candidate points to
the already sampled centers in the seed set. The op constructs one Markov chain
of the k-MC^2 algorithm and returns the index of one candidate point to be added
as an additional cluster center.
distances: Vector with squared distances to the closest previously sampled
cluster center for each candidate point.
seed: Scalar. Seed for initializing the random number generator.
index: Scalar with the index of the sampled point.
)");
REGISTER_OP("NearestNeighbors") REGISTER_OP("NearestNeighbors")
.Input("points: float32") .Input("points: float32")
.Input("centers: float32") .Input("centers: float32")

View File

@ -55,6 +55,63 @@ class KmeansPlusPlusInitializationTest(test.TestCase):
self.runTestWithSeed(seed) self.runTestWithSeed(seed)
class KMC2InitializationTest(test.TestCase):
def runTestWithSeed(self, seed):
with self.test_session():
distances = np.zeros(1000).astype(np.float32)
distances[6] = 10e7
distances[4] = 10e3
sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed)
self.assertEquals(sampled_point.eval(), 6)
distances[6] = 0.0
sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed)
self.assertEquals(sampled_point.eval(), 4)
def testBasic(self):
for seed in range(100):
self.runTestWithSeed(seed)
class KMC2InitializationLargeTest(test.TestCase):
def setUp(self):
self._distances = np.zeros(1001)
self._distances[500] = 100.0
self._distances[1000] = 50.0
def testBasic(self):
with self.test_session():
counts = {}
seed = 0
for i in range(50):
sample = clustering_ops.kmc2_chain_initialization(
self._distances, seed + i).eval()
counts[sample] = counts.get(sample, 0) + 1
self.assertEquals(len(counts), 2)
self.assertTrue(500 in counts)
self.assertTrue(1000 in counts)
self.assertGreaterEqual(counts[500], 5)
self.assertGreaterEqual(counts[1000], 5)
class KMC2InitializationCornercaseTest(test.TestCase):
def setUp(self):
self._distances = np.zeros(10)
def runTestWithSeed(self, seed):
with self.test_session():
sampled_point = clustering_ops.kmc2_chain_initialization(
self._distances, seed)
self.assertEquals(sampled_point.eval(), 0)
def testBasic(self):
for seed in range(100):
self.runTestWithSeed(seed)
# A simple test that can be verified by hand. # A simple test that can be verified by hand.
class NearestCentersTest(test.TestCase): class NearestCentersTest(test.TestCase):

View File

@ -50,6 +50,7 @@ COSINE_DISTANCE = 'cosine'
RANDOM_INIT = 'random' RANDOM_INIT = 'random'
KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus' KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus'
KMC2_INIT = 'kmc2'
# The name of the variable holding the cluster centers. Used by the Estimator. # The name of the variable holding the cluster centers. Used by the Estimator.
CLUSTERS_VAR_NAME = 'clusters' CLUSTERS_VAR_NAME = 'clusters'
@ -66,7 +67,8 @@ class KMeans(object):
use_mini_batch=False, use_mini_batch=False,
mini_batch_steps_per_iteration=1, mini_batch_steps_per_iteration=1,
random_seed=0, random_seed=0,
kmeans_plus_plus_num_retries=2): kmeans_plus_plus_num_retries=2,
kmc2_chain_length=200):
"""Creates an object for generating KMeans clustering graph. """Creates an object for generating KMeans clustering graph.
This class implements the following variants of K-means algorithm: This class implements the following variants of K-means algorithm:
@ -95,7 +97,8 @@ class KMeans(object):
exactly like a full-batch version. exactly like a full-batch version.
Args: Args:
inputs: An input tensor or list of input tensors inputs: An input tensor or list of input tensors. It is assumed that the
data points have been previously randomly permuted.
num_clusters: An integer tensor specifying the number of clusters. This num_clusters: An integer tensor specifying the number of clusters. This
argument is ignored if initial_clusters is a tensor or numpy array. argument is ignored if initial_clusters is a tensor or numpy array.
initial_clusters: Specifies the clusters used during initialization. One initial_clusters: Specifies the clusters used during initialization. One
@ -104,6 +107,7 @@ class KMeans(object):
- a function f(inputs, k) that returns up to k centers from `inputs`. - a function f(inputs, k) that returns up to k centers from `inputs`.
- "random": Choose centers randomly from `inputs`. - "random": Choose centers randomly from `inputs`.
- "kmeans_plus_plus": Use kmeans++ to choose centers from `inputs`. - "kmeans_plus_plus": Use kmeans++ to choose centers from `inputs`.
- "kmc2": Use the fast k-MC2 algorithm to choose centers from `inputs`.
In the last three cases, one batch of `inputs` may not yield In the last three cases, one batch of `inputs` may not yield
`num_clusters` centers, in which case initialization will require `num_clusters` centers, in which case initialization will require
multiple batches until enough centers are chosen. In the case of multiple batches until enough centers are chosen. In the case of
@ -121,13 +125,17 @@ class KMeans(object):
additional points to draw from the current distribution before selecting additional points to draw from the current distribution before selecting
the best. If a negative value is specified, a heuristic is used to the best. If a negative value is specified, a heuristic is used to
sample O(log(num_to_sample)) additional points. sample O(log(num_to_sample)) additional points.
kmc2_chain_length: Determines how many candidate points are used by the
k-MC2 algorithm to produce one new cluster centers. If a (mini-)batch
contains less points, one new cluster center is generated from the
(mini-)batch.
Raises: Raises:
ValueError: An invalid argument was passed to initial_clusters or ValueError: An invalid argument was passed to initial_clusters or
distance_metric. distance_metric.
""" """
if isinstance(initial_clusters, str) and initial_clusters not in [ if isinstance(initial_clusters, str) and initial_clusters not in [
RANDOM_INIT, KMEANS_PLUS_PLUS_INIT RANDOM_INIT, KMEANS_PLUS_PLUS_INIT, KMC2_INIT
]: ]:
raise ValueError( raise ValueError(
"Unsupported initialization algorithm '%s'" % initial_clusters) "Unsupported initialization algorithm '%s'" % initial_clusters)
@ -141,6 +149,7 @@ class KMeans(object):
self._mini_batch_steps_per_iteration = int(mini_batch_steps_per_iteration) self._mini_batch_steps_per_iteration = int(mini_batch_steps_per_iteration)
self._random_seed = random_seed self._random_seed = random_seed
self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries
self._kmc2_chain_length = kmc2_chain_length
@classmethod @classmethod
def _distance_graph(cls, inputs, clusters, distance_metric): def _distance_graph(cls, inputs, clusters, distance_metric):
@ -302,9 +311,10 @@ class KMeans(object):
else: else:
cluster_centers_updated = cluster_centers cluster_centers_updated = cluster_centers
update_in_steps = None update_in_steps = None
cluster_counts = (variable_scope.variable( cluster_counts = (
array_ops.ones([num_clusters], dtype=dtypes.int64)) variable_scope.variable(
if self._use_mini_batch else None) array_ops.ones([num_clusters], dtype=dtypes.int64))
if self._use_mini_batch else None)
return (cluster_centers, cluster_centers_initialized, cluster_counts, return (cluster_centers, cluster_centers_initialized, cluster_counts,
cluster_centers_updated, update_in_steps) cluster_centers_updated, update_in_steps)
@ -359,7 +369,7 @@ class KMeans(object):
init_op = _InitializeClustersOpFactory( init_op = _InitializeClustersOpFactory(
self._inputs, num_clusters, initial_clusters, self._distance_metric, self._inputs, num_clusters, initial_clusters, self._distance_metric,
self._random_seed, self._kmeans_plus_plus_num_retries, self._random_seed, self._kmeans_plus_plus_num_retries,
cluster_centers_var, cluster_centers_updated, self._kmc2_chain_length, cluster_centers_var, cluster_centers_updated,
cluster_centers_initialized).op() cluster_centers_initialized).op()
cluster_centers = cluster_centers_var cluster_centers = cluster_centers_var
@ -520,8 +530,9 @@ class KMeans(object):
array_ops.reshape(array_ops.shape(inp)[0], [-1])), array_ops.reshape(array_ops.shape(inp)[0], [-1])),
[-1, 1]), cluster_idx, num_clusters)) [-1, 1]), cluster_idx, num_clusters))
with ops.colocate_with(cluster_centers, ignore_existing=True): with ops.colocate_with(cluster_centers, ignore_existing=True):
new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast( new_clusters_centers = math_ops.add_n(cluster_sums) / (
math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon) math_ops.cast(math_ops.add_n(cluster_counts), cluster_sums[0].dtype) +
epsilon)
if self._clusters_l2_normalized(): if self._clusters_l2_normalized():
new_clusters_centers = nn_impl.l2_normalize(new_clusters_centers, dim=1) new_clusters_centers = nn_impl.l2_normalize(new_clusters_centers, dim=1)
return state_ops.assign(cluster_centers, new_clusters_centers) return state_ops.assign(cluster_centers, new_clusters_centers)
@ -548,9 +559,12 @@ class _InitializeClustersOpFactory(object):
cluster_centers_initialized := true cluster_centers_initialized := true
""" """
# TODO(ccolby): Refactor this class so that kmc2 isn't so much a special case.
def __init__(self, inputs, num_clusters, initial_clusters, distance_metric, def __init__(self, inputs, num_clusters, initial_clusters, distance_metric,
random_seed, kmeans_plus_plus_num_retries, cluster_centers, random_seed, kmeans_plus_plus_num_retries, kmc2_chain_length,
cluster_centers_updated, cluster_centers_initialized): cluster_centers, cluster_centers_updated,
cluster_centers_initialized):
"""Creates an op factory. """Creates an op factory.
Args: Args:
@ -560,6 +574,7 @@ class _InitializeClustersOpFactory(object):
distance_metric: See KMeans constructor. distance_metric: See KMeans constructor.
random_seed: See KMeans constructor. random_seed: See KMeans constructor.
kmeans_plus_plus_num_retries: See KMeans constructor. kmeans_plus_plus_num_retries: See KMeans constructor.
kmc2_chain_length: See KMeans constructor.
cluster_centers: The TF variable holding the initial centers. It may cluster_centers: The TF variable holding the initial centers. It may
already contain some centers when the op is executed. already contain some centers when the op is executed.
cluster_centers_updated: A second TF variable to hold a copy of the cluster_centers_updated: A second TF variable to hold a copy of the
@ -575,6 +590,7 @@ class _InitializeClustersOpFactory(object):
self._distance_metric = distance_metric self._distance_metric = distance_metric
self._random_seed = random_seed self._random_seed = random_seed
self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries
self._kmc2_chain_length = kmc2_chain_length
self._cluster_centers = cluster_centers self._cluster_centers = cluster_centers
self._cluster_centers_updated = cluster_centers_updated self._cluster_centers_updated = cluster_centers_updated
self._cluster_centers_initialized = cluster_centers_initialized self._cluster_centers_initialized = cluster_centers_initialized
@ -604,6 +620,90 @@ class _InitializeClustersOpFactory(object):
math_ops.to_int64(self._num_remaining), self._random_seed, math_ops.to_int64(self._num_remaining), self._random_seed,
self._kmeans_plus_plus_num_retries) self._kmeans_plus_plus_num_retries)
def _kmc2_multiple_centers(self):
"""Adds new initial cluster centers using the k-MC2 algorithm.
In each call to the op, the provided batch is split into subsets based on
the specified `kmc2_chain_length`. On each subset, a single Markov chain of
the k-MC2 algorithm is used to add *one* new center cluster center. If there
are less than `kmc2_chain_length` points in the subset, a single center is
added using one Markov chain on the full input. It is assumed that the
provided batch has previously been randomly permuted. Otherwise, k-MC2 may
return suboptimal centers.
Returns:
An op that adds new cluster centers.
"""
# The op only operates on the first shard of data.
first_shard = self._inputs[0]
# Number of points in the input that can be used.
batch_size = array_ops.shape(first_shard)[0]
# Maximum number of subsets such that the size of each subset is at least
# `kmc2_chain_length`. Final subsets may be larger.
max_to_sample = math_ops.cast(
batch_size / self._kmc2_chain_length, dtype=dtypes.int32)
# We sample at least one new center and at most all remaining centers.
num_to_sample = math_ops.maximum(
math_ops.minimum(self._num_remaining, max_to_sample), 1)
def _cond(i, _):
"""Stopping condition for the while loop."""
return math_ops.less(i, num_to_sample)
def _body(i, _):
"""Body that adds a single new center based on a subset."""
def _sample_random():
"""Returns a random point as a cluster center."""
# By assumption the batch is reshuffled and _sample_random is always
# called for i=0. Hence, we simply return the first point.
new_center = array_ops.reshape(first_shard[0], [1, -1])
if self._distance_metric == COSINE_DISTANCE:
new_center = nn_impl.l2_normalize(new_center, dim=1)
return new_center
def _sample_kmc2_chain():
"""Returns previous centers as well as a new center sampled using k-MC2.
"""
# Extract the subset from the underlying batch.
start = i * self._kmc2_chain_length
end = start + self._kmc2_chain_length
subset = first_shard[start:end]
# Compute the distances from points in the subset to previous centers.
_, distances = gen_clustering_ops.nearest_neighbors(
subset, self._cluster_centers, 1)
# Sample index of new center using k-MC2 Markov chain.
new_center_index = gen_clustering_ops.kmc2_chain_initialization(
array_ops.squeeze(distances), self._random_seed)
# Extract actual new center.
newly_sampled_center = array_ops.reshape(subset[new_center_index],
[1, -1])
# Return concatenation with previously sampled centers.
if self._distance_metric == COSINE_DISTANCE:
newly_sampled_center = nn_impl.l2_normalize(
newly_sampled_center, dim=1)
return array_ops.concat([self._cluster_centers, newly_sampled_center],
0)
# Obtain a random point if there are no previously sampled centers.
# Otherwise, construct a k-MC2 Markov chain.
new_centers = control_flow_ops.cond(
math_ops.equal(self._num_selected, 0), _sample_random,
_sample_kmc2_chain)
# Assign new cluster centers to underlying variable.
assigned_centers = state_ops.assign(
self._cluster_centers, new_centers, validate_shape=False)
if self._cluster_centers_updated is not self._cluster_centers:
assigned_centers = state_ops.assign(
self._cluster_centers_updated,
assigned_centers,
validate_shape=False)
return i + 1, self._num_clusters - array_ops.shape(assigned_centers)[0]
# Add num_to_sample new data points.
_, num_remaining = control_flow_ops.while_loop(_cond, _body, [0, 0])
return num_remaining
def _greedy_batch_sampler(self, sampler): def _greedy_batch_sampler(self, sampler):
# If the input dataset size is smaller than the number of centers # If the input dataset size is smaller than the number of centers
# remaining, choose the entire input dataset as centers. This can happen # remaining, choose the entire input dataset as centers. This can happen
@ -657,7 +757,10 @@ class _InitializeClustersOpFactory(object):
with ops.control_dependencies([ with ops.control_dependencies([
check_ops.assert_positive(self._num_remaining), check_ops.assert_positive(self._num_remaining),
]): ]):
num_now_remaining = self._add_new_centers() if self._initial_clusters == KMC2_INIT:
num_now_remaining = self._kmc2_multiple_centers()
else:
num_now_remaining = self._add_new_centers()
return control_flow_ops.cond( return control_flow_ops.cond(
math_ops.equal(num_now_remaining, 0), math_ops.equal(num_now_remaining, 0),
lambda: state_ops.assign(self._cluster_centers_initialized, True), lambda: state_ops.assign(self._cluster_centers_initialized, True),

View File

@ -37,6 +37,7 @@ See the @{$python/contrib.framework} guide.
@@arg_scope @@arg_scope
@@add_arg_scope @@add_arg_scope
@@current_arg_scope
@@has_arg_scope @@has_arg_scope
@@arg_scoped_arguments @@arg_scoped_arguments

View File

@ -67,6 +67,7 @@ from tensorflow.python.util import tf_decorator
__all__ = ['arg_scope', __all__ = ['arg_scope',
'add_arg_scope', 'add_arg_scope',
'current_arg_scope',
'has_arg_scope', 'has_arg_scope',
'arg_scoped_arguments'] 'arg_scoped_arguments']
@ -83,7 +84,7 @@ def _get_arg_stack():
return _ARGSTACK return _ARGSTACK
def _current_arg_scope(): def current_arg_scope():
stack = _get_arg_stack() stack = _get_arg_stack()
return stack[-1] return stack[-1]
@ -144,7 +145,7 @@ def arg_scope(list_ops_or_scope, **kwargs):
raise TypeError('list_ops_or_scope must either be a list/tuple or reused' raise TypeError('list_ops_or_scope must either be a list/tuple or reused'
'scope (i.e. dict)') 'scope (i.e. dict)')
try: try:
current_scope = _current_arg_scope().copy() current_scope = current_arg_scope().copy()
for op in list_ops_or_scope: for op in list_ops_or_scope:
key_op = _key_op(op) key_op = _key_op(op)
if not has_arg_scope(op): if not has_arg_scope(op):
@ -172,7 +173,7 @@ def add_arg_scope(func):
A tuple with the decorated function func_with_args(). A tuple with the decorated function func_with_args().
""" """
def func_with_args(*args, **kwargs): def func_with_args(*args, **kwargs):
current_scope = _current_arg_scope() current_scope = current_arg_scope()
current_args = kwargs current_args = kwargs
key_func = _key_op(func) key_func = _key_op(func)
if key_func in current_scope: if key_func in current_scope:

View File

@ -442,7 +442,8 @@ def read_keyed_batch_features(file_pattern,
feature_queue_capacity=100, feature_queue_capacity=100,
num_enqueue_threads=2, num_enqueue_threads=2,
parse_fn=None, parse_fn=None,
name=None): name=None,
read_batch_size=None):
"""Adds operations to read, queue, batch and parse `Example` protos. """Adds operations to read, queue, batch and parse `Example` protos.
Given file pattern (or list of files), will setup a queue for file names, Given file pattern (or list of files), will setup a queue for file names,
@ -482,6 +483,8 @@ def read_keyed_batch_features(file_pattern,
parse_fn: Parsing function, takes `Example` Tensor returns parsed parse_fn: Parsing function, takes `Example` Tensor returns parsed
representation. If `None`, no parsing is done. representation. If `None`, no parsing is done.
name: Name of resulting op. name: Name of resulting op.
read_batch_size: An int or scalar `Tensor` specifying the number of
records to read at once. If `None`, defaults to `batch_size`.
Returns: Returns:
Returns tuple of: Returns tuple of:
@ -493,6 +496,7 @@ def read_keyed_batch_features(file_pattern,
""" """
with ops.name_scope(name, 'read_batch_features', [file_pattern]) as scope: with ops.name_scope(name, 'read_batch_features', [file_pattern]) as scope:
if read_batch_size is None: read_batch_size = batch_size
keys, examples = read_keyed_batch_examples( keys, examples = read_keyed_batch_examples(
file_pattern, file_pattern,
batch_size, batch_size,
@ -501,7 +505,7 @@ def read_keyed_batch_features(file_pattern,
num_epochs=num_epochs, num_epochs=num_epochs,
queue_capacity=queue_capacity, queue_capacity=queue_capacity,
num_threads=reader_num_threads, num_threads=reader_num_threads,
read_batch_size=batch_size, read_batch_size=read_batch_size,
parse_fn=parse_fn, parse_fn=parse_fn,
name=scope) name=scope)
# Parse the example. # Parse the example.
@ -727,7 +731,8 @@ def read_batch_features(file_pattern,
reader_num_threads=1, reader_num_threads=1,
num_enqueue_threads=2, num_enqueue_threads=2,
parse_fn=None, parse_fn=None,
name=None): name=None,
read_batch_size=None):
"""Adds operations to read, queue, batch and parse `Example` protos. """Adds operations to read, queue, batch and parse `Example` protos.
Given file pattern (or list of files), will setup a queue for file names, Given file pattern (or list of files), will setup a queue for file names,
@ -768,6 +773,8 @@ def read_batch_features(file_pattern,
parse_fn: Parsing function, takes `Example` Tensor returns parsed parse_fn: Parsing function, takes `Example` Tensor returns parsed
representation. If `None`, no parsing is done. representation. If `None`, no parsing is done.
name: Name of resulting op. name: Name of resulting op.
read_batch_size: An int or scalar `Tensor` specifying the number of
records to read at once. If `None`, defaults to `batch_size`.
Returns: Returns:
A dict of `Tensor` or `SparseTensor` objects for each in `features`. A dict of `Tensor` or `SparseTensor` objects for each in `features`.
@ -786,6 +793,7 @@ def read_batch_features(file_pattern,
reader_num_threads=reader_num_threads, reader_num_threads=reader_num_threads,
feature_queue_capacity=feature_queue_capacity, feature_queue_capacity=feature_queue_capacity,
num_enqueue_threads=num_enqueue_threads, num_enqueue_threads=num_enqueue_threads,
read_batch_size=read_batch_size,
parse_fn=parse_fn, parse_fn=parse_fn,
name=name) name=name)
return features return features

View File

@ -502,6 +502,7 @@ $(wildcard tensorflow/core/platform/google/*) \
$(wildcard tensorflow/core/platform/google/*/*) \ $(wildcard tensorflow/core/platform/google/*/*) \
$(wildcard tensorflow/core/platform/jpeg.*) \ $(wildcard tensorflow/core/platform/jpeg.*) \
$(wildcard tensorflow/core/platform/png.*) \ $(wildcard tensorflow/core/platform/png.*) \
$(wildcard tensorflow/core/platform/s3/*) \
$(wildcard tensorflow/core/platform/stream_executor.*) \ $(wildcard tensorflow/core/platform/stream_executor.*) \
$(wildcard tensorflow/core/platform/windows/*) \ $(wildcard tensorflow/core/platform/windows/*) \
$(wildcard tensorflow/core/user_ops/*.cu.cc) \ $(wildcard tensorflow/core/user_ops/*.cu.cc) \

View File

@ -20,11 +20,11 @@ DOWNLOADS_DIR=tensorflow/contrib/makefile/downloads
BZL_FILE_PATH=tensorflow/workspace.bzl BZL_FILE_PATH=tensorflow/workspace.bzl
EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
GEMMLOWP_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)"
GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz"
NSYNC_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
PROTOBUF_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" PROTOBUF_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
RE2_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" RE2_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
# TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, # TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64,

File diff suppressed because it is too large Load Diff

View File

@ -1101,7 +1101,7 @@ class StreamingPrecisionTest(test.TestCase):
predictions = random_ops.random_uniform( predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
labels = random_ops.random_uniform( labels = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
precision, update_op = metrics.streaming_precision(predictions, labels) precision, update_op = metrics.streaming_precision(predictions, labels)
with self.test_session() as sess: with self.test_session() as sess:
@ -1265,7 +1265,7 @@ class StreamingRecallTest(test.TestCase):
predictions = random_ops.random_uniform( predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
labels = random_ops.random_uniform( labels = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
recall, update_op = metrics.streaming_recall(predictions, labels) recall, update_op = metrics.streaming_recall(predictions, labels)
with self.test_session() as sess: with self.test_session() as sess:
@ -1388,7 +1388,7 @@ class StreamingFPRTest(test.TestCase):
predictions = random_ops.random_uniform( predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
labels = random_ops.random_uniform( labels = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
fpr, update_op = metrics.streaming_false_positive_rate( fpr, update_op = metrics.streaming_false_positive_rate(
predictions, labels) predictions, labels)
@ -1516,7 +1516,7 @@ class StreamingFNRTest(test.TestCase):
predictions = random_ops.random_uniform( predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
labels = random_ops.random_uniform( labels = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
fnr, update_op = metrics.streaming_false_negative_rate( fnr, update_op = metrics.streaming_false_negative_rate(
predictions, labels) predictions, labels)
@ -1737,7 +1737,7 @@ class StreamingAUCTest(test.TestCase):
predictions = random_ops.random_uniform( predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
labels = random_ops.random_uniform( labels = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
auc, update_op = metrics.streaming_auc(predictions, labels) auc, update_op = metrics.streaming_auc(predictions, labels)
with self.test_session() as sess: with self.test_session() as sess:
@ -2009,7 +2009,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
predictions = random_ops.random_uniform( predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
labels = random_ops.random_uniform( labels = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
specificity, update_op = metrics.streaming_specificity_at_sensitivity( specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, sensitivity=0.7) predictions, labels, sensitivity=0.7)
@ -2271,7 +2271,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
predictions = random_ops.random_uniform( predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
labels = random_ops.random_uniform( labels = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
thresholds = [0, 0.5, 1.0] thresholds = [0, 0.5, 1.0]
prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, prec, prec_op = metrics.streaming_precision_at_thresholds(predictions,
labels, labels,
@ -2282,12 +2282,14 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
with self.test_session() as sess: with self.test_session() as sess:
sess.run(variables.local_variables_initializer()) sess.run(variables.local_variables_initializer())
# Run several updates, then verify idempotency. # Run several updates.
sess.run([prec_op, rec_op]) for _ in range(10):
sess.run([prec_op, rec_op])
# Then verify idempotency.
initial_prec = prec.eval() initial_prec = prec.eval()
initial_rec = rec.eval() initial_rec = rec.eval()
for _ in range(10): for _ in range(10):
sess.run([prec_op, rec_op])
self.assertAllClose(initial_prec, prec.eval()) self.assertAllClose(initial_prec, prec.eval())
self.assertAllClose(initial_rec, rec.eval()) self.assertAllClose(initial_rec, rec.eval())
@ -2361,14 +2363,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
rec, rec_op = metrics.streaming_recall_at_thresholds( rec, rec_op = metrics.streaming_recall_at_thresholds(
predictions, labels, thresholds, weights=weights) predictions, labels, thresholds, weights=weights)
[prec_low, prec_high] = array_ops.split( prec_low = prec[0]
value=prec, num_or_size_splits=2, axis=0) prec_high = prec[1]
prec_low = array_ops.reshape(prec_low, shape=()) rec_low = rec[0]
prec_high = array_ops.reshape(prec_high, shape=()) rec_high = rec[1]
[rec_low, rec_high] = array_ops.split(
value=rec, num_or_size_splits=2, axis=0)
rec_low = array_ops.reshape(rec_low, shape=())
rec_high = array_ops.reshape(rec_high, shape=())
sess.run(variables.local_variables_initializer()) sess.run(variables.local_variables_initializer())
sess.run([prec_op, rec_op]) sess.run([prec_op, rec_op])
@ -2391,14 +2389,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
rec, rec_op = metrics.streaming_recall_at_thresholds( rec, rec_op = metrics.streaming_recall_at_thresholds(
predictions, labels, thresholds, weights=weights) predictions, labels, thresholds, weights=weights)
[prec_low, prec_high] = array_ops.split( prec_low = prec[0]
value=prec, num_or_size_splits=2, axis=0) prec_high = prec[1]
prec_low = array_ops.reshape(prec_low, shape=()) rec_low = rec[0]
prec_high = array_ops.reshape(prec_high, shape=()) rec_high = rec[1]
[rec_low, rec_high] = array_ops.split(
value=rec, num_or_size_splits=2, axis=0)
rec_low = array_ops.reshape(rec_low, shape=())
rec_high = array_ops.reshape(rec_high, shape=())
sess.run(variables.local_variables_initializer()) sess.run(variables.local_variables_initializer())
sess.run([prec_op, rec_op]) sess.run([prec_op, rec_op])
@ -2420,10 +2414,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels,
thresholds) thresholds)
[prec_low, prec_high] = array_ops.split( prec_low = prec[0]
value=prec, num_or_size_splits=2, axis=0) prec_high = prec[1]
[rec_low, rec_high] = array_ops.split( rec_low = rec[0]
value=rec, num_or_size_splits=2, axis=0) rec_high = rec[1]
sess.run(variables.local_variables_initializer()) sess.run(variables.local_variables_initializer())
sess.run([prec_op, rec_op]) sess.run([prec_op, rec_op])
@ -2562,7 +2556,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
predictions = random_ops.random_uniform( predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
labels = random_ops.random_uniform( labels = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
thresholds = [0, 0.5, 1.0] thresholds = [0, 0.5, 1.0]
fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds(
predictions, labels, thresholds) predictions, labels, thresholds)
@ -2794,7 +2788,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
predictions = random_ops.random_uniform( predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
labels = random_ops.random_uniform( labels = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
thresholds = [0, 0.5, 1.0] thresholds = [0, 0.5, 1.0]
fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds(
predictions, labels, thresholds) predictions, labels, thresholds)

View File

@ -13,6 +13,34 @@ py_library(
deps = [], deps = [],
) )
py_library(
name = "graph_matcher",
srcs = [
"python/graph_matcher.py",
],
srcs_version = "PY2AND3",
deps = [],
)
py_test(
name = "graph_matcher_test",
size = "small",
srcs = ["python/graph_matcher_test.py"],
srcs_version = "PY2AND3",
deps = [
":graph_matcher",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
],
)
py_library( py_library(
name = "input_to_ops", name = "input_to_ops",
srcs = ["python/input_to_ops.py"], srcs = ["python/input_to_ops.py"],
@ -43,6 +71,7 @@ py_library(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":common", ":common",
":graph_matcher",
":input_to_ops", ":input_to_ops",
"//tensorflow/contrib/graph_editor:graph_editor_py", "//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
@ -58,6 +87,7 @@ py_test(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":fold_batch_norms", ":fold_batch_norms",
":graph_matcher",
"//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/layers:layers_py",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",
"//tensorflow/python:dtypes", "//tensorflow/python:dtypes",
@ -147,10 +177,11 @@ py_test(
py_test( py_test(
name = "quantize_parameterized_test", name = "quantize_parameterized_test",
size = "medium", size = "large",
srcs = ["python/quantize_parameterized_test.py"], srcs = ["python/quantize_parameterized_test.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
deps = [ deps = [
":fold_batch_norms",
":quantize", ":quantize",
"//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/layers:layers_py",
"//tensorflow/python:array_ops", "//tensorflow/python:array_ops",

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for tensorflow.quantized.mangle.copy_graph.""" """Tests for copy_graph."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division

View File

@ -21,7 +21,9 @@ from __future__ import print_function
import re import re
from tensorflow.contrib import graph_editor from tensorflow.contrib import graph_editor
from tensorflow.contrib.quantize.python import common from tensorflow.contrib.quantize.python import common
from tensorflow.contrib.quantize.python import graph_matcher
from tensorflow.contrib.quantize.python import input_to_ops from tensorflow.contrib.quantize.python import input_to_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn from tensorflow.python.ops import nn
@ -29,7 +31,7 @@ from tensorflow.python.ops import nn_ops
def FoldBatchNorms(graph): def FoldBatchNorms(graph):
"""Finds batch norm layers in the graph, folds them into preceding layers. """Finds batch norm layers and folds them into preceding layers.
Folding only affects the following layers: Conv2D, fully connected, depthwise Folding only affects the following layers: Conv2D, fully connected, depthwise
convolution. convolution.
@ -40,10 +42,269 @@ def FoldBatchNorms(graph):
Raises: Raises:
ValueError: When batch norm folding fails. ValueError: When batch norm folding fails.
""" """
# Fail immediately when the graph contains unsupported fused batch norm ops. _FoldFusedBatchNorms(graph)
if any(op for op in graph.get_operations() if op.type == 'FusedBatchNorm'): _FoldUnfusedBatchNorms(graph)
raise ValueError('Fused batch norm is not supported')
def _FoldFusedBatchNorms(graph):
"""Finds fused batch norm layers and folds them into preceding layers.
Folding only affects the following layers: Conv2D, fully connected, depthwise
convolution.
Args:
graph: Graph to walk and modify.
Raises:
ValueError: When batch norm folding fails.
"""
for match in _FindFusedBatchNorms(graph):
scope, sep, _ = match.layer_op.name.rpartition('/')
# Make sure new ops are added to `graph` and put on the same device as
# `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope
# named `scope`. Otherwise, TF creates a unique scope whose name starts with
# `scope`.
with graph.as_default(), graph.name_scope(scope + sep), ops.device(
match.bn_op.device):
# new weights = old weights * gamma / sqrt(variance + epsilon)
# new biases = -mean * gamma / sqrt(variance + epsilon) + beta
multiplier_tensor = match.gamma_tensor * math_ops.rsqrt(
match.variance_tensor + match.bn_op.get_attr('epsilon'))
bias_tensor = math_ops.subtract(
match.beta_tensor, match.mean_tensor * multiplier_tensor, name='bias')
# The shape of depthwise weights is different, so we need to reshape the
# multiplier_tensor to ensure that the scaled_weight_tensor has the
# expected shape.
if match.layer_op.type == 'DepthwiseConv2dNative':
new_shape = [
match.weight_tensor.get_shape().as_list()[2],
match.weight_tensor.get_shape().as_list()[3]
]
multiplier_tensor = array_ops.reshape(
multiplier_tensor, new_shape, name='scale_reshape')
# TODO(suharshs): This naming of the following ops needs to carefully
# follow the naming expected by quantize.py. Generalize the quantize code
# to not require these delicate naming conventions.
scaled_weight_tensor = math_ops.multiply(
match.weight_tensor, multiplier_tensor, name='mul_fold')
new_layer_tensor = _CloneWithNewOperands(
match.layer_op, match.input_tensor, scaled_weight_tensor)
bias_add_tensor = math_ops.add(
new_layer_tensor, bias_tensor, name='add_fold')
nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor,
match.output_tensor)
if nodes_modified_count != 1:
raise ValueError(
'Unexpected inputs to op: %s' % match.output_tensor.name)
def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor):
"""Clones layer_op with input_tensor and weight_tensor as new inputs."""
new_layer_name = layer_op.name.split('/')[-1] + '_Fold'
if layer_op.type == 'Conv2D':
return nn_ops.conv2d(
input_tensor,
weight_tensor,
strides=layer_op.get_attr('strides'),
padding=layer_op.get_attr('padding'),
use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'),
data_format=layer_op.get_attr('data_format'),
name=new_layer_name)
elif layer_op.type == 'MatMul':
return math_ops.matmul(
input_tensor,
weight_tensor,
transpose_a=layer_op.get_attr('transpose_a'),
transpose_b=layer_op.get_attr('transpose_b'),
name=new_layer_name)
elif layer_op.type == 'DepthwiseConv2dNative':
return nn.depthwise_conv2d(
input_tensor,
weight_tensor,
strides=layer_op.get_attr('strides'),
padding=layer_op.get_attr('padding'),
name=new_layer_name)
else:
raise ValueError('Cannot handle operation of type: %s' % layer_op.type)
def _FindFusedBatchNorms(graph):
"""Finds all ops and tensors related to found FusedBatchNorms.
Args:
graph: Graph to inspect.
Yields:
_FusedBatchNormMatches.
"""
input_pattern = graph_matcher.OpTypePattern('*')
weight_pattern = graph_matcher.OpTypePattern('*')
gamma_pattern = graph_matcher.OpTypePattern('*')
beta_pattern = graph_matcher.OpTypePattern('*')
mean_pattern = graph_matcher.OpTypePattern('*')
variance_pattern = graph_matcher.OpTypePattern('*')
conv_pattern = graph_matcher.OpTypePattern(
'Conv2D|DepthwiseConv2dNative', inputs=[input_pattern, weight_pattern])
# MatMul has a Reshape between it and FusedBatchNorm.
matmul_pattern = graph_matcher.OpTypePattern(
'MatMul', inputs=[input_pattern, weight_pattern])
matmul_reshape_pattern = graph_matcher.OpTypePattern(
'Reshape', inputs=[matmul_pattern,
graph_matcher.OpTypePattern('*')])
conv_batch_norm_pattern = graph_matcher.OpTypePattern(
'FusedBatchNorm',
inputs=[
conv_pattern, gamma_pattern, beta_pattern, mean_pattern,
variance_pattern
])
matmul_batch_norm_pattern = graph_matcher.OpTypePattern(
'FusedBatchNorm',
inputs=[
matmul_reshape_pattern, gamma_pattern, beta_pattern, mean_pattern,
variance_pattern
])
matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern(
'Reshape',
inputs=[matmul_batch_norm_pattern,
graph_matcher.OpTypePattern('*')])
conv_matcher = graph_matcher.GraphMatcher(conv_batch_norm_pattern)
matmul_matcher = graph_matcher.GraphMatcher(matmul_bn_output_reshape_pattern)
def _GetCommonTensors(match_result):
"""Gets tensors needed for FusedBatchNormMatch from match_result."""
input_tensor = match_result.get_tensor(input_pattern)
weight_tensor = match_result.get_tensor(weight_pattern)
gamma_tensor = match_result.get_tensor(gamma_pattern)
beta_tensor = match_result.get_tensor(beta_pattern)
# FusedBatchNorm in training is different from that in inference. It takes
# empty 'mean' and empty 'variance', and produces the mean and the variance
# of the batch. Therefore, when is_training is true, mean_tensor and
# variance_tensor point to 1st and 2nd (0-based) output of bn_op,
# respectively; when is_training is false, they point to bn_op's inputs.
is_training = bn_op.get_attr('is_training')
if is_training:
mean_tensor = bn_op.outputs[1]
variance_tensor = bn_op.outputs[2]
else:
mean_tensor = match_result.get_tensor(mean_pattern)
variance_tensor = match_result.get_tensor(variance_pattern)
return (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
variance_tensor)
for match_result in conv_matcher.match_graph(graph):
layer_op = match_result.get_op(conv_pattern)
bn_op = match_result.get_op(conv_batch_norm_pattern)
# In the case of convolution the output_tensor is the output of bn_op.
output_tensor = bn_op.outputs[0]
(input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
variance_tensor) = _GetCommonTensors(match_result)
yield _FusedBatchNormMatch(
layer_op=layer_op,
bn_op=bn_op,
output_tensor=output_tensor,
input_tensor=input_tensor,
weight_tensor=weight_tensor,
gamma_tensor=gamma_tensor,
beta_tensor=beta_tensor,
mean_tensor=mean_tensor,
variance_tensor=variance_tensor)
for match_result in matmul_matcher.match_graph(graph):
layer_op = match_result.get_op(matmul_pattern)
bn_op = match_result.get_op(matmul_batch_norm_pattern)
# In the MatMul case, the output of batch norm is reshaped back into a
# 2D tensor, so the output_tensor is the output of the Reshape op.
output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern)
output_tensor = output_reshape_op.outputs[0]
(input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
variance_tensor) = _GetCommonTensors(match_result)
yield _FusedBatchNormMatch(
layer_op=layer_op,
bn_op=bn_op,
output_tensor=output_tensor,
input_tensor=input_tensor,
weight_tensor=weight_tensor,
gamma_tensor=gamma_tensor,
beta_tensor=beta_tensor,
mean_tensor=mean_tensor,
variance_tensor=variance_tensor)
class _FusedBatchNormMatch(object):
"""Contains all information related to a found FusedBatchNorm."""
def __init__(self, layer_op, bn_op, output_tensor, input_tensor,
weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
variance_tensor):
self._layer_op = layer_op
self._bn_op = bn_op
self._output_tensor = output_tensor
self._input_tensor = input_tensor
self._weight_tensor = weight_tensor
self._gamma_tensor = gamma_tensor
self._beta_tensor = beta_tensor
self._mean_tensor = mean_tensor
self._variance_tensor = variance_tensor
@property
def layer_op(self):
return self._layer_op
@property
def bn_op(self):
return self._bn_op
@property
def output_tensor(self):
return self._output_tensor
@property
def input_tensor(self):
return self._input_tensor
@property
def weight_tensor(self):
return self._weight_tensor
@property
def gamma_tensor(self):
return self._gamma_tensor
@property
def beta_tensor(self):
return self._beta_tensor
@property
def mean_tensor(self):
return self._mean_tensor
@property
def variance_tensor(self):
return self._variance_tensor
def _FoldUnfusedBatchNorms(graph):
"""Finds unfused batch norm layers and folds them into preceding layers.
Folding only affects the following layers: Conv2D, fully connected, depthwise
convolution.
Args:
graph: Graph to walk and modify.
Raises:
ValueError: When batch norm folding fails.
"""
input_to_ops_map = input_to_ops.InputToOps(graph) input_to_ops_map = input_to_ops.InputToOps(graph)
for bn in common.BatchNormGroups(graph): for bn in common.BatchNormGroups(graph):

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import copy
from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.quantize.python import fold_batch_norms from tensorflow.contrib.quantize.python import fold_batch_norms
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -35,57 +34,32 @@ conv2d = layers.conv2d
fully_connected = layers.fully_connected fully_connected = layers.fully_connected
separable_conv2d = layers.separable_conv2d separable_conv2d = layers.separable_conv2d
_DEFAULT_BATCH_NORM_PARAMS = {
'center': True,
'scale': True,
'decay': 1.0 - 0.003,
'fused': False,
}
# TODO(suharshs): Use parameterized test once OSS TF supports it. # TODO(suharshs): Use parameterized test once OSS TF supports it.
class FoldBatchNormsTest(test_util.TensorFlowTestCase): class FoldBatchNormsTest(test_util.TensorFlowTestCase):
def _RunTestOverParameters(self, test_fn): def _RunTestOverParameters(self, test_fn):
parameters_list = [ parameters_list = [
# (relu, relu_op_name, with_bypass) # (relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm)
(nn_ops.relu6, 'Relu6', False), (nn_ops.relu6, 'Relu6', False, False, False),
(nn_ops.relu, 'Relu', False), (nn_ops.relu, 'Relu', False, False, False),
(nn_ops.relu6, 'Relu6', True), (nn_ops.relu6, 'Relu6', True, False, False),
(nn_ops.relu, 'Relu', True), (nn_ops.relu, 'Relu', True, False, False),
(nn_ops.relu6, 'Relu6', False, True, False),
(nn_ops.relu, 'Relu', False, True, False),
(nn_ops.relu6, 'Relu6', True, True, False),
(nn_ops.relu, 'Relu', True, True, False),
# Fused batch norm always has scaling enabled.
(nn_ops.relu6, 'Relu6', False, True, True),
(nn_ops.relu, 'Relu', False, True, True),
(nn_ops.relu6, 'Relu6', True, True, True),
(nn_ops.relu, 'Relu', True, True, True),
] ]
for parameters in parameters_list: for params in parameters_list:
test_fn(parameters[0], parameters[1], parameters[2]) test_fn(params[0], params[1], params[2], params[3], params[4])
def testFailsWithFusedBatchNorm(self): def _TestFoldConv2d(self, relu, relu_op_name, with_bypass, has_scaling,
self._RunTestOverParameters(self._TestFailsWithFusedBatchNorm) fused_batch_norm):
def _TestFailsWithFusedBatchNorm(self, relu, relu_op_name, with_bypass):
"""Tests that batch norm fails when fused batch norm ops are present."""
g = ops.Graph()
with g.as_default():
batch_size, height, width = 5, 128, 128
inputs = array_ops.zeros((batch_size, height, width, 3))
out_depth = 3 if with_bypass else 32
stride = 1 if with_bypass else 2
activation_fn = None if with_bypass else relu
batch_norm_params = _DEFAULT_BATCH_NORM_PARAMS.copy()
batch_norm_params['fused'] = True
scope = 'test/test2' if with_bypass else 'test'
node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
normalizer_fn=batch_norm,
normalizer_params=batch_norm_params,
scope=scope)
if with_bypass:
node = math_ops.add(inputs, node, name='test/Add')
relu(node, name='test/' + relu_op_name)
with self.assertRaises(ValueError):
fold_batch_norms.FoldBatchNorms(g)
def _TestFoldConv2d(self, relu, relu_op_name, with_bypass):
"""Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*.
Args: Args:
@ -93,6 +67,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
relu_op_name: String, name of the Relu* operation. relu_op_name: String, name of the Relu* operation.
with_bypass: Bool, when true there is an extra connection added from with_bypass: Bool, when true there is an extra connection added from
inputs to just before Relu*. inputs to just before Relu*.
has_scaling: Bool, when true the batch norm has scaling.
fused_batch_norm: Bool, when true the batch norm is fused.
""" """
g = ops.Graph() g = ops.Graph()
with g.as_default(): with g.as_default():
@ -102,12 +78,17 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
stride = 1 if with_bypass else 2 stride = 1 if with_bypass else 2
activation_fn = None if with_bypass else relu activation_fn = None if with_bypass else relu
scope = 'test/test2' if with_bypass else 'test' scope = 'test/test2' if with_bypass else 'test'
node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', node = conv2d(
weights_initializer=self._WeightInit(0.09), inputs,
activation_fn=activation_fn, out_depth, [5, 5],
normalizer_fn=batch_norm, stride=stride,
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, padding='SAME',
scope=scope) weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(
scale=has_scaling, fused=fused_batch_norm),
scope=scope)
if with_bypass: if with_bypass:
node = math_ops.add(inputs, node, name='test/Add') node = math_ops.add(inputs, node, name='test/Add')
relu(node, name='test/' + relu_op_name) relu(node, name='test/' + relu_op_name)
@ -116,9 +97,10 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
folded_mul = g.get_operation_by_name(scope + '/mul_fold') folded_mul = g.get_operation_by_name(scope + '/mul_fold')
self.assertEqual(folded_mul.type, 'Mul') self.assertEqual(folded_mul.type, 'Mul')
self._AssertInputOpsAre(folded_mul, self._AssertInputOpsAre(folded_mul, [
[scope + '/weights/read', scope + '/weights/read',
scope + '/BatchNorm/batchnorm/mul']) self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm)
])
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold']) self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold'])
folded_conv = g.get_operation_by_name(scope + '/convolution_Fold') folded_conv = g.get_operation_by_name(scope + '/convolution_Fold')
@ -129,16 +111,18 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
folded_add = g.get_operation_by_name(scope + '/add_fold') folded_add = g.get_operation_by_name(scope + '/add_fold')
self.assertEqual(folded_add.type, 'Add') self.assertEqual(folded_add.type, 'Add')
self._AssertInputOpsAre(folded_add, self._AssertInputOpsAre(folded_add, [
[scope + '/convolution_Fold', scope + '/convolution_Fold',
scope + '/BatchNorm/batchnorm/sub']) self._BathNormBiasName(scope, fused_batch_norm)
])
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
self._AssertOutputGoesToOps(folded_add, g, output_op_names) self._AssertOutputGoesToOps(folded_add, g, output_op_names)
def testFoldConv2d(self): def testFoldConv2d(self):
self._RunTestOverParameters(self._TestFoldConv2d) self._RunTestOverParameters(self._TestFoldConv2d)
def _TestFoldConv2dUnknownShape(self, relu, relu_op_name, with_bypass): def _TestFoldConv2dUnknownShape(self, relu, relu_op_name, with_bypass,
has_scaling, fused_batch_norm):
"""Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*.
Tests that folding works even with an input shape where some dimensions are Tests that folding works even with an input shape where some dimensions are
@ -149,6 +133,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
relu_op_name: String, name of the Relu* operation. relu_op_name: String, name of the Relu* operation.
with_bypass: Bool, when true there is an extra connection added from with_bypass: Bool, when true there is an extra connection added from
inputs to just before Relu*. inputs to just before Relu*.
has_scaling: Bool, when true the batch norm has scaling.
fused_batch_norm: Bool, when true the batch norm is fused.
""" """
g = ops.Graph() g = ops.Graph()
with g.as_default(): with g.as_default():
@ -165,7 +151,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
weights_initializer=self._WeightInit(0.09), weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn, activation_fn=activation_fn,
normalizer_fn=batch_norm, normalizer_fn=batch_norm,
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, normalizer_params=self._BatchNormParams(
scale=has_scaling, fused=fused_batch_norm),
scope=scope) scope=scope)
if with_bypass: if with_bypass:
node = math_ops.add(inputs, node, name='test/Add') node = math_ops.add(inputs, node, name='test/Add')
@ -176,7 +163,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
folded_mul = g.get_operation_by_name(scope + '/mul_fold') folded_mul = g.get_operation_by_name(scope + '/mul_fold')
self.assertEqual(folded_mul.type, 'Mul') self.assertEqual(folded_mul.type, 'Mul')
self._AssertInputOpsAre(folded_mul, [ self._AssertInputOpsAre(folded_mul, [
scope + '/weights/read', scope + '/BatchNorm/batchnorm/mul' scope + '/weights/read',
self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm)
]) ])
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold']) self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold'])
@ -188,7 +176,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
folded_add = g.get_operation_by_name(scope + '/add_fold') folded_add = g.get_operation_by_name(scope + '/add_fold')
self.assertEqual(folded_add.type, 'Add') self.assertEqual(folded_add.type, 'Add')
self._AssertInputOpsAre(folded_add, [ self._AssertInputOpsAre(folded_add, [
scope + '/convolution_Fold', scope + '/BatchNorm/batchnorm/sub' scope + '/convolution_Fold',
self._BathNormBiasName(scope, fused_batch_norm)
]) ])
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
self._AssertOutputGoesToOps(folded_add, g, output_op_names) self._AssertOutputGoesToOps(folded_add, g, output_op_names)
@ -196,62 +185,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
def testFoldConv2dUnknownShape(self): def testFoldConv2dUnknownShape(self):
self._RunTestOverParameters(self._TestFoldConv2dUnknownShape) self._RunTestOverParameters(self._TestFoldConv2dUnknownShape)
def _TestFoldConv2dWithoutScale(self, relu, relu_op_name, with_bypass): def _TestFoldFullyConnectedLayer(self, relu, relu_op_name, with_bypass,
"""Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. has_scaling, fused_batch_norm):
Args:
relu: Callable that returns an Operation, a factory method for the Relu*.
relu_op_name: String, name of the Relu* operation.
with_bypass: Bool, when true there is an extra connection added from
inputs to just before Relu*.
"""
g = ops.Graph()
with g.as_default():
batch_size, height, width = 5, 128, 128
inputs = array_ops.zeros((batch_size, height, width, 3))
out_depth = 3 if with_bypass else 32
stride = 1 if with_bypass else 2
activation_fn = None if with_bypass else relu
bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS)
bn_params['scale'] = False
scope = 'test/test2' if with_bypass else 'test'
node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
normalizer_fn=batch_norm,
normalizer_params=bn_params,
scope=scope)
if with_bypass:
node = math_ops.add(inputs, node, name='test/Add')
relu(node, name='test/' + relu_op_name)
fold_batch_norms.FoldBatchNorms(g)
folded_mul = g.get_operation_by_name(scope + '/mul_fold')
self.assertEqual(folded_mul.type, 'Mul')
self._AssertInputOpsAre(folded_mul,
[scope + '/weights/read',
scope + '/BatchNorm/batchnorm/Rsqrt'])
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold'])
folded_conv = g.get_operation_by_name(scope + '/convolution_Fold')
self.assertEqual(folded_conv.type, 'Conv2D')
self._AssertInputOpsAre(folded_conv,
[scope + '/mul_fold', inputs.op.name])
self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold'])
folded_add = g.get_operation_by_name(scope + '/add_fold')
self.assertEqual(folded_add.type, 'Add')
self._AssertInputOpsAre(folded_add,
[scope + '/convolution_Fold',
scope + '/BatchNorm/batchnorm/sub'])
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
def testFoldConv2dWithoutScale(self):
self._RunTestOverParameters(self._TestFoldConv2dWithoutScale)
def _TestFoldFullyConnectedLayer(self, relu, relu_op_name, with_bypass):
"""Tests folding cases: inputs -> FC with batch norm -> Relu*. """Tests folding cases: inputs -> FC with batch norm -> Relu*.
Args: Args:
@ -259,6 +194,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
relu_op_name: String, name of the Relu* operation. relu_op_name: String, name of the Relu* operation.
with_bypass: Bool, when true there is an extra connection added from with_bypass: Bool, when true there is an extra connection added from
inputs to just before Relu*. inputs to just before Relu*.
has_scaling: Bool, when true the batch norm has scaling.
fused_batch_norm: Bool, when true the batch norm is fused.
""" """
g = ops.Graph() g = ops.Graph()
with g.as_default(): with g.as_default():
@ -267,12 +204,15 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
out_depth = 256 if with_bypass else 128 out_depth = 256 if with_bypass else 128
activation_fn = None if with_bypass else relu activation_fn = None if with_bypass else relu
scope = 'test/test2' if with_bypass else 'test' scope = 'test/test2' if with_bypass else 'test'
node = fully_connected(inputs, out_depth, node = fully_connected(
weights_initializer=self._WeightInit(0.03), inputs,
activation_fn=activation_fn, out_depth,
normalizer_fn=batch_norm, weights_initializer=self._WeightInit(0.03),
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, activation_fn=activation_fn,
scope=scope) normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(
scale=has_scaling, fused=fused_batch_norm),
scope=scope)
if with_bypass: if with_bypass:
node = math_ops.add(inputs, node, name='test/Add') node = math_ops.add(inputs, node, name='test/Add')
relu(node, name='test/' + relu_op_name) relu(node, name='test/' + relu_op_name)
@ -281,9 +221,10 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
folded_mul = g.get_operation_by_name(scope + '/mul_fold') folded_mul = g.get_operation_by_name(scope + '/mul_fold')
self.assertEqual(folded_mul.type, 'Mul') self.assertEqual(folded_mul.type, 'Mul')
self._AssertInputOpsAre(folded_mul, self._AssertInputOpsAre(folded_mul, [
[scope + '/weights/read', scope + '/weights/read',
scope + '/BatchNorm/batchnorm/mul']) self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm)
])
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/MatMul_Fold']) self._AssertOutputGoesToOps(folded_mul, g, [scope + '/MatMul_Fold'])
folded_conv = g.get_operation_by_name(scope + '/MatMul_Fold') folded_conv = g.get_operation_by_name(scope + '/MatMul_Fold')
@ -294,71 +235,18 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
folded_add = g.get_operation_by_name(scope + '/add_fold') folded_add = g.get_operation_by_name(scope + '/add_fold')
self.assertEqual(folded_add.type, 'Add') self.assertEqual(folded_add.type, 'Add')
self._AssertInputOpsAre(folded_add, self._AssertInputOpsAre(folded_add, [
[scope + '/MatMul_Fold', scope + '/MatMul_Fold',
scope + '/BatchNorm/batchnorm/sub']) self._BathNormBiasName(scope, fused_batch_norm)
])
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
self._AssertOutputGoesToOps(folded_add, g, output_op_names) self._AssertOutputGoesToOps(folded_add, g, output_op_names)
def testFoldFullyConnectedLayer(self): def testFoldFullyConnectedLayer(self):
self._RunTestOverParameters(self._TestFoldFullyConnectedLayer) self._RunTestOverParameters(self._TestFoldFullyConnectedLayer)
def _TestFoldFullyConnectedLayerWithoutScale(self, relu, relu_op_name, def _TestFoldDepthwiseConv2d(self, relu, relu_op_name, with_bypass,
with_bypass): has_scaling, fused_batch_norm):
"""Tests folding cases: inputs -> FC with batch norm -> Relu*.
Args:
relu: Callable that returns an Operation, a factory method for the Relu*.
relu_op_name: String, name of the Relu* operation.
with_bypass: Bool, when true there is an extra connection added from
inputs to just before Relu*.
"""
g = ops.Graph()
with g.as_default():
batch_size, depth = 5, 256
inputs = array_ops.zeros((batch_size, depth))
out_depth = 256 if with_bypass else 128
activation_fn = None if with_bypass else relu
bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS)
bn_params['scale'] = False
scope = 'test/test2' if with_bypass else 'test'
node = fully_connected(inputs, out_depth,
weights_initializer=self._WeightInit(0.03),
activation_fn=activation_fn,
normalizer_fn=batch_norm,
normalizer_params=bn_params,
scope=scope)
if with_bypass:
node = math_ops.add(inputs, node, name='test/Add')
relu(node, name='test/' + relu_op_name)
fold_batch_norms.FoldBatchNorms(g)
folded_mul = g.get_operation_by_name(scope + '/mul_fold')
self.assertEqual(folded_mul.type, 'Mul')
self._AssertInputOpsAre(folded_mul,
[scope + '/weights/read',
scope + '/BatchNorm/batchnorm/Rsqrt'])
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/MatMul_Fold'])
folded_conv = g.get_operation_by_name(scope + '/MatMul_Fold')
self.assertEqual(folded_conv.type, 'MatMul')
self._AssertInputOpsAre(folded_conv,
[scope + '/mul_fold', inputs.op.name])
self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold'])
folded_add = g.get_operation_by_name(scope + '/add_fold')
self.assertEqual(folded_add.type, 'Add')
self._AssertInputOpsAre(folded_add,
[scope + '/MatMul_Fold',
scope + '/BatchNorm/batchnorm/sub'])
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
def testFoldFullyConnectedLayerWithoutScale(self):
self._RunTestOverParameters(self._TestFoldFullyConnectedLayerWithoutScale)
def _TestFoldDepthwiseConv2d(self, relu, relu_op_name, with_bypass):
"""Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*. """Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*.
Args: Args:
@ -366,6 +254,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
relu_op_name: String, name of the Relu* operation. relu_op_name: String, name of the Relu* operation.
with_bypass: Bool, when true there is an extra connection added from with_bypass: Bool, when true there is an extra connection added from
inputs to just before Relu*. inputs to just before Relu*.
has_scaling: Bool, when true the batch norm has scaling.
fused_batch_norm: Bool, when true the batch norm is fused.
""" """
g = ops.Graph() g = ops.Graph()
with g.as_default(): with g.as_default():
@ -374,13 +264,18 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
stride = 1 if with_bypass else 2 stride = 1 if with_bypass else 2
activation_fn = None if with_bypass else relu activation_fn = None if with_bypass else relu
scope = 'test/test2' if with_bypass else 'test' scope = 'test/test2' if with_bypass else 'test'
node = separable_conv2d(inputs, None, [5, 5], stride=stride, node = separable_conv2d(
depth_multiplier=1.0, padding='SAME', inputs,
weights_initializer=self._WeightInit(0.09), None, [5, 5],
activation_fn=activation_fn, stride=stride,
normalizer_fn=batch_norm, depth_multiplier=1.0,
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, padding='SAME',
scope=scope) weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(
scale=has_scaling, fused=fused_batch_norm),
scope=scope)
if with_bypass: if with_bypass:
node = math_ops.add(inputs, node, name='test/Add') node = math_ops.add(inputs, node, name='test/Add')
relu(node, name='test/' + relu_op_name) relu(node, name='test/' + relu_op_name)
@ -396,9 +291,10 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
scale_reshape = g.get_operation_by_name(scope + '/scale_reshape') scale_reshape = g.get_operation_by_name(scope + '/scale_reshape')
self.assertEqual(scale_reshape.type, 'Reshape') self.assertEqual(scale_reshape.type, 'Reshape')
self._AssertInputOpsAre(scale_reshape, self._AssertInputOpsAre(scale_reshape, [
[scope + '/BatchNorm/batchnorm/mul', self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm),
scope + '/scale_reshape/shape']) scope + '/scale_reshape/shape'
])
self._AssertOutputGoesToOps(scale_reshape, g, [scope + '/mul_fold']) self._AssertOutputGoesToOps(scale_reshape, g, [scope + '/mul_fold'])
folded_conv = g.get_operation_by_name(scope + '/depthwise_Fold') folded_conv = g.get_operation_by_name(scope + '/depthwise_Fold')
@ -409,77 +305,35 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
folded_add = g.get_operation_by_name(scope + '/add_fold') folded_add = g.get_operation_by_name(scope + '/add_fold')
self.assertEqual(folded_add.type, 'Add') self.assertEqual(folded_add.type, 'Add')
self._AssertInputOpsAre(folded_add, self._AssertInputOpsAre(folded_add, [
[scope + '/depthwise_Fold', scope + '/depthwise_Fold',
scope + '/BatchNorm/batchnorm/sub']) self._BathNormBiasName(scope, fused_batch_norm)
])
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
self._AssertOutputGoesToOps(folded_add, g, output_op_names) self._AssertOutputGoesToOps(folded_add, g, output_op_names)
def testFoldDepthwiseConv2d(self): def testFoldDepthwiseConv2d(self):
self._RunTestOverParameters(self._TestFoldDepthwiseConv2d) self._RunTestOverParameters(self._TestFoldDepthwiseConv2d)
def _TestFoldDepthwiseConv2dWithoutScale(self, relu, relu_op_name, def _BatchNormParams(self, scale=True, fused=False):
with_bypass): return {
"""Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*. 'center': True,
'scale': scale,
'decay': 1.0 - 0.003,
'fused': fused
}
Args: def _BatchNormMultiplierName(self, scope, has_scaling, fused):
relu: Callable that returns an Operation, a factory method for the Relu*. if has_scaling:
relu_op_name: String, name of the Relu* operation. if fused:
with_bypass: Bool, when true there is an extra connection added from return scope + '/mul'
inputs to just before Relu*. return scope + '/BatchNorm/batchnorm/mul'
""" return scope + '/BatchNorm/batchnorm/Rsqrt'
g = ops.Graph()
with g.as_default():
batch_size, height, width = 5, 128, 128
inputs = array_ops.zeros((batch_size, height, width, 3))
stride = 1 if with_bypass else 2
activation_fn = None if with_bypass else relu
bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS)
bn_params['scale'] = False
scope = 'test/test2' if with_bypass else 'test'
node = separable_conv2d(inputs, None, [5, 5], stride=stride,
depth_multiplier=1.0, padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
normalizer_fn=batch_norm,
normalizer_params=bn_params,
scope=scope)
if with_bypass:
node = math_ops.add(inputs, node, name='test/Add')
relu(node, name='test/' + relu_op_name)
fold_batch_norms.FoldBatchNorms(g) def _BathNormBiasName(self, scope, fused):
if fused:
folded_mul = g.get_operation_by_name(scope + '/mul_fold') return scope + '/bias'
self.assertEqual(folded_mul.type, 'Mul') return scope + '/BatchNorm/batchnorm/sub'
self._AssertInputOpsAre(folded_mul,
[scope + '/depthwise_weights/read',
scope + '/scale_reshape'])
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/depthwise_Fold'])
scale_reshape = g.get_operation_by_name(scope + '/scale_reshape')
self.assertEqual(scale_reshape.type, 'Reshape')
self._AssertInputOpsAre(scale_reshape,
[scope + '/BatchNorm/batchnorm/Rsqrt',
scope + '/scale_reshape/shape'])
self._AssertOutputGoesToOps(scale_reshape, g, [scope + '/mul_fold'])
folded_conv = g.get_operation_by_name(scope + '/depthwise_Fold')
self.assertEqual(folded_conv.type, 'DepthwiseConv2dNative')
self._AssertInputOpsAre(folded_conv,
[scope + '/mul_fold', inputs.op.name])
self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold'])
folded_add = g.get_operation_by_name(scope + '/add_fold')
self.assertEqual(folded_add.type, 'Add')
self._AssertInputOpsAre(folded_add,
[scope + '/depthwise_Fold',
scope + '/BatchNorm/batchnorm/sub'])
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
def testFoldDepthwiseConv2dWithoutScale(self):
self._RunTestOverParameters(self._TestFoldDepthwiseConv2dWithoutScale)
def _WeightInit(self, stddev): def _WeightInit(self, stddev):
"""Returns a truncated normal variable initializer. """Returns a truncated normal variable initializer.

View File

@ -0,0 +1,200 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities that match patterns in a tf.Graph."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class OpTypePattern(object):
"""A tree pattern that matches TF expressions with certain op types."""
def __init__(self, op_type, name=None, inputs=None):
"""Initializes an OpTypePattern.
Args:
op_type: string that specifies the allowed types of the root. It can be
(1) an op type, e.g. 'Conv2D',
(2) '*', i.e. wildcard, or
(3) multiple op types separated by '|', e.g., 'Relu|Relu6'.
We could use regex strings, which might be worthwhile when we have many
similar TF op types.
name: Optional string. The name of the pattern that can be looked up in
MatchResult.
inputs: Optional list of `OpTypePattern`s or strings that specify the
patterns for the inputs of a matching op. If None, this pattern accepts
any inputs of a matching op.
"""
self._op_type = op_type
self._name = name
if inputs is None:
inputs = []
self._inputs = [
input_pattern if isinstance(input_pattern, OpTypePattern) else
OpTypePattern(input_pattern) for input_pattern in inputs
]
@property
def op_type(self):
return self._op_type
@property
def inputs(self):
return self._inputs
@property
def name(self):
return self._name
class MatchResult(object):
r"""Encapsulates the result of a match done by GraphMatcher.
MatchResult contains a map from OpTypePattern to the matching op and tensor.
When the matching op has multiple output tensors, the matching tensor is the
output tensor used by the matching op of the parent pattern. E.g., when we
match graph
- +
/ \y0 y1/ \
x split z
|
y (nodes are ops; edges are going up)
against add_pattern defined as
y1_pattern = OpTypePattern('*')
z_pattern = OpTypePattern('*')
add_pattern = OpTypePattern('+', inputs=[y1_pattern, z_pattern])
the matching op of `y1_pattern` is `split`, and the matching tensor of
`y1_pattern`
is `y1` not `y0`.
"""
def __init__(self):
self._pattern_to_op_tensor = {}
self._name_to_pattern = {}
def add(self, pattern, op, tensor):
self._pattern_to_op_tensor[pattern] = op, tensor
if pattern.name is not None:
if pattern.name in self._name_to_pattern:
raise ValueError(
'Name %s is already bound to another pattern' % pattern.name)
self._name_to_pattern[pattern.name] = pattern
def _to_pattern(self, pattern_or_name):
if isinstance(pattern_or_name, OpTypePattern):
return pattern_or_name
if isinstance(pattern_or_name, str):
return self._name_to_pattern[pattern_or_name]
raise ValueError('pattern_or_name has type %s. Expect OpTypePattern or str.'
% type(pattern_or_name))
def get_op(self, pattern_or_name):
return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][0]
def get_tensor(self, pattern_or_name):
return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][1]
class GraphMatcher(object):
"""Checks if a particular subgraph matches a given pattern."""
def __init__(self, pattern):
"""Initializes a GraphMatcher.
Args:
pattern: The `OpTypePattern` against which `GraphMatcher` matches
subgraphs.
"""
self._pattern = pattern
def _match_pattern(self, pattern, op, tensor):
"""Returns whether an TF expression rooted at `op` matches `pattern`.
If there is a match, adds to `self._match_result` the matching op and tensor
with key `pattern`.
Args:
pattern: An `OpTypePattern`.
op: A `tf.Operation` to match against the pattern.
tensor: the output `tf.Tensor` of `op` that is used by the matching op of
`pattern`'s parent. Can be None if `pattern` is already the root of the
pattern tree.
Returns:
True if an TF expression rooted at `op` matches `pattern`.
"""
if pattern.op_type != '*':
if op.type not in pattern.op_type.split('|'):
return False
self._match_result.add(pattern, op, tensor)
if not pattern.inputs:
# If pattern.inputs is empty, skips the rest and accepts all the inputs.
return True
return len(op.inputs) == len(pattern.inputs) and all([
self._match_pattern(input_pattern, input_tensor.op, input_tensor)
for input_tensor, input_pattern in zip(op.inputs, pattern.inputs)
])
def match_op(self, op):
"""Matches `op` against `self._pattern`.
Args:
op: `tf.Operation` to match against the pattern.
Returns:
Returns a `MatchResult` if `op` matches the pattern; otherwise, returns
None.
"""
self._match_result = MatchResult()
if not self._match_pattern(self._pattern, op, tensor=None):
return None
return self._match_result
def match_ops(self, ops):
"""Matches each operation in `ops` against `self._pattern`.
Args:
ops: collection of `tf.Operation` to match against the pattern.
Yields:
`MatchResult` for each `tf.Operation` that matches the pattern.
"""
for op in ops:
match_result = self.match_op(op)
if match_result:
yield match_result
def match_graph(self, graph):
"""Matches each operation in `graph` against `self._pattern`.
Args:
graph: `tf.Graph` containing operations to match.
Yields:
`MatchResult` for each `tf.Operation` in `graph` that matches the pattern.
"""
# Python 3.3.2+ implements `yield from`, but for now:
for match_result in self.match_ops(graph.get_operations()):
yield match_result

View File

@ -0,0 +1,130 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for graph_matcher."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework.python import ops as contrib_ops
from tensorflow.contrib.layers.python.layers import initializers
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.quantize.python import graph_matcher
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
class GraphMatcherTest(test_util.TensorFlowTestCase):
def test_conv_layer(self):
g = ops.Graph()
with g.as_default():
inputs = array_ops.placeholder(dtypes.float32, shape=[8, 5, 5, 3])
with contrib_ops.arg_scope(
[layers.batch_norm], fused=True, is_training=True, trainable=True):
return layers.convolution(
inputs,
num_outputs=16,
kernel_size=3,
stride=1,
padding='VALID',
activation_fn=nn_ops.relu,
normalizer_fn=layers.batch_norm,
normalizer_params={},
weights_initializer=initializers.xavier_initializer(),
weights_regularizer=None,
biases_initializer=init_ops.zeros_initializer(),
biases_regularizer=None,
reuse=None,
trainable=True,
scope=None)
inputs_pattern = graph_matcher.OpTypePattern('*', name='inputs')
relu_pattern = graph_matcher.OpTypePattern(
'Relu',
name='relu',
inputs=[
graph_matcher.OpTypePattern(
'FusedBatchNorm',
inputs=[
graph_matcher.OpTypePattern(
'Conv2D', inputs=[inputs_pattern, '*']), '*', '*', '*',
'*'
])
])
matcher = graph_matcher.GraphMatcher(relu_pattern)
match_results = list(matcher.match_graph(g))
self.assertEqual(1, len(match_results))
match_result = match_results[0]
self.assertEqual(match_result.get_tensor(inputs_pattern), inputs)
self.assertEqual(match_result.get_tensor('inputs'), inputs)
def test_multiple_outputs(self):
# - +
# / \y0 y1/ \
# x split z
# |
# y (nodes are ops; edges are going up)
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtypes.float32, shape=[1], name='x')
y = array_ops.placeholder(dtypes.float32, shape=[2], name='y')
y0, y1 = array_ops.split(y, num_or_size_splits=2, axis=0)
z = array_ops.placeholder(dtypes.float32, shape=[1], name='z')
math_ops.add(x, y0)
math_ops.subtract(y1, z)
y1_pattern = graph_matcher.OpTypePattern('*')
minus_pattern = graph_matcher.OpTypePattern('Sub', inputs=[y1_pattern, '*'])
matcher = graph_matcher.GraphMatcher(minus_pattern)
match_results = list(matcher.match_graph(g))
self.assertEqual(1, len(match_results))
match_result = match_results[0]
self.assertEqual(y0.op, y1.op)
self.assertEqual(match_result.get_op(y1_pattern), y1.op)
self.assertEqual(match_result.get_tensor(y1_pattern), y1)
def test_oneof_pattern(self):
# - +
# / \ / \
# x y z
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtypes.float32, shape=[], name='x')
y = array_ops.placeholder(dtypes.float32, shape=[], name='y')
z = array_ops.placeholder(dtypes.float32, shape=[], name='z')
plus = x + y
minus = y - z
add_or_sub_pattern = graph_matcher.OpTypePattern(
'Add|Sub', inputs=['*', '*'])
matcher = graph_matcher.GraphMatcher(add_or_sub_pattern)
self.assertEqual([
match_result.get_op(add_or_sub_pattern)
for match_result in matcher.match_graph(g)
], [plus.op, minus.op])
if __name__ == '__main__':
googletest.main()

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.contrib.layers.python.layers import layers from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.quantize.python import fold_batch_norms
from tensorflow.contrib.quantize.python import quantize from tensorflow.contrib.quantize.python import quantize
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
@ -35,18 +36,11 @@ conv2d = layers.conv2d
fully_connected = layers.fully_connected fully_connected = layers.fully_connected
separable_conv2d = layers.separable_conv2d separable_conv2d = layers.separable_conv2d
_DEFAULT_BATCH_NORM_PARAMS = {
'center': True,
'scale': True,
'decay': 1.0 - 0.003,
'fused': False,
}
# TODO(suharshs): Use parameterized test once OSS TF supports it.
class QuantizeTest(test_util.TensorFlowTestCase): class QuantizeTest(test_util.TensorFlowTestCase):
def _RunTestOverParameters(self, test_fn): def _RunWithoutBatchNormTestOverParameters(self, test_fn):
# TODO(suharshs): Use parameterized test once OSS TF supports it.
parameters_list = [ parameters_list = [
# (activation, activation_op_name, with_bypass, delay) # (activation, activation_op_name, with_bypass, delay)
(nn_ops.relu6, 'Relu6', False, None), (nn_ops.relu6, 'Relu6', False, None),
@ -60,10 +54,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
(array_ops.identity, 'Identity', True, None), (array_ops.identity, 'Identity', True, None),
(nn_ops.relu6, 'Relu6', True, 5000), (nn_ops.relu6, 'Relu6', True, 5000),
(nn_ops.relu, 'Relu', True, 5000), (nn_ops.relu, 'Relu', True, 5000),
(array_ops.identity, 'Identity', True, 5000) (array_ops.identity, 'Identity', True, 5000),
] ]
for parameters in parameters_list: for params in parameters_list:
test_fn(parameters[0], parameters[1], parameters[2], parameters[3]) test_fn(params[0], params[1], params[2], params[3])
def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name, def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name,
with_bypass, delay): with_bypass, delay):
@ -137,7 +131,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
def testQuantize_Conv2dWithoutBatchNorm(self): def testQuantize_Conv2dWithoutBatchNorm(self):
self._RunTestOverParameters(self._TestQuantize_Conv2dWithoutBatchNorm) self._RunWithoutBatchNormTestOverParameters(
self._TestQuantize_Conv2dWithoutBatchNorm)
def _TestQuantize_FCWithoutBatchNorm(self, activation, activation_op_name, def _TestQuantize_FCWithoutBatchNorm(self, activation, activation_op_name,
with_bypass, delay): with_bypass, delay):
@ -210,7 +205,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
def testQuantize_FCWithoutBatchNorm(self): def testQuantize_FCWithoutBatchNorm(self):
self._RunTestOverParameters(self._TestQuantize_FCWithoutBatchNorm) self._RunWithoutBatchNormTestOverParameters(
self._TestQuantize_FCWithoutBatchNorm)
def _TestQuantize_DepthwiseConv2dWithoutBatchNorm( def _TestQuantize_DepthwiseConv2dWithoutBatchNorm(
self, activation, activation_op_name, with_bypass, delay): self, activation, activation_op_name, with_bypass, delay):
@ -284,11 +280,43 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
def testQuantize_DepthwiseConv2dWithoutBatchNorm(self): def testQuantize_DepthwiseConv2dWithoutBatchNorm(self):
self._RunTestOverParameters( self._RunWithoutBatchNormTestOverParameters(
self._TestQuantize_DepthwiseConv2dWithoutBatchNorm) self._TestQuantize_DepthwiseConv2dWithoutBatchNorm)
def _RunBatchNormTestOverParameters(self, test_fn):
# TODO(suharshs): Use parameterized test once OSS TF supports it.
parameters_list = [
# (activation, activation_op_name, with_bypass, delay, fused_batch_norm)
(nn_ops.relu6, 'Relu6', False, None, False),
(nn_ops.relu, 'Relu', False, None, False),
(array_ops.identity, 'Identity', False, None, False),
(nn_ops.relu6, 'Relu6', False, 5000, False),
(nn_ops.relu, 'Relu', False, 5000, False),
(array_ops.identity, 'Identity', False, 5000, False),
(nn_ops.relu6, 'Relu6', True, None, False),
(nn_ops.relu, 'Relu', True, None, False),
(array_ops.identity, 'Identity', True, None, False),
(nn_ops.relu6, 'Relu6', True, 5000, False),
(nn_ops.relu, 'Relu', True, 5000, False),
(array_ops.identity, 'Identity', True, 5000, False),
(nn_ops.relu6, 'Relu6', False, None, True),
(nn_ops.relu, 'Relu', False, None, True),
(array_ops.identity, 'Identity', False, None, True),
(nn_ops.relu6, 'Relu6', False, 5000, True),
(nn_ops.relu, 'Relu', False, 5000, True),
(array_ops.identity, 'Identity', False, 5000, True),
(nn_ops.relu6, 'Relu6', True, None, True),
(nn_ops.relu, 'Relu', True, None, True),
(array_ops.identity, 'Identity', True, None, True),
(nn_ops.relu6, 'Relu6', True, 5000, True),
(nn_ops.relu, 'Relu', True, 5000, True),
(array_ops.identity, 'Identity', True, 5000, True)
]
for params in parameters_list:
test_fn(params[0], params[1], params[2], params[3], params[4])
def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name,
with_bypass, delay): with_bypass, delay, fused_batch_norm):
"""Tests quantization: inputs -> Conv2d with batch norm -> Activation. """Tests quantization: inputs -> Conv2d with batch norm -> Activation.
Args: Args:
@ -298,25 +326,29 @@ class QuantizeTest(test_util.TensorFlowTestCase):
with_bypass: Bool, when true there is an extra connection added from with_bypass: Bool, when true there is an extra connection added from
inputs to just before Activation. inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts. delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
""" """
self._testQuantize_Conv2dWithBatchNorm( self._testQuantize_Conv2dWithBatchNorm(
activation, activation,
activation_op_name, activation_op_name,
with_bypass, with_bypass,
delay, delay,
fused_batch_norm,
use_ema=True) use_ema=True)
self._testQuantize_Conv2dWithBatchNorm( self._testQuantize_Conv2dWithBatchNorm(
activation, activation,
activation_op_name, activation_op_name,
with_bypass, with_bypass,
delay, delay,
fused_batch_norm,
use_ema=False) use_ema=False)
def testQuantize_Conv2dWithBatchNorm(self): def testQuantize_Conv2dWithBatchNorm(self):
self._RunTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm) self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm)
def _testQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, def _testQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name,
with_bypass, delay, use_ema): with_bypass, delay, fused_batch_norm,
use_ema):
"""Tests quantization: inputs -> Conv2d with batch norm -> Activation. """Tests quantization: inputs -> Conv2d with batch norm -> Activation.
Args: Args:
@ -326,6 +358,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
with_bypass: Bool, when true there is an extra connection added from with_bypass: Bool, when true there is an extra connection added from
inputs to just before Activation. inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts. delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_ema: Bool, when true uses EMA quantization for BN folded weights. use_ema: Bool, when true uses EMA quantization for BN folded weights.
""" """
graph = ops.Graph() graph = ops.Graph()
@ -337,39 +370,29 @@ class QuantizeTest(test_util.TensorFlowTestCase):
stride = 1 if with_bypass else 2 stride = 1 if with_bypass else 2
out_depth = 3 if with_bypass else 32 out_depth = 3 if with_bypass else 32
scope = 'test/test2' if with_bypass else 'test' scope = 'test/test2' if with_bypass else 'test'
node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', node = conv2d(
weights_initializer=self._WeightInit(0.09), inputs,
activation_fn=None, out_depth, [5, 5],
normalizer_fn=batch_norm, stride=stride,
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS,
scope=scope)
# Manually fold the batch norm.
weights = graph.get_operation_by_name(scope + '/weights/read').outputs[0]
bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul')
.outputs[0])
mul_fold = math_ops.multiply(weights, bn_mult, name=scope + '/mul_fold')
stride = [stride, stride]
conv_fold = nn_ops.convolution(
input=inputs,
filter=mul_fold,
padding='SAME', padding='SAME',
strides=stride, weights_initializer=self._WeightInit(0.09),
data_format='NHWC', activation_fn=None,
name=scope + '/convolution_Fold') normalizer_fn=batch_norm,
bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub') normalizer_params=self._BatchNormParams(fused_batch_norm),
.outputs[0]) scope=scope)
add_fold = math_ops.add(conv_fold, bn_bias, name=scope + '/add_fold')
# Manually add a bypass (optionaly) and an activation. # Manually add a bypass (optionaly) and an activation.
if with_bypass: if with_bypass:
node = math_ops.add(inputs, add_fold, name='test/Add') node = math_ops.add(inputs, node, name='test/Add')
else:
node = add_fold
node = activation(node, name='test/' + activation_op_name) node = activation(node, name='test/' + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier') update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]): with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency') array_ops.identity(node, name='control_dependency')
fold_batch_norms.FoldBatchNorms(graph)
quantize.Quantize( quantize.Quantize(
graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema)
@ -413,7 +436,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name, def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name,
with_bypass, delay): with_bypass, delay, fused_batch_norm):
"""Tests quantization: inputs -> FC with batch norm -> Activation. """Tests quantization: inputs -> FC with batch norm -> Activation.
Args: Args:
@ -423,25 +446,29 @@ class QuantizeTest(test_util.TensorFlowTestCase):
with_bypass: Bool, when true there is an extra connection added from with_bypass: Bool, when true there is an extra connection added from
inputs to just before Activation. inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts. delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
""" """
self._testQuantize_FCWithBatchNorm( self._testQuantize_FCWithBatchNorm(
activation, activation,
activation_op_name, activation_op_name,
with_bypass, with_bypass,
delay, delay,
fused_batch_norm,
use_ema=True) use_ema=True)
self._testQuantize_FCWithBatchNorm( self._testQuantize_FCWithBatchNorm(
activation, activation,
activation_op_name, activation_op_name,
with_bypass, with_bypass,
delay, delay,
fused_batch_norm,
use_ema=False) use_ema=False)
def testQuantize_FCWithBatchNorm(self): def testQuantize_FCWithBatchNorm(self):
self._RunTestOverParameters(self._TestQuantize_FCWithBatchNorm) self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm)
def _testQuantize_FCWithBatchNorm(self, activation, activation_op_name, def _testQuantize_FCWithBatchNorm(self, activation, activation_op_name,
with_bypass, delay, use_ema): with_bypass, delay, fused_batch_norm,
use_ema):
"""Tests quantization: inputs -> FC with batch norm -> Activation. """Tests quantization: inputs -> FC with batch norm -> Activation.
Args: Args:
@ -451,6 +478,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
with_bypass: Bool, when true there is an extra connection added from with_bypass: Bool, when true there is an extra connection added from
inputs to just before Activation. inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts. delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_ema: Bool, when true uses EMA quantization for BN folded weights. use_ema: Bool, when true uses EMA quantization for BN folded weights.
""" """
graph = ops.Graph() graph = ops.Graph()
@ -461,32 +489,27 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs = array_ops.zeros((batch_size, depth)) inputs = array_ops.zeros((batch_size, depth))
out_depth = 256 if with_bypass else 128 out_depth = 256 if with_bypass else 128
scope = 'test/test2' if with_bypass else 'test' scope = 'test/test2' if with_bypass else 'test'
node = fully_connected(inputs, out_depth, node = fully_connected(
weights_initializer=self._WeightInit(0.03), inputs,
activation_fn=None, out_depth,
normalizer_fn=batch_norm, weights_initializer=self._WeightInit(0.03),
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS, activation_fn=None,
scope=scope) normalizer_fn=batch_norm,
# Manually fold the batch norm. normalizer_params=self._BatchNormParams(fused_batch_norm),
weights = graph.get_operation_by_name(scope + '/weights/read').outputs[0] scope=scope)
bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul')
.outputs[0])
mul_fold = math_ops.multiply(weights, bn_mult, name=scope + '/mul_fold')
fc_fold = math_ops.matmul(inputs, mul_fold, name=scope + '/MatMul_Fold')
bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub')
.outputs[0])
add_fold = math_ops.add(fc_fold, bn_bias, name=scope + '/add_fold')
# Manually add a bypass (optionaly) and an activation. # Manually add a bypass (optionaly) and an activation.
if with_bypass: if with_bypass:
node = math_ops.add(inputs, add_fold, name='test/Add') node = math_ops.add(inputs, node, name='test/Add')
else:
node = add_fold
node = activation(node, name='test/' + activation_op_name) node = activation(node, name='test/' + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier') update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]): with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency') array_ops.identity(node, name='control_dependency')
fold_batch_norms.FoldBatchNorms(graph)
quantize.Quantize( quantize.Quantize(
graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema)
@ -530,7 +553,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
def _TestQuantize_DepthwiseConv2dWithBatchNorm( def _TestQuantize_DepthwiseConv2dWithBatchNorm(
self, activation, activation_op_name, with_bypass, delay): self, activation, activation_op_name, with_bypass, delay,
fused_batch_norm):
"""Tests quantization: inputs -> DWConv2d with batch norm -> Activation. """Tests quantization: inputs -> DWConv2d with batch norm -> Activation.
Args: Args:
@ -540,26 +564,30 @@ class QuantizeTest(test_util.TensorFlowTestCase):
with_bypass: Bool, when true there is an extra connection added from with_bypass: Bool, when true there is an extra connection added from
inputs to just before Activation. inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts. delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
""" """
self._testQuantize_DepthwiseConv2dWithBatchNorm( self._testQuantize_DepthwiseConv2dWithBatchNorm(
activation, activation,
activation_op_name, activation_op_name,
with_bypass, with_bypass,
delay, delay,
fused_batch_norm,
use_ema=True) use_ema=True)
self._testQuantize_DepthwiseConv2dWithBatchNorm( self._testQuantize_DepthwiseConv2dWithBatchNorm(
activation, activation,
activation_op_name, activation_op_name,
with_bypass, with_bypass,
delay, delay,
fused_batch_norm,
use_ema=False) use_ema=False)
def testQuantize_DepthwiseConv2dWithBatchNorm(self): def testQuantize_DepthwiseConv2dWithBatchNorm(self):
self._RunTestOverParameters( self._RunBatchNormTestOverParameters(
self._TestQuantize_DepthwiseConv2dWithoutBatchNorm) self._TestQuantize_DepthwiseConv2dWithBatchNorm)
def _testQuantize_DepthwiseConv2dWithBatchNorm( def _testQuantize_DepthwiseConv2dWithBatchNorm(
self, activation, activation_op_name, with_bypass, delay, use_ema): self, activation, activation_op_name, with_bypass, delay,
fused_batch_norm, use_ema):
"""Tests quantization: inputs -> DWConv2d with batch norm -> Activation. """Tests quantization: inputs -> DWConv2d with batch norm -> Activation.
Args: Args:
@ -569,6 +597,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
with_bypass: Bool, when true there is an extra connection added from with_bypass: Bool, when true there is an extra connection added from
inputs to just before Activation. inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts. delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_ema: Bool, when true uses EMA quantization for BN folded weights. use_ema: Bool, when true uses EMA quantization for BN folded weights.
""" """
graph = ops.Graph() graph = ops.Graph()
@ -579,46 +608,30 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs = array_ops.zeros((batch_size, height, width, depth)) inputs = array_ops.zeros((batch_size, height, width, depth))
stride = 1 if with_bypass else 2 stride = 1 if with_bypass else 2
scope = 'test/test2' if with_bypass else 'test' scope = 'test/test2' if with_bypass else 'test'
node = separable_conv2d(inputs, None, [5, 5], stride=stride, node = separable_conv2d(
depth_multiplier=1.0, padding='SAME', inputs,
weights_initializer=self._WeightInit(0.09), None, [5, 5],
activation_fn=None, stride=stride,
normalizer_fn=batch_norm, depth_multiplier=1.0,
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS,
scope=scope)
# Manually fold the batch norm.
weights = (graph.get_operation_by_name(scope + '/depthwise_weights/read')
.outputs[0])
bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul')
.outputs[0])
new_shape = [
weights.get_shape().as_list()[2], weights.get_shape().as_list()[3]
]
bn_mult_reshaped = array_ops.reshape(
bn_mult, new_shape, name=scope + '/gamma_reshape')
mul_fold = math_ops.multiply(
weights, bn_mult_reshaped, name=scope + '/mul_fold')
stride = [1, stride, stride, 1]
conv_fold = nn_ops.depthwise_conv2d(
input=inputs,
filter=mul_fold,
padding='SAME', padding='SAME',
strides=stride, weights_initializer=self._WeightInit(0.09),
name=scope + '/depthwise_Fold') activation_fn=None,
bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub') normalizer_fn=batch_norm,
.outputs[0]) normalizer_params=self._BatchNormParams(fused_batch_norm),
add_fold = math_ops.add(conv_fold, bn_bias, name=scope + '/add_fold') scope=scope)
# Manually add a bypass (optionaly) and an activation. # Manually add a bypass (optionaly) and an activation.
if with_bypass: if with_bypass:
node = math_ops.add(inputs, add_fold, name='test/Add') node = math_ops.add(inputs, node, name='test/Add')
else:
node = add_fold
node = activation(node, name='test/' + activation_op_name) node = activation(node, name='test/' + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier') update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]): with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency') array_ops.identity(node, name='control_dependency')
fold_batch_norms.FoldBatchNorms(graph)
quantize.Quantize( quantize.Quantize(
graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema)
quantization_node_name = 'FakeQuantWithMinMaxVars' quantization_node_name = 'FakeQuantWithMinMaxVars'
@ -660,6 +673,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
if delay else 'control_dependency') if delay else 'control_dependency')
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
def _BatchNormParams(self, fused=False):
return {'center': True, 'scale': True, 'decay': 1.0 - 0.003, 'fused': fused}
def _WeightInit(self, stddev): def _WeightInit(self, stddev):
"""Returns truncated normal variable initializer. """Returns truncated normal variable initializer.

View File

@ -156,6 +156,7 @@ cuda_py_tests(
"//tensorflow/python:client_testlib", "//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops", "//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:gradients", "//tensorflow/python:gradients",
"//tensorflow/python:init_ops", "//tensorflow/python:init_ops",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
@ -165,6 +166,7 @@ cuda_py_tests(
"//tensorflow/python:util", "//tensorflow/python:util",
"//tensorflow/python:variable_scope", "//tensorflow/python:variable_scope",
"//tensorflow/python:variables", "//tensorflow/python:variables",
"//tensorflow/python/eager:context",
], ],
shard_count = 10, shard_count = 10,
) )

View File

@ -25,10 +25,12 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib import rnn as rnn_lib from tensorflow.contrib import rnn as rnn_lib
from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as ops_lib from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import gradients_impl
@ -881,6 +883,7 @@ class LSTMTest(test.TestCase):
# Smoke test, this should not raise an error # Smoke test, this should not raise an error
rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32)
@test_util.run_in_graph_and_eager_modes()
def testDynamicRNNWithTupleStates(self): def testDynamicRNNWithTupleStates(self):
num_units = 3 num_units = 3
input_size = 5 input_size = 5
@ -888,13 +891,20 @@ class LSTMTest(test.TestCase):
num_proj = 4 num_proj = 4
max_length = 8 max_length = 8
sequence_length = [4, 6] sequence_length = [4, 6]
in_graph_mode = context.in_graph_mode()
with self.test_session(graph=ops_lib.Graph()) as sess: with self.test_session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer( initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed) -0.01, 0.01, seed=self._seed)
inputs = max_length * [ if in_graph_mode:
array_ops.placeholder( inputs = max_length * [
dtypes.float32, shape=(None, input_size)) array_ops.placeholder(
] dtypes.float32, shape=(None, input_size))
]
else:
inputs = max_length * [
constant_op.constant(
np.random.randn(batch_size, input_size).astype(np.float32))
]
inputs_c = array_ops.stack(inputs) inputs_c = array_ops.stack(inputs)
cell = rnn_cell.LSTMCell( cell = rnn_cell.LSTMCell(
num_units, num_units,
@ -924,21 +934,34 @@ class LSTMTest(test.TestCase):
self.assertEqual(state_dynamic[0], state_dynamic.c) self.assertEqual(state_dynamic[0], state_dynamic.c)
self.assertEqual(state_dynamic[1], state_dynamic.h) self.assertEqual(state_dynamic[1], state_dynamic.h)
variables_lib.global_variables_initializer().run() if in_graph_mode:
variables_lib.global_variables_initializer().run()
input_value = np.random.randn(batch_size, input_size)
outputs_static = sess.run(
outputs_static, feed_dict={
inputs[0]: input_value
})
outputs_dynamic = sess.run(
outputs_dynamic, feed_dict={
inputs[0]: input_value
})
state_static = sess.run(
state_static, feed_dict={
inputs[0]: input_value
})
state_dynamic = sess.run(
state_dynamic, feed_dict={
inputs[0]: input_value
})
input_value = np.random.randn(batch_size, input_size) if in_graph_mode:
outputs_static_v = sess.run(outputs_static, self.assertAllEqual(outputs_static, outputs_dynamic)
feed_dict={inputs[0]: input_value}) else:
outputs_dynamic_v = sess.run(outputs_dynamic, self.assertAllEqual(
feed_dict={inputs[0]: input_value}) array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy())
self.assertAllEqual(outputs_static_v, outputs_dynamic_v) self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic))
state_static_v = sess.run(state_static,
feed_dict={inputs[0]: input_value})
state_dynamic_v = sess.run(state_dynamic,
feed_dict={inputs[0]: input_value})
self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v))
@test_util.run_in_graph_and_eager_modes()
def testDynamicRNNWithNestedTupleStates(self): def testDynamicRNNWithNestedTupleStates(self):
num_units = 3 num_units = 3
input_size = 5 input_size = 5
@ -946,13 +969,20 @@ class LSTMTest(test.TestCase):
num_proj = 4 num_proj = 4
max_length = 8 max_length = 8
sequence_length = [4, 6] sequence_length = [4, 6]
in_graph_mode = context.in_graph_mode()
with self.test_session(graph=ops_lib.Graph()) as sess: with self.test_session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer( initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed) -0.01, 0.01, seed=self._seed)
inputs = max_length * [ if in_graph_mode:
array_ops.placeholder( inputs = max_length * [
dtypes.float32, shape=(None, input_size)) array_ops.placeholder(
] dtypes.float32, shape=(None, input_size))
]
else:
inputs = max_length * [
constant_op.constant(
np.random.randn(batch_size, input_size).astype(np.float32))
]
inputs_c = array_ops.stack(inputs) inputs_c = array_ops.stack(inputs)
def _cell(i): def _cell(i):
@ -993,20 +1023,34 @@ class LSTMTest(test.TestCase):
sequence_length=sequence_length, sequence_length=sequence_length,
scope=scope) scope=scope)
variables_lib.global_variables_initializer().run() if in_graph_mode:
input_value = np.random.randn(batch_size, input_size)
variables_lib.global_variables_initializer().run()
outputs_static = sess.run(
outputs_static, feed_dict={
inputs[0]: input_value
})
outputs_dynamic = sess.run(
outputs_dynamic, feed_dict={
inputs[0]: input_value
})
state_static = sess.run(
nest.flatten(state_static), feed_dict={
inputs[0]: input_value
})
state_dynamic = sess.run(
nest.flatten(state_dynamic), feed_dict={
inputs[0]: input_value
})
input_value = np.random.randn(batch_size, input_size) if in_graph_mode:
outputs_static_v = sess.run(outputs_static, self.assertAllEqual(outputs_static, outputs_dynamic)
feed_dict={inputs[0]: input_value}) else:
outputs_dynamic_v = sess.run(outputs_dynamic, self.assertAllEqual(
feed_dict={inputs[0]: input_value}) array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy())
self.assertAllEqual(outputs_static_v, outputs_dynamic_v) state_static = [s.numpy() for s in nest.flatten(state_static)]
state_dynamic = [s.numpy() for s in nest.flatten(state_dynamic)]
state_static_v = sess.run(nest.flatten(state_static), self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic))
feed_dict={inputs[0]: input_value})
state_dynamic_v = sess.run(nest.flatten(state_dynamic),
feed_dict={inputs[0]: input_value})
self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v))
def _testDynamicEquivalentToStaticRNN(self, use_gpu, use_sequence_length): def _testDynamicEquivalentToStaticRNN(self, use_gpu, use_sequence_length):
time_steps = 8 time_steps = 8
@ -1015,21 +1059,22 @@ class LSTMTest(test.TestCase):
input_size = 5 input_size = 5
batch_size = 2 batch_size = 2
input_values = np.random.randn(time_steps, batch_size, input_size) input_values = np.random.randn(time_steps, batch_size, input_size).astype(
np.float32)
if use_sequence_length: if use_sequence_length:
sequence_length = np.random.randint(0, time_steps, size=batch_size) sequence_length = np.random.randint(0, time_steps, size=batch_size)
else: else:
sequence_length = None sequence_length = None
########### Step 1: Run static graph and generate readouts in_graph_mode = context.in_graph_mode()
with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess:
concat_inputs = array_ops.placeholder( # TODO(b/68017812): Eager ignores operation seeds, so we need to create a
dtypes.float32, shape=(time_steps, batch_size, input_size)) # single cell and reuse it across the static and dynamic RNNs. Remove this
inputs = array_ops.unstack(concat_inputs) # special case once is fixed.
if not in_graph_mode:
initializer = init_ops.random_uniform_initializer( initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed) -0.01, 0.01, seed=self._seed)
cell = rnn_cell.LSTMCell( cell = rnn_cell.LSTMCell(
num_units, num_units,
use_peepholes=True, use_peepholes=True,
@ -1037,63 +1082,85 @@ class LSTMTest(test.TestCase):
num_proj=num_proj, num_proj=num_proj,
state_is_tuple=False) state_is_tuple=False)
########### Step 1: Run static graph and generate readouts
with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess:
if in_graph_mode:
concat_inputs = array_ops.placeholder(
dtypes.float32, shape=(time_steps, batch_size, input_size))
else:
concat_inputs = constant_op.constant(input_values)
inputs = array_ops.unstack(concat_inputs)
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
# TODO(akshayka): Remove special case once b/68017812 is fixed.
if in_graph_mode:
cell = rnn_cell.LSTMCell(
num_units,
use_peepholes=True,
initializer=initializer,
num_proj=num_proj,
state_is_tuple=False)
with variable_scope.variable_scope("dynamic_scope"): with variable_scope.variable_scope("dynamic_scope"):
outputs_static, state_static = rnn.static_rnn( outputs_static, state_static = rnn.static_rnn(
cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32) cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32)
feeds = {concat_inputs: input_values} if in_graph_mode:
# Generate gradients and run sessions to obtain outputs
feeds = {concat_inputs: input_values}
# Initialize
variables_lib.global_variables_initializer().run(feed_dict=feeds)
# Generate gradients of sum of outputs w.r.t. inputs
static_gradients = gradients_impl.gradients(
outputs_static + [state_static], [concat_inputs])
# Generate gradients of individual outputs w.r.t. inputs
static_individual_gradients = nest.flatten([
gradients_impl.gradients(y, [concat_inputs])
for y in [outputs_static[0], outputs_static[-1], state_static]
])
# Generate gradients of individual variables w.r.t. inputs
trainable_variables = ops_lib.get_collection(
ops_lib.GraphKeys.TRAINABLE_VARIABLES)
assert len(trainable_variables) > 1, (
"Count of trainable variables: %d" % len(trainable_variables))
# pylint: disable=bad-builtin
static_individual_variable_gradients = nest.flatten([
gradients_impl.gradients(y, trainable_variables)
for y in [outputs_static[0], outputs_static[-1], state_static]
])
# Test forward pass
values_static = sess.run(outputs_static, feed_dict=feeds)
(state_value_static,) = sess.run((state_static,), feed_dict=feeds)
# Initialize # Test gradients to inputs and variables w.r.t. outputs & final state
variables_lib.global_variables_initializer().run(feed_dict=feeds) static_grad_values = sess.run(static_gradients, feed_dict=feeds)
# Generate gradients of sum of outputs w.r.t. inputs static_individual_grad_values = sess.run(static_individual_gradients,
static_gradients = gradients_impl.gradients( feed_dict=feeds)
outputs_static + [state_static], [concat_inputs])
# Generate gradients of individual outputs w.r.t. inputs static_individual_var_grad_values = sess.run(
static_individual_gradients = nest.flatten([ static_individual_variable_gradients, feed_dict=feeds)
gradients_impl.gradients(y, [concat_inputs])
for y in [outputs_static[0], outputs_static[-1], state_static]
])
# Generate gradients of individual variables w.r.t. inputs
trainable_variables = ops_lib.get_collection(
ops_lib.GraphKeys.TRAINABLE_VARIABLES)
assert len(trainable_variables) > 1, ("Count of trainable variables: %d" %
len(trainable_variables))
# pylint: disable=bad-builtin
static_individual_variable_gradients = nest.flatten([
gradients_impl.gradients(y, trainable_variables)
for y in [outputs_static[0], outputs_static[-1], state_static]
])
# Test forward pass
values_static = sess.run(outputs_static, feed_dict=feeds)
(state_value_static,) = sess.run((state_static,), feed_dict=feeds)
# Test gradients to inputs and variables w.r.t. outputs & final state
static_grad_values = sess.run(static_gradients, feed_dict=feeds)
static_individual_grad_values = sess.run(static_individual_gradients,
feed_dict=feeds)
static_individual_var_grad_values = sess.run(
static_individual_variable_gradients, feed_dict=feeds)
########## Step 2: Run dynamic graph and generate readouts ########## Step 2: Run dynamic graph and generate readouts
with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess:
concat_inputs = array_ops.placeholder( if in_graph_mode:
dtypes.float32, shape=(time_steps, batch_size, input_size)) concat_inputs = array_ops.placeholder(
inputs = array_ops.unstack(concat_inputs) dtypes.float32, shape=(time_steps, batch_size, input_size))
else:
concat_inputs = constant_op.constant(input_values)
initializer = init_ops.random_uniform_initializer( initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed) -0.01, 0.01, seed=self._seed)
cell = rnn_cell.LSTMCell( # TODO(akshayka): Remove this special case once b/68017812 is
num_units, # fixed.
use_peepholes=True, if in_graph_mode:
initializer=initializer, cell = rnn_cell.LSTMCell(
num_proj=num_proj, num_units,
state_is_tuple=False) use_peepholes=True,
initializer=initializer,
num_proj=num_proj,
state_is_tuple=False)
with variable_scope.variable_scope("dynamic_scope"): with variable_scope.variable_scope("dynamic_scope"):
outputs_dynamic, state_dynamic = rnn.dynamic_rnn( outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
@ -1104,72 +1171,83 @@ class LSTMTest(test.TestCase):
dtype=dtypes.float32) dtype=dtypes.float32)
split_outputs_dynamic = array_ops.unstack(outputs_dynamic, time_steps) split_outputs_dynamic = array_ops.unstack(outputs_dynamic, time_steps)
feeds = {concat_inputs: input_values} if in_graph_mode:
feeds = {concat_inputs: input_values}
# Initialize # Initialize
variables_lib.global_variables_initializer().run(feed_dict=feeds) variables_lib.global_variables_initializer().run(feed_dict=feeds)
# Generate gradients of sum of outputs w.r.t. inputs # Generate gradients of sum of outputs w.r.t. inputs
dynamic_gradients = gradients_impl.gradients( dynamic_gradients = gradients_impl.gradients(
split_outputs_dynamic + [state_dynamic], [concat_inputs]) split_outputs_dynamic + [state_dynamic], [concat_inputs])
# Generate gradients of several individual outputs w.r.t. inputs # Generate gradients of several individual outputs w.r.t. inputs
dynamic_individual_gradients = nest.flatten([ dynamic_individual_gradients = nest.flatten([
gradients_impl.gradients(y, [concat_inputs]) gradients_impl.gradients(y, [concat_inputs])
for y in for y in
[split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic] [split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic]
]) ])
# Generate gradients of individual variables w.r.t. inputs # Generate gradients of individual variables w.r.t. inputs
trainable_variables = ops_lib.get_collection( trainable_variables = ops_lib.get_collection(
ops_lib.GraphKeys.TRAINABLE_VARIABLES) ops_lib.GraphKeys.TRAINABLE_VARIABLES)
assert len(trainable_variables) > 1, ("Count of trainable variables: %d" % assert len(trainable_variables) > 1, (
len(trainable_variables)) "Count of trainable variables: %d" % len(trainable_variables))
dynamic_individual_variable_gradients = nest.flatten([ dynamic_individual_variable_gradients = nest.flatten([
gradients_impl.gradients(y, trainable_variables) gradients_impl.gradients(y, trainable_variables)
for y in for y in
[split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic] [split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic]
]) ])
# Test forward pass # Test forward pass
values_dynamic = sess.run(split_outputs_dynamic, feed_dict=feeds) values_dynamic = sess.run(split_outputs_dynamic, feed_dict=feeds)
(state_value_dynamic,) = sess.run((state_dynamic,), feed_dict=feeds) (state_value_dynamic,) = sess.run((state_dynamic,), feed_dict=feeds)
# Test gradients to inputs and variables w.r.t. outputs & final state # Test gradients to inputs and variables w.r.t. outputs & final state
dynamic_grad_values = sess.run(dynamic_gradients, feed_dict=feeds) dynamic_grad_values = sess.run(dynamic_gradients, feed_dict=feeds)
dynamic_individual_grad_values = sess.run(dynamic_individual_gradients, dynamic_individual_grad_values = sess.run(dynamic_individual_gradients,
feed_dict=feeds) feed_dict=feeds)
dynamic_individual_var_grad_values = sess.run( dynamic_individual_var_grad_values = sess.run(
dynamic_individual_variable_gradients, feed_dict=feeds) dynamic_individual_variable_gradients, feed_dict=feeds)
######### Step 3: Comparisons ######### Step 3: Comparisons
if not in_graph_mode:
values_static = outputs_static
values_dynamic = split_outputs_dynamic
state_value_static = state_static
state_value_dynamic = state_dynamic
self.assertEqual(len(values_static), len(values_dynamic)) self.assertEqual(len(values_static), len(values_dynamic))
for (value_static, value_dynamic) in zip(values_static, values_dynamic): for (value_static, value_dynamic) in zip(values_static, values_dynamic):
self.assertAllEqual(value_static, value_dynamic) self.assertAllEqual(value_static, value_dynamic)
self.assertAllEqual(state_value_static, state_value_dynamic) self.assertAllEqual(state_value_static, state_value_dynamic)
self.assertAllEqual(static_grad_values, dynamic_grad_values) if in_graph_mode:
self.assertEqual( self.assertAllEqual(static_grad_values, dynamic_grad_values)
len(static_individual_grad_values), len(dynamic_individual_grad_values))
self.assertEqual(
len(static_individual_var_grad_values),
len(dynamic_individual_var_grad_values))
for i, (a, b) in enumerate( self.assertEqual(
zip(static_individual_grad_values, dynamic_individual_grad_values)): len(static_individual_grad_values),
tf_logging.info("Comparing individual gradients iteration %d" % i) len(dynamic_individual_grad_values))
self.assertAllEqual(a, b) self.assertEqual(
len(static_individual_var_grad_values),
len(dynamic_individual_var_grad_values))
for i, (a, b) in enumerate( for i, (a, b) in enumerate(
zip(static_individual_var_grad_values, zip(static_individual_grad_values, dynamic_individual_grad_values)):
dynamic_individual_var_grad_values)): tf_logging.info("Comparing individual gradients iteration %d" % i)
tf_logging.info("Comparing individual variable gradients iteration %d" % self.assertAllEqual(a, b)
i)
self.assertAllEqual(a, b)
for i, (a, b) in enumerate(
zip(static_individual_var_grad_values,
dynamic_individual_var_grad_values)):
tf_logging.info("Comparing individual variable gradients iteration %d" %
i)
self.assertAllEqual(a, b)
@test_util.run_in_graph_and_eager_modes()
def testDynamicEquivalentToStaticRNN(self): def testDynamicEquivalentToStaticRNN(self):
self._testDynamicEquivalentToStaticRNN( self._testDynamicEquivalentToStaticRNN(
use_gpu=False, use_sequence_length=False) use_gpu=False, use_sequence_length=False)

View File

@ -112,7 +112,7 @@ struct GatherTree<CPUDevice, int32> {
const int32 max_time = parent_ids.dimension(0); const int32 max_time = parent_ids.dimension(0);
const int32 batch_size = parent_ids.dimension(1); const int32 batch_size = parent_ids.dimension(1);
const int32 beam_width = parent_ids.dimension(2); const int32 beam_width = parent_ids.dimension(2);
beams.setConstant(-1); beams.setConstant(end_token);
auto DoWork = [&, ctx, end_token](int start_batch_beam, auto DoWork = [&, ctx, end_token](int start_batch_beam,
int limit_batch_beam) { int limit_batch_beam) {
@ -138,10 +138,13 @@ struct GatherTree<CPUDevice, int32> {
beams(level, batch, beam) = step_ids(level, batch, parent); beams(level, batch, beam) = step_ids(level, batch, parent);
parent = parent_ids(level, batch, parent); parent = parent_ids(level, batch, parent);
} }
// Not necessary when using a BeamSearchDecoder, but necessary
// when a user feeds in possibly broken trajectory (i.e., non-eos
// entries in a beam following eos entries).
bool finished = false; bool finished = false;
for (int32 time = 0; time < max_seq_len_b; ++time) { for (int32 time = 0; time < max_seq_len_b; ++time) {
if (finished) { if (finished) {
beams(time, batch, beam) = -1; beams(time, batch, beam) = end_token;
} else if (beams(time, batch, beam) == end_token) { } else if (beams(time, batch, beam) == end_token) {
finished = true; finished = true;
} }

View File

@ -46,24 +46,31 @@ __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
const int32 initial_beam_ix = GET_IX(max_seq_len_b - 1, beam); const int32 initial_beam_ix = GET_IX(max_seq_len_b - 1, beam);
beams[initial_beam_ix] = ldg(step_ids + initial_beam_ix); beams[initial_beam_ix] = ldg(step_ids + initial_beam_ix);
int32 parent = ldg(parent_ids + initial_beam_ix); int32 parent = ldg(parent_ids + initial_beam_ix);
bool found_bad = false;
for (int32 level = max_seq_len_b - 2; level >= 0; --level) { for (int32 level = max_seq_len_b - 2; level >= 0; --level) {
const int32 level_beam_ix = GET_IX(level, beam); const int32 level_beam_ix = GET_IX(level, beam);
const int32 level_parent_ix = GET_IX(level, parent); const int32 level_parent_ix = GET_IX(level, parent);
if (parent < 0 || parent > beam_width) { if (parent < 0 || parent > beam_width) {
beams[level_beam_ix] = -1; beams[level_beam_ix] = -1;
parent = -1; parent = -1;
found_bad = true;
} else { } else {
beams[level_beam_ix] = ldg(step_ids + level_parent_ix); beams[level_beam_ix] = ldg(step_ids + level_parent_ix);
parent = ldg(parent_ids + level_parent_ix); parent = ldg(parent_ids + level_parent_ix);
} }
} }
bool finished = false; // Not necessary when using a BeamSearchDecoder, but necessary
for (int32 time = 0; time < max_seq_len_b; ++time) { // when a user feeds in possibly broken trajectory (i.e., non-eos
const int32 level_beam_ix = GET_IX(time, beam); // entries in a beam following eos entries).
if (finished) { if (!found_bad) {
beams[level_beam_ix] = -1; bool finished = false;
} else if (beams[level_beam_ix] == end_token) { for (int32 time = 0; time < max_seq_len_b; ++time) {
finished = true; const int32 level_beam_ix = GET_IX(time, beam);
if (finished) {
beams[level_beam_ix] = end_token;
} else if (beams[level_beam_ix] == end_token) {
finished = true;
}
} }
} }
#undef GET_IX #undef GET_IX
@ -80,8 +87,8 @@ struct GatherTree<GPUDevice, T> {
const int32 max_time = parent_ids.dimension(0); const int32 max_time = parent_ids.dimension(0);
const int32 batch_size = parent_ids.dimension(1); const int32 batch_size = parent_ids.dimension(1);
const int32 beam_width = parent_ids.dimension(2); const int32 beam_width = parent_ids.dimension(2);
// First kernel launch to zero things out // First kernel launch to "zero" things out
beams.device(d) = beams.constant(T(-1)); beams.device(d) = beams.constant(end_token);
CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d); CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d);
// clang-format off // clang-format off

View File

@ -53,11 +53,14 @@ REGISTER_OP("GatherTree")
.Doc(R"doc( .Doc(R"doc(
Calculates the full beams from the per-step ids and parent beam ids. Calculates the full beams from the per-step ids and parent beam ids.
This op implements the following mathematical equations: On CPU, if an out of bound parent id is found, an error is returned.
On GPU, if an out of bound parent id is found, a -1 is stored in the
corresponding output value and the execution for that beam returns early.
```python For a given beam, past the time step containing the first decoded `end_token`
TODO(ebrevdo): fill in all values are filled in with `end_token`.
```
TODO(ebrevdo): fill in the remainder of this docstring.
step_ids: `[max_time, batch_size, beam_width]`. step_ids: `[max_time, batch_size, beam_width]`.
parent_ids: `[max_time, batch_size, beam_width]`. parent_ids: `[max_time, batch_size, beam_width]`.

View File

@ -36,24 +36,26 @@ class GatherTreeTest(test.TestCase):
def testGatherTreeOne(self): def testGatherTreeOne(self):
# (max_time = 4, batch_size = 1, beams = 3) # (max_time = 4, batch_size = 1, beams = 3)
end_token = 10
step_ids = _transpose_batch_time( step_ids = _transpose_batch_time(
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
parent_ids = _transpose_batch_time( parent_ids = _transpose_batch_time(
[[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]]) [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]])
max_sequence_lengths = [3] max_sequence_lengths = [3]
expected_result = _transpose_batch_time( expected_result = _transpose_batch_time([[[2, 2, 2], [6, 5, 6], [7, 8, 9],
[[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) [10, 10, 10]]])
beams = beam_search_ops.gather_tree( beams = beam_search_ops.gather_tree(
step_ids=step_ids, step_ids=step_ids,
parent_ids=parent_ids, parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths, max_sequence_lengths=max_sequence_lengths,
end_token=10) end_token=end_token)
with self.test_session(use_gpu=True): with self.test_session(use_gpu=True):
self.assertAllEqual(expected_result, beams.eval()) self.assertAllEqual(expected_result, beams.eval())
def testBadParentValuesOnCPU(self): def testBadParentValuesOnCPU(self):
# (batch_size = 1, max_time = 4, beams = 3) # (batch_size = 1, max_time = 4, beams = 3)
# bad parent in beam 1 time 1 # bad parent in beam 1 time 1
end_token = 10
step_ids = _transpose_batch_time( step_ids = _transpose_batch_time(
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
parent_ids = _transpose_batch_time( parent_ids = _transpose_batch_time(
@ -64,7 +66,7 @@ class GatherTreeTest(test.TestCase):
step_ids=step_ids, step_ids=step_ids,
parent_ids=parent_ids, parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths, max_sequence_lengths=max_sequence_lengths,
end_token=10) end_token=end_token)
with self.test_session(): with self.test_session():
with self.assertRaisesOpError( with self.assertRaisesOpError(
r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"): r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"):
@ -77,19 +79,20 @@ class GatherTreeTest(test.TestCase):
return return
# (max_time = 4, batch_size = 1, beams = 3) # (max_time = 4, batch_size = 1, beams = 3)
# bad parent in beam 1 time 1; appears as a negative index at time 0 # bad parent in beam 1 time 1; appears as a negative index at time 0
end_token = 10
step_ids = _transpose_batch_time( step_ids = _transpose_batch_time(
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
parent_ids = _transpose_batch_time( parent_ids = _transpose_batch_time(
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
max_sequence_lengths = [3] max_sequence_lengths = [3]
expected_result = _transpose_batch_time( expected_result = _transpose_batch_time([[[2, -1, 2], [6, 5, 6], [7, 8, 9],
[[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) [10, 10, 10]]])
with ops.device("/device:GPU:0"): with ops.device("/device:GPU:0"):
beams = beam_search_ops.gather_tree( beams = beam_search_ops.gather_tree(
step_ids=step_ids, step_ids=step_ids,
parent_ids=parent_ids, parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths, max_sequence_lengths=max_sequence_lengths,
end_token=10) end_token=end_token)
with self.test_session(use_gpu=True): with self.test_session(use_gpu=True):
self.assertAllEqual(expected_result, beams.eval()) self.assertAllEqual(expected_result, beams.eval())
@ -115,24 +118,24 @@ class GatherTreeTest(test.TestCase):
self.assertEqual((max_time, batch_size, beam_width), beams.shape) self.assertEqual((max_time, batch_size, beam_width), beams.shape)
beams_value = beams.eval() beams_value = beams.eval()
for b in range(batch_size): for b in range(batch_size):
# Past max_sequence_lengths[b], we emit all -1s. # Past max_sequence_lengths[b], we emit all end tokens.
b_value = beams_value[max_sequence_lengths[b]:, b, :] b_value = beams_value[max_sequence_lengths[b]:, b, :]
self.assertAllClose(b_value, -1. * np.ones_like(b_value)) self.assertAllClose(b_value, end_token * np.ones_like(b_value))
for batch, beam in itertools.product( for batch, beam in itertools.product(
range(batch_size), range(beam_width)): range(batch_size), range(beam_width)):
v = np.squeeze(beams_value[:, batch, beam]) v = np.squeeze(beams_value[:, batch, beam])
if end_token in v: if end_token in v:
found_bad = np.where(v == -1)[0]
self.assertEqual(0, len(found_bad))
found = np.where(v == end_token)[0] found = np.where(v == end_token)[0]
# Should be up to 1 instance of end_token per beam. found = found[0] # First occurrence of end_token.
self.assertEqual(len(found), 1)
found = found[0]
# If an end_token is found, everything before it should be a # If an end_token is found, everything before it should be a
# valid id and everything after it should be -1. # valid id and everything after it should be -1.
if found > 0: if found > 0:
self.assertAllEqual( self.assertAllEqual(
v[:found - 1] >= 0, np.ones_like(v[:found - 1], dtype=bool)) v[:found - 1] >= 0, np.ones_like(v[:found - 1], dtype=bool))
self.assertAllClose( self.assertAllClose(v[found + 1:],
v[found + 1:], -1 * np.ones_like(v[found + 1:])) end_token * np.ones_like(v[found + 1:]))
if __name__ == "__main__": if __name__ == "__main__":

File diff suppressed because it is too large Load Diff

View File

@ -31,6 +31,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.layers import utils
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import data_flow_ops
@ -47,7 +48,6 @@ _dtypes = input_py._dtypes
_store_sparse_tensors = input_py._store_sparse_tensors _store_sparse_tensors = input_py._store_sparse_tensors
_validate_keep_input = input_py._validate_keep_input _validate_keep_input = input_py._validate_keep_input
_shapes = input_py._shapes _shapes = input_py._shapes
_smart_cond = input_py._smart_cond
_which_queue = input_py._which_queue _which_queue = input_py._which_queue
# pylint: enable=protected-access # pylint: enable=protected-access
@ -239,7 +239,7 @@ def bucket(tensors,
] ]
return control_flow_ops.group(*enqueues, name="group_enqueues") return control_flow_ops.group(*enqueues, name="group_enqueues")
maybe_enqueue = _smart_cond( maybe_enqueue = utils.smart_cond(
keep_input, keep_input,
enqueue_which, enqueue_which,
control_flow_ops.no_op) control_flow_ops.no_op)

View File

@ -1411,7 +1411,7 @@ cc_library(
hdrs = LIB_INTERNAL_PUBLIC_HEADERS, hdrs = LIB_INTERNAL_PUBLIC_HEADERS,
copts = tf_copts(), copts = tf_copts(),
defines = tf_additional_lib_defines() + [ defines = tf_additional_lib_defines() + [
"SNAPPY", "TF_USE_SNAPPY",
] + tf_additional_verbs_lib_defines() + ] + tf_additional_verbs_lib_defines() +
tf_additional_mpi_lib_defines() + tf_additional_mpi_lib_defines() +
tf_additional_gdr_lib_defines(), tf_additional_gdr_lib_defines(),

View File

@ -51,7 +51,8 @@ message ApiDef {
// endpoints are deprecated). // endpoints are deprecated).
message Endpoint { message Endpoint {
// Name should be either like "CamelCaseName" or // Name should be either like "CamelCaseName" or
// "Package.CamelCaseName". // "Package.CamelCaseName". Client-language-specific ApiDefs may
// use a snake_case convention instead of CamelCase.
string name = 1; string name = 1;
// First GraphDef version at which the op is disallowed. // First GraphDef version at which the op is disallowed.
@ -74,7 +75,7 @@ message ApiDef {
} }
repeated Arg in_arg = 4; repeated Arg in_arg = 4;
repeated Arg out_arg = 5; repeated Arg out_arg = 5;
// List of post-rename in_arg names to specify new argument order. // List of original in_arg names to specify new argument order.
// Length of arg_order should be either empty to keep current order // Length of arg_order should be either empty to keep current order
// or match size of in_arg. // or match size of in_arg.
repeated string arg_order = 11; repeated string arg_order = 11;

View File

@ -412,6 +412,8 @@ void InitApiDefFromOpDef(const OpDef& op_def, ApiDef* api_def) {
api_in_arg->set_name(op_in_arg.name()); api_in_arg->set_name(op_in_arg.name());
api_in_arg->set_rename_to(op_in_arg.name()); api_in_arg->set_rename_to(op_in_arg.name());
api_in_arg->set_description(op_in_arg.description()); api_in_arg->set_description(op_in_arg.description());
*api_def->add_arg_order() = op_in_arg.name();
} }
for (const auto& op_out_arg : op_def.output_arg()) { for (const auto& op_out_arg : op_def.output_arg()) {
auto* api_out_arg = api_def->add_out_arg(); auto* api_out_arg = api_def->add_out_arg();
@ -503,6 +505,22 @@ Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) {
} }
// Merge arg order // Merge arg order
if (new_api_def.arg_order_size() > 0) { if (new_api_def.arg_order_size() > 0) {
// Validate that new arg_order is correct.
if (new_api_def.arg_order_size() != base_api_def->arg_order_size()) {
return errors::FailedPrecondition(
"Invalid number of arguments ", new_api_def.arg_order_size(), " for ",
base_api_def->graph_op_name(),
". Expected: ", base_api_def->arg_order_size());
}
if (!std::is_permutation(new_api_def.arg_order().begin(),
new_api_def.arg_order().end(),
base_api_def->arg_order().begin())) {
return errors::FailedPrecondition(
"Invalid arg_order: ", str_util::Join(new_api_def.arg_order(), ", "),
" for ", base_api_def->graph_op_name(),
". All elements in arg_order override must match base arg_order: ",
str_util::Join(base_api_def->arg_order(), ", "));
}
base_api_def->clear_arg_order(); base_api_def->clear_arg_order();
std::copy( std::copy(
new_api_def.arg_order().begin(), new_api_def.arg_order().end(), new_api_def.arg_order().begin(), new_api_def.arg_order().end(),

View File

@ -207,6 +207,8 @@ attr {
name: "attr_a" name: "attr_a"
rename_to: "attr_a" rename_to: "attr_a"
} }
arg_order: "arg_a"
arg_order: "arg_b"
)"; )";
OpList op_list; OpList op_list;
protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT
@ -331,8 +333,8 @@ op {
name: "arg_c" name: "arg_c"
rename_to: "arg_cc" rename_to: "arg_cc"
} }
arg_order: "arg_aa"
arg_order: "arg_b" arg_order: "arg_b"
arg_order: "arg_a"
} }
)"; )";
OpList op_list; OpList op_list;
@ -351,8 +353,8 @@ op {
EXPECT_EQ("arg_cc", api_def->out_arg(0).rename_to()); EXPECT_EQ("arg_cc", api_def->out_arg(0).rename_to());
ASSERT_EQ(2, api_def->arg_order_size()); ASSERT_EQ(2, api_def->arg_order_size());
EXPECT_EQ("arg_aa", api_def->arg_order(0)); EXPECT_EQ("arg_b", api_def->arg_order(0));
EXPECT_EQ("arg_b", api_def->arg_order(1)); EXPECT_EQ("arg_a", api_def->arg_order(1));
} }
TEST(OpGenLibTest, ApiDefOverrideDescriptions) { TEST(OpGenLibTest, ApiDefOverrideDescriptions) {
@ -411,5 +413,47 @@ op {
auto status = api_map.LoadApiDef(api_def1); auto status = api_map.LoadApiDef(api_def1);
ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code()); ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
} }
TEST(OpGenLibTest, ApiDefInvalidArgOrder) {
const string api_def1 = R"(
op {
graph_op_name: "testop"
arg_order: "arg_a"
arg_order: "unexpected_arg"
}
)";
const string api_def2 = R"(
op {
graph_op_name: "testop"
arg_order: "arg_a"
}
)";
const string api_def3 = R"(
op {
graph_op_name: "testop"
arg_order: "arg_a"
arg_order: "arg_a"
}
)";
OpList op_list;
protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT
ApiDefMap api_map(op_list);
TF_CHECK_OK(api_map.LoadApiDef(kTestApiDef));
// Loading with incorrect arg name in arg_order should fail.
auto status = api_map.LoadApiDef(api_def1);
ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
// Loading with incorrect number of args in arg_order should fail.
status = api_map.LoadApiDef(api_def2);
ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
// Loading with the same argument twice in arg_order should fail.
status = api_map.LoadApiDef(api_def3);
ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -1068,10 +1068,16 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
refiner->set_graph_def_version( refiner->set_graph_def_version(
std::min(refiner->graph_def_version(), gdef.versions().producer())); std::min(refiner->graph_def_version(), gdef.versions().producer()));
return GraphConstructor::Construct( if (results == nullptr) {
opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner, return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
&results->return_tensors, &results->return_nodes, &gdef.library(), g, refiner, nullptr,
&results->unused_input_map_keys); nullptr, nullptr);
} else {
return GraphConstructor::Construct(
opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner,
&results->return_tensors, &results->return_nodes,
&results->unused_input_map_keys);
}
} }
void CopyGraph(const Graph& src, Graph* dest) { void CopyGraph(const Graph& src, Graph* dest) {

View File

@ -450,12 +450,16 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
} }
// Optimize the graph (function inlining, l1 optimizations, etc). // Optimize the graph (function inlining, l1 optimizations, etc).
VLOG(1) << "Number of nodes in graph before OptimizeGraph: "
<< new_item->graph.node_size();
Status optimize_status = Status optimize_status =
OptimizeGraph(new_item->graph, &new_item->graph, cfg); OptimizeGraph(new_item->graph, &new_item->graph, cfg);
if (!optimize_status.ok()) { if (!optimize_status.ok()) {
LOG(ERROR) << "Graph preprocessing failed: " << optimize_status; LOG(ERROR) << "Graph preprocessing failed: " << optimize_status;
return nullptr; return nullptr;
} }
VLOG(1) << "Number of nodes in graph after OptimizeGraph: "
<< new_item->graph.node_size();
if (cfg.prune_graph) { if (cfg.prune_graph) {
VLOG(1) << "Pruning graph..."; VLOG(1) << "Pruning graph...";
@ -464,7 +468,8 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
LOG(ERROR) << "Pruning failed: " << status.error_message(); LOG(ERROR) << "Pruning failed: " << status.error_message();
return nullptr; return nullptr;
} }
VLOG(1) << "Pruning ran succesfully."; VLOG(1) << "Number of nodes in graph after pruning: "
<< new_item->graph.node_size();
} }
// Validate feed, fetch and init nodes // Validate feed, fetch and init nodes

View File

@ -18,6 +18,9 @@ limitations under the License.
#ifdef INTEL_MKL #ifdef INTEL_MKL
#define EIGEN_USE_THREADS #define EIGEN_USE_THREADS
#include "tensorflow/core/framework/numeric_types.h"
#define MKL_Complex8 tensorflow::complex64
#define MKL_Complex16 tensorflow::complex128
#include "mkl_trans.h" #include "mkl_trans.h"
#include "tensorflow/core/kernels/transpose_functor.h" #include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/kernels/transpose_op.h" #include "tensorflow/core/kernels/transpose_op.h"
@ -41,7 +44,7 @@ namespace tensorflow {
namespace { namespace {
template <typename T> template <typename T>
void MKLTranspose2D(const char trans, const Tensor& in, Tensor* out) {} Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out);
// Documentation here: https://software.intel.com/en-us/node/520863 // Documentation here: https://software.intel.com/en-us/node/520863
// Parameters: (ordering:row-major, operation:transpose, num_rows, num_cols, // Parameters: (ordering:row-major, operation:transpose, num_rows, num_cols,
@ -54,70 +57,73 @@ void MKLTranspose2D(const char trans, const Tensor& in, Tensor* out) {}
mkl_##PREFIX##omatcopy('R', trans, in.dim_size(0), in.dim_size(1), 1, \ mkl_##PREFIX##omatcopy('R', trans, in.dim_size(0), in.dim_size(1), 1, \
in.flat<T>().data(), in.dim_size(1), \ in.flat<T>().data(), in.dim_size(1), \
out->flat<T>().data(), in.dim_size(0)); \ out->flat<T>().data(), in.dim_size(0)); \
return Status::OK(); return Status::OK(); \
} }
INSTANTIATE(float, s) INSTANTIATE(float, s)
INSTANTIATE(double, d) INSTANTIATE(double, d)
INSTANTIATE(complex64, c) INSTANTIATE(complex64, c)
INSTANTIATE(complex128, z) INSTANTIATE(complex128, z)
#undef INSTANTIATE #undef INSTANTIATE
static const char kMKLTranspose = 'T'; static const char kMKLTranspose = 'T';
static const char kMKLConjugateTranspose = 'C'; static const char kMKLConjugateTranspose = 'C';
} // namespace tensorflow } // namespace
Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
gtl::ArraySlice<int32> perm, gtl::ArraySlice<int32> perm,
Tensor* out) { Tensor* out) {
if (in.dims() == 2) { if (in.dims() == 2) {
switch (in.dtype()) { if (perm[0] == 0 && perm[1] == 1) {
case DT_FLOAT: return Status::OK();
return MKLTranspose2D<float>(kMKLTranspose, in, out);
case DT_DOUBLE:
return MKLTranspose2D<double>(kMKLTranspose, in, out);
case DT_COMPLEX64:
return MKLTranspose2D<complex64>(kMKLTranspose, in, out);
case DT_COMPLEX128:
return MKLTranspose2D<complex128>(kMKLTranspose, in, out);
default:
break;
}
} }
// Fallback to eigen if transpose parameters not supported by MKL switch (in.dtype()) {
typedef Eigen::ThreadPoolDevice CPUDevice; case DT_FLOAT:
return ::tensorflow::DoTranspose(ctx->eigen_device<CPUDevice>(), in, perm, return MKLTranspose2D<float>(kMKLTranspose, in, out);
out); case DT_DOUBLE:
} return MKLTranspose2D<double>(kMKLTranspose, in, out);
case DT_COMPLEX64:
Status MklConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx, return MKLTranspose2D<complex64>(kMKLTranspose, in, out);
const Tensor& in, case DT_COMPLEX128:
gtl::ArraySlice<int32> perm, return MKLTranspose2D<complex128>(kMKLTranspose, in, out);
Tensor* out) { default:
if (in.dims() == 2) { break;
// TODO(rmlarsen): By setting lda and ldb, we could use the MKL kernels
// for any transpose that can be reduced to swapping the last two
// dimensions in a rank-3 tensor. We can even run each outer dimension in
// a separate thread.
switch (in.dtype()) {
case DT_FLOAT:
return MKLTranspose2D<float>(kMKLTranspose, in, out);
case DT_DOUBLE:
return MKLTranspose2D<double>(kMKLTranspose, in, out);
case DT_COMPLEX64:
return MKLTranspose2D<complex64>(kMKLConjugateTranspose, in, out);
case DT_COMPLEX128:
return MKLTranspose2D<complex128>(kMKLConjugateTranspose, in, out);
default:
break;
}
} }
// Fallback to eigen if transpose parameters not supported by MKL
typedef Eigen::ThreadPoolDevice CPUDevice;
return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<CPUDevice>(),
in, perm, out);
} }
// Fallback to eigen if transpose parameters not supported by MKL
typedef Eigen::ThreadPoolDevice CPUDevice;
return ::tensorflow::DoTranspose(ctx->eigen_device<CPUDevice>(), in, perm,
out);
}
Status MklConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
const Tensor& in,
gtl::ArraySlice<int32> perm,
Tensor* out) {
if (in.dims() == 2 && perm[0] == 1 && perm[1] == 0) {
// TODO(rmlarsen): By setting lda and ldb, we could use the MKL kernels
// for any transpose that can be reduced to swapping the last two
// dimensions in a rank-3 tensor. We can even run each outer dimension in
// a separate thread.
switch (in.dtype()) {
case DT_FLOAT:
return MKLTranspose2D<float>(kMKLTranspose, in, out);
case DT_DOUBLE:
return MKLTranspose2D<double>(kMKLTranspose, in, out);
case DT_COMPLEX64:
return MKLTranspose2D<complex64>(kMKLConjugateTranspose, in, out);
case DT_COMPLEX128:
return MKLTranspose2D<complex128>(kMKLConjugateTranspose, in, out);
default:
break;
}
}
// Fallback to eigen if transpose parameters not supported by MKL
typedef Eigen::ThreadPoolDevice CPUDevice;
return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<CPUDevice>(), in,
perm, out);
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -201,17 +201,26 @@ Status DoTransposeImpl(const Device& d, const Tensor& in,
case DT_COMPLEX64: case DT_COMPLEX64:
if (conjugate) { if (conjugate) {
Transpose<Device, complex64, true>::run(d, in, perm, out); #if defined(__ANDROID__) and !defined(__clang__)
// Workaround for GCC compiler bug in Android toolchain.
return errors::Unimplemented(
"Conjugate transpose of complex64 not supported for GCC on "
"Android.");
#else
Transpose<Device, complex64, /*conjugate=*/true>::run(d, in, perm, out);
#endif
} else { } else {
Transpose<Device, complex64, false>::run(d, in, perm, out); Transpose<Device, uint64>::run(d, in, perm, out);
} }
break; break;
case DT_COMPLEX128: case DT_COMPLEX128:
if (conjugate) { if (conjugate) {
Transpose<Device, complex128, true>::run(d, in, perm, out); Transpose<Device, complex128, /*conjugate=*/true>::run(d, in, perm,
out);
} else { } else {
Transpose<Device, complex128, false>::run(d, in, perm, out); Transpose<Device, complex128, /*conjugate=*/false>::run(d, in, perm,
out);
} }
break; break;

View File

@ -467,7 +467,7 @@ def tf_additional_core_deps():
"//conditions:default": [], "//conditions:default": [],
}) + select({ }) + select({
"//tensorflow:with_s3_support": [ "//tensorflow:with_s3_support": [
"//tensorflow/contrib/s3:s3_file_system", "//tensorflow/core/platform/s3:s3_file_system",
], ],
"//conditions:default": [], "//conditions:default": [],
}) })

View File

@ -29,7 +29,7 @@ limitations under the License.
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <unistd.h> #include <unistd.h>
#ifdef SNAPPY #ifdef TF_USE_SNAPPY
#include "snappy.h" #include "snappy.h"
#endif #endif
#if (defined(__APPLE__) && defined(__MACH__)) || defined(__FreeBSD__) #if (defined(__APPLE__) && defined(__MACH__)) || defined(__FreeBSD__)
@ -126,7 +126,7 @@ void AdjustFilenameForLogging(string* filename) {
} }
bool Snappy_Compress(const char* input, size_t length, string* output) { bool Snappy_Compress(const char* input, size_t length, string* output) {
#ifdef SNAPPY #ifdef TF_USE_SNAPPY
output->resize(snappy::MaxCompressedLength(length)); output->resize(snappy::MaxCompressedLength(length));
size_t outlen; size_t outlen;
snappy::RawCompress(input, length, &(*output)[0], &outlen); snappy::RawCompress(input, length, &(*output)[0], &outlen);
@ -139,7 +139,7 @@ bool Snappy_Compress(const char* input, size_t length, string* output) {
bool Snappy_GetUncompressedLength(const char* input, size_t length, bool Snappy_GetUncompressedLength(const char* input, size_t length,
size_t* result) { size_t* result) {
#ifdef SNAPPY #ifdef TF_USE_SNAPPY
return snappy::GetUncompressedLength(input, length, result); return snappy::GetUncompressedLength(input, length, result);
#else #else
return false; return false;
@ -147,7 +147,7 @@ bool Snappy_GetUncompressedLength(const char* input, size_t length,
} }
bool Snappy_Uncompress(const char* input, size_t length, char* output) { bool Snappy_Uncompress(const char* input, size_t length, char* output) {
#ifdef SNAPPY #ifdef TF_USE_SNAPPY
return snappy::RawUncompress(input, length, output); return snappy::RawUncompress(input, length, output);
#else #else
return false; return false;

View File

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/contrib/s3/s3_crypto.h" #include "tensorflow/core/platform/s3/s3_crypto.h"
#include <openssl/hmac.h> #include <openssl/hmac.h>
#include <openssl/sha.h> #include <openssl/sha.h>

View File

@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/contrib/s3/s3_file_system.h"
#include "tensorflow/contrib/s3/s3_crypto.h"
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/s3/s3_file_system.h"
#include "tensorflow/core/platform/s3/s3_crypto.h"
#include <aws/core/Aws.h> #include <aws/core/Aws.h>
#include <aws/core/utils/FileSystemUtils.h> #include <aws/core/utils/FileSystemUtils.h>

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/contrib/s3/s3_file_system.h" #include "tensorflow/core/platform/s3/s3_file_system.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/lib/gtl/stl_util.h"

View File

@ -20,7 +20,7 @@ limitations under the License.
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#ifdef SNAPPY #ifdef TF_USE_SNAPPY
#include "snappy.h" #include "snappy.h"
#endif #endif
@ -118,7 +118,7 @@ void AdjustFilenameForLogging(string* filename) {
} }
bool Snappy_Compress(const char* input, size_t length, string* output) { bool Snappy_Compress(const char* input, size_t length, string* output) {
#ifdef SNAPPY #ifdef TF_USE_SNAPPY
output->resize(snappy::MaxCompressedLength(length)); output->resize(snappy::MaxCompressedLength(length));
size_t outlen; size_t outlen;
snappy::RawCompress(input, length, &(*output)[0], &outlen); snappy::RawCompress(input, length, &(*output)[0], &outlen);
@ -131,7 +131,7 @@ bool Snappy_Compress(const char* input, size_t length, string* output) {
bool Snappy_GetUncompressedLength(const char* input, size_t length, bool Snappy_GetUncompressedLength(const char* input, size_t length,
size_t* result) { size_t* result) {
#ifdef SNAPPY #ifdef TF_USE_SNAPPY
return snappy::GetUncompressedLength(input, length, result); return snappy::GetUncompressedLength(input, length, result);
#else #else
return false; return false;
@ -139,7 +139,7 @@ bool Snappy_GetUncompressedLength(const char* input, size_t length,
} }
bool Snappy_Uncompress(const char* input, size_t length, char* output) { bool Snappy_Uncompress(const char* input, size_t length, char* output) {
#ifdef SNAPPY #ifdef TF_USE_SNAPPY
return snappy::RawUncompress(input, length, output); return snappy::RawUncompress(input, length, output);
#else #else
return false; return false;

View File

@ -17,47 +17,94 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import numpy as np import os
from sklearn import datasets import urllib
from sklearn import metrics
from sklearn import model_selection
import tensorflow as tf import tensorflow as tf
# Data sets
IRIS_TRAINING = 'iris_training.csv'
IRIS_TRAINING_URL = 'http://download.tensorflow.org/data/iris_training.csv'
X_FEATURE = 'x' # Name of the input feature. IRIS_TEST = 'iris_test.csv'
IRIS_TEST_URL = 'http://download.tensorflow.org/data/iris_test.csv'
FEATURE_KEYS = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
def maybe_download_iris_data(file_name, download_url):
"""Downloads the file and returns the number of data."""
if not os.path.exists(file_name):
raw = urllib.urlopen(download_url).read()
with open(file_name, 'w') as f:
f.write(raw)
# The first line is a comma-separated string. The first one is the number of
# total data in the file.
with open(file_name, 'r') as f:
first_line = f.readline()
num_elements = first_line.split(',')[0]
return int(num_elements)
def input_fn(file_name, num_data, batch_size, is_training):
"""Creates an input_fn required by Estimator train/evaluate."""
# If the data sets aren't stored locally, download them.
def _parse_csv(rows_string_tensor):
"""Takes the string input tensor and returns tuple of (features, labels)."""
# Last dim is the label.
num_features = len(FEATURE_KEYS)
num_columns = num_features + 1
columns = tf.decode_csv(rows_string_tensor,
record_defaults=[[]] * num_columns)
features = dict(zip(FEATURE_KEYS, columns[:num_features]))
labels = tf.cast(columns[num_features], tf.int32)
return features, labels
def _input_fn():
"""The input_fn."""
dataset = tf.data.TextLineDataset([file_name])
# Skip the first line (which does not have data).
dataset = dataset.skip(1)
dataset = dataset.map(_parse_csv)
if is_training:
# For this small dataset, which can fit into memory, to achieve true
# randomness, the shuffle buffer size is set as the total number of
# elements in the dataset.
dataset = dataset.shuffle(num_data)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
return _input_fn
def main(unused_argv): def main(unused_argv):
# Load dataset. tf.logging.set_verbosity(tf.logging.INFO)
iris = datasets.load_iris()
x_train, x_test, y_train, y_test = model_selection.train_test_split( num_training_data = maybe_download_iris_data(
iris.data, iris.target, test_size=0.2, random_state=42) IRIS_TRAINING, IRIS_TRAINING_URL)
num_test_data = maybe_download_iris_data(IRIS_TEST, IRIS_TEST_URL)
# Build 3 layer DNN with 10, 20, 10 units respectively. # Build 3 layer DNN with 10, 20, 10 units respectively.
feature_columns = [ feature_columns = [
tf.feature_column.numeric_column( tf.feature_column.numeric_column(key, shape=1) for key in FEATURE_KEYS]
X_FEATURE, shape=np.array(x_train).shape[1:])]
classifier = tf.estimator.DNNClassifier( classifier = tf.estimator.DNNClassifier(
feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3) feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3)
# Train. # Train.
train_input_fn = tf.estimator.inputs.numpy_input_fn( train_input_fn = input_fn(IRIS_TRAINING, num_training_data, batch_size=32,
x={X_FEATURE: x_train}, y=y_train, num_epochs=None, shuffle=True) is_training=True)
classifier.train(input_fn=train_input_fn, steps=200) classifier.train(input_fn=train_input_fn, steps=400)
# Predict. # Eval.
test_input_fn = tf.estimator.inputs.numpy_input_fn( test_input_fn = input_fn(IRIS_TEST, num_test_data, batch_size=32,
x={X_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False) is_training=False)
predictions = classifier.predict(input_fn=test_input_fn)
y_predicted = np.array(list(p['class_ids'] for p in predictions))
y_predicted = y_predicted.reshape(np.array(y_test).shape)
# Score with sklearn.
score = metrics.accuracy_score(y_test, y_predicted)
print('Accuracy (sklearn): {0:f}'.format(score))
# Score with tensorflow.
scores = classifier.evaluate(input_fn=test_input_fn) scores = classifier.evaluate(input_fn=test_input_fn)
print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy'])) print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy']))

Some files were not shown because too many files have changed in this diff Show More