commit
b3d5ec90bc
WORKSPACE
tensorflow
BUILD
c/eager
compiler/xla
client
layout_util.cclayout_util.hlegacy_flags
literal_util.hprotobuf_util.ccprotobuf_util.hservice
BUILD
shape_util.cccpu
gpu
hlo_computation.cchlo_runner.cchlo_runner.htranspose_folding.cctranspose_folding_test.ccuser_computation.cctests
tools
xla.protocontrib
batching
cmake/external
eager/python
factorization
g3doc
kernels
ops
python
framework
learn/python/learn/learn_io
makefile
metrics/python/ops
quantize
BUILD
python
rnn
seq2seq
kernels
ops
python/kernel_tests
tpu/python/tpu
training/python/training
core
examples/learn
@ -5,7 +5,7 @@ http_archive(
|
||||
sha256 = "110fe68753413777944b473c25eed6368c4a0487cee23a7bac1b13cc49d3e257",
|
||||
strip_prefix = "rules_closure-4af89ef1db659eb41f110df189b67d4cf14073e1",
|
||||
urls = [
|
||||
"http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz",
|
||||
"https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz",
|
||||
"https://github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", # 2017-08-28
|
||||
],
|
||||
)
|
||||
|
@ -348,6 +348,7 @@ filegroup(
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:all_files",
|
||||
"//tensorflow/compiler/xla/tests:all_files",
|
||||
"//tensorflow/compiler/xla/tools:all_files",
|
||||
"//tensorflow/compiler/xla/tools/parser:all_files",
|
||||
"//tensorflow/contrib:all_files",
|
||||
"//tensorflow/contrib/all_reduce:all_files",
|
||||
"//tensorflow/contrib/android:all_files",
|
||||
@ -421,7 +422,6 @@ filegroup(
|
||||
"//tensorflow/contrib/remote_fused_graph/pylib:all_files",
|
||||
"//tensorflow/contrib/resampler:all_files",
|
||||
"//tensorflow/contrib/rnn:all_files",
|
||||
"//tensorflow/contrib/s3:all_files",
|
||||
"//tensorflow/contrib/saved_model:all_files",
|
||||
"//tensorflow/contrib/saved_model/cc/saved_model:all_files",
|
||||
"//tensorflow/contrib/seq2seq:all_files",
|
||||
@ -475,6 +475,7 @@ filegroup(
|
||||
"//tensorflow/core/platform/cloud:all_files",
|
||||
"//tensorflow/core/platform/default/build_config:all_files",
|
||||
"//tensorflow/core/platform/hadoop:all_files",
|
||||
"//tensorflow/core/platform/s3:all_files",
|
||||
"//tensorflow/core/profiler:all_files",
|
||||
"//tensorflow/core/profiler/internal:all_files",
|
||||
"//tensorflow/core/profiler/internal/advisor:all_files",
|
||||
|
@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0
|
||||
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cuda_cc_test",
|
||||
"tf_cc_test",
|
||||
"tf_copts",
|
||||
"tf_cuda_library",
|
||||
@ -50,7 +51,7 @@ tf_cuda_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
tf_cuda_cc_test(
|
||||
name = "c_api_test",
|
||||
srcs = ["c_api_test.cc"],
|
||||
deps = [
|
||||
|
@ -54,9 +54,23 @@ string DeviceName(tensorflow::Device* d) {
|
||||
|
||||
extern "C" {
|
||||
|
||||
TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) {
|
||||
TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
|
||||
|
||||
void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
|
||||
size_t proto_len, TF_Status* status) {
|
||||
TF_SetConfig(&options->session_options, proto, proto_len, status);
|
||||
}
|
||||
|
||||
void TFE_ContextOptionsSetDevicePlacementPolicy(
|
||||
TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
|
||||
options->policy = policy;
|
||||
}
|
||||
|
||||
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
|
||||
|
||||
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
|
||||
TF_Graph* graph = TF_NewGraph();
|
||||
TF_Session* session = TF_NewSession(graph, opts, status);
|
||||
TF_Session* session = TF_NewSession(graph, &opts->session_options, status);
|
||||
if (status->status.ok()) {
|
||||
if (session->device_mgr == nullptr || session->devices.empty()) {
|
||||
status->status = tensorflow::errors::InvalidArgument(
|
||||
@ -71,9 +85,10 @@ TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) {
|
||||
}
|
||||
|
||||
TFE_Context* ret = new TFE_Context(session);
|
||||
ret->policy = opts->policy;
|
||||
ret->pflr.reset(new tensorflow::ProcessFunctionLibraryRuntime(
|
||||
ret->session->device_mgr, opts->options.env, TF_GRAPH_DEF_VERSION,
|
||||
&ret->func_lib_def, {}));
|
||||
ret->session->device_mgr, opts->session_options.options.env,
|
||||
TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {}));
|
||||
ret->rendezvous =
|
||||
new tensorflow::IntraProcessRendezvous(ret->session->device_mgr);
|
||||
|
||||
@ -408,8 +423,10 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
|
||||
namespace {
|
||||
|
||||
tensorflow::Status ValidateInputTypeAndPlacement(
|
||||
tensorflow::Device* host_device, tensorflow::Device* op_device, TFE_Op* op,
|
||||
const tensorflow::OpKernel* kernel) {
|
||||
TFE_Context* ctx, tensorflow::Device* host_device,
|
||||
tensorflow::Device* op_device, TFE_Op* op,
|
||||
const tensorflow::OpKernel* kernel,
|
||||
std::vector<TFE_TensorHandle*>* copied_tensors) {
|
||||
const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types();
|
||||
if (memtypes.size() != op->inputs.size()) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
@ -421,11 +438,42 @@ tensorflow::Status ValidateInputTypeAndPlacement(
|
||||
const tensorflow::Device* actual_device =
|
||||
op->input_devices[i] == nullptr ? host_device : op->input_devices[i];
|
||||
if (expected_device != actual_device) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"cannot compute ", op->name, " as input #", i,
|
||||
" was expected to be on ", expected_device->name(),
|
||||
" but is actually on ", actual_device->name(),
|
||||
" (operation running on ", op_device->name(), ")");
|
||||
switch (ctx->policy) {
|
||||
case TFE_DEVICE_PLACEMENT_EXPLICIT:
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
"cannot compute ", op->name, " as input #", i,
|
||||
" was expected to be on ", expected_device->name(),
|
||||
" but is actually on ", actual_device->name(),
|
||||
" (operation running on ", op_device->name(), ")");
|
||||
case TFE_DEVICE_PLACEMENT_WARN:
|
||||
LOG(WARNING) << "before computing " << op->name << " input #" << i
|
||||
<< " was expected to be on " << expected_device->name()
|
||||
<< " but is actually on " << actual_device->name()
|
||||
<< " (operation running on " << op_device->name()
|
||||
<< "). This triggers a copy which can be a performance "
|
||||
"bottleneck.";
|
||||
break;
|
||||
case TFE_DEVICE_PLACEMENT_SILENT: // Do nothing.
|
||||
break;
|
||||
}
|
||||
// We are only here if the policy is warn or silent copies, so we should
|
||||
// trigger a copy.
|
||||
TFE_TensorHandle original{op->inputs[i], op->input_devices[i]};
|
||||
TF_Status* s = TF_NewStatus();
|
||||
TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice(
|
||||
&original, ctx, expected_device->name().c_str(), s);
|
||||
if (!s->status.ok()) {
|
||||
tensorflow::Status status = s->status;
|
||||
delete s;
|
||||
return tensorflow::errors::Internal(
|
||||
"Failed copying input tensor from ", actual_device->name(), " to ",
|
||||
expected_device->name(), " in order to run ", op->name, ": ",
|
||||
status.error_message());
|
||||
}
|
||||
op->inputs[i] = copied_tensor->t;
|
||||
copied_tensors->push_back(copied_tensor);
|
||||
op->input_devices[i] = copied_tensor->d;
|
||||
delete s;
|
||||
}
|
||||
if (op->inputs[i].dtype() != kernel->input_type(i)) {
|
||||
return tensorflow::errors::InvalidArgument(
|
||||
@ -468,10 +516,14 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
}
|
||||
tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
|
||||
}
|
||||
status->status = ValidateInputTypeAndPlacement(ctx->devices()[0], device, op,
|
||||
kernel->kernel());
|
||||
std::vector<TFE_TensorHandle*> copied_tensors;
|
||||
status->status = ValidateInputTypeAndPlacement(
|
||||
ctx, ctx->devices()[0], device, op, kernel->kernel(), &copied_tensors);
|
||||
output_memory_types = &kernel->kernel()->output_memory_types();
|
||||
if (!status->status.ok()) {
|
||||
for (auto* t : copied_tensors) {
|
||||
TFE_DeleteTensorHandle(t);
|
||||
}
|
||||
return;
|
||||
}
|
||||
// WARNING: kernel->Run utilizes the FunctionLibraryRuntime
|
||||
@ -483,6 +535,9 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
|
||||
// sense for FunctionLibraryRuntime to ensure thread-safe access to
|
||||
// FunctionLibraryDefinition?).
|
||||
status->status = kernel->Run(&op->inputs, &outputs);
|
||||
for (auto* t : copied_tensors) {
|
||||
TFE_DeleteTensorHandle(t);
|
||||
}
|
||||
if (!status->status.ok()) return;
|
||||
*num_retvals = std::min<int>(*num_retvals, outputs.size());
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
|
@ -43,14 +43,46 @@ limitations under the License.
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
typedef struct TFE_ContextOptions TFE_ContextOptions;
|
||||
|
||||
// Return a new options object.
|
||||
TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions();
|
||||
|
||||
// Set the config in TF_ContextOptions.options.
|
||||
// config should be a serialized tensorflow.ConfigProto proto.
|
||||
// If config was not parsed successfully as a ConfigProto, record the
|
||||
// error information in *status.
|
||||
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetConfig(
|
||||
TFE_ContextOptions* options, const void* proto, size_t proto_len,
|
||||
TF_Status* status);
|
||||
|
||||
// Controls how to act when we try to run an operation on a given device but
|
||||
// some input tensors are not on that device.
|
||||
typedef enum TFE_ContextDevicePlacementPolicy {
|
||||
// The default: running operations with input tensors on the wrong device will
|
||||
// fail.
|
||||
TFE_DEVICE_PLACEMENT_EXPLICIT = 0,
|
||||
// Copy the tensor to the right device but log a warning.
|
||||
TFE_DEVICE_PLACEMENT_WARN = 1,
|
||||
// Silently copy the tensor, which has a performance cost since the
|
||||
// operation will be blocked till the copy completes.
|
||||
TFE_DEVICE_PLACEMENT_SILENT = 2,
|
||||
} TFE_ContextDevicePlacementPolicy;
|
||||
|
||||
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy(
|
||||
TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy);
|
||||
|
||||
// Destroy an options object.
|
||||
TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*);
|
||||
|
||||
// "Context" under which operations/functions are executed. It encapsulates
|
||||
// things like the available devices, resource manager etc.
|
||||
//
|
||||
// TODO(ashankar): Merge with TF_Session?
|
||||
typedef struct TFE_Context TFE_Context;
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext(const TF_SessionOptions* opts,
|
||||
TF_Status* status);
|
||||
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext(
|
||||
const TFE_ContextOptions* opts, TF_Status* status);
|
||||
TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status);
|
||||
TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
|
||||
TF_Status* status);
|
||||
|
@ -35,9 +35,16 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
|
||||
struct TFE_ContextOptions {
|
||||
TF_SessionOptions session_options;
|
||||
TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_EXPLICIT};
|
||||
};
|
||||
|
||||
struct TFE_Context {
|
||||
explicit TFE_Context(TF_Session* s) : session(s) {}
|
||||
|
||||
TFE_ContextDevicePlacementPolicy policy;
|
||||
|
||||
// TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph.
|
||||
TF_Session* session;
|
||||
tensorflow::Rendezvous* rendezvous;
|
||||
|
@ -62,10 +62,10 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
|
||||
void BM_InitOp(int iters) {
|
||||
tensorflow::testing::StopTiming();
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_SessionOptions* opts = TF_NewSessionOptions();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteSessionOptions(opts);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
tensorflow::testing::StartTiming();
|
||||
@ -84,10 +84,10 @@ BENCHMARK(BM_InitOp);
|
||||
void BM_Execute(int iters) {
|
||||
tensorflow::testing::StopTiming();
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_SessionOptions* opts = TF_NewSessionOptions();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteSessionOptions(opts);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
@ -109,9 +109,9 @@ BENCHMARK(BM_Execute);
|
||||
|
||||
TEST(CAPI, Context) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_SessionOptions* opts = TF_NewSessionOptions();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
TF_DeleteSessionOptions(opts);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
@ -150,9 +150,9 @@ TEST(CAPI, TensorHandle) {
|
||||
TEST(CAPI, TensorHandleCopyBetweenDevices) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_SessionOptions* opts = TF_NewSessionOptions();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
TF_DeleteSessionOptions(opts);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
@ -216,12 +216,58 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) {
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CAPI, TensorHandleSilentCopy) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status.get());
|
||||
TFE_DeleteContextOptions(opts);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
const int num_devices = TF_DeviceListCount(devices);
|
||||
|
||||
// Disable the test if no GPU is present.
|
||||
if (num_devices > 1) {
|
||||
const int device_to_use = 1;
|
||||
const string name(TF_DeviceListName(devices, device_to_use, status.get()));
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_TensorHandle* hgpu =
|
||||
TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
|
||||
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
|
||||
TFE_OpSetDevice(matmul, name.c_str(), status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_TensorHandle* retvals[1];
|
||||
int num_retvals = 1;
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
|
||||
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
|
||||
TFE_DeleteOp(matmul);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
TFE_DeleteTensorHandle(hgpu);
|
||||
}
|
||||
|
||||
TF_DeleteDeviceList(devices);
|
||||
TF_DeleteTensor(t);
|
||||
TFE_DeleteTensorHandle(hcpu);
|
||||
TFE_DeleteContext(ctx, status.get());
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
}
|
||||
|
||||
TEST(CAPI, Execute) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_SessionOptions* opts = TF_NewSessionOptions();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteSessionOptions(opts);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
@ -285,10 +331,10 @@ string MatMulFunction() {
|
||||
|
||||
TEST(CAPI, FunctionDefAndExecute) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_SessionOptions* opts = TF_NewSessionOptions();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteSessionOptions(opts);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
string function_def = MatMulFunction();
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
@ -326,10 +372,10 @@ TEST(CAPI, FunctionDefAndExecute) {
|
||||
void BM_ExecuteFunction(int iters) {
|
||||
tensorflow::testing::StopTiming();
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_SessionOptions* opts = TF_NewSessionOptions();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteSessionOptions(opts);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
string function_def = MatMulFunction();
|
||||
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
|
||||
@ -406,10 +452,10 @@ TEST(CAPI, Variables) {
|
||||
// Variables use resource handles, so this is really a test for resource
|
||||
// tensor handling.
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_SessionOptions* opts = TF_NewSessionOptions();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteSessionOptions(opts);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
@ -446,10 +492,10 @@ TEST(CAPI, Variables) {
|
||||
void BM_ReadVariable(int iters) {
|
||||
tensorflow::testing::StopTiming();
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_SessionOptions* opts = TF_NewSessionOptions();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteSessionOptions(opts);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
|
@ -138,6 +138,11 @@ class ComputationBuilder {
|
||||
ComputationDataHandle ConstantR2(
|
||||
std::initializer_list<std::initializer_list<NativeT>> values);
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ConstantFromArrayWithLayout(
|
||||
const Array<NativeT>& values, const Layout& layout);
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ConstantFromArray(const Array<NativeT>& values);
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ConstantR2FromArray2DWithLayout(
|
||||
const Array2D<NativeT>& values, const Layout& layout);
|
||||
template <typename NativeT>
|
||||
@ -909,49 +914,55 @@ ComputationDataHandle ComputationBuilder::ConstantR2(
|
||||
[&values](Literal* literal) { literal->PopulateR2(values); });
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout(
|
||||
const Array<NativeT>& values, const Layout& layout) {
|
||||
return ConstantOp([&values, &layout](Literal* literal) {
|
||||
literal->PopulateFromArrayWithLayout(values, layout);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ComputationBuilder::ConstantFromArray(
|
||||
const Array<NativeT>& values) {
|
||||
return ConstantOp(
|
||||
[&values](Literal* literal) { literal->PopulateFromArray(values); });
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout(
|
||||
const Array2D<NativeT>& values, const Layout& layout) {
|
||||
return ConstantOp([&values, &layout](Literal* literal) {
|
||||
literal->PopulateR2FromArray2DWithLayout(values, layout);
|
||||
});
|
||||
return ConstantFromArrayWithLayout(values, layout);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D(
|
||||
const Array2D<NativeT>& values) {
|
||||
return ConstantOp(
|
||||
[&values](Literal* literal) { literal->PopulateR2FromArray2D(values); });
|
||||
return ConstantFromArray(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout(
|
||||
const Array3D<NativeT>& values, const Layout& layout) {
|
||||
return ConstantOp([&values, &layout](Literal* literal) {
|
||||
literal->PopulateR3FromArray3DWithLayout(values, layout);
|
||||
});
|
||||
return ConstantFromArrayWithLayout(values, layout);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D(
|
||||
const Array3D<NativeT>& values) {
|
||||
return ConstantOp(
|
||||
[&values](Literal* literal) { literal->PopulateR3FromArray3D(values); });
|
||||
return ConstantFromArray(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout(
|
||||
const Array4D<NativeT>& values, const Layout& layout) {
|
||||
return ConstantOp([&values, &layout](Literal* literal) {
|
||||
literal->PopulateR4FromArray4DWithLayout(values, layout);
|
||||
});
|
||||
return ConstantFromArrayWithLayout(values, layout);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D(
|
||||
const Array4D<NativeT>& values) {
|
||||
return ConstantOp(
|
||||
[&values](Literal* literal) { literal->PopulateR4FromArray4D(values); });
|
||||
return ConstantFromArray(values);
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -83,6 +83,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
|
||||
return CreateDefaultLayoutForRank(shape.dimensions_size());
|
||||
}
|
||||
|
||||
/* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64 rank) {
|
||||
return CreateDefaultLayoutForRank(rank);
|
||||
}
|
||||
|
||||
/* static */ Layout LayoutUtil::GetDefaultLayoutForR2() {
|
||||
return CreateDefaultLayoutForRank(2);
|
||||
}
|
||||
|
@ -40,6 +40,7 @@ class LayoutUtil {
|
||||
static Layout GetDefaultLayoutForShape(const Shape& shape);
|
||||
|
||||
// Helper functions that create default layouts for various ranks.
|
||||
static Layout GetDefaultLayoutForRank(int64 rank);
|
||||
static Layout GetDefaultLayoutForR2();
|
||||
static Layout GetDefaultLayoutForR3();
|
||||
static Layout GetDefaultLayoutForR4();
|
||||
|
@ -206,9 +206,9 @@ void AllocateFlags() {
|
||||
flag_values->xla_gpu_disable_multi_streaming(),
|
||||
"If true, multi-streaming in the GPU backend is disabled."),
|
||||
tensorflow::Flag(
|
||||
"xla_dump_debug_json_to",
|
||||
flag_values->mutable_xla_dump_debug_json_to(),
|
||||
"Dump compilation artifacts as JSON into this directory."),
|
||||
"xla_dump_hlo_proto_to",
|
||||
flag_values->mutable_xla_dump_hlo_proto_to(),
|
||||
"Dump compilation artifacts as proto binary into this directory."),
|
||||
tensorflow::Flag(
|
||||
"xla_test_all_output_layouts",
|
||||
bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts),
|
||||
|
@ -334,6 +334,11 @@ class Literal {
|
||||
// WithLayout use the default XLA layout for the literal's linear
|
||||
// representation in memory.
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values);
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateFromArrayWithLayout(
|
||||
const Array<NativeT>& values, const Layout& layout);
|
||||
template <typename NativeT>
|
||||
static std::unique_ptr<Literal> CreateR2FromArray2D(
|
||||
const Array2D<NativeT>& values);
|
||||
template <typename NativeT>
|
||||
@ -481,6 +486,11 @@ class Literal {
|
||||
std::initializer_list<std::initializer_list<NativeT>> values,
|
||||
const Layout& layout);
|
||||
template <typename NativeT>
|
||||
void PopulateFromArray(const Array<NativeT>& values);
|
||||
template <typename NativeT>
|
||||
void PopulateFromArrayWithLayout(const Array<NativeT>& values,
|
||||
const Layout& layout);
|
||||
template <typename NativeT>
|
||||
void PopulateR2FromArray2D(const Array2D<NativeT>& values);
|
||||
template <typename NativeT>
|
||||
void PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
|
||||
@ -815,34 +825,43 @@ template <typename NativeT>
|
||||
return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> Literal::CreateFromArrayWithLayout(
|
||||
const Array<NativeT>& values, const Layout& layout) {
|
||||
auto literal = MakeUnique<Literal>();
|
||||
literal->PopulateFromArrayWithLayout(values, layout);
|
||||
return literal;
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> Literal::CreateFromArray(
|
||||
const Array<NativeT>& values) {
|
||||
return CreateFromArrayWithLayout(
|
||||
values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2DWithLayout(
|
||||
const Array2D<NativeT>& values, const Layout& layout) {
|
||||
auto literal = MakeUnique<Literal>();
|
||||
literal->PopulateR2FromArray2DWithLayout(values, layout);
|
||||
return literal;
|
||||
return CreateFromArrayWithLayout(values, layout);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2D(
|
||||
const Array2D<NativeT>& values) {
|
||||
return CreateR2FromArray2DWithLayout(values,
|
||||
LayoutUtil::GetDefaultLayoutForR2());
|
||||
return CreateFromArray(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3DWithLayout(
|
||||
const Array3D<NativeT>& values, const Layout& layout) {
|
||||
auto literal = MakeUnique<Literal>();
|
||||
literal->PopulateR3FromArray3DWithLayout(values, layout);
|
||||
return literal;
|
||||
return CreateFromArrayWithLayout(values, layout);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3D(
|
||||
const Array3D<NativeT>& values) {
|
||||
return CreateR3FromArray3DWithLayout(values,
|
||||
LayoutUtil::GetDefaultLayoutForR3());
|
||||
return CreateFromArray(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
@ -901,16 +920,13 @@ template <typename NativeT>
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4D(
|
||||
const Array4D<NativeT>& values) {
|
||||
return CreateR4FromArray4DWithLayout(values,
|
||||
LayoutUtil::GetDefaultLayoutForR4());
|
||||
return CreateFromArray(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
/* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4DWithLayout(
|
||||
const Array4D<NativeT>& values, const Layout& layout) {
|
||||
auto literal = MakeUnique<Literal>();
|
||||
literal->PopulateR4FromArray4DWithLayout(values, layout);
|
||||
return literal;
|
||||
return CreateFromArrayWithLayout(values, layout);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
@ -1069,83 +1085,54 @@ void Literal::PopulateR2(
|
||||
PopulateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void Literal::PopulateFromArrayWithLayout(const Array<NativeT>& values,
|
||||
const Layout& layout) {
|
||||
*mutable_shape() = ShapeUtil::MakeShapeWithLayout(
|
||||
primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
|
||||
AsInt64Slice(layout.minor_to_major()));
|
||||
Reserve(values.num_elements());
|
||||
values.Each([this](tensorflow::gtl::ArraySlice<int64> indices,
|
||||
NativeT value) { this->Set(indices, value); });
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void Literal::PopulateFromArray(const Array<NativeT>& values) {
|
||||
PopulateFromArrayWithLayout(
|
||||
values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void Literal::PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
|
||||
const Layout& layout) {
|
||||
*mutable_shape() = ShapeUtil::MakeShapeWithLayout(
|
||||
primitive_util::NativeToPrimitiveType<NativeT>(),
|
||||
{values.height(), values.width()}, AsInt64Slice(layout.minor_to_major()));
|
||||
|
||||
const int64 dim1_size = values.width();
|
||||
const int64 dim0_size = values.height();
|
||||
CHECK_EQ(dim0_size, shape().dimensions(0));
|
||||
CHECK_EQ(dim1_size, shape().dimensions(1));
|
||||
Reserve(dim1_size * dim0_size);
|
||||
for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
|
||||
for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) {
|
||||
Set({dim0, dim1}, values(dim0, dim1));
|
||||
}
|
||||
}
|
||||
PopulateFromArrayWithLayout(values, layout);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
|
||||
PopulateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
|
||||
PopulateFromArray(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void Literal::PopulateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
|
||||
const Layout& layout) {
|
||||
*mutable_shape() = ShapeUtil::MakeShapeWithLayout(
|
||||
primitive_util::NativeToPrimitiveType<NativeT>(),
|
||||
{values.n1(), values.n2(), values.n3()},
|
||||
AsInt64Slice(layout.minor_to_major()));
|
||||
|
||||
CHECK_EQ(values.n1(), shape().dimensions(0));
|
||||
CHECK_EQ(values.n2(), shape().dimensions(1));
|
||||
CHECK_EQ(values.n3(), shape().dimensions(2));
|
||||
Reserve(values.n1() * values.n2() * values.n3());
|
||||
for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) {
|
||||
for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) {
|
||||
for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) {
|
||||
Set({dim0, dim1, dim2}, values(dim0, dim1, dim2));
|
||||
}
|
||||
}
|
||||
}
|
||||
PopulateFromArrayWithLayout(values, layout);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
|
||||
PopulateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
|
||||
PopulateFromArray(values);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void Literal::PopulateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
|
||||
const Layout& layout) {
|
||||
*mutable_shape() = ShapeUtil::MakeShapeWithLayout(
|
||||
primitive_util::NativeToPrimitiveType<NativeT>(),
|
||||
{values.planes(), values.depth(), values.height(), values.width()},
|
||||
AsInt64Slice(layout.minor_to_major()));
|
||||
|
||||
CHECK_EQ(values.n1(), shape().dimensions(0));
|
||||
CHECK_EQ(values.n2(), shape().dimensions(1));
|
||||
CHECK_EQ(values.n3(), shape().dimensions(2));
|
||||
CHECK_EQ(values.n4(), shape().dimensions(3));
|
||||
Reserve(values.n1() * values.n2() * values.n3() * values.n4());
|
||||
for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) {
|
||||
for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) {
|
||||
for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) {
|
||||
for (int64 dim3 = 0; dim3 < values.n4(); ++dim3) {
|
||||
Set({dim0, dim1, dim2, dim3}, values(dim0, dim1, dim2, dim3));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
PopulateFromArrayWithLayout(values, layout);
|
||||
}
|
||||
|
||||
template <typename NativeT>
|
||||
void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
|
||||
PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
|
||||
PopulateFromArray(values);
|
||||
}
|
||||
|
||||
template <typename NativeT, typename FnType>
|
||||
|
@ -37,20 +37,6 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
|
||||
return (serialized1 == serialized2);
|
||||
}
|
||||
|
||||
StatusOr<string> ToJson(const tensorflow::protobuf::Message& message) {
|
||||
string json_output;
|
||||
tensorflow::protobuf::util::JsonPrintOptions json_options;
|
||||
json_options.add_whitespace = true;
|
||||
json_options.always_print_primitive_fields = true;
|
||||
auto status = tensorflow::protobuf::util::MessageToJsonString(
|
||||
message, &json_output, json_options);
|
||||
if (!status.ok()) {
|
||||
return InternalError("MessageToJsonString failed: %s",
|
||||
status.error_message().data());
|
||||
}
|
||||
return json_output;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
string SanitizeFilename(const string& file_name) {
|
||||
@ -65,17 +51,6 @@ string SanitizeFilename(const string& file_name) {
|
||||
|
||||
} // namespace
|
||||
|
||||
Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message,
|
||||
const string& directory, const string& file_name) {
|
||||
TF_ASSIGN_OR_RETURN(const string json_output, ToJson(message));
|
||||
|
||||
tensorflow::Env* env = tensorflow::Env::Default();
|
||||
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory));
|
||||
string safe_file_name = SanitizeFileName(file_name) + ".json";
|
||||
const string path = tensorflow::io::JoinPath(directory, safe_file_name);
|
||||
return tensorflow::WriteStringToFile(env, path, json_output);
|
||||
}
|
||||
|
||||
Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
|
||||
const string& directory, const string& file_name) {
|
||||
tensorflow::Env* env = tensorflow::Env::Default();
|
||||
|
@ -32,17 +32,12 @@ namespace protobuf_util {
|
||||
extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
|
||||
const tensorflow::protobuf::Message& m2);
|
||||
|
||||
// Returns 'message' as a JSON string.
|
||||
StatusOr<string> ToJson(const tensorflow::protobuf::Message& message);
|
||||
|
||||
// Writes the given message in binary proto or JSON format to the path formed by
|
||||
// joining 'directory/file_name.pb' (or file_name.json). The 'directory' is
|
||||
// recursively created if it doesn't already exist, and the 'file_name' is
|
||||
// sanitized by replacing illegal characters with underscore '_'.
|
||||
// Writes the given message in binary proto to the path formed by joining
|
||||
// 'directory/file_name.pb'. The 'directory' is recursively created if it
|
||||
// doesn't already exist, and the 'file_name' is sanitized by replacing
|
||||
// illegal characters with underscore '_'.
|
||||
Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
|
||||
const string& directory, const string& file_name);
|
||||
Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message,
|
||||
const string& directory, const string& file_name);
|
||||
|
||||
} // namespace protobuf_util
|
||||
} // namespace xla
|
||||
|
@ -2064,6 +2064,29 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_runner",
|
||||
srcs = ["hlo_runner.cc"],
|
||||
hdrs = ["hlo_runner.h"],
|
||||
deps = [
|
||||
":executable",
|
||||
":hlo",
|
||||
":transfer_manager",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:backend",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
|
@ -475,8 +475,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
|
||||
// ownership is std::moved.
|
||||
const bool embed_ir_in_executable =
|
||||
module->config().debug_options().xla_embed_ir_in_executable();
|
||||
const string dump_debug_json_to =
|
||||
module->config().debug_options().xla_dump_debug_json_to();
|
||||
const string xla_dump_hlo_proto_to =
|
||||
module->config().debug_options().xla_dump_hlo_proto_to();
|
||||
|
||||
if (options::CpuParallelBackendRequested(module->config())) {
|
||||
VLOG(1) << "Using parallel cpu backend";
|
||||
@ -496,10 +496,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
|
||||
// print one ourselves.
|
||||
XLA_VLOG_LINES(2, assignment->ToString());
|
||||
|
||||
if (!dump_debug_json_to.empty()) {
|
||||
if (!xla_dump_hlo_proto_to.empty()) {
|
||||
HloProto proto = MakeHloProto(*module, *assignment);
|
||||
TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory(
|
||||
proto, dump_debug_json_to, module->name()));
|
||||
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
|
||||
proto, xla_dump_hlo_proto_to, module->name()));
|
||||
}
|
||||
|
||||
// If we are using the parallel CPU backend, we need to create map from
|
||||
@ -603,12 +603,11 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
|
||||
// print one ourselves.
|
||||
XLA_VLOG_LINES(2, assignment->ToString());
|
||||
|
||||
if (!dump_debug_json_to.empty()) {
|
||||
if (!xla_dump_hlo_proto_to.empty()) {
|
||||
HloProto proto = MakeHloProto(*module, *assignment);
|
||||
TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory(
|
||||
proto, dump_debug_json_to, module->name()));
|
||||
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
|
||||
proto, xla_dump_hlo_proto_to, module->name()));
|
||||
}
|
||||
|
||||
// Each computation is a single function. Emit all embedded computations
|
||||
// before the entry computation. The order of computations returned from
|
||||
// GetEmbeddedComputations guarantees that a called computation occurs
|
||||
@ -775,12 +774,12 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
|
||||
// print one ourselves.
|
||||
XLA_VLOG_LINES(2, assignment->ToString());
|
||||
|
||||
const string dump_debug_json_to =
|
||||
module->config().debug_options().xla_dump_debug_json_to();
|
||||
if (!dump_debug_json_to.empty()) {
|
||||
const string xla_dump_hlo_proto_to =
|
||||
module->config().debug_options().xla_dump_hlo_proto_to();
|
||||
if (!xla_dump_hlo_proto_to.empty()) {
|
||||
HloProto proto = MakeHloProto(*module, *assignment);
|
||||
TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory(
|
||||
proto, dump_debug_json_to, module->name()));
|
||||
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
|
||||
proto, xla_dump_hlo_proto_to, module->name()));
|
||||
}
|
||||
|
||||
IrEmitter ir_emitter(*module, *assignment, &llvm_module,
|
||||
|
@ -136,6 +136,8 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
|
||||
instruction->opcode() == HloOpcode::kCall ||
|
||||
instruction->opcode() == HloOpcode::kCustomCall ||
|
||||
instruction->opcode() == HloOpcode::kSelectAndScatter ||
|
||||
instruction->opcode() == HloOpcode::kGetTupleElement ||
|
||||
instruction->opcode() == HloOpcode::kBitcast ||
|
||||
(instruction->opcode() == HloOpcode::kConvolution &&
|
||||
PotentiallyImplementedAsEigenConvolution(*instruction)) ||
|
||||
PotentiallyImplementedAsEigenDot(*instruction) ||
|
||||
|
@ -318,12 +318,12 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
|
||||
// print one ourselves.
|
||||
XLA_VLOG_LINES(2, buffer_assignment->ToString());
|
||||
|
||||
const string dump_debug_json_to =
|
||||
module->config().debug_options().xla_dump_debug_json_to();
|
||||
if (!dump_debug_json_to.empty()) {
|
||||
const string xla_dump_hlo_proto_to =
|
||||
module->config().debug_options().xla_dump_hlo_proto_to();
|
||||
if (!xla_dump_hlo_proto_to.empty()) {
|
||||
HloProto proto = MakeHloProto(*module, *buffer_assignment);
|
||||
TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory(
|
||||
proto, dump_debug_json_to, module->name()));
|
||||
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
|
||||
proto, xla_dump_hlo_proto_to, module->name()));
|
||||
}
|
||||
|
||||
IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(),
|
||||
|
@ -373,8 +373,8 @@ string HloComputation::ToString(int nested_level) const {
|
||||
for (int i = 0; i < nested_level; i++) {
|
||||
s << " ";
|
||||
}
|
||||
s << name() << " " << ShapeUtil::HumanString(ComputeProgramShape())
|
||||
<< " { \n";
|
||||
s << "%" << name() << " " << ShapeUtil::HumanString(ComputeProgramShape())
|
||||
<< " {\n";
|
||||
for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
|
||||
for (int i = 0; i < nested_level; i++) {
|
||||
s << " ";
|
||||
|
199
tensorflow/compiler/xla/service/hlo_runner.cc
Normal file
199
tensorflow/compiler/xla/service/hlo_runner.cc
Normal file
@ -0,0 +1,199 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_runner.h"
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/backend.h"
|
||||
#include "tensorflow/compiler/xla/service/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace se = ::perftools::gputools;
|
||||
|
||||
namespace xla {
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<HloModule>>
|
||||
HloRunner::ReadModuleFromHloProtoFile(const char* filename,
|
||||
const DebugOptions& debug_options) {
|
||||
HloProto proto;
|
||||
TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
|
||||
filename, &proto));
|
||||
HloModuleConfig config;
|
||||
config.set_debug_options(debug_options);
|
||||
TF_ASSIGN_OR_RETURN(auto module, HloModule::CreateFromProto(
|
||||
proto.hlo_module(),
|
||||
VersionedComputationHandle(), config));
|
||||
return std::move(module);
|
||||
}
|
||||
|
||||
// Define this in .cc file to avoid having to include eigen or forward declare
|
||||
// these types in the header.
|
||||
struct HloRunner::EigenThreadPoolWrapper {
|
||||
std::unique_ptr<EigenThreadPoolWrapper> pool;
|
||||
std::unique_ptr<Eigen::ThreadPoolDevice> device;
|
||||
};
|
||||
|
||||
HloRunner::HloRunner() {}
|
||||
|
||||
HloRunner::HloRunner(se::Platform* platform) {
|
||||
BackendOptions backend_options;
|
||||
backend_options.set_platform(platform);
|
||||
backend_ = Backend::CreateBackend(backend_options).ConsumeValueOrDie();
|
||||
VLOG(1) << "Created HloRunner for platform: " << platform->Name();
|
||||
}
|
||||
|
||||
HloRunner::~HloRunner() {
|
||||
// Deallocate all the memory allocated during the tests.
|
||||
for (auto& allocation : allocations_) {
|
||||
backend().default_stream_executor()->Deallocate(&allocation);
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<se::DeviceMemoryBase> HloRunner::Execute(
|
||||
std::unique_ptr<HloModule> module,
|
||||
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
|
||||
Shape* result_shape) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Executable> executable,
|
||||
backend().compiler()->Compile(std::move(module),
|
||||
backend().default_stream_executor()));
|
||||
|
||||
se::Stream stream(backend().default_stream_executor());
|
||||
stream.Init();
|
||||
|
||||
ExecutableRunOptions run_options;
|
||||
run_options.set_stream(&stream);
|
||||
run_options.set_allocator(backend().memory_allocator());
|
||||
run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool());
|
||||
run_options.set_intra_op_thread_pool(
|
||||
backend().eigen_intra_op_thread_pool_device());
|
||||
|
||||
HloExecutionProfile hlo_execution_profile;
|
||||
ServiceExecutableRunOptions service_run_options(
|
||||
run_options, backend().StreamBorrower(),
|
||||
backend().inter_op_thread_pool());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
se::DeviceMemoryBase result,
|
||||
executable->ExecuteOnStream(&service_run_options, arguments,
|
||||
&hlo_execution_profile));
|
||||
TF_RET_CHECK(stream.BlockHostUntilDone());
|
||||
|
||||
allocations_.push_back(result);
|
||||
|
||||
*result_shape = executable->result_shape();
|
||||
|
||||
if (ShapeUtil::IsTuple(*result_shape)) {
|
||||
// We must record element buffers of tuples as well to avoid leaks.
|
||||
DCHECK(!ShapeUtil::IsNestedTuple(*result_shape));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<se::DeviceMemoryBase> element_buffers,
|
||||
backend().transfer_manager()->ShallowCopyTupleFromDevice(
|
||||
backend().default_stream_executor(), result, *result_shape));
|
||||
|
||||
// A tuple may contain the same buffer in more than one element. Keep track
|
||||
// of the buffers already added to avoid duplicates in allocations_.
|
||||
std::set<void*> added_opaques;
|
||||
for (auto element_buffer : element_buffers) {
|
||||
if (added_opaques.count(element_buffer.opaque()) == 0) {
|
||||
CHECK(element_buffer.opaque() != nullptr);
|
||||
added_opaques.insert(element_buffer.opaque());
|
||||
allocations_.push_back(element_buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
se::DeviceMemoryBase HloRunner::TransferToDevice(const Literal& literal) {
|
||||
// Allocate memory on the device using the stream executor.
|
||||
int64 allocation_size =
|
||||
backend().transfer_manager()->GetByteSizeRequirement(literal.shape());
|
||||
se::DeviceMemoryBase allocation =
|
||||
backend().default_stream_executor()->AllocateArray<uint8>(
|
||||
allocation_size);
|
||||
allocations_.push_back(allocation);
|
||||
|
||||
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToDevice(
|
||||
backend().default_stream_executor(), literal, &allocation));
|
||||
|
||||
return allocation;
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> HloRunner::TransferFromDevice(
|
||||
const Shape& shape, se::DeviceMemoryBase device_base) {
|
||||
auto literal = MakeUnique<Literal>();
|
||||
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromDevice(
|
||||
backend().default_stream_executor(), device_base, shape, shape,
|
||||
literal.get()));
|
||||
return literal;
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> HloRunner::ExecuteAndTransfer(
|
||||
std::unique_ptr<HloModule> module,
|
||||
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
|
||||
Shape result_shape;
|
||||
se::DeviceMemoryBase device_base =
|
||||
Execute(std::move(module), arguments, &result_shape).ValueOrDie();
|
||||
return TransferFromDevice(result_shape, device_base);
|
||||
}
|
||||
|
||||
template <>
|
||||
std::unique_ptr<Literal> HloRunner::Execute(
|
||||
std::unique_ptr<HloModule> module,
|
||||
const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>>& literals) {
|
||||
std::vector<se::DeviceMemoryBase> arguments;
|
||||
for (const auto& literal : literals) {
|
||||
arguments.push_back(TransferToDevice(*literal));
|
||||
}
|
||||
return ExecuteAndTransfer(std::move(module), arguments);
|
||||
}
|
||||
|
||||
template <>
|
||||
std::unique_ptr<Literal> HloRunner::Execute(
|
||||
std::unique_ptr<HloModule> module,
|
||||
const tensorflow::gtl::ArraySlice<Literal*>& literals) {
|
||||
std::vector<se::DeviceMemoryBase> arguments;
|
||||
for (const auto& literal : literals) {
|
||||
arguments.push_back(TransferToDevice(*literal));
|
||||
}
|
||||
return ExecuteAndTransfer(std::move(module), arguments);
|
||||
}
|
||||
|
||||
Backend& HloRunner::backend() {
|
||||
if (!backend_) {
|
||||
backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie();
|
||||
VLOG(1) << "executing on platform " << backend().platform()->Name();
|
||||
}
|
||||
return *backend_;
|
||||
}
|
||||
|
||||
} // namespace xla
|
100
tensorflow/compiler/xla/service/hlo_runner.h
Normal file
100
tensorflow/compiler/xla/service/hlo_runner.h
Normal file
@ -0,0 +1,100 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/backend.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// A base class for running an HloModule. This executes the given HloModule on a
|
||||
// certain backend directly without using the client interface. HloModule can be
|
||||
// explicitly built, or loaded from a serialization file (e.g., hlo proto file).
|
||||
class HloRunner {
|
||||
public:
|
||||
HloRunner();
|
||||
|
||||
HloRunner(::perftools::gputools::Platform* platform);
|
||||
|
||||
~HloRunner();
|
||||
|
||||
// Reads the binary proto file in xla.HloProto format, creates and returns the
|
||||
// HloModule.
|
||||
static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloProtoFile(
|
||||
const char* filename, const DebugOptions& debug_options);
|
||||
|
||||
// Executes the given module with given literals as input and returns the
|
||||
// result as a Literal. The LiteralPtr type accepts Literal* or
|
||||
// std::unique_ptr<Literal>.
|
||||
template <typename LiteralPtr>
|
||||
std::unique_ptr<Literal> Execute(
|
||||
std::unique_ptr<HloModule> module,
|
||||
const tensorflow::gtl::ArraySlice<LiteralPtr>& literals);
|
||||
|
||||
// Executes the given module and returns a global data handle.
|
||||
StatusOr<perftools::gputools::DeviceMemoryBase> Execute(
|
||||
std::unique_ptr<HloModule> module,
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
||||
arguments,
|
||||
Shape* result_shape);
|
||||
|
||||
// Transfers the given literal to the device and returns the data handle.
|
||||
perftools::gputools::DeviceMemoryBase TransferToDevice(
|
||||
const Literal& literal);
|
||||
|
||||
// Transfers the array referred to by the given handle from the device and
|
||||
// returns as a Literal.
|
||||
std::unique_ptr<Literal> TransferFromDevice(
|
||||
const Shape& shape, perftools::gputools::DeviceMemoryBase device_base);
|
||||
|
||||
// Executes the given module and return the result as a Literal.
|
||||
std::unique_ptr<Literal> ExecuteAndTransfer(
|
||||
std::unique_ptr<HloModule> module,
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
||||
arguments);
|
||||
|
||||
// If backend is not created in the constructor, creates and returns the
|
||||
// default backend. If creation fails, crashes the program.
|
||||
//
|
||||
// This creates the backend lazily so it's possible to instantiate an
|
||||
// HloRunner in a program without any backends linked in.
|
||||
Backend& backend();
|
||||
|
||||
private:
|
||||
struct EigenThreadPoolWrapper;
|
||||
|
||||
std::vector<perftools::gputools::DeviceMemoryBase> allocations_;
|
||||
|
||||
std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_;
|
||||
|
||||
std::unique_ptr<Backend> backend_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
|
@ -58,14 +58,32 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution(
|
||||
return {};
|
||||
}
|
||||
|
||||
// We only support folding the RHS.
|
||||
const int64 kRhsOperandIndex = 1;
|
||||
auto& operand = *convolution.operand(kRhsOperandIndex);
|
||||
if (operand.opcode() == HloOpcode::kTranspose && operand.user_count() == 1) {
|
||||
return transposable_conv_operands(convolution, {kRhsOperandIndex});
|
||||
const ConvolutionDimensionNumbers& dnums =
|
||||
convolution.convolution_dimension_numbers();
|
||||
|
||||
TransposeFolding::OperandIndices operand_set;
|
||||
for (int64 i = 0; i < convolution.operand_count(); ++i) {
|
||||
auto& operand = *convolution.operand(i);
|
||||
if (operand.opcode() == HloOpcode::kTranspose &&
|
||||
operand.user_count() == 1) {
|
||||
const auto& transpose_dimensions = operand.dimensions();
|
||||
// We can transpose the LHS so long as it doesn't move around spatial
|
||||
// dimensions because ConvolutionDimensionNumbers doesn't have different
|
||||
// fields for input and output spatial dimensions.
|
||||
if (i == 0 &&
|
||||
std::any_of(dnums.spatial_dimensions().begin(),
|
||||
dnums.spatial_dimensions().end(),
|
||||
[&](const int64 spatial_dimension) {
|
||||
return transpose_dimensions[spatial_dimension] !=
|
||||
spatial_dimension;
|
||||
})) {
|
||||
continue;
|
||||
}
|
||||
operand_set.push_back(i);
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
return transposable_conv_operands(convolution, operand_set);
|
||||
}
|
||||
|
||||
using InstructionOperandsPair =
|
||||
@ -98,40 +116,61 @@ bool FoldTransposeIntoDot(InstructionOperandsPair pair) {
|
||||
// Returns whether the module is changed.
|
||||
bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
|
||||
auto& convolution = *pair.first;
|
||||
|
||||
// We only support fusing the RHS transpose into convolution.
|
||||
//
|
||||
// ConvolutionDimensionNumbers doesn't make enough of a distinction between
|
||||
// the output and the activations.
|
||||
//
|
||||
// TODO(b/37125184): Support transposing the LHS too.
|
||||
if (pair.second.size() != 1 || pair.second.front() != 1) {
|
||||
return false;
|
||||
}
|
||||
auto& operand_indices = pair.second;
|
||||
|
||||
const ConvolutionDimensionNumbers& dnums =
|
||||
convolution.convolution_dimension_numbers();
|
||||
HloInstruction& transpose = *convolution.mutable_operand(1);
|
||||
CHECK_EQ(transpose.opcode(), HloOpcode::kTranspose);
|
||||
const auto& transpose_dimensions = transpose.dimensions();
|
||||
HloInstruction& transpose_operand = *transpose.mutable_operand(0);
|
||||
|
||||
// Everything remains the same except for the kernel dimension numbers. We
|
||||
// need to apply the transpose permutation to the original shape to figure out
|
||||
// what the new logical dimensions are.
|
||||
ConvolutionDimensionNumbers new_dnums = dnums;
|
||||
new_dnums.set_kernel_input_feature_dimension(
|
||||
transpose_dimensions[dnums.kernel_input_feature_dimension()]);
|
||||
new_dnums.set_kernel_output_feature_dimension(
|
||||
transpose_dimensions[dnums.kernel_output_feature_dimension()]);
|
||||
for (auto& kernel_spatial_dimension :
|
||||
*new_dnums.mutable_kernel_spatial_dimensions()) {
|
||||
kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension];
|
||||
|
||||
HloInstruction* new_lhs;
|
||||
const int64 kLhsIdx = 0;
|
||||
if (std::find(operand_indices.begin(), operand_indices.end(), kLhsIdx) !=
|
||||
operand_indices.end()) {
|
||||
HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx);
|
||||
const auto& transpose_dimensions = transpose.dimensions();
|
||||
HloInstruction& transpose_operand = *transpose.mutable_operand(0);
|
||||
|
||||
// Everything remains the same except for the input/output dimension
|
||||
// numbers. We need to apply the transpose permutation to the original shape
|
||||
// to figure out what the new logical dimensions are.
|
||||
new_dnums.set_input_batch_dimension(
|
||||
transpose_dimensions[dnums.input_batch_dimension()]);
|
||||
new_dnums.set_input_feature_dimension(
|
||||
transpose_dimensions[dnums.input_feature_dimension()]);
|
||||
for (const auto& spatial_dimension : dnums.spatial_dimensions()) {
|
||||
CHECK_EQ(spatial_dimension, transpose_dimensions[spatial_dimension]);
|
||||
}
|
||||
new_lhs = &transpose_operand;
|
||||
} else {
|
||||
new_lhs = convolution.mutable_operand(kLhsIdx);
|
||||
}
|
||||
|
||||
HloInstruction* new_rhs;
|
||||
const int64 kRhsIdx = 1;
|
||||
if (std::find(operand_indices.begin(), operand_indices.end(), kRhsIdx) !=
|
||||
operand_indices.end()) {
|
||||
HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx);
|
||||
const auto& transpose_dimensions = transpose.dimensions();
|
||||
HloInstruction& transpose_operand = *transpose.mutable_operand(0);
|
||||
|
||||
// Everything remains the same except for the kernel dimension numbers. We
|
||||
// need to apply the transpose permutation to the original shape to figure
|
||||
// out what the new logical dimensions are.
|
||||
new_dnums.set_kernel_input_feature_dimension(
|
||||
transpose_dimensions[dnums.kernel_input_feature_dimension()]);
|
||||
new_dnums.set_kernel_output_feature_dimension(
|
||||
transpose_dimensions[dnums.kernel_output_feature_dimension()]);
|
||||
for (auto& kernel_spatial_dimension :
|
||||
*new_dnums.mutable_kernel_spatial_dimensions()) {
|
||||
kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension];
|
||||
}
|
||||
new_rhs = &transpose_operand;
|
||||
} else {
|
||||
new_rhs = convolution.mutable_operand(kRhsIdx);
|
||||
}
|
||||
|
||||
auto new_conv = HloInstruction::CreateConvolve(
|
||||
convolution.shape(), convolution.mutable_operand(0), &transpose_operand,
|
||||
convolution.window(), new_dnums);
|
||||
convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums);
|
||||
TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction(
|
||||
&convolution, std::move(new_conv)));
|
||||
|
||||
|
@ -313,8 +313,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) {
|
||||
new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(1));
|
||||
}
|
||||
|
||||
// Test that a transpose of the activations does not get folded into
|
||||
// convolution.
|
||||
// Test that a transpose of the activations gets folded into convolution.
|
||||
TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
|
||||
auto builder = HloComputation::Builder("entry_computation");
|
||||
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
|
||||
@ -348,18 +347,25 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
|
||||
module.AddEntryComputation(builder.Build(conv));
|
||||
FoldTranspose(&module);
|
||||
|
||||
// Instructions after folding: transpose_x, y, and the convolution.
|
||||
// Instructions after folding: x, y, and the convolution.
|
||||
std::unordered_set<HloInstruction*> instruction_set(
|
||||
entry_computation->instructions().begin(),
|
||||
entry_computation->instructions().end());
|
||||
CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
|
||||
CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
|
||||
CHECK_EQ(1, instruction_set.erase(transpose_x))
|
||||
<< "transpose_x is not in entry_computation.";
|
||||
CHECK_EQ(1, instruction_set.erase(conv))
|
||||
<< "transpose_x is not in entry_computation.";
|
||||
CHECK_EQ(0, instruction_set.size())
|
||||
<< "entry_computation should contain exactly 4 instructions.";
|
||||
EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
|
||||
EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
|
||||
EXPECT_EQ(1, instruction_set.size())
|
||||
<< "entry_computation should contain exactly 3 instructions.";
|
||||
HloInstruction* new_conv = *instruction_set.begin();
|
||||
EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
|
||||
EXPECT_EQ(dnums.input_feature_dimension(),
|
||||
new_conv->convolution_dimension_numbers().input_batch_dimension());
|
||||
EXPECT_EQ(
|
||||
dnums.input_batch_dimension(),
|
||||
new_conv->convolution_dimension_numbers().input_feature_dimension());
|
||||
EXPECT_EQ(dnums.spatial_dimensions(0),
|
||||
new_conv->convolution_dimension_numbers().spatial_dimensions(0));
|
||||
EXPECT_EQ(dnums.spatial_dimensions(1),
|
||||
new_conv->convolution_dimension_numbers().spatial_dimensions(1));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <stack>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
@ -1843,10 +1844,17 @@ UserComputation::GetEmbeddedComputations(
|
||||
XLA_VLOG_LINES(3, session_computation_.DebugString());
|
||||
|
||||
std::vector<VersionedComputationHandle> computations;
|
||||
std::vector<int64> sorted_handles;
|
||||
for (const auto& handle_request : session_computation_.requests()) {
|
||||
int64 handle_value = handle_request.first;
|
||||
sorted_handles.push_back(handle_request.first);
|
||||
}
|
||||
std::sort(sorted_handles.begin(), sorted_handles.end());
|
||||
for (int64 handle : sorted_handles) {
|
||||
const auto& handle_request = session_computation_.requests().find(handle);
|
||||
CHECK(handle_request != session_computation_.requests().end());
|
||||
int64 handle_value = handle_request->first;
|
||||
if (handle_value <= version) {
|
||||
const OperationRequest& request = handle_request.second;
|
||||
const OperationRequest& request = handle_request->second;
|
||||
switch (request.request().op_case()) {
|
||||
case OpRequest::kCallRequest: {
|
||||
CHECK_EQ(1, request.embedded_computation_versions_size());
|
||||
|
@ -102,6 +102,32 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Constructs and returns the new shape with the given minor_to_major order in
|
||||
// its Layout.
|
||||
StatusOr<Shape> MakeShapeWithLayoutInternal(
|
||||
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> minor_to_major) {
|
||||
if (dimensions.size() != minor_to_major.size()) {
|
||||
return InvalidArgument("Dimensions size is %ld, but layout size is %ld.",
|
||||
dimensions.size(), minor_to_major.size());
|
||||
}
|
||||
if (element_type == OPAQUE || element_type == TUPLE) {
|
||||
return InvalidArgument("Unsupported element type: %s",
|
||||
PrimitiveType_Name(element_type).c_str());
|
||||
}
|
||||
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
|
||||
auto min2maj = shape.mutable_layout()->mutable_minor_to_major();
|
||||
min2maj->Clear();
|
||||
for (int64 value : minor_to_major) {
|
||||
min2maj->Add(value);
|
||||
}
|
||||
if (!shape.has_layout()) {
|
||||
return InvalidArgument("Shape has no layout.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
|
||||
return shape;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
|
||||
@ -152,16 +178,8 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
|
||||
/* static */ Shape ShapeUtil::MakeShapeWithLayout(
|
||||
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
|
||||
tensorflow::gtl::ArraySlice<int64> minor_to_major) {
|
||||
CHECK_EQ(dimensions.size(), minor_to_major.size());
|
||||
Shape shape = MakeShape(element_type, dimensions);
|
||||
auto min2maj = shape.mutable_layout()->mutable_minor_to_major();
|
||||
min2maj->Clear();
|
||||
for (int64 value : minor_to_major) {
|
||||
min2maj->Add(value);
|
||||
}
|
||||
DCHECK(shape.has_layout());
|
||||
TF_DCHECK_OK(ValidateShape(shape));
|
||||
return shape;
|
||||
return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major)
|
||||
.ValueOrDie();
|
||||
}
|
||||
|
||||
/* static */ Shape ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
|
||||
@ -499,11 +517,10 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
|
||||
// Extract the layout minor-to-major and set it.
|
||||
TF_ASSIGN_OR_RETURN(std::vector<int64> min2maj,
|
||||
comma_list_to_int64s(layout_string));
|
||||
TF_RET_CHECK(dimensions.size() == min2maj.size());
|
||||
result =
|
||||
ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, min2maj);
|
||||
TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal(
|
||||
primitive_type, dimensions, min2maj));
|
||||
}
|
||||
TF_DCHECK_OK(ShapeUtil::ValidateShape(result));
|
||||
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result));
|
||||
return std::move(result);
|
||||
}
|
||||
|
||||
|
@ -102,28 +102,18 @@ cc_library(
|
||||
deps = [
|
||||
":literal_test_util",
|
||||
"//tensorflow/compiler/xla:shape_layout",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
|
||||
"//tensorflow/compiler/xla/service",
|
||||
"//tensorflow/compiler/xla/service:backend",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service:computation_layout",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_execution_profile",
|
||||
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
|
||||
"//tensorflow/compiler/xla/service:transfer_manager",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/compiler/xla/service:hlo_runner",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core:test",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -19,24 +19,9 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/backend.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
||||
#include "tensorflow/compiler/xla/service/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/transfer_manager.h"
|
||||
#include "tensorflow/compiler/xla/shape_layout.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -45,22 +30,6 @@ namespace se = ::perftools::gputools;
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Define this in .cc file to avoid having to include eigen or forward declare
|
||||
// these types in the header.
|
||||
struct HloTestBase::EigenThreadPoolWrapper {
|
||||
std::unique_ptr<EigenThreadPoolWrapper> pool;
|
||||
std::unique_ptr<Eigen::ThreadPoolDevice> device;
|
||||
};
|
||||
|
||||
HloTestBase::HloTestBase() {}
|
||||
|
||||
HloTestBase::~HloTestBase() {
|
||||
// Deallocate all the memory allocated during the tests.
|
||||
for (auto& allocation : allocations_) {
|
||||
backend().default_stream_executor()->Deallocate(&allocation);
|
||||
}
|
||||
}
|
||||
|
||||
/* static */
|
||||
std::unique_ptr<HloModule> HloTestBase::CreateNewModule() {
|
||||
HloModuleConfig config;
|
||||
@ -80,98 +49,25 @@ StatusOr<perftools::gputools::DeviceMemoryBase> HloTestBase::Execute(
|
||||
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
||||
arguments,
|
||||
Shape* result_shape) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Executable> executable,
|
||||
backend().compiler()->Compile(std::move(module),
|
||||
backend().default_stream_executor()));
|
||||
|
||||
se::Stream stream(backend().default_stream_executor());
|
||||
stream.Init();
|
||||
|
||||
ExecutableRunOptions run_options;
|
||||
run_options.set_stream(&stream);
|
||||
run_options.set_allocator(backend().memory_allocator());
|
||||
run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool());
|
||||
run_options.set_intra_op_thread_pool(
|
||||
backend().eigen_intra_op_thread_pool_device());
|
||||
|
||||
HloExecutionProfile hlo_execution_profile;
|
||||
ServiceExecutableRunOptions service_run_options(
|
||||
run_options, backend().StreamBorrower(),
|
||||
backend().inter_op_thread_pool());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
se::DeviceMemoryBase result,
|
||||
executable->ExecuteOnStream(&service_run_options, arguments,
|
||||
&hlo_execution_profile));
|
||||
TF_RET_CHECK(stream.BlockHostUntilDone());
|
||||
|
||||
allocations_.push_back(result);
|
||||
|
||||
*result_shape = executable->result_shape();
|
||||
|
||||
if (ShapeUtil::IsTuple(*result_shape)) {
|
||||
// We must record element buffers of tuples as well to avoid leaks.
|
||||
DCHECK(!ShapeUtil::IsNestedTuple(*result_shape));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::vector<se::DeviceMemoryBase> element_buffers,
|
||||
backend().transfer_manager()->ShallowCopyTupleFromDevice(
|
||||
backend().default_stream_executor(), result, *result_shape));
|
||||
|
||||
// A tuple may contain the same buffer in more than one element. Keep track
|
||||
// of the buffers already added to avoid duplicates in allocations_.
|
||||
std::set<void*> added_opaques;
|
||||
for (auto element_buffer : element_buffers) {
|
||||
if (added_opaques.count(element_buffer.opaque()) == 0) {
|
||||
CHECK(element_buffer.opaque() != nullptr);
|
||||
added_opaques.insert(element_buffer.opaque());
|
||||
allocations_.push_back(element_buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
return runner_.Execute(std::move(module), arguments, result_shape);
|
||||
}
|
||||
|
||||
se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) {
|
||||
// Allocate memory on the device using the stream executor.
|
||||
int64 allocation_size =
|
||||
backend().transfer_manager()->GetByteSizeRequirement(literal.shape());
|
||||
se::DeviceMemoryBase allocation =
|
||||
backend().default_stream_executor()->AllocateArray<uint8>(
|
||||
allocation_size);
|
||||
allocations_.push_back(allocation);
|
||||
|
||||
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToDevice(
|
||||
backend().default_stream_executor(), literal, &allocation));
|
||||
|
||||
return allocation;
|
||||
return runner_.TransferToDevice(literal);
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> HloTestBase::TransferFromDevice(
|
||||
const Shape& shape, se::DeviceMemoryBase device_base) {
|
||||
auto literal = MakeUnique<Literal>();
|
||||
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromDevice(
|
||||
backend().default_stream_executor(), device_base, shape, shape,
|
||||
literal.get()));
|
||||
return literal;
|
||||
return runner_.TransferFromDevice(shape, device_base);
|
||||
}
|
||||
|
||||
std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
|
||||
std::unique_ptr<HloModule> module,
|
||||
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
|
||||
Shape result_shape;
|
||||
se::DeviceMemoryBase device_base =
|
||||
Execute(std::move(module), arguments, &result_shape).ValueOrDie();
|
||||
return TransferFromDevice(result_shape, device_base);
|
||||
return runner_.ExecuteAndTransfer(std::move(module), arguments);
|
||||
}
|
||||
|
||||
Backend& HloTestBase::backend() {
|
||||
if (!backend_) {
|
||||
backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie();
|
||||
VLOG(1) << "executing on platform " << backend().platform()->Name();
|
||||
}
|
||||
return *backend_;
|
||||
}
|
||||
Backend& HloTestBase::backend() { return runner_.backend(); }
|
||||
|
||||
/* static */
|
||||
string HloTestBase::TestName() {
|
||||
|
@ -21,12 +21,12 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/backend.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_runner.h"
|
||||
#include "tensorflow/compiler/xla/shape_layout.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
@ -39,10 +39,9 @@ namespace xla {
|
||||
// building a graph of HLO instructions to run.
|
||||
class HloTestBase : public ::testing::Test {
|
||||
protected:
|
||||
struct EigenThreadPoolWrapper;
|
||||
HloTestBase();
|
||||
HloTestBase() {}
|
||||
|
||||
~HloTestBase() override;
|
||||
~HloTestBase() override {}
|
||||
|
||||
// Creates a new HLO module for a test. The module created will have
|
||||
// TestName() for its name; it will also automatically populate its debug
|
||||
@ -102,23 +101,12 @@ class HloTestBase : public ::testing::Test {
|
||||
|
||||
static string TestName();
|
||||
|
||||
// Creates (if necessary) and returns the default backend. If creation fails,
|
||||
// crashes the program.
|
||||
//
|
||||
// This creates the backend lazily so it's possible to instantiate an
|
||||
// HloTestBase in a program without any backends linked in.
|
||||
// Returns the backend owned by the HloRunner.
|
||||
Backend& backend();
|
||||
|
||||
// This vector contains handles of all the device memory allocations performed
|
||||
// by the test. These are deallocated on destruction of the test object.
|
||||
std::vector<perftools::gputools::DeviceMemoryBase> allocations_;
|
||||
HloRunner runner_;
|
||||
|
||||
ErrorSpec error_spec_{0.0001};
|
||||
|
||||
std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_;
|
||||
|
||||
private:
|
||||
std::unique_ptr<Backend> backend_; // Lazily populated. Access via backend().
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -210,6 +210,18 @@ tf_cc_binary(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "hlo_proto_to_json",
|
||||
srcs = ["hlo_proto_to_json.cc"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:hlo_proto",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
|
91
tensorflow/compiler/xla/tools/hlo_proto_to_json.cc
Normal file
91
tensorflow/compiler/xla/tools/hlo_proto_to_json.cc
Normal file
@ -0,0 +1,91 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Usage:
|
||||
// hlo_proto_to_json --input_file=some_binary_proto
|
||||
// --output_file=path_to_dump_output
|
||||
//
|
||||
// Reads one serilized Hlo module, convert it into JSON format and dump into
|
||||
// some output directory. some_binaray_proto is obtained by serializing Hlo
|
||||
// module to disk using --xla_dump_hlo_proto_to debug optoin.
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
using tensorflow::Env;
|
||||
using xla::string;
|
||||
|
||||
namespace xla {
|
||||
namespace tools {
|
||||
|
||||
StatusOr<string> ToJson(const tensorflow::protobuf::Message& message) {
|
||||
string json_output;
|
||||
tensorflow::protobuf::util::JsonPrintOptions json_options;
|
||||
json_options.add_whitespace = true;
|
||||
json_options.always_print_primitive_fields = true;
|
||||
auto status = tensorflow::protobuf::util::MessageToJsonString(
|
||||
message, &json_output, json_options);
|
||||
if (!status.ok()) {
|
||||
return InternalError("MessageToJsonString failed: %s",
|
||||
status.error_message().data());
|
||||
}
|
||||
return json_output;
|
||||
}
|
||||
|
||||
void RealMain(const string& input, const string& output) {
|
||||
HloProto hlo_proto;
|
||||
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), input,
|
||||
&hlo_proto))
|
||||
<< "Can't open, read, or parse input file " << input;
|
||||
|
||||
auto statusor = ToJson(hlo_proto);
|
||||
QCHECK(statusor.ok()) << "Error converting " << input << " to JSON."
|
||||
<< statusor.status();
|
||||
|
||||
TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(), output,
|
||||
statusor.ValueOrDie()));
|
||||
}
|
||||
|
||||
} // namespace tools
|
||||
} // namespace xla
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
string input_file, output_file;
|
||||
const std::vector<tensorflow::Flag> flag_list = {
|
||||
tensorflow::Flag("input_file", &input_file, "file to convert."),
|
||||
tensorflow::Flag("output_file", &output_file, "converted file"),
|
||||
};
|
||||
const string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
||||
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
|
||||
QCHECK(parse_ok && argc == 1) << "\n" << usage;
|
||||
|
||||
QCHECK(!input_file.empty()) << "--input_file is required";
|
||||
QCHECK(!output_file.empty()) << "--output_file is required";
|
||||
|
||||
xla::tools::RealMain(input_file, output_file);
|
||||
|
||||
return 0;
|
||||
}
|
84
tensorflow/compiler/xla/tools/parser/BUILD
Normal file
84
tensorflow/compiler/xla/tools/parser/BUILD
Normal file
@ -0,0 +1,84 @@
|
||||
# Build file for the Hlo parser.
|
||||
|
||||
licenses(["notice"]) # Apache 2.0
|
||||
|
||||
package(
|
||||
default_visibility = [":friends"],
|
||||
)
|
||||
|
||||
package_group(
|
||||
name = "friends",
|
||||
includes = [
|
||||
"//tensorflow/compiler/xla:friends",
|
||||
],
|
||||
)
|
||||
|
||||
# Filegroup used to collect source files for dependency checking.
|
||||
filegroup(
|
||||
name = "c_srcs",
|
||||
data = glob([
|
||||
"**/*.cc",
|
||||
"**/*.h",
|
||||
]),
|
||||
)
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
cc_library(
|
||||
name = "hlo_lexer",
|
||||
srcs = ["hlo_lexer.cc"],
|
||||
hdrs = [
|
||||
"hlo_lexer.h",
|
||||
"hlo_token.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_parser",
|
||||
srcs = ["hlo_parser.cc"],
|
||||
hdrs = ["hlo_parser.h"],
|
||||
deps = [
|
||||
":hlo_lexer",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "hlo_parser_test",
|
||||
size = "small",
|
||||
srcs = ["hlo_parser_test.cc"],
|
||||
deps = [
|
||||
":hlo_parser",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
["**/*"],
|
||||
exclude = [
|
||||
"**/METADATA",
|
||||
"**/OWNERS",
|
||||
],
|
||||
),
|
||||
visibility = ["//tensorflow:__subpackages__"],
|
||||
)
|
69
tensorflow/compiler/xla/tools/parser/README.md
Normal file
69
tensorflow/compiler/xla/tools/parser/README.md
Normal file
@ -0,0 +1,69 @@
|
||||
# HloModule string syntax
|
||||
|
||||
TODO: Support subcomputations (for fusion, reduce, while, ...).
|
||||
|
||||
TODO: Support ops that require extra attributes, e.g. dimensions, strides.
|
||||
|
||||
```yacc
|
||||
hlo_module
|
||||
: 'HloModule' name computation
|
||||
;
|
||||
|
||||
computation
|
||||
: 'ENTRY' name param_list '->' shape instruction_list
|
||||
;
|
||||
|
||||
instruction_list
|
||||
: '{' instruction_list1 '}'
|
||||
;
|
||||
instruction_list1
|
||||
: instruction
|
||||
| instruction_list1 instruction
|
||||
;
|
||||
instruction
|
||||
: name '=' shape opcode operands
|
||||
;
|
||||
|
||||
operands
|
||||
: '(' operands1 ')'
|
||||
;
|
||||
operands1
|
||||
: /*empty*/
|
||||
| operand
|
||||
| operands1 ',' operand
|
||||
;
|
||||
operand
|
||||
: shape name
|
||||
;
|
||||
|
||||
param_list
|
||||
: '(' param_list1 ')'
|
||||
;
|
||||
param_list1
|
||||
: /*empty*/
|
||||
| param
|
||||
| param_list1 ',' param
|
||||
;
|
||||
param
|
||||
: name shape
|
||||
;
|
||||
|
||||
shape
|
||||
: shape_val_
|
||||
| '(' tuple_elements ')'
|
||||
;
|
||||
tuple_elements
|
||||
: /*empty*/
|
||||
| shape (',' shape)*
|
||||
;
|
||||
|
||||
name
|
||||
: identifier ':'
|
||||
| '%' identifier
|
||||
;
|
||||
|
||||
identifier
|
||||
: [a-zA-Z_][a-zA-Z0-9_.-]*
|
||||
;
|
||||
|
||||
```
|
270
tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
Normal file
270
tensorflow/compiler/xla/tools/parser/hlo_lexer.cc
Normal file
@ -0,0 +1,270 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h"
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/gtl/optional.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
|
||||
namespace xla {
|
||||
namespace tools {
|
||||
|
||||
using tensorflow::StringPiece;
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr int kEOF = -1;
|
||||
constexpr int kError = -2;
|
||||
|
||||
// [a-zA-Z0-9_.-]
|
||||
bool IsIdentifierChar(char c) {
|
||||
return isalnum(static_cast<unsigned char>(c)) || c == '-' || c == '.' ||
|
||||
c == '_';
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
int HloLexer::GetNextChar() {
|
||||
int current_char = PeekCurrentChar();
|
||||
if (current_char != kEOF && current_char != kError) {
|
||||
current_ptr_++;
|
||||
}
|
||||
return current_char;
|
||||
}
|
||||
|
||||
int HloLexer::PeekCurrentChar() const {
|
||||
if (current_ptr_ == buf_.end()) {
|
||||
return kEOF;
|
||||
}
|
||||
char current_char = *current_ptr_;
|
||||
if (current_char == 0) {
|
||||
// '\0' should not appear in the middle of the string.
|
||||
return kError;
|
||||
}
|
||||
return static_cast<unsigned char>(current_char);
|
||||
}
|
||||
|
||||
bool HloLexer::CanDereference(const char* ptr) const {
|
||||
return ptr < buf_.end() && ptr >= buf_.begin();
|
||||
}
|
||||
|
||||
StringPiece HloLexer::StringPieceFromPointers(const char* begin,
|
||||
const char* end) const {
|
||||
CHECK(begin <= end);
|
||||
CHECK(begin == buf_.end() || CanDereference(begin));
|
||||
CHECK(end == buf_.end() || CanDereference(end));
|
||||
return StringPiece(begin, end - begin);
|
||||
}
|
||||
|
||||
tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers(
|
||||
const char* begin, const char* end) const {
|
||||
CHECK(begin <= end);
|
||||
CHECK(begin == buf_.end() || CanDereference(begin));
|
||||
CHECK(end == buf_.end() || CanDereference(end));
|
||||
return tensorflow::RegexpStringPiece(begin, end - begin);
|
||||
}
|
||||
|
||||
TokKind HloLexer::LexToken() {
|
||||
while (true) {
|
||||
token_start_ = current_ptr_;
|
||||
|
||||
int current_char = GetNextChar();
|
||||
switch (current_char) {
|
||||
default:
|
||||
// [a-zA-Z_]
|
||||
if (isalpha(static_cast<unsigned char>(current_char)) ||
|
||||
current_char == '_') {
|
||||
return LexIdentifier();
|
||||
}
|
||||
return TokKind::kError;
|
||||
case kEOF:
|
||||
// Hit the end of the input buffer.
|
||||
return TokKind::kEof;
|
||||
case kError:
|
||||
// Hit an invalid character in the input buffer.
|
||||
return TokKind::kError;
|
||||
case ' ':
|
||||
case '\t':
|
||||
case '\n':
|
||||
case '\r':
|
||||
// Ignore whitespace.
|
||||
continue;
|
||||
case '0':
|
||||
case '1':
|
||||
case '2':
|
||||
case '3':
|
||||
case '4':
|
||||
case '5':
|
||||
case '6':
|
||||
case '7':
|
||||
case '8':
|
||||
case '9':
|
||||
case '-':
|
||||
if (current_char == '-' && PeekCurrentChar() == '>') {
|
||||
current_ptr_++;
|
||||
return TokKind::kArrow;
|
||||
}
|
||||
return LexDigitOrNegative();
|
||||
case '=':
|
||||
return TokKind::kEqual;
|
||||
case ',':
|
||||
return TokKind::kComma;
|
||||
case '%':
|
||||
return LexPercent();
|
||||
case ':':
|
||||
return TokKind::kColon;
|
||||
case '[':
|
||||
return TokKind::kLsquare;
|
||||
case ']':
|
||||
return TokKind::kRsquare;
|
||||
case '{':
|
||||
return TokKind::kLbrace;
|
||||
case '}':
|
||||
return TokKind::kRbrace;
|
||||
case '(':
|
||||
return TokKind::kLparen;
|
||||
case ')':
|
||||
return TokKind::kRparen;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Lex a shape, name, keyword, or opcode.
|
||||
// shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})?
|
||||
// name ::= [a-zA-Z_][a-zA-Z0-9_.-]*:
|
||||
// keyword ::= HloModule, ENTRY, ...
|
||||
// opcode ::= add, greater-than, ...
|
||||
TokKind HloLexer::LexIdentifier() {
|
||||
{
|
||||
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
|
||||
// 'consumable' will be advanced iff its prefix matches the pattern.
|
||||
static LazyRE2 shape_pattern = {
|
||||
R"(^(\w*\d*)\[([\d,]*)\](?:\s*{([\d,]*)})?)"};
|
||||
if (RE2::Consume(&consumable, *shape_pattern)) {
|
||||
auto status_or_shape = ShapeUtil::ParseShapeString(
|
||||
StringPieceFromPointers(token_start_, consumable.begin()));
|
||||
if (status_or_shape.ok()) {
|
||||
// This is a shape string.
|
||||
shape_val_ = status_or_shape.ValueOrDie();
|
||||
current_ptr_ = consumable.begin();
|
||||
return TokKind::kShape;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
while (IsIdentifierChar(PeekCurrentChar())) {
|
||||
current_ptr_++;
|
||||
}
|
||||
|
||||
// If followed by ':', it's a name.
|
||||
if (PeekCurrentChar() == ':') {
|
||||
str_val_.assign(token_start_, current_ptr_);
|
||||
current_ptr_++; // skip ':'
|
||||
return TokKind::kName;
|
||||
}
|
||||
|
||||
StringPiece identifier = StringPieceFromPointers(token_start_, current_ptr_);
|
||||
|
||||
// See if this is a keyword.
|
||||
#define KEYWORD(STR) \
|
||||
do { \
|
||||
if (identifier == #STR) { \
|
||||
return TokKind::kw_##STR; \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
KEYWORD(true);
|
||||
KEYWORD(false);
|
||||
KEYWORD(HloModule);
|
||||
KEYWORD(ENTRY);
|
||||
|
||||
#undef KEYWORD
|
||||
|
||||
// See if this is an opcode.
|
||||
auto opcode = StringToHloOpcode(identifier.ToString());
|
||||
if (opcode.ok()) {
|
||||
opcode_val_ = opcode.ValueOrDie();
|
||||
return TokKind::kOpcode;
|
||||
}
|
||||
|
||||
current_ptr_ = token_start_ + 1;
|
||||
return TokKind::kError;
|
||||
}
|
||||
|
||||
// Lex names after a % character.
|
||||
// name ::= [a-zA-Z_][a-zA-Z0-9_.-]*
|
||||
TokKind HloLexer::LexPercent() {
|
||||
const char* name_start = current_ptr_;
|
||||
if (isalpha(static_cast<unsigned char>(PeekCurrentChar())) ||
|
||||
PeekCurrentChar() == '_') {
|
||||
current_ptr_++;
|
||||
while (IsIdentifierChar(PeekCurrentChar())) {
|
||||
current_ptr_++;
|
||||
}
|
||||
str_val_.assign(name_start, current_ptr_);
|
||||
return TokKind::kName;
|
||||
}
|
||||
return TokKind::kError;
|
||||
}
|
||||
|
||||
// Lex integer and floating-point values.
|
||||
// int [-]?[0-9]+
|
||||
// fp with exp [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
|
||||
// fp without exp [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
|
||||
TokKind HloLexer::LexDigitOrNegative() {
|
||||
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
|
||||
static LazyRE2 float_pattern = {
|
||||
R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|(\d+[.]\d*|\d*[.]\d+))"};
|
||||
if (RE2::Consume(&consumable, *float_pattern)) {
|
||||
current_ptr_ = consumable.begin();
|
||||
tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(),
|
||||
&decimal_val_);
|
||||
return TokKind::kDecimal;
|
||||
}
|
||||
|
||||
static LazyRE2 int_pattern = {R"([-]?\d+)"};
|
||||
if (RE2::Consume(&consumable, *int_pattern)) {
|
||||
current_ptr_ = consumable.begin();
|
||||
tensorflow::strings::safe_strto64(
|
||||
StringPieceFromPointers(token_start_, current_ptr_), &int64_val_);
|
||||
return TokKind::kInt;
|
||||
}
|
||||
|
||||
return TokKind::kError;
|
||||
}
|
||||
|
||||
StringPiece HloLexer::GetCurrentLine() const {
|
||||
const char* start = token_start_;
|
||||
const char* end = current_ptr_;
|
||||
if (!CanDereference(start) || !CanDereference(end)) {
|
||||
return "LINE OUT OF RANGE";
|
||||
}
|
||||
while (start > buf_.begin() && *start != '\n') {
|
||||
start--;
|
||||
}
|
||||
while (end < buf_.end() && *end != '\n') {
|
||||
end++;
|
||||
}
|
||||
return StringPieceFromPointers(start, end);
|
||||
}
|
||||
|
||||
} // namespace tools
|
||||
} // namespace xla
|
108
tensorflow/compiler/xla/tools/parser/hlo_lexer.h
Normal file
108
tensorflow/compiler/xla/tools/parser/hlo_lexer.h
Normal file
@ -0,0 +1,108 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/tools/parser/hlo_token.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
namespace tools {
|
||||
|
||||
// Lexer for the HloModule::ToString() format text.
|
||||
class HloLexer {
|
||||
public:
|
||||
explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) {
|
||||
current_ptr_ = buf_.begin();
|
||||
}
|
||||
|
||||
TokKind Lex() { return current_kind_ = LexToken(); }
|
||||
TokKind GetKind() const { return current_kind_; }
|
||||
string GetStrVal() const {
|
||||
CHECK(GetKind() == TokKind::kName);
|
||||
return str_val_;
|
||||
}
|
||||
Shape GetShapeVal() const {
|
||||
CHECK(GetKind() == TokKind::kShape);
|
||||
return shape_val_;
|
||||
}
|
||||
HloOpcode GetOpcodeVal() const {
|
||||
CHECK(GetKind() == TokKind::kOpcode);
|
||||
return opcode_val_;
|
||||
}
|
||||
int64 GetInt64Val() const {
|
||||
CHECK(GetKind() == TokKind::kInt);
|
||||
return int64_val_;
|
||||
}
|
||||
double GetDecimalVal() const {
|
||||
CHECK(GetKind() == TokKind::kDecimal);
|
||||
return decimal_val_;
|
||||
}
|
||||
|
||||
// Returns the line of text that is currently being lexed.
|
||||
tensorflow::StringPiece GetCurrentLine() const;
|
||||
|
||||
private:
|
||||
// Returns the current character. If it's neither the end of input buffer nor
|
||||
// an invalid character, moves the pointer forward.
|
||||
int GetNextChar();
|
||||
|
||||
// Returns the current character.
|
||||
int PeekCurrentChar() const;
|
||||
|
||||
// Creates StringPiece with the given begin and end. Exits if the begin > end,
|
||||
// or it's out of the range of the current buffer.
|
||||
tensorflow::StringPiece StringPieceFromPointers(const char* begin,
|
||||
const char* end) const;
|
||||
tensorflow::RegexpStringPiece RegexpStringPieceFromPointers(
|
||||
const char* begin, const char* end) const;
|
||||
|
||||
// Returns true if the given ptr is dereferenceable within the range of the
|
||||
// current buffer.
|
||||
bool CanDereference(const char* ptr) const;
|
||||
|
||||
TokKind LexToken();
|
||||
|
||||
TokKind LexIdentifier();
|
||||
TokKind LexPercent();
|
||||
TokKind LexShape();
|
||||
TokKind LexConstant();
|
||||
TokKind LexDigitOrNegative();
|
||||
|
||||
const tensorflow::StringPiece buf_;
|
||||
const char* current_ptr_;
|
||||
|
||||
// Information about the current token.
|
||||
const char* token_start_;
|
||||
TokKind current_kind_;
|
||||
string str_val_;
|
||||
Shape shape_val_;
|
||||
HloOpcode opcode_val_;
|
||||
int64 int64_val_;
|
||||
double decimal_val_;
|
||||
};
|
||||
|
||||
} // namespace tools
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_
|
502
tensorflow/compiler/xla/tools/parser/hlo_parser.cc
Normal file
502
tensorflow/compiler/xla/tools/parser/hlo_parser.cc
Normal file
@ -0,0 +1,502 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace xla {
|
||||
namespace tools {
|
||||
|
||||
namespace {
|
||||
|
||||
using tensorflow::StringPiece;
|
||||
using tensorflow::strings::StrCat;
|
||||
|
||||
// Parser for the HloModule::ToString() format text.
|
||||
class HloParser {
|
||||
public:
|
||||
explicit HloParser(StringPiece str) : lexer_(str) {}
|
||||
|
||||
// Runs the parser. Returns false if an error occurred.
|
||||
bool Run();
|
||||
|
||||
// Returns the parsed HloModule.
|
||||
std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); }
|
||||
|
||||
// Returns the error information.
|
||||
string GetError() const { return tensorflow::str_util::Join(error_, "\n"); }
|
||||
|
||||
private:
|
||||
// ParseXXX returns false if an error occurred.
|
||||
bool ParseHloModule();
|
||||
bool ParseComputation();
|
||||
bool ParseInstructionList(HloComputation::Builder* builder);
|
||||
bool ParseInstruction(HloComputation::Builder* builder);
|
||||
bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
|
||||
bool ParseOperands(std::vector<HloInstruction*>* operands,
|
||||
const int expected_size);
|
||||
bool ParseParamList();
|
||||
bool ParseName(string* result);
|
||||
bool ParseShape(Shape* result);
|
||||
bool ParseOpcode(HloOpcode* result);
|
||||
bool ParseInt64(int64* result);
|
||||
bool ParseDecimal(double* result);
|
||||
bool ParseBool(bool* result);
|
||||
bool ParseToken(TokKind kind, const string& msg);
|
||||
|
||||
// Logs the current parsing line and the given message. Always returns false.
|
||||
bool TokenError(StringPiece msg);
|
||||
|
||||
// If the current token is 'kind', eats it (i.e. lexes the next token) and
|
||||
// returns true.
|
||||
bool EatIfPresent(TokKind kind);
|
||||
|
||||
// Adds the instruction to the pool. Returns false and emits an error if the
|
||||
// instruction already exists.
|
||||
bool AddInstruction(const string& name, HloInstruction* instruction);
|
||||
|
||||
// The map from the instruction name to the instruction. This does not own the
|
||||
// instructions.
|
||||
std::unordered_map<string, HloInstruction*> instruction_pool_;
|
||||
|
||||
HloLexer lexer_;
|
||||
std::unique_ptr<HloModule> module_;
|
||||
std::vector<string> error_;
|
||||
};
|
||||
|
||||
bool HloParser::TokenError(StringPiece msg) {
|
||||
error_.push_back(
|
||||
StrCat("was parsing \"", lexer_.GetCurrentLine(), "\"; ", msg));
|
||||
return false;
|
||||
}
|
||||
|
||||
bool HloParser::Run() {
|
||||
lexer_.Lex();
|
||||
return ParseHloModule();
|
||||
}
|
||||
|
||||
// ::= 'HloModule' name computation
|
||||
bool HloParser::ParseHloModule() {
|
||||
if (lexer_.GetKind() != TokKind::kw_HloModule) {
|
||||
return TokenError("expects HloModule");
|
||||
}
|
||||
// Eat 'HloModule'
|
||||
lexer_.Lex();
|
||||
|
||||
string name;
|
||||
if (!ParseName(&name)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
module_ = MakeUnique<HloModule>(name);
|
||||
|
||||
return ParseComputation();
|
||||
}
|
||||
|
||||
// computation ::= 'ENTRY' name param_list '->' shape instruction_list
|
||||
bool HloParser::ParseComputation() {
|
||||
string name;
|
||||
if (!ParseToken(TokKind::kw_ENTRY, "expects 'ENTRY'") || !ParseName(&name)) {
|
||||
return false;
|
||||
}
|
||||
auto builder = MakeUnique<HloComputation::Builder>(name);
|
||||
|
||||
Shape shape;
|
||||
if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'") ||
|
||||
!ParseShape(&shape) || !ParseInstructionList(builder.get())) {
|
||||
return false;
|
||||
}
|
||||
module_->AddEntryComputation(builder->Build());
|
||||
return true;
|
||||
}
|
||||
|
||||
// instruction_list ::= '{' instruction_list1 '}'
|
||||
// instruction_list1 ::= (instruction)+
|
||||
bool HloParser::ParseInstructionList(HloComputation::Builder* builder) {
|
||||
if (!ParseToken(TokKind::kLbrace,
|
||||
"expects '{' at the beginning of instruction list.")) {
|
||||
return false;
|
||||
}
|
||||
do {
|
||||
if (!ParseInstruction(builder)) {
|
||||
return false;
|
||||
}
|
||||
} while (lexer_.GetKind() != TokKind::kRbrace);
|
||||
return ParseToken(TokKind::kRbrace,
|
||||
"expects '}' at the end of instruction list.");
|
||||
}
|
||||
|
||||
// instruction ::= name '=' shape opcode operands
|
||||
bool HloParser::ParseInstruction(HloComputation::Builder* builder) {
|
||||
string name;
|
||||
Shape shape;
|
||||
HloOpcode opcode;
|
||||
std::vector<HloInstruction*> operands;
|
||||
if (!ParseName(&name) ||
|
||||
!ParseToken(TokKind::kEqual, "expects '=' in instruction") ||
|
||||
!ParseShape(&shape) || !ParseOpcode(&opcode)) {
|
||||
return false;
|
||||
}
|
||||
switch (opcode) {
|
||||
case HloOpcode::kParameter: {
|
||||
int64 parameter_number;
|
||||
return ParseToken(TokKind::kLparen,
|
||||
"expects '(' before parameter number") &&
|
||||
ParseInt64(¶meter_number) &&
|
||||
ParseToken(TokKind::kRparen,
|
||||
"expects ')' after parameter number") &&
|
||||
AddInstruction(
|
||||
name, builder->AddInstruction(HloInstruction::CreateParameter(
|
||||
parameter_number, shape, name)));
|
||||
}
|
||||
case HloOpcode::kConstant: {
|
||||
std::unique_ptr<Literal> literal;
|
||||
return ParseToken(TokKind::kLparen,
|
||||
"expects '(' before parameter number") &&
|
||||
ParseLiteral(&literal, shape) &&
|
||||
ParseToken(TokKind::kRparen,
|
||||
"expects ')' after parameter number") &&
|
||||
AddInstruction(
|
||||
name, builder->AddInstruction(
|
||||
HloInstruction::CreateConstant(std::move(literal))));
|
||||
}
|
||||
// Unary ops.
|
||||
case HloOpcode::kAbs:
|
||||
case HloOpcode::kRoundNearestAfz:
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kCeil:
|
||||
case HloOpcode::kCopy:
|
||||
case HloOpcode::kCos:
|
||||
case HloOpcode::kExp:
|
||||
case HloOpcode::kIsFinite:
|
||||
case HloOpcode::kFloor:
|
||||
case HloOpcode::kLog:
|
||||
case HloOpcode::kNot:
|
||||
case HloOpcode::kNegate:
|
||||
case HloOpcode::kSign:
|
||||
case HloOpcode::kSin:
|
||||
case HloOpcode::kSort:
|
||||
case HloOpcode::kTanh: {
|
||||
return ParseOperands(&operands, /*expected_size=*/1) &&
|
||||
AddInstruction(name,
|
||||
builder->AddInstruction(HloInstruction::CreateUnary(
|
||||
shape, opcode, operands[0])));
|
||||
}
|
||||
// Binary ops.
|
||||
case HloOpcode::kAdd:
|
||||
case HloOpcode::kDivide:
|
||||
case HloOpcode::kMultiply:
|
||||
case HloOpcode::kSubtract:
|
||||
case HloOpcode::kEq:
|
||||
case HloOpcode::kGe:
|
||||
case HloOpcode::kGt:
|
||||
case HloOpcode::kLe:
|
||||
case HloOpcode::kLt:
|
||||
case HloOpcode::kNe:
|
||||
case HloOpcode::kDot:
|
||||
case HloOpcode::kMaximum:
|
||||
case HloOpcode::kMinimum:
|
||||
case HloOpcode::kPower:
|
||||
case HloOpcode::kRemainder:
|
||||
case HloOpcode::kAnd:
|
||||
case HloOpcode::kOr:
|
||||
case HloOpcode::kShiftLeft:
|
||||
case HloOpcode::kShiftRightArithmetic:
|
||||
case HloOpcode::kShiftRightLogical: {
|
||||
return ParseOperands(&operands, /*expected_size=*/2) &&
|
||||
AddInstruction(
|
||||
name, builder->AddInstruction(HloInstruction::CreateBinary(
|
||||
shape, opcode, operands[0], operands[1])));
|
||||
}
|
||||
// Ternary ops.
|
||||
case HloOpcode::kClamp:
|
||||
case HloOpcode::kSelect: {
|
||||
return ParseOperands(&operands, /*expected_size=*/3) &&
|
||||
AddInstruction(
|
||||
name,
|
||||
builder->AddInstruction(HloInstruction::CreateTernary(
|
||||
shape, opcode, operands[0], operands[1], operands[2])));
|
||||
}
|
||||
// Other supported ops.
|
||||
case HloOpcode::kConvert: {
|
||||
return ParseOperands(&operands, /*expected_size=*/1) &&
|
||||
AddInstruction(
|
||||
name, builder->AddInstruction(
|
||||
HloInstruction::CreateConvert(shape, operands[0])));
|
||||
}
|
||||
case HloOpcode::kCrossReplicaSum: {
|
||||
return ParseOperands(&operands, /*expected_size=*/1) &&
|
||||
AddInstruction(name, builder->AddInstruction(
|
||||
HloInstruction::CreateCrossReplicaSum(
|
||||
shape, operands[0])));
|
||||
}
|
||||
case HloOpcode::kReshape: {
|
||||
return ParseOperands(&operands, /*expected_size=*/1) &&
|
||||
AddInstruction(
|
||||
name, builder->AddInstruction(
|
||||
HloInstruction::CreateReshape(shape, operands[0])));
|
||||
}
|
||||
case HloOpcode::kBroadcast:
|
||||
case HloOpcode::kCall:
|
||||
case HloOpcode::kCustomCall:
|
||||
case HloOpcode::kConcatenate:
|
||||
case HloOpcode::kReducePrecision:
|
||||
case HloOpcode::kConvolution:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
case HloOpcode::kMap:
|
||||
case HloOpcode::kPad:
|
||||
case HloOpcode::kReduce:
|
||||
case HloOpcode::kReduceWindow:
|
||||
case HloOpcode::kSelectAndScatter:
|
||||
case HloOpcode::kReverse:
|
||||
case HloOpcode::kRng:
|
||||
case HloOpcode::kSlice:
|
||||
case HloOpcode::kDynamicSlice:
|
||||
case HloOpcode::kDynamicUpdateSlice:
|
||||
case HloOpcode::kTranspose:
|
||||
case HloOpcode::kTuple:
|
||||
case HloOpcode::kWhile:
|
||||
case HloOpcode::kFusion:
|
||||
case HloOpcode::kBatchNormTraining:
|
||||
case HloOpcode::kBatchNormInference:
|
||||
case HloOpcode::kInfeed:
|
||||
case HloOpcode::kOutfeed:
|
||||
case HloOpcode::kBatchNormGrad:
|
||||
case HloOpcode::kRecv:
|
||||
case HloOpcode::kSend:
|
||||
case HloOpcode::kUpdate:
|
||||
case HloOpcode::kIndex:
|
||||
case HloOpcode::kTrace:
|
||||
return TokenError(StrCat("parsing not yet implemented for op: ",
|
||||
HloOpcodeString(opcode)));
|
||||
}
|
||||
}
|
||||
|
||||
bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
|
||||
const Shape& shape) {
|
||||
switch (shape.element_type()) {
|
||||
case PRED:
|
||||
bool b;
|
||||
if (!ParseBool(&b)) {
|
||||
return false;
|
||||
}
|
||||
*literal = Literal::CreateR0<bool>(b);
|
||||
return true;
|
||||
case S32:
|
||||
int64 i;
|
||||
if (!ParseInt64(&i)) {
|
||||
return false;
|
||||
}
|
||||
*literal = Literal::CreateR0<int32>(i);
|
||||
return true;
|
||||
case F32:
|
||||
double d;
|
||||
if (!ParseDecimal(&d)) {
|
||||
return false;
|
||||
}
|
||||
*literal = Literal::CreateR0<float>(d);
|
||||
return true;
|
||||
default:
|
||||
return TokenError(StrCat("unsupported constant in shape: ",
|
||||
ShapeUtil::HumanString(shape)));
|
||||
}
|
||||
}
|
||||
|
||||
// operands ::= '(' operands1 ')'
|
||||
// operands1
|
||||
// ::= /*empty*/
|
||||
// ::= operand (, operand)*
|
||||
// operand ::= shape name
|
||||
bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands,
|
||||
const int expected_size) {
|
||||
if (!ParseToken(TokKind::kLparen,
|
||||
"expects '(' at the beginning of operands")) {
|
||||
return false;
|
||||
}
|
||||
if (lexer_.GetKind() == TokKind::kRparen) {
|
||||
// empty
|
||||
} else {
|
||||
do {
|
||||
Shape shape;
|
||||
string name;
|
||||
if (!ParseShape(&shape) || !ParseName(&name)) {
|
||||
return false;
|
||||
}
|
||||
HloInstruction* instruction =
|
||||
tensorflow::gtl::FindPtrOrNull(instruction_pool_, name);
|
||||
if (!instruction) {
|
||||
return TokenError(StrCat("instruction does not exist: ", name));
|
||||
}
|
||||
operands->push_back(instruction);
|
||||
} while (EatIfPresent(TokKind::kComma));
|
||||
}
|
||||
if (expected_size != operands->size()) {
|
||||
return TokenError(StrCat("expects ", expected_size, " operands, but has ",
|
||||
operands->size(), " operands"));
|
||||
}
|
||||
return ParseToken(TokKind::kRparen, "expects ')' at the end of operands");
|
||||
}
|
||||
|
||||
// param_list ::= '(' param_list1 ')'
|
||||
// param_list1
|
||||
// ::= /*empty*/
|
||||
// ::= param (',' param)*
|
||||
// param ::= name shape
|
||||
bool HloParser::ParseParamList() {
|
||||
if (!ParseToken(TokKind::kLparen,
|
||||
"expects '(' at the beginning of param list")) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (lexer_.GetKind() == TokKind::kRparen) {
|
||||
// empty
|
||||
} else {
|
||||
do {
|
||||
Shape shape;
|
||||
if (!ParseToken(TokKind::kName, "expects name in parameter") ||
|
||||
!ParseShape(&shape)) {
|
||||
return false;
|
||||
}
|
||||
} while (EatIfPresent(TokKind::kComma));
|
||||
}
|
||||
return ParseToken(TokKind::kRparen, "expects ')' at the end of param list");
|
||||
}
|
||||
|
||||
// shape ::= shape_val_
|
||||
// shape ::= '(' tuple_elements ')'
|
||||
// tuple_elements
|
||||
// ::= /*empty*/
|
||||
// ::= shape (',' shape)*
|
||||
bool HloParser::ParseShape(Shape* result) {
|
||||
if (EatIfPresent(TokKind::kLparen)) { // Tuple
|
||||
std::vector<Shape> shapes;
|
||||
if (lexer_.GetKind() == TokKind::kRparen) {
|
||||
/*empty*/
|
||||
} else {
|
||||
// shape (',' shape)*
|
||||
do {
|
||||
shapes.emplace_back();
|
||||
if (!ParseShape(&shapes.back())) {
|
||||
return false;
|
||||
}
|
||||
} while (EatIfPresent(TokKind::kComma));
|
||||
}
|
||||
*result = ShapeUtil::MakeTupleShape(shapes);
|
||||
return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple.");
|
||||
}
|
||||
|
||||
if (lexer_.GetKind() != TokKind::kShape) {
|
||||
return TokenError("expects shape");
|
||||
}
|
||||
*result = lexer_.GetShapeVal();
|
||||
lexer_.Lex();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParser::ParseName(string* result) {
|
||||
VLOG(1) << "ParseName";
|
||||
if (lexer_.GetKind() != TokKind::kName) {
|
||||
return TokenError("expects name");
|
||||
}
|
||||
*result = lexer_.GetStrVal();
|
||||
lexer_.Lex();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParser::ParseOpcode(HloOpcode* result) {
|
||||
VLOG(1) << "ParseOpcode";
|
||||
if (lexer_.GetKind() != TokKind::kOpcode) {
|
||||
return TokenError("expects opcode");
|
||||
}
|
||||
*result = lexer_.GetOpcodeVal();
|
||||
lexer_.Lex();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParser::ParseInt64(int64* result) {
|
||||
VLOG(1) << "ParseInt64";
|
||||
if (lexer_.GetKind() != TokKind::kInt) {
|
||||
return TokenError("expects integer");
|
||||
}
|
||||
*result = lexer_.GetInt64Val();
|
||||
lexer_.Lex();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParser::ParseDecimal(double* result) {
|
||||
switch (lexer_.GetKind()) {
|
||||
case TokKind::kDecimal:
|
||||
*result = lexer_.GetDecimalVal();
|
||||
break;
|
||||
case TokKind::kInt:
|
||||
*result = static_cast<double>(lexer_.GetInt64Val());
|
||||
break;
|
||||
default:
|
||||
return TokenError("expects decimal or integer");
|
||||
}
|
||||
lexer_.Lex();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParser::ParseBool(bool* result) {
|
||||
if (lexer_.GetKind() != TokKind::kw_true &&
|
||||
lexer_.GetKind() != TokKind::kw_false) {
|
||||
return TokenError("expects true or false");
|
||||
}
|
||||
*result = lexer_.GetKind() == TokKind::kw_true;
|
||||
lexer_.Lex();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParser::ParseToken(TokKind kind, const string& msg) {
|
||||
if (lexer_.GetKind() != kind) {
|
||||
return TokenError(msg);
|
||||
}
|
||||
lexer_.Lex();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParser::EatIfPresent(TokKind kind) {
|
||||
if (lexer_.GetKind() != kind) {
|
||||
return false;
|
||||
}
|
||||
lexer_.Lex();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HloParser::AddInstruction(const string& name,
|
||||
HloInstruction* instruction) {
|
||||
auto result = instruction_pool_.insert({name, instruction});
|
||||
if (!result.second) {
|
||||
return TokenError(StrCat("instruction already exists: ", name));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str) {
|
||||
HloParser parser(str);
|
||||
if (!parser.Run()) {
|
||||
return InvalidArgument("Syntax error: %s", parser.GetError().c_str());
|
||||
}
|
||||
return parser.ConsumeHloModule();
|
||||
}
|
||||
|
||||
} // namespace tools
|
||||
} // namespace xla
|
37
tensorflow/compiler/xla/tools/parser/hlo_parser.h
Normal file
37
tensorflow/compiler/xla/tools/parser/hlo_parser.h
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
namespace tools {
|
||||
|
||||
// The api of the hlo parser. Given a string in the HloModule::ToString()
|
||||
// format, returns the parsed HloModule.
|
||||
StatusOr<std::unique_ptr<HloModule>> Parse(tensorflow::StringPiece str);
|
||||
|
||||
} // namespace tools
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_
|
240
tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
Normal file
240
tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc
Normal file
@ -0,0 +1,240 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
|
||||
|
||||
#include <string>
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace xla {
|
||||
namespace tools {
|
||||
namespace {
|
||||
|
||||
struct TestData {
|
||||
string test_name;
|
||||
string module_string;
|
||||
};
|
||||
|
||||
string TestDataToString(const ::testing::TestParamInfo<TestData>& data) {
|
||||
return data.param.test_name;
|
||||
}
|
||||
|
||||
std::vector<TestData> CreateTestCases() {
|
||||
// clang-format off
|
||||
return std::vector<TestData>({
|
||||
// ax + y
|
||||
{
|
||||
"AxpyParam",
|
||||
R"(HloModule axpy_module:
|
||||
|
||||
ENTRY %axpy.v5 (alpha: f32[2,4], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
|
||||
%alpha = f32[2,4]{1,0} parameter(0)
|
||||
%x = f32[2,4]{1,0} parameter(1)
|
||||
%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %alpha, f32[2,4]{1,0} %x)
|
||||
%y = f32[2,4]{1,0} parameter(2)
|
||||
%add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// pred constant
|
||||
{
|
||||
"ConstantPred",
|
||||
R"(HloModule constant_pred_module:
|
||||
|
||||
ENTRY %constant_pred () -> pred[] {
|
||||
%constant = pred[] constant(true)
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// s32 constant
|
||||
{
|
||||
"ConstantS32",
|
||||
R"(HloModule constant_s32_module:
|
||||
|
||||
ENTRY %constant_s32 () -> s32[] {
|
||||
%constant = s32[] constant(-42)
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// f32 constant, but the value is not a decimal
|
||||
{
|
||||
"ConstantF32", R"(HloModule ConstantF32_module:
|
||||
|
||||
ENTRY %ConstantF32.v4 () -> f32[] {
|
||||
%constant = f32[] constant(42)
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// constant + constant
|
||||
{
|
||||
"AddConstants",
|
||||
R"(HloModule add_constants_module:
|
||||
|
||||
ENTRY %add_constants () -> f32[] {
|
||||
%constant = f32[] constant(3.14)
|
||||
%add = f32[] add(f32[] %constant, f32[] %constant)
|
||||
}
|
||||
|
||||
)"
|
||||
},
|
||||
// v1 > v2 ? v1 : v2
|
||||
{
|
||||
"SelectR1F32",
|
||||
R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module:
|
||||
|
||||
ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
|
||||
%v1 = f32[4]{0} parameter(0)
|
||||
%v2 = f32[4]{0} parameter(1)
|
||||
%greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2)
|
||||
%select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2)
|
||||
}
|
||||
|
||||
)"
|
||||
}
|
||||
});
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
class HloParserTest : public ::testing::Test,
|
||||
public ::testing::WithParamInterface<TestData> {
|
||||
protected:
|
||||
void ExpectSuccess() {
|
||||
const string& original = GetParam().module_string;
|
||||
auto result = Parse(original);
|
||||
TF_EXPECT_OK(result.status());
|
||||
EXPECT_EQ(original, result.ValueOrDie()->ToString());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(HloParserTest, Run) { ExpectSuccess(); }
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest,
|
||||
::testing::ValuesIn(CreateTestCases()),
|
||||
TestDataToString);
|
||||
|
||||
TEST_F(HloParserTest, Empty) {
|
||||
const string original = "";
|
||||
auto result = Parse(original);
|
||||
EXPECT_NE(tensorflow::Status::OK(), result.status());
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, Garbage) {
|
||||
const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$";
|
||||
auto result = Parse(original);
|
||||
EXPECT_NE(tensorflow::Status::OK(), result.status());
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, WrongOpcode) {
|
||||
const string original = R"(HloModule wrong_opcode:
|
||||
|
||||
ENTRY %blabla (x: f32[], y: f32[]) -> f32[] {
|
||||
%x = f32[]{} parameter(0)
|
||||
%y = f32[]{} parameter(1)
|
||||
%le = pred[]{} le(f32[]{} %x, f32[]{} %y)
|
||||
}
|
||||
|
||||
)";
|
||||
auto result = Parse(original);
|
||||
EXPECT_NE(tensorflow::Status::OK(), result.status());
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, WrongShape) {
|
||||
const string original = R"(HloModule wrong_opcode:
|
||||
|
||||
ENTRY %blabla (x: g32[]) -> g32[] {
|
||||
%x = g32[]{} parameter(0)
|
||||
}
|
||||
|
||||
)";
|
||||
auto result = Parse(original);
|
||||
EXPECT_NE(tensorflow::Status::OK(), result.status());
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, WrongOperandsSize) {
|
||||
const string original = R"(HloModule wrong_opcode:
|
||||
|
||||
ENTRY %blabla (x: f32[]) -> pred[] {
|
||||
%x = f32[]{} parameter(0)
|
||||
%eq = pred[]{} equal-to(f32[]{} %x)
|
||||
}
|
||||
|
||||
)";
|
||||
auto result = Parse(original);
|
||||
EXPECT_NE(tensorflow::Status::OK(), result.status());
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, OperandNotFound) {
|
||||
const string original = R"(HloModule operand_not_found:
|
||||
ENTRY %blabla (x: f32[]) -> pred[] {
|
||||
%x = f32[]{} parameter(0)
|
||||
%eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y)
|
||||
}
|
||||
)";
|
||||
auto result = Parse(original);
|
||||
EXPECT_NE(tensorflow::Status::OK(), result.status());
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, MoreConstants) {
|
||||
const string original = R"(HloModule SelectScalarS32True_module:
|
||||
|
||||
ENTRY %SelectScalarS32True.v4 () -> s32[] {
|
||||
%constant.2 = pred[] constant(true)
|
||||
%constant.1 = s32[] constant(-42)
|
||||
%constant = s32[] constant(42)
|
||||
%select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
|
||||
}
|
||||
|
||||
)";
|
||||
auto result = Parse(original);
|
||||
TF_EXPECT_OK(result.status());
|
||||
// Constant instructions have no name. The string will be parsed successfully
|
||||
// but the constant names will not be exactly the same.
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, ConstantWithExp) {
|
||||
const string original = R"(HloModule ConstantWithExp_module:
|
||||
|
||||
ENTRY %ConstantWithExp.v4 () -> f32[] {
|
||||
%constant.1 = f32[] constant(3e+2)
|
||||
}
|
||||
|
||||
)";
|
||||
auto result = Parse(original);
|
||||
TF_EXPECT_OK(result.status());
|
||||
// The string will be parsed successfully but the output strings are not
|
||||
// exactly the same, because "3e2" is parsed into value 300 and will be
|
||||
// printed as "300".
|
||||
}
|
||||
|
||||
TEST_F(HloParserTest, Tuple) {
|
||||
const string original = R"(HloModule EmptyTupleCreate_module:
|
||||
|
||||
ENTRY %EmptyTupleCreate.v1 () -> () {
|
||||
%tuple = () tuple()
|
||||
}
|
||||
|
||||
)";
|
||||
auto result = Parse(original);
|
||||
EXPECT_NE(tensorflow::Status::OK(), result.status());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tools
|
||||
} // namespace xla
|
58
tensorflow/compiler/xla/tools/parser/hlo_token.h
Normal file
58
tensorflow/compiler/xla/tools/parser/hlo_token.h
Normal file
@ -0,0 +1,58 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_
|
||||
|
||||
namespace xla {
|
||||
namespace tools {
|
||||
|
||||
// Defines different kinds of tokens in a hlo module string.
|
||||
enum class TokKind {
|
||||
// Markers
|
||||
kEof,
|
||||
kError,
|
||||
|
||||
// Tokens with no info.
|
||||
kEqual, // =
|
||||
kComma, // ,
|
||||
kColon, // :
|
||||
kLsquare,
|
||||
kRsquare, // [ ]
|
||||
kLbrace,
|
||||
kRbrace, // { }
|
||||
kLparen,
|
||||
kRparen, // ( )
|
||||
|
||||
kArrow, // ->
|
||||
|
||||
// Keywords
|
||||
kw_HloModule,
|
||||
kw_ENTRY,
|
||||
kw_true,
|
||||
kw_false,
|
||||
|
||||
// Typed tokens.
|
||||
kName, // %foo
|
||||
kShape, // f32[2,3]{1,0}
|
||||
kOpcode, // add
|
||||
kInt, // 42
|
||||
kDecimal, // 4.2
|
||||
};
|
||||
|
||||
} // namespace tools
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_
|
@ -82,8 +82,8 @@ message DebugOptions {
|
||||
// Dump all HLO modules as text into the provided directory path.
|
||||
string xla_generate_hlo_text_to = 7;
|
||||
|
||||
// Dump compilation artifacts as JSON into this directory.
|
||||
string xla_dump_debug_json_to = 8;
|
||||
// Dump compilation artifacts in binary proto into this directory.
|
||||
string xla_dump_hlo_proto_to = 8;
|
||||
|
||||
// Instrument the computation to collect per-HLO cycle counts.
|
||||
bool xla_hlo_profile = 9;
|
||||
|
@ -69,6 +69,28 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "adaptive_shared_batch_scheduler",
|
||||
hdrs = ["adaptive_shared_batch_scheduler.h"],
|
||||
deps = [
|
||||
":batch_scheduler",
|
||||
"//tensorflow/contrib/batching/util:periodic_function_dynamic",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "adaptive_shared_batch_scheduler_test",
|
||||
srcs = ["adaptive_shared_batch_scheduler_test.cc"],
|
||||
deps = [
|
||||
":adaptive_shared_batch_scheduler",
|
||||
"//tensorflow/contrib/batching/test_util:fake_clock_env",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "basic_batch_scheduler",
|
||||
hdrs = ["basic_batch_scheduler.h"],
|
||||
|
463
tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
Normal file
463
tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
Normal file
@ -0,0 +1,463 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/contrib/batching/batch_scheduler.h"
|
||||
#include "tensorflow/contrib/batching/util/periodic_function.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/platform/cpu_info.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace serving {
|
||||
namespace internal {
|
||||
template <typename TaskType>
|
||||
class ASBSBatch;
|
||||
|
||||
template <typename TaskType>
|
||||
class ASBSQueue;
|
||||
} // namespace internal
|
||||
|
||||
// Shared batch scheduler designed to minimize latency. The scheduler keeps
|
||||
// track of a number of queues (one per model or model version) which are
|
||||
// continuously enqueuing requests. The scheduler groups the requests into
|
||||
// batches which it periodically sends off for processing (see
|
||||
// shared_batch_scheduler.h for more details). The AdaptiveSharedBatchScheduler
|
||||
// prioritizes batches by age (i.e. the batch's oldest request) irrespective of
|
||||
// queue. The scheduler will process the oldest batch at an adjustable rate,
|
||||
// regardless of batch size. The user can provide feedback to help set this rate
|
||||
// to achieve some goal (i.e. minimize overall latency, limit cpu usage, etc).
|
||||
//
|
||||
// The rate (or rather, the corresponding period) is adjusted each time a batch
|
||||
// is processed, using an exponentially weighted moving average to smooth
|
||||
// potentially noisy feedback:
|
||||
// ewma_feedback = ((N - 1) * ewma_feedback + feedback()) / N
|
||||
// period *= (1 + K * emwa_feedback)
|
||||
//
|
||||
// Some potential use cases:
|
||||
// Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing
|
||||
// involves serial processing by a device, from a latency perspective it is
|
||||
// desirable to keep the device evenly loaded, avoiding the need to wait for
|
||||
// the device to process prior batches.
|
||||
// feedback = num_pending_on_device() - desired_pending.
|
||||
// CPU utilization - If the batch processing is cpu dominated, you can reap
|
||||
// latency gains when underutilized by increasing the processing rate, but
|
||||
// back the rate off when the load increases to avoid overload.
|
||||
// feedback = cpu_rate() - desired_cpu_rate.
|
||||
|
||||
template <typename TaskType>
|
||||
class AdaptiveSharedBatchScheduler
|
||||
: public std::enable_shared_from_this<
|
||||
AdaptiveSharedBatchScheduler<TaskType>> {
|
||||
public:
|
||||
struct Options {
|
||||
// The name to use for the pool of batch threads.
|
||||
string thread_pool_name = {"batch_threads"};
|
||||
// Number of batch processing threads; equivalently the maximum number of
|
||||
// concurrently running batches.
|
||||
int64 num_batch_threads = port::NumSchedulableCPUs();
|
||||
// The environment to use (typically only overridden by test code).
|
||||
Env* env = Env::Default();
|
||||
// Initial batch scheduling period in microseconds. Will be altered for
|
||||
// non-zero rate_feedback.
|
||||
double initial_scheduling_period_micros = 500;
|
||||
// Minimum batch scheduling period in microseconds. Recommend setting this
|
||||
// value greater than 0, otherwise it may take a while to recover from a
|
||||
// sustained time of negative scheduling_period_feedback (which may occur
|
||||
// under low load).
|
||||
double min_scheduling_period_micros = 100;
|
||||
// Maximum batch scheduling period in microseconds.
|
||||
double max_scheduling_period_micros = 10000;
|
||||
// Feedback function used to modify the scheduling period each time a batch
|
||||
// is scheduled. Should return values roughly O(1), with positive values
|
||||
// resulting in an increased period.
|
||||
std::function<double()> scheduling_period_feedback = [] { return 0.; };
|
||||
// To handle potentially noisy scheduling_period_feedback, the period is
|
||||
// adjusted using an exponentially weighted moving average over the previous
|
||||
// feedback_smoothing_batches batches. Must be greater than 0.
|
||||
int64 feedback_smoothing_batches = 10;
|
||||
};
|
||||
|
||||
// Ownership is shared between the caller of Create() and any queues created
|
||||
// via AddQueue().
|
||||
static Status Create(
|
||||
const Options& options,
|
||||
std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler);
|
||||
|
||||
struct QueueOptions {
|
||||
// Maximum size of each batch.
|
||||
int max_batch_size = 1000;
|
||||
// Maximum number of enqueued (i.e. non-scheduled) batches.
|
||||
int max_enqueued_batches = 10;
|
||||
};
|
||||
|
||||
using BatchProcessor = std::function<void(std::unique_ptr<Batch<TaskType>>)>;
|
||||
|
||||
// Adds queue (and its callback) to be managed by this scheduler.
|
||||
Status AddQueue(const QueueOptions& options,
|
||||
BatchProcessor process_batch_callback,
|
||||
std::unique_ptr<BatchScheduler<TaskType>>* queue);
|
||||
|
||||
private:
|
||||
// access to AddBatch, RemoveQueue, GetEnv.
|
||||
friend class internal::ASBSQueue<TaskType>;
|
||||
|
||||
explicit AdaptiveSharedBatchScheduler(const Options& options);
|
||||
|
||||
// Batch scheduling function which runs every scheduling_period_ microseconds.
|
||||
void ProcessOneBatch();
|
||||
|
||||
// Notifies scheduler of non-empty batch which is eligible for processing.
|
||||
void AddBatch(internal::ASBSBatch<TaskType>*);
|
||||
|
||||
// Removes queue from scheduler.
|
||||
void RemoveQueue(const internal::ASBSQueue<TaskType>* queue);
|
||||
|
||||
Env* GetEnv() const { return options_.env; }
|
||||
|
||||
const Options options_;
|
||||
|
||||
struct BatchCompare {
|
||||
bool operator()(const internal::ASBSBatch<TaskType>* a,
|
||||
const internal::ASBSBatch<TaskType>* b);
|
||||
};
|
||||
|
||||
// Collection of batches added by AddBatch, ordered by age. Owned by scheduler
|
||||
// until they are released for processing.
|
||||
std::priority_queue<const internal::ASBSBatch<TaskType>*,
|
||||
std::vector<internal::ASBSBatch<TaskType>*>, BatchCompare>
|
||||
batches_ GUARDED_BY(mu_);
|
||||
|
||||
// Unowned queues and callbacks added by AddQueue.
|
||||
std::unordered_map<const internal::ASBSQueue<TaskType>*, BatchProcessor>
|
||||
queues_and_callbacks_ GUARDED_BY(mu_);
|
||||
|
||||
mutex mu_;
|
||||
|
||||
// Responsible for running ProcessOneBatch. PeriodicFunction was used in order
|
||||
// to check for deletion so that the thread can be shut down.
|
||||
std::unique_ptr<PeriodicFunction> scheduling_thread_;
|
||||
|
||||
// Responsible for running the batch processing callbacks.
|
||||
std::unique_ptr<thread::ThreadPool> batch_thread_pool_;
|
||||
|
||||
// Time interval in microseconds between successive ProcessOneBatch calls.
|
||||
double scheduling_period_;
|
||||
|
||||
// Exponentially weighted moving average of
|
||||
// options_.scheduling_period_feedback() evaluated in each ProcessOneBatch
|
||||
// call.
|
||||
double ewma_feedback_ = 0;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(AdaptiveSharedBatchScheduler);
|
||||
};
|
||||
|
||||
//////////////////////////////////////////////////////////
|
||||
// Implementation details follow. API users need not read.
|
||||
|
||||
namespace internal {
|
||||
// Consolidates tasks into batches, passing them off to the
|
||||
// AdaptiveSharedBatchScheduler for processing.
|
||||
template <typename TaskType>
|
||||
class ASBSQueue : public BatchScheduler<TaskType> {
|
||||
public:
|
||||
using QueueOptions =
|
||||
typename AdaptiveSharedBatchScheduler<TaskType>::QueueOptions;
|
||||
|
||||
ASBSQueue(std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
|
||||
const QueueOptions& options);
|
||||
|
||||
~ASBSQueue() override;
|
||||
|
||||
// Adds task to current batch. Fails if the task size is larger than the batch
|
||||
// size or if the current batch is full and this queue's number of outstanding
|
||||
// batches is at its maximum.
|
||||
Status Schedule(std::unique_ptr<TaskType>* task) override;
|
||||
|
||||
// Number of tasks waiting to be scheduled.
|
||||
size_t NumEnqueuedTasks() const override;
|
||||
|
||||
// Number of size 1 tasks which could currently be scheduled without failing.
|
||||
size_t SchedulingCapacity() const override;
|
||||
|
||||
// Notifies queue that a batch is about to be scheduled; the queue should not
|
||||
// place any more tasks in this batch.
|
||||
void ReleaseBatch(const ASBSBatch<TaskType>* batch);
|
||||
|
||||
private:
|
||||
std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler_;
|
||||
const QueueOptions options_;
|
||||
// Owned by scheduler_.
|
||||
ASBSBatch<TaskType>* current_batch_ GUARDED_BY(mu_) = nullptr;
|
||||
int64 num_enqueued_batches_ GUARDED_BY(mu_) = 0;
|
||||
int64 num_enqueued_tasks_ GUARDED_BY(mu_) = 0;
|
||||
mutable mutex mu_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ASBSQueue);
|
||||
};
|
||||
|
||||
// Batch which remembers when and by whom it was created.
|
||||
template <typename TaskType>
|
||||
class ASBSBatch : public Batch<TaskType> {
|
||||
public:
|
||||
ASBSBatch(ASBSQueue<TaskType>* queue, int64 creation_time_micros)
|
||||
: queue_(queue), creation_time_micros_(creation_time_micros) {}
|
||||
|
||||
~ASBSBatch() override {}
|
||||
|
||||
ASBSQueue<TaskType>* queue() const { return queue_; }
|
||||
|
||||
int64 creation_time_micros() const { return creation_time_micros_; }
|
||||
|
||||
private:
|
||||
ASBSQueue<TaskType>* queue_;
|
||||
const int64 creation_time_micros_;
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ASBSBatch);
|
||||
};
|
||||
} // namespace internal
|
||||
|
||||
// ---------------- AdaptiveSharedBatchScheduler ----------------
|
||||
|
||||
template <typename TaskType>
|
||||
Status AdaptiveSharedBatchScheduler<TaskType>::Create(
|
||||
const Options& options,
|
||||
std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler) {
|
||||
if (options.num_batch_threads < 1) {
|
||||
return errors::InvalidArgument("num_batch_threads must be positive; was ",
|
||||
options.num_batch_threads);
|
||||
}
|
||||
if (options.min_scheduling_period_micros < 0) {
|
||||
return errors::InvalidArgument(
|
||||
"min_scheduling_period_micros must be >= 0; was ",
|
||||
options.min_scheduling_period_micros);
|
||||
}
|
||||
if (options.min_scheduling_period_micros >
|
||||
options.initial_scheduling_period_micros) {
|
||||
return errors::InvalidArgument(
|
||||
"initial_scheduling_period_micros (",
|
||||
options.initial_scheduling_period_micros,
|
||||
") must be >= min_scheduling_period_micros (",
|
||||
options.min_scheduling_period_micros, ")");
|
||||
}
|
||||
if (options.initial_scheduling_period_micros >
|
||||
options.max_scheduling_period_micros) {
|
||||
return errors::InvalidArgument(
|
||||
"initial_scheduling_period_micros (",
|
||||
options.initial_scheduling_period_micros,
|
||||
") must be <= max_scheduling_period_micros (",
|
||||
options.max_scheduling_period_micros, ")");
|
||||
}
|
||||
if (options.feedback_smoothing_batches < 1) {
|
||||
return errors::InvalidArgument(
|
||||
"feedback_smoothing_batches must be positive; was ",
|
||||
options.feedback_smoothing_batches);
|
||||
}
|
||||
scheduler->reset(new AdaptiveSharedBatchScheduler<TaskType>(options));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename TaskType>
|
||||
AdaptiveSharedBatchScheduler<TaskType>::AdaptiveSharedBatchScheduler(
|
||||
const Options& options)
|
||||
: options_(options),
|
||||
scheduling_period_(options.initial_scheduling_period_micros) {
|
||||
PeriodicFunction::Options opts;
|
||||
opts.thread_name_prefix = "scheduling_thread";
|
||||
opts.env = GetEnv();
|
||||
scheduling_thread_.reset(
|
||||
new PeriodicFunction([this] { ProcessOneBatch(); }, 0, opts));
|
||||
batch_thread_pool_.reset(new thread::ThreadPool(
|
||||
GetEnv(), options.thread_pool_name, options.num_batch_threads));
|
||||
}
|
||||
|
||||
template <typename TaskType>
|
||||
Status AdaptiveSharedBatchScheduler<TaskType>::AddQueue(
|
||||
const QueueOptions& options, BatchProcessor process_batch_callback,
|
||||
std::unique_ptr<BatchScheduler<TaskType>>* queue) {
|
||||
if (options.max_batch_size <= 0) {
|
||||
return errors::InvalidArgument("max_batch_size must be positive; was ",
|
||||
options.max_batch_size);
|
||||
}
|
||||
if (options.max_enqueued_batches <= 0) {
|
||||
return errors::InvalidArgument(
|
||||
"max_enqueued_batches must be positive; was ",
|
||||
options.max_enqueued_batches);
|
||||
}
|
||||
internal::ASBSQueue<TaskType>* asbs_queue_raw;
|
||||
queue->reset(asbs_queue_raw = new internal::ASBSQueue<TaskType>(
|
||||
this->shared_from_this(), options));
|
||||
mutex_lock l(mu_);
|
||||
queues_and_callbacks_[asbs_queue_raw] = process_batch_callback;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename TaskType>
|
||||
void AdaptiveSharedBatchScheduler<TaskType>::AddBatch(
|
||||
internal::ASBSBatch<TaskType>* batch) {
|
||||
mutex_lock l(mu_);
|
||||
batches_.push(batch);
|
||||
}
|
||||
|
||||
template <typename TaskType>
|
||||
void AdaptiveSharedBatchScheduler<TaskType>::RemoveQueue(
|
||||
const internal::ASBSQueue<TaskType>* queue) {
|
||||
mutex_lock l(mu_);
|
||||
queues_and_callbacks_.erase(queue);
|
||||
}
|
||||
|
||||
template <typename TaskType>
|
||||
void AdaptiveSharedBatchScheduler<TaskType>::ProcessOneBatch() {
|
||||
static const double kFeedbackMultiplier = .001;
|
||||
internal::ASBSBatch<TaskType>* batch = nullptr;
|
||||
BatchProcessor callback;
|
||||
const int64 start_time_micros = GetEnv()->NowMicros();
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (!batches_.empty()) {
|
||||
batch = batches_.top();
|
||||
batches_.pop();
|
||||
callback = queues_and_callbacks_[batch->queue()];
|
||||
}
|
||||
}
|
||||
if (batch != nullptr) {
|
||||
double feedback = options_.scheduling_period_feedback();
|
||||
const int64 N = options_.feedback_smoothing_batches;
|
||||
ewma_feedback_ = ((N - 1) * ewma_feedback_ + feedback) / N;
|
||||
scheduling_period_ *= (1 + kFeedbackMultiplier * ewma_feedback_);
|
||||
if (scheduling_period_ < options_.min_scheduling_period_micros) {
|
||||
scheduling_period_ = options_.min_scheduling_period_micros;
|
||||
} else if (scheduling_period_ > options_.max_scheduling_period_micros) {
|
||||
scheduling_period_ = options_.max_scheduling_period_micros;
|
||||
}
|
||||
// Queue may destroy itself after ReleaseBatch is called.
|
||||
batch->queue()->ReleaseBatch(batch);
|
||||
batch_thread_pool_->Schedule([callback, batch] {
|
||||
callback(std::unique_ptr<Batch<TaskType>>(batch));
|
||||
});
|
||||
}
|
||||
const int64 sleep_time =
|
||||
scheduling_period_ - (GetEnv()->NowMicros() - start_time_micros);
|
||||
if (sleep_time > 0) {
|
||||
GetEnv()->SleepForMicroseconds(sleep_time);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TaskType>
|
||||
bool AdaptiveSharedBatchScheduler<TaskType>::BatchCompare::operator()(
|
||||
const internal::ASBSBatch<TaskType>* a,
|
||||
const internal::ASBSBatch<TaskType>* b) {
|
||||
return a->creation_time_micros() > b->creation_time_micros();
|
||||
}
|
||||
|
||||
// ---------------- ASBSQueue ----------------
|
||||
|
||||
namespace internal {
|
||||
template <typename TaskType>
|
||||
ASBSQueue<TaskType>::ASBSQueue(
|
||||
std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
|
||||
const QueueOptions& options)
|
||||
: scheduler_(scheduler), options_(options) {}
|
||||
|
||||
template <typename TaskType>
|
||||
ASBSQueue<TaskType>::~ASBSQueue() {
|
||||
// Wait until last batch has been scheduled.
|
||||
const int kSleepMicros = 1000;
|
||||
for (;;) {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (num_enqueued_batches_ == 0) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
scheduler_->GetEnv()->SleepForMicroseconds(kSleepMicros);
|
||||
}
|
||||
scheduler_->RemoveQueue(this);
|
||||
}
|
||||
|
||||
template <typename TaskType>
|
||||
Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
|
||||
bool added_new_batch = false;
|
||||
size_t size = (*task)->size();
|
||||
if (size > options_.max_batch_size) {
|
||||
return errors::InvalidArgument("Task size ", size,
|
||||
" is larger than maximum batch size ",
|
||||
options_.max_batch_size);
|
||||
}
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
// Current batch is full, create another if allowed.
|
||||
if (current_batch_ &&
|
||||
current_batch_->size() + size > options_.max_batch_size) {
|
||||
if (num_enqueued_batches_ >= options_.max_enqueued_batches) {
|
||||
return errors::Unavailable("The batch scheduling queue is full");
|
||||
}
|
||||
current_batch_->Close();
|
||||
current_batch_ = nullptr;
|
||||
}
|
||||
if (!current_batch_) {
|
||||
added_new_batch = true;
|
||||
num_enqueued_batches_++;
|
||||
current_batch_ =
|
||||
new ASBSBatch<TaskType>(this, scheduler_->GetEnv()->NowMicros());
|
||||
}
|
||||
current_batch_->AddTask(std::move(*task));
|
||||
num_enqueued_tasks_++;
|
||||
}
|
||||
if (added_new_batch) scheduler_->AddBatch(current_batch_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename TaskType>
|
||||
void ASBSQueue<TaskType>::ReleaseBatch(const ASBSBatch<TaskType>* batch) {
|
||||
mutex_lock l(mu_);
|
||||
num_enqueued_batches_--;
|
||||
num_enqueued_tasks_ -= batch->num_tasks();
|
||||
if (batch == current_batch_) {
|
||||
current_batch_->Close();
|
||||
current_batch_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TaskType>
|
||||
size_t ASBSQueue<TaskType>::NumEnqueuedTasks() const {
|
||||
mutex_lock l(mu_);
|
||||
return num_enqueued_tasks_;
|
||||
}
|
||||
|
||||
template <typename TaskType>
|
||||
size_t ASBSQueue<TaskType>::SchedulingCapacity() const {
|
||||
mutex_lock l(mu_);
|
||||
const int current_batch_capacity =
|
||||
current_batch_ ? options_.max_batch_size - current_batch_->size() : 0;
|
||||
const int spare_batches =
|
||||
options_.max_enqueued_batches - num_enqueued_batches_;
|
||||
return spare_batches * options_.max_batch_size + current_batch_capacity;
|
||||
}
|
||||
} // namespace internal
|
||||
} // namespace serving
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
|
@ -0,0 +1,438 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h"
|
||||
|
||||
#include "tensorflow/contrib/batching/test_util/fake_clock_env.h"
|
||||
#include "tensorflow/core/lib/core/notification.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace serving {
|
||||
namespace anonymous {
|
||||
|
||||
class FakeTask : public BatchTask {
|
||||
public:
|
||||
explicit FakeTask(size_t size) : size_(size) {}
|
||||
|
||||
~FakeTask() override = default;
|
||||
|
||||
size_t size() const override { return size_; }
|
||||
|
||||
private:
|
||||
const size_t size_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(FakeTask);
|
||||
};
|
||||
|
||||
// Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' on
|
||||
// that task. Returns the resulting status.
|
||||
Status ScheduleTask(size_t task_size, BatchScheduler<FakeTask>* scheduler) {
|
||||
std::unique_ptr<FakeTask> task(new FakeTask(task_size));
|
||||
Status status = scheduler->Schedule(&task);
|
||||
// Schedule() should have consumed 'task' iff it returned Status::OK.
|
||||
CHECK_EQ(status.ok(), task == nullptr);
|
||||
return status;
|
||||
}
|
||||
|
||||
// Creates a thread that waits on 'start' and then advances the fake clock in
|
||||
// 'env' in a loop until 'stop' is notified. Useful for allowing objects that
|
||||
// use the clock to be destroyed.
|
||||
std::unique_ptr<Thread> CreateFakeClockAdvancerThread(
|
||||
test_util::FakeClockEnv* env, Notification* start, Notification* stop) {
|
||||
return std::unique_ptr<Thread>(Env::Default()->StartThread(
|
||||
{}, "FakeClockAdvancerThread", [env, start, stop] {
|
||||
start->WaitForNotification();
|
||||
while (!stop->HasBeenNotified()) {
|
||||
env->AdvanceByMicroseconds(10);
|
||||
Env::Default()->SleepForMicroseconds(10);
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
TEST(AdaptiveSharedBatchSchedulerTest, Basic) {
|
||||
for (const bool delete_scheduler_early : {false, true}) {
|
||||
for (const bool delete_queue_1_early : {false, true}) {
|
||||
int queue_0_tasks = 0;
|
||||
auto queue_0_callback =
|
||||
[&queue_0_tasks](std::unique_ptr<Batch<FakeTask>> batch) {
|
||||
ASSERT_TRUE(batch->IsClosed());
|
||||
EXPECT_GT(batch->num_tasks(), 0);
|
||||
for (int i = 0; i < batch->num_tasks(); i++) {
|
||||
queue_0_tasks += batch->task(i).size();
|
||||
}
|
||||
};
|
||||
int queue_1_tasks = 0;
|
||||
auto queue_1_callback =
|
||||
[&queue_1_tasks](std::unique_ptr<Batch<FakeTask>> batch) {
|
||||
ASSERT_TRUE(batch->IsClosed());
|
||||
EXPECT_GT(batch->num_tasks(), 0);
|
||||
for (int i = 0; i < batch->num_tasks(); i++) {
|
||||
queue_1_tasks += batch->task(i).size();
|
||||
}
|
||||
};
|
||||
{
|
||||
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
|
||||
TF_ASSERT_OK(
|
||||
AdaptiveSharedBatchScheduler<FakeTask>::Create({}, &scheduler));
|
||||
|
||||
// Create two queues.
|
||||
std::unique_ptr<BatchScheduler<FakeTask>> queue_0;
|
||||
TF_ASSERT_OK(scheduler->AddQueue({}, queue_0_callback, &queue_0));
|
||||
std::unique_ptr<BatchScheduler<FakeTask>> queue_1;
|
||||
TF_ASSERT_OK(scheduler->AddQueue({}, queue_1_callback, &queue_1));
|
||||
|
||||
if (delete_scheduler_early) {
|
||||
// Delete our copy of the scheduler. The queues should keep it alive
|
||||
// under the covers.
|
||||
scheduler = nullptr;
|
||||
}
|
||||
// Submit tasks to the two queues, and (optionally) remove the queues.
|
||||
TF_ASSERT_OK(ScheduleTask(1, queue_0.get()));
|
||||
TF_ASSERT_OK(ScheduleTask(2, queue_1.get()));
|
||||
TF_ASSERT_OK(ScheduleTask(3, queue_0.get()));
|
||||
TF_ASSERT_OK(ScheduleTask(4, queue_1.get()));
|
||||
if (delete_queue_1_early) {
|
||||
queue_1 = nullptr;
|
||||
}
|
||||
TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
|
||||
}
|
||||
EXPECT_EQ(queue_0_tasks, 9);
|
||||
EXPECT_EQ(queue_1_tasks, 6);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(AdaptiveSharedBatchSchedulerTest, BadOptions) {
|
||||
using Scheduler = AdaptiveSharedBatchScheduler<FakeTask>;
|
||||
std::shared_ptr<Scheduler> scheduler;
|
||||
Scheduler::Options options;
|
||||
options.num_batch_threads = 0;
|
||||
EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
|
||||
options = Scheduler::Options();
|
||||
options.min_scheduling_period_micros = 50;
|
||||
options.max_scheduling_period_micros = 100;
|
||||
options.initial_scheduling_period_micros = 1;
|
||||
EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
|
||||
options = Scheduler::Options();
|
||||
options.min_scheduling_period_micros = 50;
|
||||
options.max_scheduling_period_micros = 100;
|
||||
options.initial_scheduling_period_micros = 1000;
|
||||
EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
|
||||
options = Scheduler::Options();
|
||||
options.min_scheduling_period_micros = 100;
|
||||
options.max_scheduling_period_micros = 50;
|
||||
options.initial_scheduling_period_micros = 75;
|
||||
EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
|
||||
options = Scheduler::Options();
|
||||
options.feedback_smoothing_batches = 0;
|
||||
EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
|
||||
}
|
||||
|
||||
TEST(AdaptiveSharedBatchSchedulerTest, ObeysQueueOptions) {
|
||||
test_util::FakeClockEnv env(Env::Default());
|
||||
Notification start_teardown, stop_teardown;
|
||||
std::unique_ptr<Thread> teardown_thread =
|
||||
CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
|
||||
{
|
||||
AdaptiveSharedBatchScheduler<FakeTask>::Options options;
|
||||
options.initial_scheduling_period_micros = 1000;
|
||||
options.env = &env;
|
||||
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
|
||||
TF_ASSERT_OK(
|
||||
AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
|
||||
std::unique_ptr<BatchScheduler<FakeTask>> queue_0;
|
||||
std::unique_ptr<BatchScheduler<FakeTask>> queue_1;
|
||||
int queue_0_tasks = 0;
|
||||
int queue_1_tasks = 0;
|
||||
auto queue_0_callback = [&queue_0_tasks,
|
||||
&env](std::unique_ptr<Batch<FakeTask>> batch) {
|
||||
ASSERT_TRUE(batch->IsClosed());
|
||||
EXPECT_GT(batch->num_tasks(), 0);
|
||||
for (int i = 0; i < batch->num_tasks(); i++) {
|
||||
queue_0_tasks += batch->task(i).size();
|
||||
}
|
||||
env.SleepForMicroseconds(1);
|
||||
};
|
||||
auto queue_1_callback = [&queue_1_tasks,
|
||||
&env](std::unique_ptr<Batch<FakeTask>> batch) {
|
||||
ASSERT_TRUE(batch->IsClosed());
|
||||
EXPECT_GT(batch->num_tasks(), 0);
|
||||
for (int i = 0; i < batch->num_tasks(); i++) {
|
||||
queue_1_tasks += batch->task(i).size();
|
||||
}
|
||||
env.SleepForMicroseconds(1);
|
||||
};
|
||||
AdaptiveSharedBatchScheduler<FakeTask>::QueueOptions queue_options;
|
||||
queue_options.max_batch_size = 10;
|
||||
queue_options.max_enqueued_batches = 0;
|
||||
// Queue must have max_enqueued_batchs > 1.
|
||||
EXPECT_FALSE(
|
||||
scheduler->AddQueue(queue_options, queue_0_callback, &queue_0).ok());
|
||||
queue_options.max_enqueued_batches = 2;
|
||||
TF_ASSERT_OK(
|
||||
scheduler->AddQueue(queue_options, queue_0_callback, &queue_0));
|
||||
queue_options.max_batch_size = 0;
|
||||
// Queue must have max_batch_size > 0.
|
||||
EXPECT_FALSE(
|
||||
scheduler->AddQueue(queue_options, queue_1_callback, &queue_1).ok());
|
||||
queue_options.max_batch_size = 2;
|
||||
queue_options.max_enqueued_batches = 1;
|
||||
TF_ASSERT_OK(
|
||||
scheduler->AddQueue(queue_options, queue_1_callback, &queue_1));
|
||||
|
||||
// Wait for scheduling_thread to sleep.
|
||||
env.BlockUntilThreadsAsleep(1);
|
||||
// Task larger than max_batch_size shouldn't schedule.
|
||||
EXPECT_FALSE(ScheduleTask(15, queue_0.get()).ok());
|
||||
TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
|
||||
TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
|
||||
env.AdvanceByMicroseconds(1);
|
||||
|
||||
// Task larger than max_batch_size shouldn't schedule.
|
||||
EXPECT_FALSE(ScheduleTask(3, queue_1.get()).ok());
|
||||
TF_ASSERT_OK(ScheduleTask(1, queue_1.get()));
|
||||
TF_ASSERT_OK(ScheduleTask(1, queue_1.get()));
|
||||
env.AdvanceByMicroseconds(1);
|
||||
// Exceeds max_enqueued_batches, shouldn't schedule.
|
||||
EXPECT_FALSE(ScheduleTask(1, queue_1.get()).ok());
|
||||
|
||||
TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
|
||||
// Exceeds max_enqueued_batches, shouldn't schedule.
|
||||
EXPECT_FALSE(ScheduleTask(6, queue_0.get()).ok());
|
||||
TF_ASSERT_OK(ScheduleTask(4, queue_0.get()));
|
||||
|
||||
// Batches should be processed in order from oldest to newest.
|
||||
env.AdvanceByMicroseconds(1000);
|
||||
env.BlockUntilThreadsAsleep(2);
|
||||
EXPECT_EQ(queue_0_tasks, 10);
|
||||
EXPECT_EQ(queue_1_tasks, 0);
|
||||
|
||||
env.AdvanceByMicroseconds(1000);
|
||||
env.BlockUntilThreadsAsleep(2);
|
||||
EXPECT_EQ(queue_0_tasks, 10);
|
||||
EXPECT_EQ(queue_1_tasks, 2);
|
||||
|
||||
env.AdvanceByMicroseconds(1000);
|
||||
env.BlockUntilThreadsAsleep(2);
|
||||
EXPECT_EQ(queue_0_tasks, 19);
|
||||
EXPECT_EQ(queue_1_tasks, 2);
|
||||
start_teardown.Notify();
|
||||
}
|
||||
stop_teardown.Notify();
|
||||
}
|
||||
|
||||
TEST(AdaptiveSharedBatchSchedulerTest, RateFeedback) {
|
||||
test_util::FakeClockEnv env(Env::Default());
|
||||
Notification start_teardown, stop_teardown;
|
||||
std::unique_ptr<Thread> teardown_thread =
|
||||
CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
|
||||
{
|
||||
double feedback = 0;
|
||||
AdaptiveSharedBatchScheduler<FakeTask>::Options options;
|
||||
options.initial_scheduling_period_micros = 1000;
|
||||
options.min_scheduling_period_micros = 200;
|
||||
options.max_scheduling_period_micros = 2000;
|
||||
options.env = &env;
|
||||
options.scheduling_period_feedback = [&feedback] { return feedback; };
|
||||
options.feedback_smoothing_batches = 1;
|
||||
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
|
||||
TF_ASSERT_OK(
|
||||
AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
|
||||
std::unique_ptr<BatchScheduler<FakeTask>> queue;
|
||||
int scheduled_items = 0;
|
||||
auto queue_callback = [&scheduled_items,
|
||||
&env](std::unique_ptr<Batch<FakeTask>> batch) {
|
||||
ASSERT_TRUE(batch->IsClosed());
|
||||
EXPECT_GT(batch->num_tasks(), 0);
|
||||
scheduled_items = 0;
|
||||
for (int i = 0; i < batch->num_tasks(); i++) {
|
||||
scheduled_items += batch->task(i).size();
|
||||
}
|
||||
env.SleepForMicroseconds(1);
|
||||
};
|
||||
|
||||
TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
|
||||
|
||||
// Wait for scheduling_thread to sleep.
|
||||
env.BlockUntilThreadsAsleep(1);
|
||||
// Enqueue 6 batches.
|
||||
for (int i = 0; i < 6; i++) {
|
||||
TF_ASSERT_OK(ScheduleTask(900 + i, queue.get()));
|
||||
env.AdvanceByMicroseconds(1);
|
||||
}
|
||||
feedback = -500;
|
||||
env.AdvanceByMicroseconds(994);
|
||||
env.BlockUntilThreadsAsleep(2); // scheduling period = 500 usec.
|
||||
EXPECT_EQ(scheduled_items, 900);
|
||||
env.AdvanceByMicroseconds(500);
|
||||
env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec.
|
||||
EXPECT_EQ(scheduled_items, 901);
|
||||
feedback = 0;
|
||||
env.AdvanceByMicroseconds(250);
|
||||
env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec.
|
||||
EXPECT_EQ(scheduled_items, 902);
|
||||
feedback = 10000; // large feedback should hit max_scheduling_period.
|
||||
env.AdvanceByMicroseconds(250);
|
||||
env.BlockUntilThreadsAsleep(2); // scheduling period = 2000 usec.
|
||||
EXPECT_EQ(scheduled_items, 903);
|
||||
feedback = -10000; // large feedback should hit min_scheduling_period.
|
||||
env.AdvanceByMicroseconds(1999);
|
||||
// No callback scheduled, only scheduling thread sleeping.
|
||||
env.BlockUntilThreadsAsleep(1);
|
||||
EXPECT_EQ(scheduled_items, 903);
|
||||
env.AdvanceByMicroseconds(1);
|
||||
env.BlockUntilThreadsAsleep(2); // scheduling period = 200 usec.
|
||||
EXPECT_EQ(scheduled_items, 904);
|
||||
env.AdvanceByMicroseconds(200);
|
||||
env.BlockUntilThreadsAsleep(2);
|
||||
EXPECT_EQ(scheduled_items, 905);
|
||||
start_teardown.Notify();
|
||||
}
|
||||
stop_teardown.Notify();
|
||||
}
|
||||
|
||||
TEST(AdaptiveSharedBatchSchedulerTest, FeedbackSmoothing) {
|
||||
test_util::FakeClockEnv env(Env::Default());
|
||||
Notification start_teardown, stop_teardown;
|
||||
std::unique_ptr<Thread> teardown_thread =
|
||||
CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
|
||||
{
|
||||
double feedback = 0;
|
||||
AdaptiveSharedBatchScheduler<FakeTask>::Options options;
|
||||
options.initial_scheduling_period_micros = 1000;
|
||||
options.env = &env;
|
||||
options.scheduling_period_feedback = [&feedback] { return feedback; };
|
||||
options.feedback_smoothing_batches = 3;
|
||||
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
|
||||
TF_ASSERT_OK(
|
||||
AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
|
||||
std::unique_ptr<BatchScheduler<FakeTask>> queue;
|
||||
int scheduled_items = 0;
|
||||
auto queue_callback = [&scheduled_items,
|
||||
&env](std::unique_ptr<Batch<FakeTask>> batch) {
|
||||
ASSERT_TRUE(batch->IsClosed());
|
||||
EXPECT_GT(batch->num_tasks(), 0);
|
||||
scheduled_items = 0;
|
||||
for (int i = 0; i < batch->num_tasks(); i++) {
|
||||
scheduled_items += batch->task(i).size();
|
||||
}
|
||||
env.SleepForMicroseconds(1);
|
||||
};
|
||||
|
||||
TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
|
||||
|
||||
// Wait for scheduling_thread to sleep.
|
||||
env.BlockUntilThreadsAsleep(1);
|
||||
// Enqueue 4 batches.
|
||||
for (int i = 0; i < 4; i++) {
|
||||
TF_ASSERT_OK(ScheduleTask(900 + i, queue.get()));
|
||||
env.AdvanceByMicroseconds(1);
|
||||
}
|
||||
feedback = -300;
|
||||
env.AdvanceByMicroseconds(996);
|
||||
env.BlockUntilThreadsAsleep(2);
|
||||
// ewma_feedback = 100, scheduling_period = 900.
|
||||
EXPECT_EQ(scheduled_items, 900);
|
||||
env.AdvanceByMicroseconds(899);
|
||||
// No callback scheduled, only scheduling thread sleeping.
|
||||
env.BlockUntilThreadsAsleep(1);
|
||||
EXPECT_EQ(scheduled_items, 900);
|
||||
env.AdvanceByMicroseconds(1);
|
||||
env.BlockUntilThreadsAsleep(2);
|
||||
// ewma_feedback = 167, scheduling_period = 750.
|
||||
EXPECT_EQ(scheduled_items, 901);
|
||||
env.AdvanceByMicroseconds(749);
|
||||
// No callback scheduled, only scheduling thread sleeping.
|
||||
env.BlockUntilThreadsAsleep(1);
|
||||
EXPECT_EQ(scheduled_items, 901);
|
||||
feedback = 1000 / 3.;
|
||||
env.AdvanceByMicroseconds(1);
|
||||
env.BlockUntilThreadsAsleep(2);
|
||||
// emwa_feedback = 0, scheduling_period = 750.
|
||||
EXPECT_EQ(scheduled_items, 902);
|
||||
env.AdvanceByMicroseconds(749);
|
||||
// No callback scheduled, only scheduling thread sleeping.
|
||||
env.BlockUntilThreadsAsleep(1);
|
||||
EXPECT_EQ(scheduled_items, 902);
|
||||
env.AdvanceByMicroseconds(1);
|
||||
env.BlockUntilThreadsAsleep(2);
|
||||
EXPECT_EQ(scheduled_items, 903);
|
||||
start_teardown.Notify();
|
||||
}
|
||||
stop_teardown.Notify();
|
||||
}
|
||||
|
||||
TEST(AdaptiveSharedBatchSchedulerTest, QueueCapacityInfo) {
|
||||
test_util::FakeClockEnv env(Env::Default());
|
||||
Notification start_teardown, stop_teardown;
|
||||
std::unique_ptr<Thread> teardown_thread =
|
||||
CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
|
||||
{
|
||||
AdaptiveSharedBatchScheduler<FakeTask>::Options options;
|
||||
options.initial_scheduling_period_micros = 1000;
|
||||
options.env = &env;
|
||||
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
|
||||
TF_ASSERT_OK(
|
||||
AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
|
||||
std::unique_ptr<BatchScheduler<FakeTask>> queue;
|
||||
int scheduled_items = 0;
|
||||
auto queue_callback = [&scheduled_items,
|
||||
&env](std::unique_ptr<Batch<FakeTask>> batch) {
|
||||
ASSERT_TRUE(batch->IsClosed());
|
||||
EXPECT_GT(batch->num_tasks(), 0);
|
||||
scheduled_items = 0;
|
||||
for (int i = 0; i < batch->num_tasks(); i++) {
|
||||
scheduled_items += batch->task(i).size();
|
||||
}
|
||||
env.SleepForMicroseconds(1);
|
||||
};
|
||||
AdaptiveSharedBatchScheduler<FakeTask>::QueueOptions queue_options;
|
||||
queue_options.max_batch_size = 10;
|
||||
queue_options.max_enqueued_batches = 10;
|
||||
TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue));
|
||||
|
||||
// Wait for scheduling_thread to sleep.
|
||||
env.BlockUntilThreadsAsleep(1);
|
||||
// Enqueue 3 tasks.
|
||||
EXPECT_EQ(queue->NumEnqueuedTasks(), 0);
|
||||
EXPECT_EQ(queue->SchedulingCapacity(), 100);
|
||||
TF_ASSERT_OK(ScheduleTask(5, queue.get()));
|
||||
EXPECT_EQ(queue->NumEnqueuedTasks(), 1);
|
||||
EXPECT_EQ(queue->SchedulingCapacity(), 95);
|
||||
env.AdvanceByMicroseconds(1);
|
||||
TF_ASSERT_OK(ScheduleTask(6, queue.get()));
|
||||
EXPECT_EQ(queue->NumEnqueuedTasks(), 2);
|
||||
EXPECT_EQ(queue->SchedulingCapacity(), 84);
|
||||
env.AdvanceByMicroseconds(1);
|
||||
TF_ASSERT_OK(ScheduleTask(1, queue.get()));
|
||||
EXPECT_EQ(queue->NumEnqueuedTasks(), 3);
|
||||
EXPECT_EQ(queue->SchedulingCapacity(), 83);
|
||||
|
||||
env.AdvanceByMicroseconds(998);
|
||||
env.BlockUntilThreadsAsleep(2);
|
||||
EXPECT_EQ(scheduled_items, 5);
|
||||
env.AdvanceByMicroseconds(1000);
|
||||
env.BlockUntilThreadsAsleep(2);
|
||||
EXPECT_EQ(scheduled_items, 7);
|
||||
start_teardown.Notify();
|
||||
}
|
||||
stop_teardown.Notify();
|
||||
}
|
||||
} // namespace anonymous
|
||||
} // namespace serving
|
||||
} // namespace tensorflow
|
@ -78,7 +78,7 @@ template <typename TaskType>
|
||||
class Batch {
|
||||
public:
|
||||
Batch() = default;
|
||||
~Batch(); // Blocks until the batch is closed.
|
||||
virtual ~Batch(); // Blocks until the batch is closed.
|
||||
|
||||
// Appends 'task' to the batch. After calling AddTask(), the newly-added task
|
||||
// can be accessed via task(num_tasks()-1) or mutable_task(num_tasks()-1).
|
||||
|
2
tensorflow/contrib/cmake/external/cub.cmake
vendored
2
tensorflow/contrib/cmake/external/cub.cmake
vendored
@ -14,7 +14,7 @@
|
||||
# ==============================================================================
|
||||
include (ExternalProject)
|
||||
|
||||
set(cub_URL https://github.com/NVlabs/cub/archive/1.7.4.zip)
|
||||
set(cub_URL https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip)
|
||||
set(cub_HASH SHA256=20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31)
|
||||
set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub)
|
||||
set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub)
|
||||
|
2
tensorflow/contrib/cmake/external/gif.cmake
vendored
2
tensorflow/contrib/cmake/external/gif.cmake
vendored
@ -15,7 +15,7 @@
|
||||
include (ExternalProject)
|
||||
|
||||
set(gif_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/gif_archive/giflib-5.1.4/)
|
||||
set(gif_URL http://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz)
|
||||
set(gif_URL https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz)
|
||||
set(gif_HASH SHA256=34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1)
|
||||
set(gif_INSTALL ${CMAKE_BINARY_DIR}/gif/install)
|
||||
set(gif_BUILD ${CMAKE_BINARY_DIR}/gif/src/gif)
|
||||
|
2
tensorflow/contrib/cmake/external/jpeg.cmake
vendored
2
tensorflow/contrib/cmake/external/jpeg.cmake
vendored
@ -15,7 +15,7 @@
|
||||
include (ExternalProject)
|
||||
|
||||
set(jpeg_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/jpeg_archive)
|
||||
set(jpeg_URL http://mirror.bazel.build/www.ijg.org/files/jpegsrc.v9a.tar.gz)
|
||||
set(jpeg_URL https://mirror.bazel.build/www.ijg.org/files/jpegsrc.v9a.tar.gz)
|
||||
set(jpeg_HASH SHA256=3a753ea48d917945dd54a2d97de388aa06ca2eb1066cbfdc6652036349fe05a7)
|
||||
set(jpeg_BUILD ${CMAKE_CURRENT_BINARY_DIR}/jpeg/src/jpeg)
|
||||
set(jpeg_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/jpeg/install)
|
||||
|
2
tensorflow/contrib/cmake/external/lmdb.cmake
vendored
2
tensorflow/contrib/cmake/external/lmdb.cmake
vendored
@ -15,7 +15,7 @@
|
||||
include (ExternalProject)
|
||||
|
||||
set(lmdb_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/lmdb)
|
||||
set(lmdb_URL http://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz)
|
||||
set(lmdb_URL https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz)
|
||||
set(lmdb_HASH SHA256=108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d05326)
|
||||
set(lmdb_BUILD ${CMAKE_BINARY_DIR}/lmdb/src/lmdb)
|
||||
set(lmdb_INSTALL ${CMAKE_BINARY_DIR}/lmdb/install)
|
||||
|
@ -47,4 +47,4 @@ ExternalProject_Add(snappy
|
||||
)
|
||||
|
||||
# actually enables snappy in the source code
|
||||
add_definitions(-DSNAPPY)
|
||||
add_definitions(-DTF_USE_SNAPPY)
|
||||
|
@ -86,7 +86,7 @@ cuda_py_test(
|
||||
"//tensorflow/python:client",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python/eager:graph_callable",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python:variables",
|
||||
],
|
||||
)
|
||||
@ -132,11 +132,12 @@ py_library(
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:layers_base",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:function",
|
||||
],
|
||||
)
|
||||
|
||||
@ -146,6 +147,10 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":metrics",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:test",
|
||||
],
|
||||
)
|
||||
@ -160,6 +165,8 @@ py_library(
|
||||
deps = [
|
||||
":datasets",
|
||||
":metrics",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:function",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -86,7 +86,7 @@ class EvaluatorTest(test.TestCase):
|
||||
for v in e.metric_variables:
|
||||
p = v.name.split("/")[0]
|
||||
prefix_count[p] = prefix_count.get(p, 0) + 1
|
||||
self.assertEqual({"outer-mean": 2, "mean": 2}, prefix_count)
|
||||
self.assertEqual({"outer_mean": 2, "mean": 2}, prefix_count)
|
||||
|
||||
def testDataset(self):
|
||||
e = SimpleEvaluator(IdentityModel())
|
||||
|
@ -18,6 +18,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import re
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import function
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
@ -25,55 +29,69 @@ from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
|
||||
|
||||
_to_replace = re.compile("[^A-Za-z0-9.]")
|
||||
|
||||
|
||||
class Metric(object):
|
||||
"""A metric holds state for aggregating statistics over an evaluation run.
|
||||
|
||||
Users will use Evaluator.add_metric() to add Metric objects to their
|
||||
evaluation, call them in each step, and then use
|
||||
Evaluator.all_metric_results() at the end.
|
||||
evaluation, call them in each step (treating the object as a callable),
|
||||
and then use Evaluator.all_metric_results() at the end.
|
||||
|
||||
Descendants will implement:
|
||||
* call(): Should follow this pattern:
|
||||
if not self.built:
|
||||
self.var = self.add_variable(...)
|
||||
self.add_update(self.var.assign_add(...))
|
||||
* aggregate(): Adds in the state from a list of metrics of the same type
|
||||
as `self`. (Default of summing all the variables will be fine for most
|
||||
descendants.)
|
||||
* result(): Computes and returns a final value for the metric
|
||||
* `build()`: All variables should be created in this method, by calling
|
||||
`self.add_variable()` as in: `self.var = self.add_variable(...)`
|
||||
build() will be called in the first invocation of `__call__()`, with
|
||||
the same arguments passed `call()`.
|
||||
* `call()`: Has all updates to variables, as in:
|
||||
self.var.assign_add(...)
|
||||
* `result()`: Computes and returns a final value for the metric
|
||||
from the variables in `self`.
|
||||
|
||||
Decendants may override, but usually won't need to:
|
||||
* `aggregate()`: Adds in the state from a list of metrics of the same type
|
||||
as `self`. (Default is to sum all the variables.)
|
||||
* `reset()`: Reset all variables to their initial state. (Default is to
|
||||
zero all the variables.)
|
||||
Note that users should not call `aggregate()` or `reset()`, they are for
|
||||
use by TensorFlow infrastructure.
|
||||
"""
|
||||
|
||||
def __init__(self, name=None):
|
||||
self.built = False
|
||||
self._built = False
|
||||
self._vars = []
|
||||
self._updates = []
|
||||
self._name = name or self.__class__.__name__
|
||||
# TODO(josh11b): Need some way to make sure two Metrics in the same
|
||||
# Network have distinct names. Maybe we can get a unique name from
|
||||
# a name/variable scope?
|
||||
# TODO(josh11b): self._in_graph_mode = context.in_graph_mode()
|
||||
name = name or self.__class__.__name__
|
||||
# Replace things like spaces in name to create a valid scope name.
|
||||
scope_name = _to_replace.sub("_", name)
|
||||
# We create the variable scope now to get the unique name that will
|
||||
# be used as a variable prefix when build() calls add_variable().
|
||||
with variable_scope.variable_scope(
|
||||
None, default_name=scope_name, use_resource=True, reuse=False) as scope:
|
||||
pos = scope.name.rfind(scope_name)
|
||||
self._name = name + scope.name[pos + len(scope_name):]
|
||||
self._scope = scope
|
||||
if context.in_graph_mode():
|
||||
# We make self.call() into a graph callable here, so that we can
|
||||
# return a single op that performs all of the variable updates.
|
||||
self.call = function.defun(self.call)
|
||||
|
||||
# ---- API for users ----
|
||||
def __call__(self, *args, **kwargs):
|
||||
# TODO(josh11b): If self._in_graph_mode is true, make self.call() into a
|
||||
# graph callable here, so that variable updates happen without requiring
|
||||
# a separate fetch.
|
||||
# TODO(josh11b): Do we need a separate build() method to separate
|
||||
# initialization from each update? If so, how do we get the arguments
|
||||
# to it? We *could* just pass in *args and **kwargs...
|
||||
if not self.built:
|
||||
# TODO(ashankar): Set up container isolation so there is no chance
|
||||
# distinct metrics objects accidentally share variables.
|
||||
# TODO(josh11b): Replace things like spaces in self._name to create
|
||||
# a valid scope name.
|
||||
with variable_scope.variable_scope(
|
||||
self._name, use_resource=True, reuse=False):
|
||||
ret = self.call(*args, **kwargs)
|
||||
self.built = True
|
||||
else:
|
||||
ret = self.call(*args, **kwargs)
|
||||
return ret
|
||||
"""Returns op to execute to update this metric for these inputs.
|
||||
|
||||
Returns None if eager execution is enabled.
|
||||
|
||||
Args:
|
||||
*args:
|
||||
**kwargs: A mini-batch of inputs to the Metric, passed on to `call()`.
|
||||
"""
|
||||
if not self._built:
|
||||
with variable_scope.variable_scope(self._scope):
|
||||
self.build(*args, **kwargs)
|
||||
self._built = True
|
||||
return self.call(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
@ -84,10 +102,43 @@ class Metric(object):
|
||||
return self._vars
|
||||
|
||||
# ---- To be implemented by descendants ---
|
||||
def build(self, *args, **kwargs):
|
||||
"""Method to create variables.
|
||||
|
||||
Called by `__call__()` before `call()` for the first time.
|
||||
|
||||
Args:
|
||||
*args:
|
||||
**kwargs: The arguments to the first invocation of `__call__()`.
|
||||
`build()` may use the shape and/or dtype of these arguments
|
||||
when deciding how to create variables.
|
||||
"""
|
||||
raise NotImplementedError("Metrics must define a build() member function")
|
||||
|
||||
def call(self, *args, **kwargs):
|
||||
"""Accumulates statistics for the metric."""
|
||||
"""Accumulates statistics for the metric. Users should use __call__ instead.
|
||||
|
||||
Note: This function is executed as a graph function in graph mode.
|
||||
This means:
|
||||
a) Operations on the same resource are executed in textual order.
|
||||
This should make it easier to do things like add the updated
|
||||
value of a variable to another, for example.
|
||||
b) You don't need to worry about collecting the update ops to execute.
|
||||
All update ops added to the graph by this function will be executed.
|
||||
As a result, code should generally work the same way with graph or
|
||||
eager execution.
|
||||
|
||||
Args:
|
||||
*args:
|
||||
**kwargs: A mini-batch of inputs to the Metric, as passed to
|
||||
`__call__()`.
|
||||
"""
|
||||
raise NotImplementedError("Metrics must define a call() member function")
|
||||
|
||||
def result(self): # TODO(josh11b): Add an optional summary_writer parameter.
|
||||
"""Computes and returns a final value for the metric."""
|
||||
raise NotImplementedError("Metrics must define a result() member function")
|
||||
|
||||
# We can support two different strategies of for doing data-parallel
|
||||
# distributed metric computations:
|
||||
# * Put metric variables on the first device and rely on small
|
||||
@ -123,16 +174,19 @@ class Metric(object):
|
||||
self._vars[i].assign_add(math_ops.add_n([m._vars[i] for m in metrics]))
|
||||
# pylint: enable=protected-access
|
||||
|
||||
def result(self): # TODO(josh11b): Add an optional summary_writer parameter.
|
||||
"""Computes and returns a final value for the metric."""
|
||||
raise NotImplementedError("Metrics must define a result() member function")
|
||||
def reset(self):
|
||||
"""Reset this metric to a freshly initialized state.
|
||||
|
||||
Default implementation zeros all the metric variables.
|
||||
"""
|
||||
for v in self._vars:
|
||||
v.assign(math_ops.zeros_like(v))
|
||||
|
||||
# ---- For use by descendants ---
|
||||
def add_variable(self, name, shape=None, dtype=None, initializer=None):
|
||||
"""***Only for use by descendants of Metric***."""
|
||||
if self.built:
|
||||
raise RuntimeError("Can't call add_variable() after a Metric has been "
|
||||
"built in the first call().")
|
||||
if self._built:
|
||||
raise RuntimeError("Can't call add_variable() except in build().")
|
||||
v = variable_scope.get_variable(name, shape, dtype, initializer,
|
||||
trainable=False, use_resource=True)
|
||||
self._vars.append(v)
|
||||
@ -144,6 +198,15 @@ class Mean(Metric):
|
||||
# TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64?
|
||||
# Or defaults to type of the input if it is tf.float32, else tf.float64?
|
||||
|
||||
def build(self, values, weights=None):
|
||||
del values, weights # build() does not use call's arguments
|
||||
self.numer = self.add_variable(name="numer", shape=(),
|
||||
dtype=dtypes.float64,
|
||||
initializer=init_ops.zeros_initializer)
|
||||
self.denom = self.add_variable(name="denom", shape=(),
|
||||
dtype=dtypes.float64,
|
||||
initializer=init_ops.zeros_initializer)
|
||||
|
||||
def call(self, values, weights=None):
|
||||
"""Accumulate statistics for computing the mean.
|
||||
|
||||
@ -154,13 +217,6 @@ class Mean(Metric):
|
||||
values: Tensor with the per-example value.
|
||||
weights: Optional weighting of each example. Defaults to 1.
|
||||
"""
|
||||
if not self.built: # False only in the first call().
|
||||
self.numer = self.add_variable(name="numer", shape=(),
|
||||
dtype=dtypes.float64,
|
||||
initializer=init_ops.zeros_initializer)
|
||||
self.denom = self.add_variable(name="denom", shape=(),
|
||||
dtype=dtypes.float64,
|
||||
initializer=init_ops.zeros_initializer)
|
||||
if weights is None:
|
||||
self.denom.assign_add(
|
||||
math_ops.cast(array_ops.size(values), dtypes.float64))
|
||||
@ -179,6 +235,10 @@ class Mean(Metric):
|
||||
class Accuracy(Mean):
|
||||
"""Calculates how often `predictions` matches `labels`."""
|
||||
|
||||
def build(self, labels, predictions, weights=None):
|
||||
del labels, predictions, weights
|
||||
super(Accuracy, self).build(None) # Arguments are unused
|
||||
|
||||
def call(self, labels, predictions, weights=None):
|
||||
"""Accumulate accuracy statistics.
|
||||
|
||||
|
@ -19,7 +19,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.eager.python import metrics
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
|
||||
|
||||
class MetricsTest(test.TestCase):
|
||||
@ -56,6 +60,53 @@ class MetricsTest(test.TestCase):
|
||||
m([7], [2]) # 0 correct, weight 1
|
||||
self.assertEqual(2.5/5, m.result().numpy())
|
||||
|
||||
def testTwoMeans(self):
|
||||
# Verify two metrics with the same class and name don't
|
||||
# accidentally share state.
|
||||
m1 = metrics.Mean()
|
||||
m2 = metrics.Mean()
|
||||
m1(0)
|
||||
m2(2)
|
||||
self.assertEqual(0, m1.result().numpy())
|
||||
self.assertEqual(2, m2.result().numpy())
|
||||
self.assertNotEqual(m1.name, m2.name)
|
||||
|
||||
def testNamesWithSpaces(self):
|
||||
# Verify two metrics with the same class and name don't
|
||||
# accidentally share state.
|
||||
m1 = metrics.Mean("has space")
|
||||
m2 = metrics.Mean("has space")
|
||||
m2(2)
|
||||
m1(0)
|
||||
self.assertEqual(m1.name, "has space")
|
||||
self.assertEqual(m1.numer.name, "has_space/numer:0")
|
||||
self.assertEqual(m2.name, "has space_1")
|
||||
self.assertEqual(m2.numer.name, "has_space_1/numer:0")
|
||||
|
||||
def testGraph(self):
|
||||
with context.graph_mode(), self.test_session() as sess:
|
||||
m = metrics.Mean()
|
||||
p = array_ops.placeholder(dtypes.float32)
|
||||
accumulate = m(p)
|
||||
variables.global_variables_initializer().run()
|
||||
sess.run(accumulate, feed_dict={p: [1, 10, 100]})
|
||||
sess.run(accumulate, feed_dict={p: 1000})
|
||||
sess.run(accumulate, feed_dict={p: [10000, 100000]})
|
||||
self.assertAllEqual(m.result().eval(), 111111.0/6)
|
||||
|
||||
def testTwoMeansGraph(self):
|
||||
# Verify two metrics with the same class and name don't
|
||||
# accidentally share state.
|
||||
with context.graph_mode(), self.test_session() as sess:
|
||||
m1 = metrics.Mean()
|
||||
m2 = metrics.Mean()
|
||||
accumulate1 = m1(0)
|
||||
accumulate2 = m2(2)
|
||||
variables.global_variables_initializer().run()
|
||||
sess.run([accumulate1, accumulate2])
|
||||
self.assertEqual(0, m1.result().eval())
|
||||
self.assertEqual(2, m2.result().eval())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -22,6 +22,7 @@ import os
|
||||
from tensorflow.contrib.eager.python import saver as _saver
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.eager import graph_callable
|
||||
from tensorflow.python.eager import test
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
@ -29,7 +30,6 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
class SaverTest(test.TestCase):
|
||||
@ -38,7 +38,7 @@ class SaverTest(test.TestCase):
|
||||
return '/device:GPU:0' if context.num_gpus() else '/device:CPU:0'
|
||||
|
||||
def testBasics(self):
|
||||
with context.eager_mode(), ops.device(self._dev()):
|
||||
with ops.device(self._dev()):
|
||||
v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
|
||||
def model():
|
||||
return array_ops.constant(2.0) * v1
|
||||
@ -54,8 +54,42 @@ class SaverTest(test.TestCase):
|
||||
saver.restore(ckpt_prefix)
|
||||
self.assertEqual(v1.read_value().numpy(), 1.0)
|
||||
|
||||
def testRestoreOnCreate(self):
|
||||
def testSameNameNoClobbering(self):
|
||||
with context.eager_mode(), ops.device(self._dev()):
|
||||
# Note that this test purposefully uses Graphs rather than
|
||||
# IsolateTest. Users are more likely to accidentally create the same
|
||||
# variable name this way.
|
||||
first_graph = ops.Graph()
|
||||
with first_graph.as_default():
|
||||
v1_first_graph = resource_variable_ops.ResourceVariable(1.0, name='v1')
|
||||
with ops.Graph().as_default():
|
||||
v1_second_graph = resource_variable_ops.ResourceVariable(2.0, name='v1')
|
||||
saver = _saver.Saver([v1_first_graph, v1_second_graph])
|
||||
ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
|
||||
with self.assertRaisesRegexp(ValueError, 'v1'):
|
||||
saver.save(ckpt_prefix)
|
||||
|
||||
def testDifferentGraphError(self):
|
||||
with context.eager_mode(), ops.device(self._dev()):
|
||||
with ops.Graph().as_default():
|
||||
v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
|
||||
with ops.Graph().as_default():
|
||||
saver = _saver.Saver([v1])
|
||||
ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
|
||||
with self.assertRaisesRegexp(ValueError, 'Graph'):
|
||||
saver.save(ckpt_prefix)
|
||||
|
||||
def testSameObjectOK(self):
|
||||
with context.eager_mode(), ops.device(self._dev()):
|
||||
v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
|
||||
# While different objects with the same shared_name are not good, passing
|
||||
# in the same object multiple times is fine.
|
||||
saver = _saver.Saver([v1, v1])
|
||||
ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
|
||||
saver.save(ckpt_prefix)
|
||||
|
||||
def testRestoreOnCreate(self):
|
||||
with ops.device(self._dev()):
|
||||
def model(init_val):
|
||||
v1 = resource_variable_ops.ResourceVariable(init_val, name='v1')
|
||||
return array_ops.constant(1.0) * v1, v1
|
||||
@ -71,12 +105,9 @@ class SaverTest(test.TestCase):
|
||||
# Value is from checkpoint, but not from argument.
|
||||
ret, _ = model(2.0)
|
||||
self.assertEqual(ret.numpy(), 1.0)
|
||||
# Create it a second time won't re-assign the checkpoint value.
|
||||
v1_2 = resource_variable_ops.ResourceVariable(3.0, name='v1')
|
||||
self.assertEqual(v1_2.read_value().numpy(), 3.0)
|
||||
|
||||
def testRestoreNotFound(self):
|
||||
with context.eager_mode(), ops.device(self._dev()):
|
||||
with ops.device(self._dev()):
|
||||
def model(v):
|
||||
return array_ops.constant(1.0) * v
|
||||
|
||||
@ -92,7 +123,7 @@ class SaverTest(test.TestCase):
|
||||
_ = model(resource_variable_ops.ResourceVariable(1.0, name='v2'))
|
||||
|
||||
def testSaveRestoreGraphCallable(self):
|
||||
with context.eager_mode(), ops.device(self._dev()):
|
||||
with ops.device(self._dev()):
|
||||
@graph_callable.graph_callable(
|
||||
[graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
|
||||
def model(x):
|
||||
|
@ -53,6 +53,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
|
||||
@@in_eager_mode
|
||||
@@in_graph_mode
|
||||
|
||||
@@IsolateTest
|
||||
@@run_test_in_graph_and_eager_modes
|
||||
"""
|
||||
|
||||
@ -84,6 +85,7 @@ from tensorflow.python.eager.execution_callbacks import nan_callback
|
||||
from tensorflow.python.eager.execution_callbacks import seterr
|
||||
from tensorflow.python.framework.ops import enable_eager_execution
|
||||
from tensorflow.python.framework.ops import eager_run as run
|
||||
from tensorflow.python.framework.test_util import IsolateTest
|
||||
from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes
|
||||
from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
@ -24,7 +24,11 @@ the full-batch version.
|
||||
approach for computing the initial cluster assignments that is expensive but is
|
||||
typically less prone to getting stuck in bad local minima.
|
||||
|
||||
We provide distributed implementations of both full-batch and mini-batch
|
||||
K-Means algorithm. Both K-Means++ and random initialization are supported.
|
||||
The user can also choose between **Cosine** and **Squared Euclidean** distance
|
||||
metrics.
|
||||
**[k-MC2](https://www.aaai.org/ocs/index.php/AAAI/AAAI16/paper/view/12147/11759)**
|
||||
provides a very fast seeding method that provides high quality centers
|
||||
comparable to K-Means++ seeding. k-MC2 works particularly well if it is combined
|
||||
with Mini-batch K-Means.
|
||||
|
||||
We provide distributed implementations of both full-batch and mini-batch K-Means
|
||||
algorithm. K-Means++, k-MC2 and random initialization are supported. The user
|
||||
can also choose between **Cosine** and **Squared Euclidean** distance metrics.
|
||||
|
@ -224,6 +224,58 @@ class KmeansPlusPlusInitializationOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER(Name("KmeansPlusPlusInitialization").Device(DEVICE_CPU),
|
||||
KmeansPlusPlusInitializationOp);
|
||||
|
||||
// Implementation of one single Markov Chain for the k-MC^2 algorithm
|
||||
class KMC2ChainInitializationOp : public OpKernel {
|
||||
public:
|
||||
explicit KMC2ChainInitializationOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->MatchSignature({DT_FLOAT, DT_INT64}, {DT_INT64}));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& distances_tensor = context->input(0);
|
||||
const Tensor& seed_tensor = context->input(1);
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsVector(distances_tensor.shape()),
|
||||
InvalidArgument("Input distances should be a vector."));
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()),
|
||||
InvalidArgument("Input seed should be a scalar."));
|
||||
const int64 num_points = distances_tensor.dim_size(0);
|
||||
const int64 seed = seed_tensor.scalar<int64>()();
|
||||
OP_REQUIRES(context, num_points > 0,
|
||||
InvalidArgument("Expected distances_tensor.size() > 0."));
|
||||
|
||||
random::PhiloxRandom random(seed);
|
||||
random::SimplePhilox rng(&random);
|
||||
|
||||
auto distances = distances_tensor.flat<float>();
|
||||
// Set the initial state of the Markov chain to be the first candidate.
|
||||
int64 selected_index = 0;
|
||||
float selected_distance = distances(selected_index);
|
||||
// Build a Markov chain of length num_points.
|
||||
for (int64 i = 1; i < num_points; ++i) {
|
||||
const float candidate_distance = distances(i);
|
||||
// Set the next state of the Markov chain to be the candidate with
|
||||
// probability min(1, candidate_distance/selected_distance).
|
||||
if (candidate_distance > rng.RandFloat() * selected_distance) {
|
||||
selected_index = i;
|
||||
selected_distance = candidate_distance;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor* output_sampled_index_tensor;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, TensorShape({}),
|
||||
&output_sampled_index_tensor));
|
||||
auto output = output_sampled_index_tensor->scalar<int64>();
|
||||
// Return the last state of the Markov chain as the new center.
|
||||
output() = selected_index;
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("KMC2ChainInitialization").Device(DEVICE_CPU),
|
||||
KMC2ChainInitializationOp);
|
||||
|
||||
// Operator for computing the nearest neighbors for a set of points.
|
||||
class NearestNeighborsOp : public OpKernel {
|
||||
public:
|
||||
|
@ -116,6 +116,62 @@ RUN_BM_KmeansPlusPlusInitialization(k3RetriesPerSample);
|
||||
#undef RUN_BM_KmeansPlusPlusInitialization
|
||||
#undef BENCHMARK_KMEANS_PLUS_PLUS
|
||||
|
||||
Graph* SetUpKMC2Initialization(int num_points) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor distances(DT_FLOAT, TensorShape({num_points}));
|
||||
Tensor seed(DT_INT64, TensorShape({}));
|
||||
distances.flat<float>().setRandom();
|
||||
seed.flat<int64>().setConstant(12345);
|
||||
|
||||
TF_CHECK_OK(
|
||||
NodeBuilder("KMC2ChainInitializationOp", "KMC2ChainInitialization")
|
||||
.Input(test::graph::Constant(g, distances))
|
||||
.Input(test::graph::Constant(g, seed))
|
||||
.Finalize(g, nullptr /* node */));
|
||||
return g;
|
||||
}
|
||||
|
||||
template <int num_points, int num_to_sample, int num_dims>
|
||||
void BM_KMC2Initialization(int iters) {
|
||||
testing::StopTiming();
|
||||
testing::ItemsProcessed(static_cast<int64>(iters) * num_points * num_dims *
|
||||
num_to_sample);
|
||||
testing::UseRealTime();
|
||||
Graph* g = SetUpKMC2Initialization(num_points);
|
||||
testing::StartTiming();
|
||||
test::Benchmark("cpu", g).Run(iters);
|
||||
}
|
||||
#define BENCHMARK_KMC2(p, c, d) \
|
||||
void BM_KMC2Initialization_##p##_##c##_##d(int iters) { \
|
||||
BM_KMC2Initialization<p, c, d>(iters); \
|
||||
} \
|
||||
BENCHMARK(BM_KMC2Initialization_##p##_##c##_##d);
|
||||
|
||||
#define RUN_BM_KMC2Initialization \
|
||||
BENCHMARK_KMC2(k10Points, k2Centers, k100Dim); \
|
||||
BENCHMARK_KMC2(k10Points, k5Centers, k100Dim); \
|
||||
BENCHMARK_KMC2(k10Points, k10Centers, k100Dim); \
|
||||
BENCHMARK_KMC2(k100Points, k10Centers, k100Dim); \
|
||||
BENCHMARK_KMC2(k100Points, k20Centers, k100Dim); \
|
||||
BENCHMARK_KMC2(k100Points, k50Centers, k100Dim); \
|
||||
BENCHMARK_KMC2(k100Points, k100Centers, k100Dim); \
|
||||
BENCHMARK_KMC2(k1kPoints, k100Centers, k100Dim); \
|
||||
BENCHMARK_KMC2(k1kPoints, k200Centers, k100Dim); \
|
||||
BENCHMARK_KMC2(k1kPoints, k500Centers, k100Dim); \
|
||||
BENCHMARK_KMC2(k1kPoints, k1kCenters, k100Dim); \
|
||||
BENCHMARK_KMC2(k10kPoints, k100Centers, k100Dim); \
|
||||
BENCHMARK_KMC2(k10kPoints, k200Centers, k100Dim); \
|
||||
BENCHMARK_KMC2(k10kPoints, k500Centers, k100Dim); \
|
||||
BENCHMARK_KMC2(k10kPoints, k1kCenters, k100Dim); \
|
||||
BENCHMARK_KMC2(k1MPoints, k100Centers, k100Dim); \
|
||||
BENCHMARK_KMC2(k1MPoints, k200Centers, k100Dim); \
|
||||
BENCHMARK_KMC2(k1MPoints, k500Centers, k100Dim); \
|
||||
BENCHMARK_KMC2(k1MPoints, k1kCenters, k100Dim)
|
||||
|
||||
RUN_BM_KMC2Initialization;
|
||||
#undef RUN_BM_KMC2Initialization
|
||||
#undef BENCHMARK_KMC2
|
||||
|
||||
Graph* SetUpNearestNeighbors(int num_dims, int num_points, int num_centers,
|
||||
int k) {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
|
@ -44,6 +44,25 @@ num_retries_per_sample: Scalar. For each row that is sampled, this parameter
|
||||
samples: Matrix of shape (num_to_sample, d). The sampled rows.
|
||||
)");
|
||||
|
||||
REGISTER_OP("KMC2ChainInitialization")
|
||||
.Input("distances: float32")
|
||||
.Input("seed: int64")
|
||||
.Output("index: int64")
|
||||
.SetShapeFn(shape_inference::ScalarShape)
|
||||
.Doc(R"(
|
||||
Returns the index of a data point that should be added to the seed set.
|
||||
|
||||
Entries in distances are assumed to be squared distances of candidate points to
|
||||
the already sampled centers in the seed set. The op constructs one Markov chain
|
||||
of the k-MC^2 algorithm and returns the index of one candidate point to be added
|
||||
as an additional cluster center.
|
||||
|
||||
distances: Vector with squared distances to the closest previously sampled
|
||||
cluster center for each candidate point.
|
||||
seed: Scalar. Seed for initializing the random number generator.
|
||||
index: Scalar with the index of the sampled point.
|
||||
)");
|
||||
|
||||
REGISTER_OP("NearestNeighbors")
|
||||
.Input("points: float32")
|
||||
.Input("centers: float32")
|
||||
|
@ -55,6 +55,63 @@ class KmeansPlusPlusInitializationTest(test.TestCase):
|
||||
self.runTestWithSeed(seed)
|
||||
|
||||
|
||||
class KMC2InitializationTest(test.TestCase):
|
||||
|
||||
def runTestWithSeed(self, seed):
|
||||
with self.test_session():
|
||||
distances = np.zeros(1000).astype(np.float32)
|
||||
distances[6] = 10e7
|
||||
distances[4] = 10e3
|
||||
|
||||
sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed)
|
||||
self.assertEquals(sampled_point.eval(), 6)
|
||||
distances[6] = 0.0
|
||||
sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed)
|
||||
self.assertEquals(sampled_point.eval(), 4)
|
||||
|
||||
def testBasic(self):
|
||||
for seed in range(100):
|
||||
self.runTestWithSeed(seed)
|
||||
|
||||
|
||||
class KMC2InitializationLargeTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._distances = np.zeros(1001)
|
||||
self._distances[500] = 100.0
|
||||
self._distances[1000] = 50.0
|
||||
|
||||
def testBasic(self):
|
||||
with self.test_session():
|
||||
counts = {}
|
||||
seed = 0
|
||||
for i in range(50):
|
||||
sample = clustering_ops.kmc2_chain_initialization(
|
||||
self._distances, seed + i).eval()
|
||||
counts[sample] = counts.get(sample, 0) + 1
|
||||
self.assertEquals(len(counts), 2)
|
||||
self.assertTrue(500 in counts)
|
||||
self.assertTrue(1000 in counts)
|
||||
self.assertGreaterEqual(counts[500], 5)
|
||||
self.assertGreaterEqual(counts[1000], 5)
|
||||
|
||||
|
||||
class KMC2InitializationCornercaseTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._distances = np.zeros(10)
|
||||
|
||||
def runTestWithSeed(self, seed):
|
||||
with self.test_session():
|
||||
sampled_point = clustering_ops.kmc2_chain_initialization(
|
||||
self._distances, seed)
|
||||
self.assertEquals(sampled_point.eval(), 0)
|
||||
|
||||
def testBasic(self):
|
||||
for seed in range(100):
|
||||
self.runTestWithSeed(seed)
|
||||
|
||||
|
||||
# A simple test that can be verified by hand.
|
||||
class NearestCentersTest(test.TestCase):
|
||||
|
||||
|
@ -50,6 +50,7 @@ COSINE_DISTANCE = 'cosine'
|
||||
|
||||
RANDOM_INIT = 'random'
|
||||
KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus'
|
||||
KMC2_INIT = 'kmc2'
|
||||
|
||||
# The name of the variable holding the cluster centers. Used by the Estimator.
|
||||
CLUSTERS_VAR_NAME = 'clusters'
|
||||
@ -66,7 +67,8 @@ class KMeans(object):
|
||||
use_mini_batch=False,
|
||||
mini_batch_steps_per_iteration=1,
|
||||
random_seed=0,
|
||||
kmeans_plus_plus_num_retries=2):
|
||||
kmeans_plus_plus_num_retries=2,
|
||||
kmc2_chain_length=200):
|
||||
"""Creates an object for generating KMeans clustering graph.
|
||||
|
||||
This class implements the following variants of K-means algorithm:
|
||||
@ -95,7 +97,8 @@ class KMeans(object):
|
||||
exactly like a full-batch version.
|
||||
|
||||
Args:
|
||||
inputs: An input tensor or list of input tensors
|
||||
inputs: An input tensor or list of input tensors. It is assumed that the
|
||||
data points have been previously randomly permuted.
|
||||
num_clusters: An integer tensor specifying the number of clusters. This
|
||||
argument is ignored if initial_clusters is a tensor or numpy array.
|
||||
initial_clusters: Specifies the clusters used during initialization. One
|
||||
@ -104,6 +107,7 @@ class KMeans(object):
|
||||
- a function f(inputs, k) that returns up to k centers from `inputs`.
|
||||
- "random": Choose centers randomly from `inputs`.
|
||||
- "kmeans_plus_plus": Use kmeans++ to choose centers from `inputs`.
|
||||
- "kmc2": Use the fast k-MC2 algorithm to choose centers from `inputs`.
|
||||
In the last three cases, one batch of `inputs` may not yield
|
||||
`num_clusters` centers, in which case initialization will require
|
||||
multiple batches until enough centers are chosen. In the case of
|
||||
@ -121,13 +125,17 @@ class KMeans(object):
|
||||
additional points to draw from the current distribution before selecting
|
||||
the best. If a negative value is specified, a heuristic is used to
|
||||
sample O(log(num_to_sample)) additional points.
|
||||
kmc2_chain_length: Determines how many candidate points are used by the
|
||||
k-MC2 algorithm to produce one new cluster centers. If a (mini-)batch
|
||||
contains less points, one new cluster center is generated from the
|
||||
(mini-)batch.
|
||||
|
||||
Raises:
|
||||
ValueError: An invalid argument was passed to initial_clusters or
|
||||
distance_metric.
|
||||
"""
|
||||
if isinstance(initial_clusters, str) and initial_clusters not in [
|
||||
RANDOM_INIT, KMEANS_PLUS_PLUS_INIT
|
||||
RANDOM_INIT, KMEANS_PLUS_PLUS_INIT, KMC2_INIT
|
||||
]:
|
||||
raise ValueError(
|
||||
"Unsupported initialization algorithm '%s'" % initial_clusters)
|
||||
@ -141,6 +149,7 @@ class KMeans(object):
|
||||
self._mini_batch_steps_per_iteration = int(mini_batch_steps_per_iteration)
|
||||
self._random_seed = random_seed
|
||||
self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries
|
||||
self._kmc2_chain_length = kmc2_chain_length
|
||||
|
||||
@classmethod
|
||||
def _distance_graph(cls, inputs, clusters, distance_metric):
|
||||
@ -302,9 +311,10 @@ class KMeans(object):
|
||||
else:
|
||||
cluster_centers_updated = cluster_centers
|
||||
update_in_steps = None
|
||||
cluster_counts = (variable_scope.variable(
|
||||
array_ops.ones([num_clusters], dtype=dtypes.int64))
|
||||
if self._use_mini_batch else None)
|
||||
cluster_counts = (
|
||||
variable_scope.variable(
|
||||
array_ops.ones([num_clusters], dtype=dtypes.int64))
|
||||
if self._use_mini_batch else None)
|
||||
return (cluster_centers, cluster_centers_initialized, cluster_counts,
|
||||
cluster_centers_updated, update_in_steps)
|
||||
|
||||
@ -359,7 +369,7 @@ class KMeans(object):
|
||||
init_op = _InitializeClustersOpFactory(
|
||||
self._inputs, num_clusters, initial_clusters, self._distance_metric,
|
||||
self._random_seed, self._kmeans_plus_plus_num_retries,
|
||||
cluster_centers_var, cluster_centers_updated,
|
||||
self._kmc2_chain_length, cluster_centers_var, cluster_centers_updated,
|
||||
cluster_centers_initialized).op()
|
||||
cluster_centers = cluster_centers_var
|
||||
|
||||
@ -520,8 +530,9 @@ class KMeans(object):
|
||||
array_ops.reshape(array_ops.shape(inp)[0], [-1])),
|
||||
[-1, 1]), cluster_idx, num_clusters))
|
||||
with ops.colocate_with(cluster_centers, ignore_existing=True):
|
||||
new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast(
|
||||
math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon)
|
||||
new_clusters_centers = math_ops.add_n(cluster_sums) / (
|
||||
math_ops.cast(math_ops.add_n(cluster_counts), cluster_sums[0].dtype) +
|
||||
epsilon)
|
||||
if self._clusters_l2_normalized():
|
||||
new_clusters_centers = nn_impl.l2_normalize(new_clusters_centers, dim=1)
|
||||
return state_ops.assign(cluster_centers, new_clusters_centers)
|
||||
@ -548,9 +559,12 @@ class _InitializeClustersOpFactory(object):
|
||||
cluster_centers_initialized := true
|
||||
"""
|
||||
|
||||
# TODO(ccolby): Refactor this class so that kmc2 isn't so much a special case.
|
||||
|
||||
def __init__(self, inputs, num_clusters, initial_clusters, distance_metric,
|
||||
random_seed, kmeans_plus_plus_num_retries, cluster_centers,
|
||||
cluster_centers_updated, cluster_centers_initialized):
|
||||
random_seed, kmeans_plus_plus_num_retries, kmc2_chain_length,
|
||||
cluster_centers, cluster_centers_updated,
|
||||
cluster_centers_initialized):
|
||||
"""Creates an op factory.
|
||||
|
||||
Args:
|
||||
@ -560,6 +574,7 @@ class _InitializeClustersOpFactory(object):
|
||||
distance_metric: See KMeans constructor.
|
||||
random_seed: See KMeans constructor.
|
||||
kmeans_plus_plus_num_retries: See KMeans constructor.
|
||||
kmc2_chain_length: See KMeans constructor.
|
||||
cluster_centers: The TF variable holding the initial centers. It may
|
||||
already contain some centers when the op is executed.
|
||||
cluster_centers_updated: A second TF variable to hold a copy of the
|
||||
@ -575,6 +590,7 @@ class _InitializeClustersOpFactory(object):
|
||||
self._distance_metric = distance_metric
|
||||
self._random_seed = random_seed
|
||||
self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries
|
||||
self._kmc2_chain_length = kmc2_chain_length
|
||||
self._cluster_centers = cluster_centers
|
||||
self._cluster_centers_updated = cluster_centers_updated
|
||||
self._cluster_centers_initialized = cluster_centers_initialized
|
||||
@ -604,6 +620,90 @@ class _InitializeClustersOpFactory(object):
|
||||
math_ops.to_int64(self._num_remaining), self._random_seed,
|
||||
self._kmeans_plus_plus_num_retries)
|
||||
|
||||
def _kmc2_multiple_centers(self):
|
||||
"""Adds new initial cluster centers using the k-MC2 algorithm.
|
||||
|
||||
In each call to the op, the provided batch is split into subsets based on
|
||||
the specified `kmc2_chain_length`. On each subset, a single Markov chain of
|
||||
the k-MC2 algorithm is used to add *one* new center cluster center. If there
|
||||
are less than `kmc2_chain_length` points in the subset, a single center is
|
||||
added using one Markov chain on the full input. It is assumed that the
|
||||
provided batch has previously been randomly permuted. Otherwise, k-MC2 may
|
||||
return suboptimal centers.
|
||||
|
||||
Returns:
|
||||
An op that adds new cluster centers.
|
||||
"""
|
||||
# The op only operates on the first shard of data.
|
||||
first_shard = self._inputs[0]
|
||||
# Number of points in the input that can be used.
|
||||
batch_size = array_ops.shape(first_shard)[0]
|
||||
# Maximum number of subsets such that the size of each subset is at least
|
||||
# `kmc2_chain_length`. Final subsets may be larger.
|
||||
max_to_sample = math_ops.cast(
|
||||
batch_size / self._kmc2_chain_length, dtype=dtypes.int32)
|
||||
# We sample at least one new center and at most all remaining centers.
|
||||
num_to_sample = math_ops.maximum(
|
||||
math_ops.minimum(self._num_remaining, max_to_sample), 1)
|
||||
|
||||
def _cond(i, _):
|
||||
"""Stopping condition for the while loop."""
|
||||
return math_ops.less(i, num_to_sample)
|
||||
|
||||
def _body(i, _):
|
||||
"""Body that adds a single new center based on a subset."""
|
||||
|
||||
def _sample_random():
|
||||
"""Returns a random point as a cluster center."""
|
||||
# By assumption the batch is reshuffled and _sample_random is always
|
||||
# called for i=0. Hence, we simply return the first point.
|
||||
new_center = array_ops.reshape(first_shard[0], [1, -1])
|
||||
if self._distance_metric == COSINE_DISTANCE:
|
||||
new_center = nn_impl.l2_normalize(new_center, dim=1)
|
||||
return new_center
|
||||
|
||||
def _sample_kmc2_chain():
|
||||
"""Returns previous centers as well as a new center sampled using k-MC2.
|
||||
"""
|
||||
# Extract the subset from the underlying batch.
|
||||
start = i * self._kmc2_chain_length
|
||||
end = start + self._kmc2_chain_length
|
||||
subset = first_shard[start:end]
|
||||
# Compute the distances from points in the subset to previous centers.
|
||||
_, distances = gen_clustering_ops.nearest_neighbors(
|
||||
subset, self._cluster_centers, 1)
|
||||
# Sample index of new center using k-MC2 Markov chain.
|
||||
new_center_index = gen_clustering_ops.kmc2_chain_initialization(
|
||||
array_ops.squeeze(distances), self._random_seed)
|
||||
# Extract actual new center.
|
||||
newly_sampled_center = array_ops.reshape(subset[new_center_index],
|
||||
[1, -1])
|
||||
# Return concatenation with previously sampled centers.
|
||||
if self._distance_metric == COSINE_DISTANCE:
|
||||
newly_sampled_center = nn_impl.l2_normalize(
|
||||
newly_sampled_center, dim=1)
|
||||
return array_ops.concat([self._cluster_centers, newly_sampled_center],
|
||||
0)
|
||||
|
||||
# Obtain a random point if there are no previously sampled centers.
|
||||
# Otherwise, construct a k-MC2 Markov chain.
|
||||
new_centers = control_flow_ops.cond(
|
||||
math_ops.equal(self._num_selected, 0), _sample_random,
|
||||
_sample_kmc2_chain)
|
||||
# Assign new cluster centers to underlying variable.
|
||||
assigned_centers = state_ops.assign(
|
||||
self._cluster_centers, new_centers, validate_shape=False)
|
||||
if self._cluster_centers_updated is not self._cluster_centers:
|
||||
assigned_centers = state_ops.assign(
|
||||
self._cluster_centers_updated,
|
||||
assigned_centers,
|
||||
validate_shape=False)
|
||||
return i + 1, self._num_clusters - array_ops.shape(assigned_centers)[0]
|
||||
|
||||
# Add num_to_sample new data points.
|
||||
_, num_remaining = control_flow_ops.while_loop(_cond, _body, [0, 0])
|
||||
return num_remaining
|
||||
|
||||
def _greedy_batch_sampler(self, sampler):
|
||||
# If the input dataset size is smaller than the number of centers
|
||||
# remaining, choose the entire input dataset as centers. This can happen
|
||||
@ -657,7 +757,10 @@ class _InitializeClustersOpFactory(object):
|
||||
with ops.control_dependencies([
|
||||
check_ops.assert_positive(self._num_remaining),
|
||||
]):
|
||||
num_now_remaining = self._add_new_centers()
|
||||
if self._initial_clusters == KMC2_INIT:
|
||||
num_now_remaining = self._kmc2_multiple_centers()
|
||||
else:
|
||||
num_now_remaining = self._add_new_centers()
|
||||
return control_flow_ops.cond(
|
||||
math_ops.equal(num_now_remaining, 0),
|
||||
lambda: state_ops.assign(self._cluster_centers_initialized, True),
|
||||
|
@ -37,6 +37,7 @@ See the @{$python/contrib.framework} guide.
|
||||
|
||||
@@arg_scope
|
||||
@@add_arg_scope
|
||||
@@current_arg_scope
|
||||
@@has_arg_scope
|
||||
@@arg_scoped_arguments
|
||||
|
||||
|
@ -67,6 +67,7 @@ from tensorflow.python.util import tf_decorator
|
||||
|
||||
__all__ = ['arg_scope',
|
||||
'add_arg_scope',
|
||||
'current_arg_scope',
|
||||
'has_arg_scope',
|
||||
'arg_scoped_arguments']
|
||||
|
||||
@ -83,7 +84,7 @@ def _get_arg_stack():
|
||||
return _ARGSTACK
|
||||
|
||||
|
||||
def _current_arg_scope():
|
||||
def current_arg_scope():
|
||||
stack = _get_arg_stack()
|
||||
return stack[-1]
|
||||
|
||||
@ -144,7 +145,7 @@ def arg_scope(list_ops_or_scope, **kwargs):
|
||||
raise TypeError('list_ops_or_scope must either be a list/tuple or reused'
|
||||
'scope (i.e. dict)')
|
||||
try:
|
||||
current_scope = _current_arg_scope().copy()
|
||||
current_scope = current_arg_scope().copy()
|
||||
for op in list_ops_or_scope:
|
||||
key_op = _key_op(op)
|
||||
if not has_arg_scope(op):
|
||||
@ -172,7 +173,7 @@ def add_arg_scope(func):
|
||||
A tuple with the decorated function func_with_args().
|
||||
"""
|
||||
def func_with_args(*args, **kwargs):
|
||||
current_scope = _current_arg_scope()
|
||||
current_scope = current_arg_scope()
|
||||
current_args = kwargs
|
||||
key_func = _key_op(func)
|
||||
if key_func in current_scope:
|
||||
|
@ -442,7 +442,8 @@ def read_keyed_batch_features(file_pattern,
|
||||
feature_queue_capacity=100,
|
||||
num_enqueue_threads=2,
|
||||
parse_fn=None,
|
||||
name=None):
|
||||
name=None,
|
||||
read_batch_size=None):
|
||||
"""Adds operations to read, queue, batch and parse `Example` protos.
|
||||
|
||||
Given file pattern (or list of files), will setup a queue for file names,
|
||||
@ -482,6 +483,8 @@ def read_keyed_batch_features(file_pattern,
|
||||
parse_fn: Parsing function, takes `Example` Tensor returns parsed
|
||||
representation. If `None`, no parsing is done.
|
||||
name: Name of resulting op.
|
||||
read_batch_size: An int or scalar `Tensor` specifying the number of
|
||||
records to read at once. If `None`, defaults to `batch_size`.
|
||||
|
||||
Returns:
|
||||
Returns tuple of:
|
||||
@ -493,6 +496,7 @@ def read_keyed_batch_features(file_pattern,
|
||||
"""
|
||||
|
||||
with ops.name_scope(name, 'read_batch_features', [file_pattern]) as scope:
|
||||
if read_batch_size is None: read_batch_size = batch_size
|
||||
keys, examples = read_keyed_batch_examples(
|
||||
file_pattern,
|
||||
batch_size,
|
||||
@ -501,7 +505,7 @@ def read_keyed_batch_features(file_pattern,
|
||||
num_epochs=num_epochs,
|
||||
queue_capacity=queue_capacity,
|
||||
num_threads=reader_num_threads,
|
||||
read_batch_size=batch_size,
|
||||
read_batch_size=read_batch_size,
|
||||
parse_fn=parse_fn,
|
||||
name=scope)
|
||||
# Parse the example.
|
||||
@ -727,7 +731,8 @@ def read_batch_features(file_pattern,
|
||||
reader_num_threads=1,
|
||||
num_enqueue_threads=2,
|
||||
parse_fn=None,
|
||||
name=None):
|
||||
name=None,
|
||||
read_batch_size=None):
|
||||
"""Adds operations to read, queue, batch and parse `Example` protos.
|
||||
|
||||
Given file pattern (or list of files), will setup a queue for file names,
|
||||
@ -768,6 +773,8 @@ def read_batch_features(file_pattern,
|
||||
parse_fn: Parsing function, takes `Example` Tensor returns parsed
|
||||
representation. If `None`, no parsing is done.
|
||||
name: Name of resulting op.
|
||||
read_batch_size: An int or scalar `Tensor` specifying the number of
|
||||
records to read at once. If `None`, defaults to `batch_size`.
|
||||
|
||||
Returns:
|
||||
A dict of `Tensor` or `SparseTensor` objects for each in `features`.
|
||||
@ -786,6 +793,7 @@ def read_batch_features(file_pattern,
|
||||
reader_num_threads=reader_num_threads,
|
||||
feature_queue_capacity=feature_queue_capacity,
|
||||
num_enqueue_threads=num_enqueue_threads,
|
||||
read_batch_size=read_batch_size,
|
||||
parse_fn=parse_fn,
|
||||
name=name)
|
||||
return features
|
||||
|
@ -502,6 +502,7 @@ $(wildcard tensorflow/core/platform/google/*) \
|
||||
$(wildcard tensorflow/core/platform/google/*/*) \
|
||||
$(wildcard tensorflow/core/platform/jpeg.*) \
|
||||
$(wildcard tensorflow/core/platform/png.*) \
|
||||
$(wildcard tensorflow/core/platform/s3/*) \
|
||||
$(wildcard tensorflow/core/platform/stream_executor.*) \
|
||||
$(wildcard tensorflow/core/platform/windows/*) \
|
||||
$(wildcard tensorflow/core/user_ops/*.cu.cc) \
|
||||
|
@ -20,11 +20,11 @@ DOWNLOADS_DIR=tensorflow/contrib/makefile/downloads
|
||||
BZL_FILE_PATH=tensorflow/workspace.bzl
|
||||
|
||||
EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
|
||||
GEMMLOWP_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)"
|
||||
GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)"
|
||||
GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz"
|
||||
NSYNC_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
|
||||
PROTOBUF_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
|
||||
RE2_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
|
||||
NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
|
||||
PROTOBUF_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
|
||||
RE2_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
|
||||
FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
|
||||
|
||||
# TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64,
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1101,7 +1101,7 @@ class StreamingPrecisionTest(test.TestCase):
|
||||
predictions = random_ops.random_uniform(
|
||||
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
|
||||
labels = random_ops.random_uniform(
|
||||
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
|
||||
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
|
||||
precision, update_op = metrics.streaming_precision(predictions, labels)
|
||||
|
||||
with self.test_session() as sess:
|
||||
@ -1265,7 +1265,7 @@ class StreamingRecallTest(test.TestCase):
|
||||
predictions = random_ops.random_uniform(
|
||||
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
|
||||
labels = random_ops.random_uniform(
|
||||
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
|
||||
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
|
||||
recall, update_op = metrics.streaming_recall(predictions, labels)
|
||||
|
||||
with self.test_session() as sess:
|
||||
@ -1388,7 +1388,7 @@ class StreamingFPRTest(test.TestCase):
|
||||
predictions = random_ops.random_uniform(
|
||||
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
|
||||
labels = random_ops.random_uniform(
|
||||
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
|
||||
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
|
||||
fpr, update_op = metrics.streaming_false_positive_rate(
|
||||
predictions, labels)
|
||||
|
||||
@ -1516,7 +1516,7 @@ class StreamingFNRTest(test.TestCase):
|
||||
predictions = random_ops.random_uniform(
|
||||
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
|
||||
labels = random_ops.random_uniform(
|
||||
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
|
||||
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
|
||||
fnr, update_op = metrics.streaming_false_negative_rate(
|
||||
predictions, labels)
|
||||
|
||||
@ -1737,7 +1737,7 @@ class StreamingAUCTest(test.TestCase):
|
||||
predictions = random_ops.random_uniform(
|
||||
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
|
||||
labels = random_ops.random_uniform(
|
||||
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
|
||||
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
|
||||
auc, update_op = metrics.streaming_auc(predictions, labels)
|
||||
|
||||
with self.test_session() as sess:
|
||||
@ -2009,7 +2009,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
|
||||
predictions = random_ops.random_uniform(
|
||||
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
|
||||
labels = random_ops.random_uniform(
|
||||
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
|
||||
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
|
||||
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
|
||||
predictions, labels, sensitivity=0.7)
|
||||
|
||||
@ -2271,7 +2271,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
|
||||
predictions = random_ops.random_uniform(
|
||||
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
|
||||
labels = random_ops.random_uniform(
|
||||
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
|
||||
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
|
||||
thresholds = [0, 0.5, 1.0]
|
||||
prec, prec_op = metrics.streaming_precision_at_thresholds(predictions,
|
||||
labels,
|
||||
@ -2282,12 +2282,14 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
|
||||
with self.test_session() as sess:
|
||||
sess.run(variables.local_variables_initializer())
|
||||
|
||||
# Run several updates, then verify idempotency.
|
||||
sess.run([prec_op, rec_op])
|
||||
# Run several updates.
|
||||
for _ in range(10):
|
||||
sess.run([prec_op, rec_op])
|
||||
|
||||
# Then verify idempotency.
|
||||
initial_prec = prec.eval()
|
||||
initial_rec = rec.eval()
|
||||
for _ in range(10):
|
||||
sess.run([prec_op, rec_op])
|
||||
self.assertAllClose(initial_prec, prec.eval())
|
||||
self.assertAllClose(initial_rec, rec.eval())
|
||||
|
||||
@ -2361,14 +2363,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
|
||||
rec, rec_op = metrics.streaming_recall_at_thresholds(
|
||||
predictions, labels, thresholds, weights=weights)
|
||||
|
||||
[prec_low, prec_high] = array_ops.split(
|
||||
value=prec, num_or_size_splits=2, axis=0)
|
||||
prec_low = array_ops.reshape(prec_low, shape=())
|
||||
prec_high = array_ops.reshape(prec_high, shape=())
|
||||
[rec_low, rec_high] = array_ops.split(
|
||||
value=rec, num_or_size_splits=2, axis=0)
|
||||
rec_low = array_ops.reshape(rec_low, shape=())
|
||||
rec_high = array_ops.reshape(rec_high, shape=())
|
||||
prec_low = prec[0]
|
||||
prec_high = prec[1]
|
||||
rec_low = rec[0]
|
||||
rec_high = rec[1]
|
||||
|
||||
sess.run(variables.local_variables_initializer())
|
||||
sess.run([prec_op, rec_op])
|
||||
@ -2391,14 +2389,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
|
||||
rec, rec_op = metrics.streaming_recall_at_thresholds(
|
||||
predictions, labels, thresholds, weights=weights)
|
||||
|
||||
[prec_low, prec_high] = array_ops.split(
|
||||
value=prec, num_or_size_splits=2, axis=0)
|
||||
prec_low = array_ops.reshape(prec_low, shape=())
|
||||
prec_high = array_ops.reshape(prec_high, shape=())
|
||||
[rec_low, rec_high] = array_ops.split(
|
||||
value=rec, num_or_size_splits=2, axis=0)
|
||||
rec_low = array_ops.reshape(rec_low, shape=())
|
||||
rec_high = array_ops.reshape(rec_high, shape=())
|
||||
prec_low = prec[0]
|
||||
prec_high = prec[1]
|
||||
rec_low = rec[0]
|
||||
rec_high = rec[1]
|
||||
|
||||
sess.run(variables.local_variables_initializer())
|
||||
sess.run([prec_op, rec_op])
|
||||
@ -2420,10 +2414,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
|
||||
rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels,
|
||||
thresholds)
|
||||
|
||||
[prec_low, prec_high] = array_ops.split(
|
||||
value=prec, num_or_size_splits=2, axis=0)
|
||||
[rec_low, rec_high] = array_ops.split(
|
||||
value=rec, num_or_size_splits=2, axis=0)
|
||||
prec_low = prec[0]
|
||||
prec_high = prec[1]
|
||||
rec_low = rec[0]
|
||||
rec_high = rec[1]
|
||||
|
||||
sess.run(variables.local_variables_initializer())
|
||||
sess.run([prec_op, rec_op])
|
||||
@ -2562,7 +2556,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
|
||||
predictions = random_ops.random_uniform(
|
||||
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
|
||||
labels = random_ops.random_uniform(
|
||||
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
|
||||
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
|
||||
thresholds = [0, 0.5, 1.0]
|
||||
fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds(
|
||||
predictions, labels, thresholds)
|
||||
@ -2794,7 +2788,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
|
||||
predictions = random_ops.random_uniform(
|
||||
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
|
||||
labels = random_ops.random_uniform(
|
||||
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
|
||||
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
|
||||
thresholds = [0, 0.5, 1.0]
|
||||
fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds(
|
||||
predictions, labels, thresholds)
|
||||
|
@ -13,6 +13,34 @@ py_library(
|
||||
deps = [],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "graph_matcher",
|
||||
srcs = [
|
||||
"python/graph_matcher.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "graph_matcher_test",
|
||||
size = "small",
|
||||
srcs = ["python/graph_matcher_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":graph_matcher",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "input_to_ops",
|
||||
srcs = ["python/input_to_ops.py"],
|
||||
@ -43,6 +71,7 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":common",
|
||||
":graph_matcher",
|
||||
":input_to_ops",
|
||||
"//tensorflow/contrib/graph_editor:graph_editor_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
@ -58,6 +87,7 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":fold_batch_norms",
|
||||
":graph_matcher",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
@ -147,10 +177,11 @@ py_test(
|
||||
|
||||
py_test(
|
||||
name = "quantize_parameterized_test",
|
||||
size = "medium",
|
||||
size = "large",
|
||||
srcs = ["python/quantize_parameterized_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":fold_batch_norms",
|
||||
":quantize",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for tensorflow.quantized.mangle.copy_graph."""
|
||||
"""Tests for copy_graph."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
|
@ -21,7 +21,9 @@ from __future__ import print_function
|
||||
import re
|
||||
from tensorflow.contrib import graph_editor
|
||||
from tensorflow.contrib.quantize.python import common
|
||||
from tensorflow.contrib.quantize.python import graph_matcher
|
||||
from tensorflow.contrib.quantize.python import input_to_ops
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
@ -29,7 +31,7 @@ from tensorflow.python.ops import nn_ops
|
||||
|
||||
|
||||
def FoldBatchNorms(graph):
|
||||
"""Finds batch norm layers in the graph, folds them into preceding layers.
|
||||
"""Finds batch norm layers and folds them into preceding layers.
|
||||
|
||||
Folding only affects the following layers: Conv2D, fully connected, depthwise
|
||||
convolution.
|
||||
@ -40,10 +42,269 @@ def FoldBatchNorms(graph):
|
||||
Raises:
|
||||
ValueError: When batch norm folding fails.
|
||||
"""
|
||||
# Fail immediately when the graph contains unsupported fused batch norm ops.
|
||||
if any(op for op in graph.get_operations() if op.type == 'FusedBatchNorm'):
|
||||
raise ValueError('Fused batch norm is not supported')
|
||||
_FoldFusedBatchNorms(graph)
|
||||
_FoldUnfusedBatchNorms(graph)
|
||||
|
||||
|
||||
def _FoldFusedBatchNorms(graph):
|
||||
"""Finds fused batch norm layers and folds them into preceding layers.
|
||||
|
||||
Folding only affects the following layers: Conv2D, fully connected, depthwise
|
||||
convolution.
|
||||
|
||||
Args:
|
||||
graph: Graph to walk and modify.
|
||||
|
||||
Raises:
|
||||
ValueError: When batch norm folding fails.
|
||||
"""
|
||||
for match in _FindFusedBatchNorms(graph):
|
||||
scope, sep, _ = match.layer_op.name.rpartition('/')
|
||||
# Make sure new ops are added to `graph` and put on the same device as
|
||||
# `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope
|
||||
# named `scope`. Otherwise, TF creates a unique scope whose name starts with
|
||||
# `scope`.
|
||||
with graph.as_default(), graph.name_scope(scope + sep), ops.device(
|
||||
match.bn_op.device):
|
||||
# new weights = old weights * gamma / sqrt(variance + epsilon)
|
||||
# new biases = -mean * gamma / sqrt(variance + epsilon) + beta
|
||||
multiplier_tensor = match.gamma_tensor * math_ops.rsqrt(
|
||||
match.variance_tensor + match.bn_op.get_attr('epsilon'))
|
||||
bias_tensor = math_ops.subtract(
|
||||
match.beta_tensor, match.mean_tensor * multiplier_tensor, name='bias')
|
||||
|
||||
# The shape of depthwise weights is different, so we need to reshape the
|
||||
# multiplier_tensor to ensure that the scaled_weight_tensor has the
|
||||
# expected shape.
|
||||
if match.layer_op.type == 'DepthwiseConv2dNative':
|
||||
new_shape = [
|
||||
match.weight_tensor.get_shape().as_list()[2],
|
||||
match.weight_tensor.get_shape().as_list()[3]
|
||||
]
|
||||
multiplier_tensor = array_ops.reshape(
|
||||
multiplier_tensor, new_shape, name='scale_reshape')
|
||||
|
||||
# TODO(suharshs): This naming of the following ops needs to carefully
|
||||
# follow the naming expected by quantize.py. Generalize the quantize code
|
||||
# to not require these delicate naming conventions.
|
||||
scaled_weight_tensor = math_ops.multiply(
|
||||
match.weight_tensor, multiplier_tensor, name='mul_fold')
|
||||
|
||||
new_layer_tensor = _CloneWithNewOperands(
|
||||
match.layer_op, match.input_tensor, scaled_weight_tensor)
|
||||
|
||||
bias_add_tensor = math_ops.add(
|
||||
new_layer_tensor, bias_tensor, name='add_fold')
|
||||
|
||||
nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor,
|
||||
match.output_tensor)
|
||||
if nodes_modified_count != 1:
|
||||
raise ValueError(
|
||||
'Unexpected inputs to op: %s' % match.output_tensor.name)
|
||||
|
||||
|
||||
def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor):
|
||||
"""Clones layer_op with input_tensor and weight_tensor as new inputs."""
|
||||
new_layer_name = layer_op.name.split('/')[-1] + '_Fold'
|
||||
if layer_op.type == 'Conv2D':
|
||||
return nn_ops.conv2d(
|
||||
input_tensor,
|
||||
weight_tensor,
|
||||
strides=layer_op.get_attr('strides'),
|
||||
padding=layer_op.get_attr('padding'),
|
||||
use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'),
|
||||
data_format=layer_op.get_attr('data_format'),
|
||||
name=new_layer_name)
|
||||
elif layer_op.type == 'MatMul':
|
||||
return math_ops.matmul(
|
||||
input_tensor,
|
||||
weight_tensor,
|
||||
transpose_a=layer_op.get_attr('transpose_a'),
|
||||
transpose_b=layer_op.get_attr('transpose_b'),
|
||||
name=new_layer_name)
|
||||
elif layer_op.type == 'DepthwiseConv2dNative':
|
||||
return nn.depthwise_conv2d(
|
||||
input_tensor,
|
||||
weight_tensor,
|
||||
strides=layer_op.get_attr('strides'),
|
||||
padding=layer_op.get_attr('padding'),
|
||||
name=new_layer_name)
|
||||
else:
|
||||
raise ValueError('Cannot handle operation of type: %s' % layer_op.type)
|
||||
|
||||
|
||||
def _FindFusedBatchNorms(graph):
|
||||
"""Finds all ops and tensors related to found FusedBatchNorms.
|
||||
|
||||
Args:
|
||||
graph: Graph to inspect.
|
||||
|
||||
Yields:
|
||||
_FusedBatchNormMatches.
|
||||
"""
|
||||
input_pattern = graph_matcher.OpTypePattern('*')
|
||||
weight_pattern = graph_matcher.OpTypePattern('*')
|
||||
gamma_pattern = graph_matcher.OpTypePattern('*')
|
||||
beta_pattern = graph_matcher.OpTypePattern('*')
|
||||
mean_pattern = graph_matcher.OpTypePattern('*')
|
||||
variance_pattern = graph_matcher.OpTypePattern('*')
|
||||
|
||||
conv_pattern = graph_matcher.OpTypePattern(
|
||||
'Conv2D|DepthwiseConv2dNative', inputs=[input_pattern, weight_pattern])
|
||||
# MatMul has a Reshape between it and FusedBatchNorm.
|
||||
matmul_pattern = graph_matcher.OpTypePattern(
|
||||
'MatMul', inputs=[input_pattern, weight_pattern])
|
||||
matmul_reshape_pattern = graph_matcher.OpTypePattern(
|
||||
'Reshape', inputs=[matmul_pattern,
|
||||
graph_matcher.OpTypePattern('*')])
|
||||
|
||||
conv_batch_norm_pattern = graph_matcher.OpTypePattern(
|
||||
'FusedBatchNorm',
|
||||
inputs=[
|
||||
conv_pattern, gamma_pattern, beta_pattern, mean_pattern,
|
||||
variance_pattern
|
||||
])
|
||||
matmul_batch_norm_pattern = graph_matcher.OpTypePattern(
|
||||
'FusedBatchNorm',
|
||||
inputs=[
|
||||
matmul_reshape_pattern, gamma_pattern, beta_pattern, mean_pattern,
|
||||
variance_pattern
|
||||
])
|
||||
matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern(
|
||||
'Reshape',
|
||||
inputs=[matmul_batch_norm_pattern,
|
||||
graph_matcher.OpTypePattern('*')])
|
||||
|
||||
conv_matcher = graph_matcher.GraphMatcher(conv_batch_norm_pattern)
|
||||
matmul_matcher = graph_matcher.GraphMatcher(matmul_bn_output_reshape_pattern)
|
||||
|
||||
def _GetCommonTensors(match_result):
|
||||
"""Gets tensors needed for FusedBatchNormMatch from match_result."""
|
||||
input_tensor = match_result.get_tensor(input_pattern)
|
||||
weight_tensor = match_result.get_tensor(weight_pattern)
|
||||
gamma_tensor = match_result.get_tensor(gamma_pattern)
|
||||
beta_tensor = match_result.get_tensor(beta_pattern)
|
||||
# FusedBatchNorm in training is different from that in inference. It takes
|
||||
# empty 'mean' and empty 'variance', and produces the mean and the variance
|
||||
# of the batch. Therefore, when is_training is true, mean_tensor and
|
||||
# variance_tensor point to 1st and 2nd (0-based) output of bn_op,
|
||||
# respectively; when is_training is false, they point to bn_op's inputs.
|
||||
is_training = bn_op.get_attr('is_training')
|
||||
if is_training:
|
||||
mean_tensor = bn_op.outputs[1]
|
||||
variance_tensor = bn_op.outputs[2]
|
||||
else:
|
||||
mean_tensor = match_result.get_tensor(mean_pattern)
|
||||
variance_tensor = match_result.get_tensor(variance_pattern)
|
||||
return (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
|
||||
variance_tensor)
|
||||
|
||||
for match_result in conv_matcher.match_graph(graph):
|
||||
layer_op = match_result.get_op(conv_pattern)
|
||||
bn_op = match_result.get_op(conv_batch_norm_pattern)
|
||||
# In the case of convolution the output_tensor is the output of bn_op.
|
||||
output_tensor = bn_op.outputs[0]
|
||||
|
||||
(input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
|
||||
variance_tensor) = _GetCommonTensors(match_result)
|
||||
yield _FusedBatchNormMatch(
|
||||
layer_op=layer_op,
|
||||
bn_op=bn_op,
|
||||
output_tensor=output_tensor,
|
||||
input_tensor=input_tensor,
|
||||
weight_tensor=weight_tensor,
|
||||
gamma_tensor=gamma_tensor,
|
||||
beta_tensor=beta_tensor,
|
||||
mean_tensor=mean_tensor,
|
||||
variance_tensor=variance_tensor)
|
||||
|
||||
for match_result in matmul_matcher.match_graph(graph):
|
||||
layer_op = match_result.get_op(matmul_pattern)
|
||||
bn_op = match_result.get_op(matmul_batch_norm_pattern)
|
||||
# In the MatMul case, the output of batch norm is reshaped back into a
|
||||
# 2D tensor, so the output_tensor is the output of the Reshape op.
|
||||
output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern)
|
||||
output_tensor = output_reshape_op.outputs[0]
|
||||
|
||||
(input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
|
||||
variance_tensor) = _GetCommonTensors(match_result)
|
||||
yield _FusedBatchNormMatch(
|
||||
layer_op=layer_op,
|
||||
bn_op=bn_op,
|
||||
output_tensor=output_tensor,
|
||||
input_tensor=input_tensor,
|
||||
weight_tensor=weight_tensor,
|
||||
gamma_tensor=gamma_tensor,
|
||||
beta_tensor=beta_tensor,
|
||||
mean_tensor=mean_tensor,
|
||||
variance_tensor=variance_tensor)
|
||||
|
||||
|
||||
class _FusedBatchNormMatch(object):
|
||||
"""Contains all information related to a found FusedBatchNorm."""
|
||||
|
||||
def __init__(self, layer_op, bn_op, output_tensor, input_tensor,
|
||||
weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
|
||||
variance_tensor):
|
||||
self._layer_op = layer_op
|
||||
self._bn_op = bn_op
|
||||
self._output_tensor = output_tensor
|
||||
self._input_tensor = input_tensor
|
||||
self._weight_tensor = weight_tensor
|
||||
self._gamma_tensor = gamma_tensor
|
||||
self._beta_tensor = beta_tensor
|
||||
self._mean_tensor = mean_tensor
|
||||
self._variance_tensor = variance_tensor
|
||||
|
||||
@property
|
||||
def layer_op(self):
|
||||
return self._layer_op
|
||||
|
||||
@property
|
||||
def bn_op(self):
|
||||
return self._bn_op
|
||||
|
||||
@property
|
||||
def output_tensor(self):
|
||||
return self._output_tensor
|
||||
|
||||
@property
|
||||
def input_tensor(self):
|
||||
return self._input_tensor
|
||||
|
||||
@property
|
||||
def weight_tensor(self):
|
||||
return self._weight_tensor
|
||||
|
||||
@property
|
||||
def gamma_tensor(self):
|
||||
return self._gamma_tensor
|
||||
|
||||
@property
|
||||
def beta_tensor(self):
|
||||
return self._beta_tensor
|
||||
|
||||
@property
|
||||
def mean_tensor(self):
|
||||
return self._mean_tensor
|
||||
|
||||
@property
|
||||
def variance_tensor(self):
|
||||
return self._variance_tensor
|
||||
|
||||
|
||||
def _FoldUnfusedBatchNorms(graph):
|
||||
"""Finds unfused batch norm layers and folds them into preceding layers.
|
||||
|
||||
Folding only affects the following layers: Conv2D, fully connected, depthwise
|
||||
convolution.
|
||||
|
||||
Args:
|
||||
graph: Graph to walk and modify.
|
||||
|
||||
Raises:
|
||||
ValueError: When batch norm folding fails.
|
||||
"""
|
||||
input_to_ops_map = input_to_ops.InputToOps(graph)
|
||||
|
||||
for bn in common.BatchNormGroups(graph):
|
||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
from tensorflow.contrib.layers.python.layers import layers
|
||||
from tensorflow.contrib.quantize.python import fold_batch_norms
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -35,57 +34,32 @@ conv2d = layers.conv2d
|
||||
fully_connected = layers.fully_connected
|
||||
separable_conv2d = layers.separable_conv2d
|
||||
|
||||
_DEFAULT_BATCH_NORM_PARAMS = {
|
||||
'center': True,
|
||||
'scale': True,
|
||||
'decay': 1.0 - 0.003,
|
||||
'fused': False,
|
||||
}
|
||||
|
||||
|
||||
# TODO(suharshs): Use parameterized test once OSS TF supports it.
|
||||
class FoldBatchNormsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def _RunTestOverParameters(self, test_fn):
|
||||
parameters_list = [
|
||||
# (relu, relu_op_name, with_bypass)
|
||||
(nn_ops.relu6, 'Relu6', False),
|
||||
(nn_ops.relu, 'Relu', False),
|
||||
(nn_ops.relu6, 'Relu6', True),
|
||||
(nn_ops.relu, 'Relu', True),
|
||||
# (relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm)
|
||||
(nn_ops.relu6, 'Relu6', False, False, False),
|
||||
(nn_ops.relu, 'Relu', False, False, False),
|
||||
(nn_ops.relu6, 'Relu6', True, False, False),
|
||||
(nn_ops.relu, 'Relu', True, False, False),
|
||||
(nn_ops.relu6, 'Relu6', False, True, False),
|
||||
(nn_ops.relu, 'Relu', False, True, False),
|
||||
(nn_ops.relu6, 'Relu6', True, True, False),
|
||||
(nn_ops.relu, 'Relu', True, True, False),
|
||||
# Fused batch norm always has scaling enabled.
|
||||
(nn_ops.relu6, 'Relu6', False, True, True),
|
||||
(nn_ops.relu, 'Relu', False, True, True),
|
||||
(nn_ops.relu6, 'Relu6', True, True, True),
|
||||
(nn_ops.relu, 'Relu', True, True, True),
|
||||
]
|
||||
for parameters in parameters_list:
|
||||
test_fn(parameters[0], parameters[1], parameters[2])
|
||||
for params in parameters_list:
|
||||
test_fn(params[0], params[1], params[2], params[3], params[4])
|
||||
|
||||
def testFailsWithFusedBatchNorm(self):
|
||||
self._RunTestOverParameters(self._TestFailsWithFusedBatchNorm)
|
||||
|
||||
def _TestFailsWithFusedBatchNorm(self, relu, relu_op_name, with_bypass):
|
||||
"""Tests that batch norm fails when fused batch norm ops are present."""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
batch_size, height, width = 5, 128, 128
|
||||
inputs = array_ops.zeros((batch_size, height, width, 3))
|
||||
out_depth = 3 if with_bypass else 32
|
||||
stride = 1 if with_bypass else 2
|
||||
activation_fn = None if with_bypass else relu
|
||||
batch_norm_params = _DEFAULT_BATCH_NORM_PARAMS.copy()
|
||||
batch_norm_params['fused'] = True
|
||||
scope = 'test/test2' if with_bypass else 'test'
|
||||
node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME',
|
||||
weights_initializer=self._WeightInit(0.09),
|
||||
activation_fn=activation_fn,
|
||||
normalizer_fn=batch_norm,
|
||||
normalizer_params=batch_norm_params,
|
||||
scope=scope)
|
||||
if with_bypass:
|
||||
node = math_ops.add(inputs, node, name='test/Add')
|
||||
relu(node, name='test/' + relu_op_name)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
fold_batch_norms.FoldBatchNorms(g)
|
||||
|
||||
def _TestFoldConv2d(self, relu, relu_op_name, with_bypass):
|
||||
def _TestFoldConv2d(self, relu, relu_op_name, with_bypass, has_scaling,
|
||||
fused_batch_norm):
|
||||
"""Tests folding cases: inputs -> Conv2d with batch norm -> Relu*.
|
||||
|
||||
Args:
|
||||
@ -93,6 +67,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
|
||||
relu_op_name: String, name of the Relu* operation.
|
||||
with_bypass: Bool, when true there is an extra connection added from
|
||||
inputs to just before Relu*.
|
||||
has_scaling: Bool, when true the batch norm has scaling.
|
||||
fused_batch_norm: Bool, when true the batch norm is fused.
|
||||
"""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
@ -102,12 +78,17 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
|
||||
stride = 1 if with_bypass else 2
|
||||
activation_fn = None if with_bypass else relu
|
||||
scope = 'test/test2' if with_bypass else 'test'
|
||||
node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME',
|
||||
weights_initializer=self._WeightInit(0.09),
|
||||
activation_fn=activation_fn,
|
||||
normalizer_fn=batch_norm,
|
||||
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS,
|
||||
scope=scope)
|
||||
node = conv2d(
|
||||
inputs,
|
||||
out_depth, [5, 5],
|
||||
stride=stride,
|
||||
padding='SAME',
|
||||
weights_initializer=self._WeightInit(0.09),
|
||||
activation_fn=activation_fn,
|
||||
normalizer_fn=batch_norm,
|
||||
normalizer_params=self._BatchNormParams(
|
||||
scale=has_scaling, fused=fused_batch_norm),
|
||||
scope=scope)
|
||||
if with_bypass:
|
||||
node = math_ops.add(inputs, node, name='test/Add')
|
||||
relu(node, name='test/' + relu_op_name)
|
||||
@ -116,9 +97,10 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
folded_mul = g.get_operation_by_name(scope + '/mul_fold')
|
||||
self.assertEqual(folded_mul.type, 'Mul')
|
||||
self._AssertInputOpsAre(folded_mul,
|
||||
[scope + '/weights/read',
|
||||
scope + '/BatchNorm/batchnorm/mul'])
|
||||
self._AssertInputOpsAre(folded_mul, [
|
||||
scope + '/weights/read',
|
||||
self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm)
|
||||
])
|
||||
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold'])
|
||||
|
||||
folded_conv = g.get_operation_by_name(scope + '/convolution_Fold')
|
||||
@ -129,16 +111,18 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
folded_add = g.get_operation_by_name(scope + '/add_fold')
|
||||
self.assertEqual(folded_add.type, 'Add')
|
||||
self._AssertInputOpsAre(folded_add,
|
||||
[scope + '/convolution_Fold',
|
||||
scope + '/BatchNorm/batchnorm/sub'])
|
||||
self._AssertInputOpsAre(folded_add, [
|
||||
scope + '/convolution_Fold',
|
||||
self._BathNormBiasName(scope, fused_batch_norm)
|
||||
])
|
||||
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
|
||||
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
|
||||
|
||||
def testFoldConv2d(self):
|
||||
self._RunTestOverParameters(self._TestFoldConv2d)
|
||||
|
||||
def _TestFoldConv2dUnknownShape(self, relu, relu_op_name, with_bypass):
|
||||
def _TestFoldConv2dUnknownShape(self, relu, relu_op_name, with_bypass,
|
||||
has_scaling, fused_batch_norm):
|
||||
"""Tests folding cases: inputs -> Conv2d with batch norm -> Relu*.
|
||||
|
||||
Tests that folding works even with an input shape where some dimensions are
|
||||
@ -149,6 +133,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
|
||||
relu_op_name: String, name of the Relu* operation.
|
||||
with_bypass: Bool, when true there is an extra connection added from
|
||||
inputs to just before Relu*.
|
||||
has_scaling: Bool, when true the batch norm has scaling.
|
||||
fused_batch_norm: Bool, when true the batch norm is fused.
|
||||
"""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
@ -165,7 +151,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
|
||||
weights_initializer=self._WeightInit(0.09),
|
||||
activation_fn=activation_fn,
|
||||
normalizer_fn=batch_norm,
|
||||
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS,
|
||||
normalizer_params=self._BatchNormParams(
|
||||
scale=has_scaling, fused=fused_batch_norm),
|
||||
scope=scope)
|
||||
if with_bypass:
|
||||
node = math_ops.add(inputs, node, name='test/Add')
|
||||
@ -176,7 +163,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
|
||||
folded_mul = g.get_operation_by_name(scope + '/mul_fold')
|
||||
self.assertEqual(folded_mul.type, 'Mul')
|
||||
self._AssertInputOpsAre(folded_mul, [
|
||||
scope + '/weights/read', scope + '/BatchNorm/batchnorm/mul'
|
||||
scope + '/weights/read',
|
||||
self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm)
|
||||
])
|
||||
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold'])
|
||||
|
||||
@ -188,7 +176,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
|
||||
folded_add = g.get_operation_by_name(scope + '/add_fold')
|
||||
self.assertEqual(folded_add.type, 'Add')
|
||||
self._AssertInputOpsAre(folded_add, [
|
||||
scope + '/convolution_Fold', scope + '/BatchNorm/batchnorm/sub'
|
||||
scope + '/convolution_Fold',
|
||||
self._BathNormBiasName(scope, fused_batch_norm)
|
||||
])
|
||||
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
|
||||
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
|
||||
@ -196,62 +185,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
|
||||
def testFoldConv2dUnknownShape(self):
|
||||
self._RunTestOverParameters(self._TestFoldConv2dUnknownShape)
|
||||
|
||||
def _TestFoldConv2dWithoutScale(self, relu, relu_op_name, with_bypass):
|
||||
"""Tests folding cases: inputs -> Conv2d with batch norm -> Relu*.
|
||||
|
||||
Args:
|
||||
relu: Callable that returns an Operation, a factory method for the Relu*.
|
||||
relu_op_name: String, name of the Relu* operation.
|
||||
with_bypass: Bool, when true there is an extra connection added from
|
||||
inputs to just before Relu*.
|
||||
"""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
batch_size, height, width = 5, 128, 128
|
||||
inputs = array_ops.zeros((batch_size, height, width, 3))
|
||||
out_depth = 3 if with_bypass else 32
|
||||
stride = 1 if with_bypass else 2
|
||||
activation_fn = None if with_bypass else relu
|
||||
bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS)
|
||||
bn_params['scale'] = False
|
||||
scope = 'test/test2' if with_bypass else 'test'
|
||||
node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME',
|
||||
weights_initializer=self._WeightInit(0.09),
|
||||
activation_fn=activation_fn,
|
||||
normalizer_fn=batch_norm,
|
||||
normalizer_params=bn_params,
|
||||
scope=scope)
|
||||
if with_bypass:
|
||||
node = math_ops.add(inputs, node, name='test/Add')
|
||||
relu(node, name='test/' + relu_op_name)
|
||||
|
||||
fold_batch_norms.FoldBatchNorms(g)
|
||||
|
||||
folded_mul = g.get_operation_by_name(scope + '/mul_fold')
|
||||
self.assertEqual(folded_mul.type, 'Mul')
|
||||
self._AssertInputOpsAre(folded_mul,
|
||||
[scope + '/weights/read',
|
||||
scope + '/BatchNorm/batchnorm/Rsqrt'])
|
||||
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold'])
|
||||
|
||||
folded_conv = g.get_operation_by_name(scope + '/convolution_Fold')
|
||||
self.assertEqual(folded_conv.type, 'Conv2D')
|
||||
self._AssertInputOpsAre(folded_conv,
|
||||
[scope + '/mul_fold', inputs.op.name])
|
||||
self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold'])
|
||||
|
||||
folded_add = g.get_operation_by_name(scope + '/add_fold')
|
||||
self.assertEqual(folded_add.type, 'Add')
|
||||
self._AssertInputOpsAre(folded_add,
|
||||
[scope + '/convolution_Fold',
|
||||
scope + '/BatchNorm/batchnorm/sub'])
|
||||
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
|
||||
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
|
||||
|
||||
def testFoldConv2dWithoutScale(self):
|
||||
self._RunTestOverParameters(self._TestFoldConv2dWithoutScale)
|
||||
|
||||
def _TestFoldFullyConnectedLayer(self, relu, relu_op_name, with_bypass):
|
||||
def _TestFoldFullyConnectedLayer(self, relu, relu_op_name, with_bypass,
|
||||
has_scaling, fused_batch_norm):
|
||||
"""Tests folding cases: inputs -> FC with batch norm -> Relu*.
|
||||
|
||||
Args:
|
||||
@ -259,6 +194,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
|
||||
relu_op_name: String, name of the Relu* operation.
|
||||
with_bypass: Bool, when true there is an extra connection added from
|
||||
inputs to just before Relu*.
|
||||
has_scaling: Bool, when true the batch norm has scaling.
|
||||
fused_batch_norm: Bool, when true the batch norm is fused.
|
||||
"""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
@ -267,12 +204,15 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
|
||||
out_depth = 256 if with_bypass else 128
|
||||
activation_fn = None if with_bypass else relu
|
||||
scope = 'test/test2' if with_bypass else 'test'
|
||||
node = fully_connected(inputs, out_depth,
|
||||
weights_initializer=self._WeightInit(0.03),
|
||||
activation_fn=activation_fn,
|
||||
normalizer_fn=batch_norm,
|
||||
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS,
|
||||
scope=scope)
|
||||
node = fully_connected(
|
||||
inputs,
|
||||
out_depth,
|
||||
weights_initializer=self._WeightInit(0.03),
|
||||
activation_fn=activation_fn,
|
||||
normalizer_fn=batch_norm,
|
||||
normalizer_params=self._BatchNormParams(
|
||||
scale=has_scaling, fused=fused_batch_norm),
|
||||
scope=scope)
|
||||
if with_bypass:
|
||||
node = math_ops.add(inputs, node, name='test/Add')
|
||||
relu(node, name='test/' + relu_op_name)
|
||||
@ -281,9 +221,10 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
folded_mul = g.get_operation_by_name(scope + '/mul_fold')
|
||||
self.assertEqual(folded_mul.type, 'Mul')
|
||||
self._AssertInputOpsAre(folded_mul,
|
||||
[scope + '/weights/read',
|
||||
scope + '/BatchNorm/batchnorm/mul'])
|
||||
self._AssertInputOpsAre(folded_mul, [
|
||||
scope + '/weights/read',
|
||||
self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm)
|
||||
])
|
||||
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/MatMul_Fold'])
|
||||
|
||||
folded_conv = g.get_operation_by_name(scope + '/MatMul_Fold')
|
||||
@ -294,71 +235,18 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
folded_add = g.get_operation_by_name(scope + '/add_fold')
|
||||
self.assertEqual(folded_add.type, 'Add')
|
||||
self._AssertInputOpsAre(folded_add,
|
||||
[scope + '/MatMul_Fold',
|
||||
scope + '/BatchNorm/batchnorm/sub'])
|
||||
self._AssertInputOpsAre(folded_add, [
|
||||
scope + '/MatMul_Fold',
|
||||
self._BathNormBiasName(scope, fused_batch_norm)
|
||||
])
|
||||
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
|
||||
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
|
||||
|
||||
def testFoldFullyConnectedLayer(self):
|
||||
self._RunTestOverParameters(self._TestFoldFullyConnectedLayer)
|
||||
|
||||
def _TestFoldFullyConnectedLayerWithoutScale(self, relu, relu_op_name,
|
||||
with_bypass):
|
||||
"""Tests folding cases: inputs -> FC with batch norm -> Relu*.
|
||||
|
||||
Args:
|
||||
relu: Callable that returns an Operation, a factory method for the Relu*.
|
||||
relu_op_name: String, name of the Relu* operation.
|
||||
with_bypass: Bool, when true there is an extra connection added from
|
||||
inputs to just before Relu*.
|
||||
"""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
batch_size, depth = 5, 256
|
||||
inputs = array_ops.zeros((batch_size, depth))
|
||||
out_depth = 256 if with_bypass else 128
|
||||
activation_fn = None if with_bypass else relu
|
||||
bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS)
|
||||
bn_params['scale'] = False
|
||||
scope = 'test/test2' if with_bypass else 'test'
|
||||
node = fully_connected(inputs, out_depth,
|
||||
weights_initializer=self._WeightInit(0.03),
|
||||
activation_fn=activation_fn,
|
||||
normalizer_fn=batch_norm,
|
||||
normalizer_params=bn_params,
|
||||
scope=scope)
|
||||
if with_bypass:
|
||||
node = math_ops.add(inputs, node, name='test/Add')
|
||||
relu(node, name='test/' + relu_op_name)
|
||||
|
||||
fold_batch_norms.FoldBatchNorms(g)
|
||||
|
||||
folded_mul = g.get_operation_by_name(scope + '/mul_fold')
|
||||
self.assertEqual(folded_mul.type, 'Mul')
|
||||
self._AssertInputOpsAre(folded_mul,
|
||||
[scope + '/weights/read',
|
||||
scope + '/BatchNorm/batchnorm/Rsqrt'])
|
||||
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/MatMul_Fold'])
|
||||
|
||||
folded_conv = g.get_operation_by_name(scope + '/MatMul_Fold')
|
||||
self.assertEqual(folded_conv.type, 'MatMul')
|
||||
self._AssertInputOpsAre(folded_conv,
|
||||
[scope + '/mul_fold', inputs.op.name])
|
||||
self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold'])
|
||||
|
||||
folded_add = g.get_operation_by_name(scope + '/add_fold')
|
||||
self.assertEqual(folded_add.type, 'Add')
|
||||
self._AssertInputOpsAre(folded_add,
|
||||
[scope + '/MatMul_Fold',
|
||||
scope + '/BatchNorm/batchnorm/sub'])
|
||||
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
|
||||
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
|
||||
|
||||
def testFoldFullyConnectedLayerWithoutScale(self):
|
||||
self._RunTestOverParameters(self._TestFoldFullyConnectedLayerWithoutScale)
|
||||
|
||||
def _TestFoldDepthwiseConv2d(self, relu, relu_op_name, with_bypass):
|
||||
def _TestFoldDepthwiseConv2d(self, relu, relu_op_name, with_bypass,
|
||||
has_scaling, fused_batch_norm):
|
||||
"""Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*.
|
||||
|
||||
Args:
|
||||
@ -366,6 +254,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
|
||||
relu_op_name: String, name of the Relu* operation.
|
||||
with_bypass: Bool, when true there is an extra connection added from
|
||||
inputs to just before Relu*.
|
||||
has_scaling: Bool, when true the batch norm has scaling.
|
||||
fused_batch_norm: Bool, when true the batch norm is fused.
|
||||
"""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
@ -374,13 +264,18 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
|
||||
stride = 1 if with_bypass else 2
|
||||
activation_fn = None if with_bypass else relu
|
||||
scope = 'test/test2' if with_bypass else 'test'
|
||||
node = separable_conv2d(inputs, None, [5, 5], stride=stride,
|
||||
depth_multiplier=1.0, padding='SAME',
|
||||
weights_initializer=self._WeightInit(0.09),
|
||||
activation_fn=activation_fn,
|
||||
normalizer_fn=batch_norm,
|
||||
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS,
|
||||
scope=scope)
|
||||
node = separable_conv2d(
|
||||
inputs,
|
||||
None, [5, 5],
|
||||
stride=stride,
|
||||
depth_multiplier=1.0,
|
||||
padding='SAME',
|
||||
weights_initializer=self._WeightInit(0.09),
|
||||
activation_fn=activation_fn,
|
||||
normalizer_fn=batch_norm,
|
||||
normalizer_params=self._BatchNormParams(
|
||||
scale=has_scaling, fused=fused_batch_norm),
|
||||
scope=scope)
|
||||
if with_bypass:
|
||||
node = math_ops.add(inputs, node, name='test/Add')
|
||||
relu(node, name='test/' + relu_op_name)
|
||||
@ -396,9 +291,10 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
scale_reshape = g.get_operation_by_name(scope + '/scale_reshape')
|
||||
self.assertEqual(scale_reshape.type, 'Reshape')
|
||||
self._AssertInputOpsAre(scale_reshape,
|
||||
[scope + '/BatchNorm/batchnorm/mul',
|
||||
scope + '/scale_reshape/shape'])
|
||||
self._AssertInputOpsAre(scale_reshape, [
|
||||
self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm),
|
||||
scope + '/scale_reshape/shape'
|
||||
])
|
||||
self._AssertOutputGoesToOps(scale_reshape, g, [scope + '/mul_fold'])
|
||||
|
||||
folded_conv = g.get_operation_by_name(scope + '/depthwise_Fold')
|
||||
@ -409,77 +305,35 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
|
||||
|
||||
folded_add = g.get_operation_by_name(scope + '/add_fold')
|
||||
self.assertEqual(folded_add.type, 'Add')
|
||||
self._AssertInputOpsAre(folded_add,
|
||||
[scope + '/depthwise_Fold',
|
||||
scope + '/BatchNorm/batchnorm/sub'])
|
||||
self._AssertInputOpsAre(folded_add, [
|
||||
scope + '/depthwise_Fold',
|
||||
self._BathNormBiasName(scope, fused_batch_norm)
|
||||
])
|
||||
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
|
||||
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
|
||||
|
||||
def testFoldDepthwiseConv2d(self):
|
||||
self._RunTestOverParameters(self._TestFoldDepthwiseConv2d)
|
||||
|
||||
def _TestFoldDepthwiseConv2dWithoutScale(self, relu, relu_op_name,
|
||||
with_bypass):
|
||||
"""Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*.
|
||||
def _BatchNormParams(self, scale=True, fused=False):
|
||||
return {
|
||||
'center': True,
|
||||
'scale': scale,
|
||||
'decay': 1.0 - 0.003,
|
||||
'fused': fused
|
||||
}
|
||||
|
||||
Args:
|
||||
relu: Callable that returns an Operation, a factory method for the Relu*.
|
||||
relu_op_name: String, name of the Relu* operation.
|
||||
with_bypass: Bool, when true there is an extra connection added from
|
||||
inputs to just before Relu*.
|
||||
"""
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
batch_size, height, width = 5, 128, 128
|
||||
inputs = array_ops.zeros((batch_size, height, width, 3))
|
||||
stride = 1 if with_bypass else 2
|
||||
activation_fn = None if with_bypass else relu
|
||||
bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS)
|
||||
bn_params['scale'] = False
|
||||
scope = 'test/test2' if with_bypass else 'test'
|
||||
node = separable_conv2d(inputs, None, [5, 5], stride=stride,
|
||||
depth_multiplier=1.0, padding='SAME',
|
||||
weights_initializer=self._WeightInit(0.09),
|
||||
activation_fn=activation_fn,
|
||||
normalizer_fn=batch_norm,
|
||||
normalizer_params=bn_params,
|
||||
scope=scope)
|
||||
if with_bypass:
|
||||
node = math_ops.add(inputs, node, name='test/Add')
|
||||
relu(node, name='test/' + relu_op_name)
|
||||
def _BatchNormMultiplierName(self, scope, has_scaling, fused):
|
||||
if has_scaling:
|
||||
if fused:
|
||||
return scope + '/mul'
|
||||
return scope + '/BatchNorm/batchnorm/mul'
|
||||
return scope + '/BatchNorm/batchnorm/Rsqrt'
|
||||
|
||||
fold_batch_norms.FoldBatchNorms(g)
|
||||
|
||||
folded_mul = g.get_operation_by_name(scope + '/mul_fold')
|
||||
self.assertEqual(folded_mul.type, 'Mul')
|
||||
self._AssertInputOpsAre(folded_mul,
|
||||
[scope + '/depthwise_weights/read',
|
||||
scope + '/scale_reshape'])
|
||||
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/depthwise_Fold'])
|
||||
|
||||
scale_reshape = g.get_operation_by_name(scope + '/scale_reshape')
|
||||
self.assertEqual(scale_reshape.type, 'Reshape')
|
||||
self._AssertInputOpsAre(scale_reshape,
|
||||
[scope + '/BatchNorm/batchnorm/Rsqrt',
|
||||
scope + '/scale_reshape/shape'])
|
||||
self._AssertOutputGoesToOps(scale_reshape, g, [scope + '/mul_fold'])
|
||||
|
||||
folded_conv = g.get_operation_by_name(scope + '/depthwise_Fold')
|
||||
self.assertEqual(folded_conv.type, 'DepthwiseConv2dNative')
|
||||
self._AssertInputOpsAre(folded_conv,
|
||||
[scope + '/mul_fold', inputs.op.name])
|
||||
self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold'])
|
||||
|
||||
folded_add = g.get_operation_by_name(scope + '/add_fold')
|
||||
self.assertEqual(folded_add.type, 'Add')
|
||||
self._AssertInputOpsAre(folded_add,
|
||||
[scope + '/depthwise_Fold',
|
||||
scope + '/BatchNorm/batchnorm/sub'])
|
||||
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
|
||||
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
|
||||
|
||||
def testFoldDepthwiseConv2dWithoutScale(self):
|
||||
self._RunTestOverParameters(self._TestFoldDepthwiseConv2dWithoutScale)
|
||||
def _BathNormBiasName(self, scope, fused):
|
||||
if fused:
|
||||
return scope + '/bias'
|
||||
return scope + '/BatchNorm/batchnorm/sub'
|
||||
|
||||
def _WeightInit(self, stddev):
|
||||
"""Returns a truncated normal variable initializer.
|
||||
|
200
tensorflow/contrib/quantize/python/graph_matcher.py
Normal file
200
tensorflow/contrib/quantize/python/graph_matcher.py
Normal file
@ -0,0 +1,200 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Utilities that match patterns in a tf.Graph."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
class OpTypePattern(object):
|
||||
"""A tree pattern that matches TF expressions with certain op types."""
|
||||
|
||||
def __init__(self, op_type, name=None, inputs=None):
|
||||
"""Initializes an OpTypePattern.
|
||||
|
||||
Args:
|
||||
op_type: string that specifies the allowed types of the root. It can be
|
||||
(1) an op type, e.g. 'Conv2D',
|
||||
(2) '*', i.e. wildcard, or
|
||||
(3) multiple op types separated by '|', e.g., 'Relu|Relu6'.
|
||||
We could use regex strings, which might be worthwhile when we have many
|
||||
similar TF op types.
|
||||
name: Optional string. The name of the pattern that can be looked up in
|
||||
MatchResult.
|
||||
inputs: Optional list of `OpTypePattern`s or strings that specify the
|
||||
patterns for the inputs of a matching op. If None, this pattern accepts
|
||||
any inputs of a matching op.
|
||||
"""
|
||||
self._op_type = op_type
|
||||
self._name = name
|
||||
if inputs is None:
|
||||
inputs = []
|
||||
self._inputs = [
|
||||
input_pattern if isinstance(input_pattern, OpTypePattern) else
|
||||
OpTypePattern(input_pattern) for input_pattern in inputs
|
||||
]
|
||||
|
||||
@property
|
||||
def op_type(self):
|
||||
return self._op_type
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return self._inputs
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
|
||||
class MatchResult(object):
|
||||
r"""Encapsulates the result of a match done by GraphMatcher.
|
||||
|
||||
MatchResult contains a map from OpTypePattern to the matching op and tensor.
|
||||
When the matching op has multiple output tensors, the matching tensor is the
|
||||
output tensor used by the matching op of the parent pattern. E.g., when we
|
||||
match graph
|
||||
|
||||
- +
|
||||
/ \y0 y1/ \
|
||||
x split z
|
||||
|
|
||||
y (nodes are ops; edges are going up)
|
||||
|
||||
against add_pattern defined as
|
||||
|
||||
y1_pattern = OpTypePattern('*')
|
||||
z_pattern = OpTypePattern('*')
|
||||
add_pattern = OpTypePattern('+', inputs=[y1_pattern, z_pattern])
|
||||
|
||||
the matching op of `y1_pattern` is `split`, and the matching tensor of
|
||||
`y1_pattern`
|
||||
is `y1` not `y0`.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._pattern_to_op_tensor = {}
|
||||
self._name_to_pattern = {}
|
||||
|
||||
def add(self, pattern, op, tensor):
|
||||
self._pattern_to_op_tensor[pattern] = op, tensor
|
||||
if pattern.name is not None:
|
||||
if pattern.name in self._name_to_pattern:
|
||||
raise ValueError(
|
||||
'Name %s is already bound to another pattern' % pattern.name)
|
||||
self._name_to_pattern[pattern.name] = pattern
|
||||
|
||||
def _to_pattern(self, pattern_or_name):
|
||||
if isinstance(pattern_or_name, OpTypePattern):
|
||||
return pattern_or_name
|
||||
|
||||
if isinstance(pattern_or_name, str):
|
||||
return self._name_to_pattern[pattern_or_name]
|
||||
|
||||
raise ValueError('pattern_or_name has type %s. Expect OpTypePattern or str.'
|
||||
% type(pattern_or_name))
|
||||
|
||||
def get_op(self, pattern_or_name):
|
||||
return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][0]
|
||||
|
||||
def get_tensor(self, pattern_or_name):
|
||||
return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][1]
|
||||
|
||||
|
||||
class GraphMatcher(object):
|
||||
"""Checks if a particular subgraph matches a given pattern."""
|
||||
|
||||
def __init__(self, pattern):
|
||||
"""Initializes a GraphMatcher.
|
||||
|
||||
Args:
|
||||
pattern: The `OpTypePattern` against which `GraphMatcher` matches
|
||||
subgraphs.
|
||||
"""
|
||||
self._pattern = pattern
|
||||
|
||||
def _match_pattern(self, pattern, op, tensor):
|
||||
"""Returns whether an TF expression rooted at `op` matches `pattern`.
|
||||
|
||||
If there is a match, adds to `self._match_result` the matching op and tensor
|
||||
with key `pattern`.
|
||||
|
||||
Args:
|
||||
pattern: An `OpTypePattern`.
|
||||
op: A `tf.Operation` to match against the pattern.
|
||||
tensor: the output `tf.Tensor` of `op` that is used by the matching op of
|
||||
`pattern`'s parent. Can be None if `pattern` is already the root of the
|
||||
pattern tree.
|
||||
|
||||
Returns:
|
||||
True if an TF expression rooted at `op` matches `pattern`.
|
||||
"""
|
||||
if pattern.op_type != '*':
|
||||
if op.type not in pattern.op_type.split('|'):
|
||||
return False
|
||||
|
||||
self._match_result.add(pattern, op, tensor)
|
||||
|
||||
if not pattern.inputs:
|
||||
# If pattern.inputs is empty, skips the rest and accepts all the inputs.
|
||||
return True
|
||||
|
||||
return len(op.inputs) == len(pattern.inputs) and all([
|
||||
self._match_pattern(input_pattern, input_tensor.op, input_tensor)
|
||||
for input_tensor, input_pattern in zip(op.inputs, pattern.inputs)
|
||||
])
|
||||
|
||||
def match_op(self, op):
|
||||
"""Matches `op` against `self._pattern`.
|
||||
|
||||
Args:
|
||||
op: `tf.Operation` to match against the pattern.
|
||||
|
||||
Returns:
|
||||
Returns a `MatchResult` if `op` matches the pattern; otherwise, returns
|
||||
None.
|
||||
"""
|
||||
self._match_result = MatchResult()
|
||||
if not self._match_pattern(self._pattern, op, tensor=None):
|
||||
return None
|
||||
return self._match_result
|
||||
|
||||
def match_ops(self, ops):
|
||||
"""Matches each operation in `ops` against `self._pattern`.
|
||||
|
||||
Args:
|
||||
ops: collection of `tf.Operation` to match against the pattern.
|
||||
|
||||
Yields:
|
||||
`MatchResult` for each `tf.Operation` that matches the pattern.
|
||||
"""
|
||||
for op in ops:
|
||||
match_result = self.match_op(op)
|
||||
if match_result:
|
||||
yield match_result
|
||||
|
||||
def match_graph(self, graph):
|
||||
"""Matches each operation in `graph` against `self._pattern`.
|
||||
|
||||
Args:
|
||||
graph: `tf.Graph` containing operations to match.
|
||||
|
||||
Yields:
|
||||
`MatchResult` for each `tf.Operation` in `graph` that matches the pattern.
|
||||
"""
|
||||
# Python 3.3.2+ implements `yield from`, but for now:
|
||||
for match_result in self.match_ops(graph.get_operations()):
|
||||
yield match_result
|
130
tensorflow/contrib/quantize/python/graph_matcher_test.py
Normal file
130
tensorflow/contrib/quantize/python/graph_matcher_test.py
Normal file
@ -0,0 +1,130 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for graph_matcher."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.framework.python import ops as contrib_ops
|
||||
from tensorflow.contrib.layers.python.layers import initializers
|
||||
from tensorflow.contrib.layers.python.layers import layers
|
||||
from tensorflow.contrib.quantize.python import graph_matcher
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class GraphMatcherTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def test_conv_layer(self):
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
inputs = array_ops.placeholder(dtypes.float32, shape=[8, 5, 5, 3])
|
||||
|
||||
with contrib_ops.arg_scope(
|
||||
[layers.batch_norm], fused=True, is_training=True, trainable=True):
|
||||
return layers.convolution(
|
||||
inputs,
|
||||
num_outputs=16,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding='VALID',
|
||||
activation_fn=nn_ops.relu,
|
||||
normalizer_fn=layers.batch_norm,
|
||||
normalizer_params={},
|
||||
weights_initializer=initializers.xavier_initializer(),
|
||||
weights_regularizer=None,
|
||||
biases_initializer=init_ops.zeros_initializer(),
|
||||
biases_regularizer=None,
|
||||
reuse=None,
|
||||
trainable=True,
|
||||
scope=None)
|
||||
|
||||
inputs_pattern = graph_matcher.OpTypePattern('*', name='inputs')
|
||||
relu_pattern = graph_matcher.OpTypePattern(
|
||||
'Relu',
|
||||
name='relu',
|
||||
inputs=[
|
||||
graph_matcher.OpTypePattern(
|
||||
'FusedBatchNorm',
|
||||
inputs=[
|
||||
graph_matcher.OpTypePattern(
|
||||
'Conv2D', inputs=[inputs_pattern, '*']), '*', '*', '*',
|
||||
'*'
|
||||
])
|
||||
])
|
||||
matcher = graph_matcher.GraphMatcher(relu_pattern)
|
||||
match_results = list(matcher.match_graph(g))
|
||||
self.assertEqual(1, len(match_results))
|
||||
match_result = match_results[0]
|
||||
self.assertEqual(match_result.get_tensor(inputs_pattern), inputs)
|
||||
self.assertEqual(match_result.get_tensor('inputs'), inputs)
|
||||
|
||||
def test_multiple_outputs(self):
|
||||
# - +
|
||||
# / \y0 y1/ \
|
||||
# x split z
|
||||
# |
|
||||
# y (nodes are ops; edges are going up)
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
x = array_ops.placeholder(dtypes.float32, shape=[1], name='x')
|
||||
y = array_ops.placeholder(dtypes.float32, shape=[2], name='y')
|
||||
y0, y1 = array_ops.split(y, num_or_size_splits=2, axis=0)
|
||||
z = array_ops.placeholder(dtypes.float32, shape=[1], name='z')
|
||||
math_ops.add(x, y0)
|
||||
math_ops.subtract(y1, z)
|
||||
|
||||
y1_pattern = graph_matcher.OpTypePattern('*')
|
||||
minus_pattern = graph_matcher.OpTypePattern('Sub', inputs=[y1_pattern, '*'])
|
||||
matcher = graph_matcher.GraphMatcher(minus_pattern)
|
||||
|
||||
match_results = list(matcher.match_graph(g))
|
||||
self.assertEqual(1, len(match_results))
|
||||
match_result = match_results[0]
|
||||
|
||||
self.assertEqual(y0.op, y1.op)
|
||||
self.assertEqual(match_result.get_op(y1_pattern), y1.op)
|
||||
self.assertEqual(match_result.get_tensor(y1_pattern), y1)
|
||||
|
||||
def test_oneof_pattern(self):
|
||||
# - +
|
||||
# / \ / \
|
||||
# x y z
|
||||
g = ops.Graph()
|
||||
with g.as_default():
|
||||
x = array_ops.placeholder(dtypes.float32, shape=[], name='x')
|
||||
y = array_ops.placeholder(dtypes.float32, shape=[], name='y')
|
||||
z = array_ops.placeholder(dtypes.float32, shape=[], name='z')
|
||||
plus = x + y
|
||||
minus = y - z
|
||||
|
||||
add_or_sub_pattern = graph_matcher.OpTypePattern(
|
||||
'Add|Sub', inputs=['*', '*'])
|
||||
matcher = graph_matcher.GraphMatcher(add_or_sub_pattern)
|
||||
self.assertEqual([
|
||||
match_result.get_op(add_or_sub_pattern)
|
||||
for match_result in matcher.match_graph(g)
|
||||
], [plus.op, minus.op])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.layers.python.layers import layers
|
||||
from tensorflow.contrib.quantize.python import fold_batch_norms
|
||||
from tensorflow.contrib.quantize.python import quantize
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -35,18 +36,11 @@ conv2d = layers.conv2d
|
||||
fully_connected = layers.fully_connected
|
||||
separable_conv2d = layers.separable_conv2d
|
||||
|
||||
_DEFAULT_BATCH_NORM_PARAMS = {
|
||||
'center': True,
|
||||
'scale': True,
|
||||
'decay': 1.0 - 0.003,
|
||||
'fused': False,
|
||||
}
|
||||
|
||||
|
||||
# TODO(suharshs): Use parameterized test once OSS TF supports it.
|
||||
class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def _RunTestOverParameters(self, test_fn):
|
||||
def _RunWithoutBatchNormTestOverParameters(self, test_fn):
|
||||
# TODO(suharshs): Use parameterized test once OSS TF supports it.
|
||||
parameters_list = [
|
||||
# (activation, activation_op_name, with_bypass, delay)
|
||||
(nn_ops.relu6, 'Relu6', False, None),
|
||||
@ -60,10 +54,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
(array_ops.identity, 'Identity', True, None),
|
||||
(nn_ops.relu6, 'Relu6', True, 5000),
|
||||
(nn_ops.relu, 'Relu', True, 5000),
|
||||
(array_ops.identity, 'Identity', True, 5000)
|
||||
(array_ops.identity, 'Identity', True, 5000),
|
||||
]
|
||||
for parameters in parameters_list:
|
||||
test_fn(parameters[0], parameters[1], parameters[2], parameters[3])
|
||||
for params in parameters_list:
|
||||
test_fn(params[0], params[1], params[2], params[3])
|
||||
|
||||
def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name,
|
||||
with_bypass, delay):
|
||||
@ -137,7 +131,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
|
||||
|
||||
def testQuantize_Conv2dWithoutBatchNorm(self):
|
||||
self._RunTestOverParameters(self._TestQuantize_Conv2dWithoutBatchNorm)
|
||||
self._RunWithoutBatchNormTestOverParameters(
|
||||
self._TestQuantize_Conv2dWithoutBatchNorm)
|
||||
|
||||
def _TestQuantize_FCWithoutBatchNorm(self, activation, activation_op_name,
|
||||
with_bypass, delay):
|
||||
@ -210,7 +205,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
|
||||
|
||||
def testQuantize_FCWithoutBatchNorm(self):
|
||||
self._RunTestOverParameters(self._TestQuantize_FCWithoutBatchNorm)
|
||||
self._RunWithoutBatchNormTestOverParameters(
|
||||
self._TestQuantize_FCWithoutBatchNorm)
|
||||
|
||||
def _TestQuantize_DepthwiseConv2dWithoutBatchNorm(
|
||||
self, activation, activation_op_name, with_bypass, delay):
|
||||
@ -284,11 +280,43 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
|
||||
|
||||
def testQuantize_DepthwiseConv2dWithoutBatchNorm(self):
|
||||
self._RunTestOverParameters(
|
||||
self._RunWithoutBatchNormTestOverParameters(
|
||||
self._TestQuantize_DepthwiseConv2dWithoutBatchNorm)
|
||||
|
||||
def _RunBatchNormTestOverParameters(self, test_fn):
|
||||
# TODO(suharshs): Use parameterized test once OSS TF supports it.
|
||||
parameters_list = [
|
||||
# (activation, activation_op_name, with_bypass, delay, fused_batch_norm)
|
||||
(nn_ops.relu6, 'Relu6', False, None, False),
|
||||
(nn_ops.relu, 'Relu', False, None, False),
|
||||
(array_ops.identity, 'Identity', False, None, False),
|
||||
(nn_ops.relu6, 'Relu6', False, 5000, False),
|
||||
(nn_ops.relu, 'Relu', False, 5000, False),
|
||||
(array_ops.identity, 'Identity', False, 5000, False),
|
||||
(nn_ops.relu6, 'Relu6', True, None, False),
|
||||
(nn_ops.relu, 'Relu', True, None, False),
|
||||
(array_ops.identity, 'Identity', True, None, False),
|
||||
(nn_ops.relu6, 'Relu6', True, 5000, False),
|
||||
(nn_ops.relu, 'Relu', True, 5000, False),
|
||||
(array_ops.identity, 'Identity', True, 5000, False),
|
||||
(nn_ops.relu6, 'Relu6', False, None, True),
|
||||
(nn_ops.relu, 'Relu', False, None, True),
|
||||
(array_ops.identity, 'Identity', False, None, True),
|
||||
(nn_ops.relu6, 'Relu6', False, 5000, True),
|
||||
(nn_ops.relu, 'Relu', False, 5000, True),
|
||||
(array_ops.identity, 'Identity', False, 5000, True),
|
||||
(nn_ops.relu6, 'Relu6', True, None, True),
|
||||
(nn_ops.relu, 'Relu', True, None, True),
|
||||
(array_ops.identity, 'Identity', True, None, True),
|
||||
(nn_ops.relu6, 'Relu6', True, 5000, True),
|
||||
(nn_ops.relu, 'Relu', True, 5000, True),
|
||||
(array_ops.identity, 'Identity', True, 5000, True)
|
||||
]
|
||||
for params in parameters_list:
|
||||
test_fn(params[0], params[1], params[2], params[3], params[4])
|
||||
|
||||
def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name,
|
||||
with_bypass, delay):
|
||||
with_bypass, delay, fused_batch_norm):
|
||||
"""Tests quantization: inputs -> Conv2d with batch norm -> Activation.
|
||||
|
||||
Args:
|
||||
@ -298,25 +326,29 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
with_bypass: Bool, when true there is an extra connection added from
|
||||
inputs to just before Activation.
|
||||
delay: Int (optional), delay in number of steps until quantization starts.
|
||||
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
||||
"""
|
||||
self._testQuantize_Conv2dWithBatchNorm(
|
||||
activation,
|
||||
activation_op_name,
|
||||
with_bypass,
|
||||
delay,
|
||||
fused_batch_norm,
|
||||
use_ema=True)
|
||||
self._testQuantize_Conv2dWithBatchNorm(
|
||||
activation,
|
||||
activation_op_name,
|
||||
with_bypass,
|
||||
delay,
|
||||
fused_batch_norm,
|
||||
use_ema=False)
|
||||
|
||||
def testQuantize_Conv2dWithBatchNorm(self):
|
||||
self._RunTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm)
|
||||
self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm)
|
||||
|
||||
def _testQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name,
|
||||
with_bypass, delay, use_ema):
|
||||
with_bypass, delay, fused_batch_norm,
|
||||
use_ema):
|
||||
"""Tests quantization: inputs -> Conv2d with batch norm -> Activation.
|
||||
|
||||
Args:
|
||||
@ -326,6 +358,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
with_bypass: Bool, when true there is an extra connection added from
|
||||
inputs to just before Activation.
|
||||
delay: Int (optional), delay in number of steps until quantization starts.
|
||||
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
||||
use_ema: Bool, when true uses EMA quantization for BN folded weights.
|
||||
"""
|
||||
graph = ops.Graph()
|
||||
@ -337,39 +370,29 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
stride = 1 if with_bypass else 2
|
||||
out_depth = 3 if with_bypass else 32
|
||||
scope = 'test/test2' if with_bypass else 'test'
|
||||
node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME',
|
||||
weights_initializer=self._WeightInit(0.09),
|
||||
activation_fn=None,
|
||||
normalizer_fn=batch_norm,
|
||||
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS,
|
||||
scope=scope)
|
||||
# Manually fold the batch norm.
|
||||
weights = graph.get_operation_by_name(scope + '/weights/read').outputs[0]
|
||||
bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul')
|
||||
.outputs[0])
|
||||
mul_fold = math_ops.multiply(weights, bn_mult, name=scope + '/mul_fold')
|
||||
stride = [stride, stride]
|
||||
conv_fold = nn_ops.convolution(
|
||||
input=inputs,
|
||||
filter=mul_fold,
|
||||
node = conv2d(
|
||||
inputs,
|
||||
out_depth, [5, 5],
|
||||
stride=stride,
|
||||
padding='SAME',
|
||||
strides=stride,
|
||||
data_format='NHWC',
|
||||
name=scope + '/convolution_Fold')
|
||||
bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub')
|
||||
.outputs[0])
|
||||
add_fold = math_ops.add(conv_fold, bn_bias, name=scope + '/add_fold')
|
||||
weights_initializer=self._WeightInit(0.09),
|
||||
activation_fn=None,
|
||||
normalizer_fn=batch_norm,
|
||||
normalizer_params=self._BatchNormParams(fused_batch_norm),
|
||||
scope=scope)
|
||||
|
||||
# Manually add a bypass (optionaly) and an activation.
|
||||
if with_bypass:
|
||||
node = math_ops.add(inputs, add_fold, name='test/Add')
|
||||
else:
|
||||
node = add_fold
|
||||
node = math_ops.add(inputs, node, name='test/Add')
|
||||
|
||||
node = activation(node, name='test/' + activation_op_name)
|
||||
|
||||
update_barrier = control_flow_ops.no_op(name='update_barrier')
|
||||
with ops.control_dependencies([update_barrier]):
|
||||
array_ops.identity(node, name='control_dependency')
|
||||
|
||||
fold_batch_norms.FoldBatchNorms(graph)
|
||||
|
||||
quantize.Quantize(
|
||||
graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema)
|
||||
|
||||
@ -413,7 +436,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
|
||||
|
||||
def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name,
|
||||
with_bypass, delay):
|
||||
with_bypass, delay, fused_batch_norm):
|
||||
"""Tests quantization: inputs -> FC with batch norm -> Activation.
|
||||
|
||||
Args:
|
||||
@ -423,25 +446,29 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
with_bypass: Bool, when true there is an extra connection added from
|
||||
inputs to just before Activation.
|
||||
delay: Int (optional), delay in number of steps until quantization starts.
|
||||
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
||||
"""
|
||||
self._testQuantize_FCWithBatchNorm(
|
||||
activation,
|
||||
activation_op_name,
|
||||
with_bypass,
|
||||
delay,
|
||||
fused_batch_norm,
|
||||
use_ema=True)
|
||||
self._testQuantize_FCWithBatchNorm(
|
||||
activation,
|
||||
activation_op_name,
|
||||
with_bypass,
|
||||
delay,
|
||||
fused_batch_norm,
|
||||
use_ema=False)
|
||||
|
||||
def testQuantize_FCWithBatchNorm(self):
|
||||
self._RunTestOverParameters(self._TestQuantize_FCWithBatchNorm)
|
||||
self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm)
|
||||
|
||||
def _testQuantize_FCWithBatchNorm(self, activation, activation_op_name,
|
||||
with_bypass, delay, use_ema):
|
||||
with_bypass, delay, fused_batch_norm,
|
||||
use_ema):
|
||||
"""Tests quantization: inputs -> FC with batch norm -> Activation.
|
||||
|
||||
Args:
|
||||
@ -451,6 +478,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
with_bypass: Bool, when true there is an extra connection added from
|
||||
inputs to just before Activation.
|
||||
delay: Int (optional), delay in number of steps until quantization starts.
|
||||
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
||||
use_ema: Bool, when true uses EMA quantization for BN folded weights.
|
||||
"""
|
||||
graph = ops.Graph()
|
||||
@ -461,32 +489,27 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
inputs = array_ops.zeros((batch_size, depth))
|
||||
out_depth = 256 if with_bypass else 128
|
||||
scope = 'test/test2' if with_bypass else 'test'
|
||||
node = fully_connected(inputs, out_depth,
|
||||
weights_initializer=self._WeightInit(0.03),
|
||||
activation_fn=None,
|
||||
normalizer_fn=batch_norm,
|
||||
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS,
|
||||
scope=scope)
|
||||
# Manually fold the batch norm.
|
||||
weights = graph.get_operation_by_name(scope + '/weights/read').outputs[0]
|
||||
bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul')
|
||||
.outputs[0])
|
||||
mul_fold = math_ops.multiply(weights, bn_mult, name=scope + '/mul_fold')
|
||||
fc_fold = math_ops.matmul(inputs, mul_fold, name=scope + '/MatMul_Fold')
|
||||
bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub')
|
||||
.outputs[0])
|
||||
add_fold = math_ops.add(fc_fold, bn_bias, name=scope + '/add_fold')
|
||||
node = fully_connected(
|
||||
inputs,
|
||||
out_depth,
|
||||
weights_initializer=self._WeightInit(0.03),
|
||||
activation_fn=None,
|
||||
normalizer_fn=batch_norm,
|
||||
normalizer_params=self._BatchNormParams(fused_batch_norm),
|
||||
scope=scope)
|
||||
|
||||
# Manually add a bypass (optionaly) and an activation.
|
||||
if with_bypass:
|
||||
node = math_ops.add(inputs, add_fold, name='test/Add')
|
||||
else:
|
||||
node = add_fold
|
||||
node = math_ops.add(inputs, node, name='test/Add')
|
||||
|
||||
node = activation(node, name='test/' + activation_op_name)
|
||||
|
||||
update_barrier = control_flow_ops.no_op(name='update_barrier')
|
||||
with ops.control_dependencies([update_barrier]):
|
||||
array_ops.identity(node, name='control_dependency')
|
||||
|
||||
fold_batch_norms.FoldBatchNorms(graph)
|
||||
|
||||
quantize.Quantize(
|
||||
graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema)
|
||||
|
||||
@ -530,7 +553,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
|
||||
|
||||
def _TestQuantize_DepthwiseConv2dWithBatchNorm(
|
||||
self, activation, activation_op_name, with_bypass, delay):
|
||||
self, activation, activation_op_name, with_bypass, delay,
|
||||
fused_batch_norm):
|
||||
"""Tests quantization: inputs -> DWConv2d with batch norm -> Activation.
|
||||
|
||||
Args:
|
||||
@ -540,26 +564,30 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
with_bypass: Bool, when true there is an extra connection added from
|
||||
inputs to just before Activation.
|
||||
delay: Int (optional), delay in number of steps until quantization starts.
|
||||
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
||||
"""
|
||||
self._testQuantize_DepthwiseConv2dWithBatchNorm(
|
||||
activation,
|
||||
activation_op_name,
|
||||
with_bypass,
|
||||
delay,
|
||||
fused_batch_norm,
|
||||
use_ema=True)
|
||||
self._testQuantize_DepthwiseConv2dWithBatchNorm(
|
||||
activation,
|
||||
activation_op_name,
|
||||
with_bypass,
|
||||
delay,
|
||||
fused_batch_norm,
|
||||
use_ema=False)
|
||||
|
||||
def testQuantize_DepthwiseConv2dWithBatchNorm(self):
|
||||
self._RunTestOverParameters(
|
||||
self._TestQuantize_DepthwiseConv2dWithoutBatchNorm)
|
||||
self._RunBatchNormTestOverParameters(
|
||||
self._TestQuantize_DepthwiseConv2dWithBatchNorm)
|
||||
|
||||
def _testQuantize_DepthwiseConv2dWithBatchNorm(
|
||||
self, activation, activation_op_name, with_bypass, delay, use_ema):
|
||||
self, activation, activation_op_name, with_bypass, delay,
|
||||
fused_batch_norm, use_ema):
|
||||
"""Tests quantization: inputs -> DWConv2d with batch norm -> Activation.
|
||||
|
||||
Args:
|
||||
@ -569,6 +597,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
with_bypass: Bool, when true there is an extra connection added from
|
||||
inputs to just before Activation.
|
||||
delay: Int (optional), delay in number of steps until quantization starts.
|
||||
fused_batch_norm: Bool, when true use FusedBatchNorm.
|
||||
use_ema: Bool, when true uses EMA quantization for BN folded weights.
|
||||
"""
|
||||
graph = ops.Graph()
|
||||
@ -579,46 +608,30 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
inputs = array_ops.zeros((batch_size, height, width, depth))
|
||||
stride = 1 if with_bypass else 2
|
||||
scope = 'test/test2' if with_bypass else 'test'
|
||||
node = separable_conv2d(inputs, None, [5, 5], stride=stride,
|
||||
depth_multiplier=1.0, padding='SAME',
|
||||
weights_initializer=self._WeightInit(0.09),
|
||||
activation_fn=None,
|
||||
normalizer_fn=batch_norm,
|
||||
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS,
|
||||
scope=scope)
|
||||
# Manually fold the batch norm.
|
||||
weights = (graph.get_operation_by_name(scope + '/depthwise_weights/read')
|
||||
.outputs[0])
|
||||
bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul')
|
||||
.outputs[0])
|
||||
new_shape = [
|
||||
weights.get_shape().as_list()[2], weights.get_shape().as_list()[3]
|
||||
]
|
||||
bn_mult_reshaped = array_ops.reshape(
|
||||
bn_mult, new_shape, name=scope + '/gamma_reshape')
|
||||
mul_fold = math_ops.multiply(
|
||||
weights, bn_mult_reshaped, name=scope + '/mul_fold')
|
||||
stride = [1, stride, stride, 1]
|
||||
conv_fold = nn_ops.depthwise_conv2d(
|
||||
input=inputs,
|
||||
filter=mul_fold,
|
||||
node = separable_conv2d(
|
||||
inputs,
|
||||
None, [5, 5],
|
||||
stride=stride,
|
||||
depth_multiplier=1.0,
|
||||
padding='SAME',
|
||||
strides=stride,
|
||||
name=scope + '/depthwise_Fold')
|
||||
bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub')
|
||||
.outputs[0])
|
||||
add_fold = math_ops.add(conv_fold, bn_bias, name=scope + '/add_fold')
|
||||
weights_initializer=self._WeightInit(0.09),
|
||||
activation_fn=None,
|
||||
normalizer_fn=batch_norm,
|
||||
normalizer_params=self._BatchNormParams(fused_batch_norm),
|
||||
scope=scope)
|
||||
|
||||
# Manually add a bypass (optionaly) and an activation.
|
||||
if with_bypass:
|
||||
node = math_ops.add(inputs, add_fold, name='test/Add')
|
||||
else:
|
||||
node = add_fold
|
||||
node = math_ops.add(inputs, node, name='test/Add')
|
||||
|
||||
node = activation(node, name='test/' + activation_op_name)
|
||||
|
||||
update_barrier = control_flow_ops.no_op(name='update_barrier')
|
||||
with ops.control_dependencies([update_barrier]):
|
||||
array_ops.identity(node, name='control_dependency')
|
||||
|
||||
fold_batch_norms.FoldBatchNorms(graph)
|
||||
|
||||
quantize.Quantize(
|
||||
graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema)
|
||||
quantization_node_name = 'FakeQuantWithMinMaxVars'
|
||||
@ -660,6 +673,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
|
||||
if delay else 'control_dependency')
|
||||
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
|
||||
|
||||
def _BatchNormParams(self, fused=False):
|
||||
return {'center': True, 'scale': True, 'decay': 1.0 - 0.003, 'fused': fused}
|
||||
|
||||
def _WeightInit(self, stddev):
|
||||
"""Returns truncated normal variable initializer.
|
||||
|
||||
|
@ -156,6 +156,7 @@ cuda_py_tests(
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:framework_for_generated_wrappers",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:gradients",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
@ -165,6 +166,7 @@ cuda_py_tests(
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/eager:context",
|
||||
],
|
||||
shard_count = 10,
|
||||
)
|
||||
|
@ -25,10 +25,12 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
|
||||
from tensorflow.contrib import rnn as rnn_lib
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops as ops_lib
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
@ -881,6 +883,7 @@ class LSTMTest(test.TestCase):
|
||||
# Smoke test, this should not raise an error
|
||||
rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testDynamicRNNWithTupleStates(self):
|
||||
num_units = 3
|
||||
input_size = 5
|
||||
@ -888,13 +891,20 @@ class LSTMTest(test.TestCase):
|
||||
num_proj = 4
|
||||
max_length = 8
|
||||
sequence_length = [4, 6]
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
with self.test_session(graph=ops_lib.Graph()) as sess:
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=self._seed)
|
||||
inputs = max_length * [
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, shape=(None, input_size))
|
||||
]
|
||||
if in_graph_mode:
|
||||
inputs = max_length * [
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, shape=(None, input_size))
|
||||
]
|
||||
else:
|
||||
inputs = max_length * [
|
||||
constant_op.constant(
|
||||
np.random.randn(batch_size, input_size).astype(np.float32))
|
||||
]
|
||||
inputs_c = array_ops.stack(inputs)
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
@ -924,21 +934,34 @@ class LSTMTest(test.TestCase):
|
||||
self.assertEqual(state_dynamic[0], state_dynamic.c)
|
||||
self.assertEqual(state_dynamic[1], state_dynamic.h)
|
||||
|
||||
variables_lib.global_variables_initializer().run()
|
||||
if in_graph_mode:
|
||||
variables_lib.global_variables_initializer().run()
|
||||
input_value = np.random.randn(batch_size, input_size)
|
||||
outputs_static = sess.run(
|
||||
outputs_static, feed_dict={
|
||||
inputs[0]: input_value
|
||||
})
|
||||
outputs_dynamic = sess.run(
|
||||
outputs_dynamic, feed_dict={
|
||||
inputs[0]: input_value
|
||||
})
|
||||
state_static = sess.run(
|
||||
state_static, feed_dict={
|
||||
inputs[0]: input_value
|
||||
})
|
||||
state_dynamic = sess.run(
|
||||
state_dynamic, feed_dict={
|
||||
inputs[0]: input_value
|
||||
})
|
||||
|
||||
input_value = np.random.randn(batch_size, input_size)
|
||||
outputs_static_v = sess.run(outputs_static,
|
||||
feed_dict={inputs[0]: input_value})
|
||||
outputs_dynamic_v = sess.run(outputs_dynamic,
|
||||
feed_dict={inputs[0]: input_value})
|
||||
self.assertAllEqual(outputs_static_v, outputs_dynamic_v)
|
||||
|
||||
state_static_v = sess.run(state_static,
|
||||
feed_dict={inputs[0]: input_value})
|
||||
state_dynamic_v = sess.run(state_dynamic,
|
||||
feed_dict={inputs[0]: input_value})
|
||||
self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v))
|
||||
if in_graph_mode:
|
||||
self.assertAllEqual(outputs_static, outputs_dynamic)
|
||||
else:
|
||||
self.assertAllEqual(
|
||||
array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy())
|
||||
self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic))
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testDynamicRNNWithNestedTupleStates(self):
|
||||
num_units = 3
|
||||
input_size = 5
|
||||
@ -946,13 +969,20 @@ class LSTMTest(test.TestCase):
|
||||
num_proj = 4
|
||||
max_length = 8
|
||||
sequence_length = [4, 6]
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
with self.test_session(graph=ops_lib.Graph()) as sess:
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=self._seed)
|
||||
inputs = max_length * [
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, shape=(None, input_size))
|
||||
]
|
||||
if in_graph_mode:
|
||||
inputs = max_length * [
|
||||
array_ops.placeholder(
|
||||
dtypes.float32, shape=(None, input_size))
|
||||
]
|
||||
else:
|
||||
inputs = max_length * [
|
||||
constant_op.constant(
|
||||
np.random.randn(batch_size, input_size).astype(np.float32))
|
||||
]
|
||||
inputs_c = array_ops.stack(inputs)
|
||||
|
||||
def _cell(i):
|
||||
@ -993,20 +1023,34 @@ class LSTMTest(test.TestCase):
|
||||
sequence_length=sequence_length,
|
||||
scope=scope)
|
||||
|
||||
variables_lib.global_variables_initializer().run()
|
||||
if in_graph_mode:
|
||||
input_value = np.random.randn(batch_size, input_size)
|
||||
variables_lib.global_variables_initializer().run()
|
||||
outputs_static = sess.run(
|
||||
outputs_static, feed_dict={
|
||||
inputs[0]: input_value
|
||||
})
|
||||
outputs_dynamic = sess.run(
|
||||
outputs_dynamic, feed_dict={
|
||||
inputs[0]: input_value
|
||||
})
|
||||
state_static = sess.run(
|
||||
nest.flatten(state_static), feed_dict={
|
||||
inputs[0]: input_value
|
||||
})
|
||||
state_dynamic = sess.run(
|
||||
nest.flatten(state_dynamic), feed_dict={
|
||||
inputs[0]: input_value
|
||||
})
|
||||
|
||||
input_value = np.random.randn(batch_size, input_size)
|
||||
outputs_static_v = sess.run(outputs_static,
|
||||
feed_dict={inputs[0]: input_value})
|
||||
outputs_dynamic_v = sess.run(outputs_dynamic,
|
||||
feed_dict={inputs[0]: input_value})
|
||||
self.assertAllEqual(outputs_static_v, outputs_dynamic_v)
|
||||
|
||||
state_static_v = sess.run(nest.flatten(state_static),
|
||||
feed_dict={inputs[0]: input_value})
|
||||
state_dynamic_v = sess.run(nest.flatten(state_dynamic),
|
||||
feed_dict={inputs[0]: input_value})
|
||||
self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v))
|
||||
if in_graph_mode:
|
||||
self.assertAllEqual(outputs_static, outputs_dynamic)
|
||||
else:
|
||||
self.assertAllEqual(
|
||||
array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy())
|
||||
state_static = [s.numpy() for s in nest.flatten(state_static)]
|
||||
state_dynamic = [s.numpy() for s in nest.flatten(state_dynamic)]
|
||||
self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic))
|
||||
|
||||
def _testDynamicEquivalentToStaticRNN(self, use_gpu, use_sequence_length):
|
||||
time_steps = 8
|
||||
@ -1015,21 +1059,22 @@ class LSTMTest(test.TestCase):
|
||||
input_size = 5
|
||||
batch_size = 2
|
||||
|
||||
input_values = np.random.randn(time_steps, batch_size, input_size)
|
||||
input_values = np.random.randn(time_steps, batch_size, input_size).astype(
|
||||
np.float32)
|
||||
|
||||
if use_sequence_length:
|
||||
sequence_length = np.random.randint(0, time_steps, size=batch_size)
|
||||
else:
|
||||
sequence_length = None
|
||||
|
||||
########### Step 1: Run static graph and generate readouts
|
||||
with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess:
|
||||
concat_inputs = array_ops.placeholder(
|
||||
dtypes.float32, shape=(time_steps, batch_size, input_size))
|
||||
inputs = array_ops.unstack(concat_inputs)
|
||||
in_graph_mode = context.in_graph_mode()
|
||||
|
||||
# TODO(b/68017812): Eager ignores operation seeds, so we need to create a
|
||||
# single cell and reuse it across the static and dynamic RNNs. Remove this
|
||||
# special case once is fixed.
|
||||
if not in_graph_mode:
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=self._seed)
|
||||
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=True,
|
||||
@ -1037,63 +1082,85 @@ class LSTMTest(test.TestCase):
|
||||
num_proj=num_proj,
|
||||
state_is_tuple=False)
|
||||
|
||||
########### Step 1: Run static graph and generate readouts
|
||||
with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess:
|
||||
if in_graph_mode:
|
||||
concat_inputs = array_ops.placeholder(
|
||||
dtypes.float32, shape=(time_steps, batch_size, input_size))
|
||||
else:
|
||||
concat_inputs = constant_op.constant(input_values)
|
||||
inputs = array_ops.unstack(concat_inputs)
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=self._seed)
|
||||
|
||||
# TODO(akshayka): Remove special case once b/68017812 is fixed.
|
||||
if in_graph_mode:
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=True,
|
||||
initializer=initializer,
|
||||
num_proj=num_proj,
|
||||
state_is_tuple=False)
|
||||
|
||||
with variable_scope.variable_scope("dynamic_scope"):
|
||||
outputs_static, state_static = rnn.static_rnn(
|
||||
cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32)
|
||||
|
||||
feeds = {concat_inputs: input_values}
|
||||
if in_graph_mode:
|
||||
# Generate gradients and run sessions to obtain outputs
|
||||
feeds = {concat_inputs: input_values}
|
||||
# Initialize
|
||||
variables_lib.global_variables_initializer().run(feed_dict=feeds)
|
||||
# Generate gradients of sum of outputs w.r.t. inputs
|
||||
static_gradients = gradients_impl.gradients(
|
||||
outputs_static + [state_static], [concat_inputs])
|
||||
# Generate gradients of individual outputs w.r.t. inputs
|
||||
static_individual_gradients = nest.flatten([
|
||||
gradients_impl.gradients(y, [concat_inputs])
|
||||
for y in [outputs_static[0], outputs_static[-1], state_static]
|
||||
])
|
||||
# Generate gradients of individual variables w.r.t. inputs
|
||||
trainable_variables = ops_lib.get_collection(
|
||||
ops_lib.GraphKeys.TRAINABLE_VARIABLES)
|
||||
assert len(trainable_variables) > 1, (
|
||||
"Count of trainable variables: %d" % len(trainable_variables))
|
||||
# pylint: disable=bad-builtin
|
||||
static_individual_variable_gradients = nest.flatten([
|
||||
gradients_impl.gradients(y, trainable_variables)
|
||||
for y in [outputs_static[0], outputs_static[-1], state_static]
|
||||
])
|
||||
# Test forward pass
|
||||
values_static = sess.run(outputs_static, feed_dict=feeds)
|
||||
(state_value_static,) = sess.run((state_static,), feed_dict=feeds)
|
||||
|
||||
# Initialize
|
||||
variables_lib.global_variables_initializer().run(feed_dict=feeds)
|
||||
# Test gradients to inputs and variables w.r.t. outputs & final state
|
||||
static_grad_values = sess.run(static_gradients, feed_dict=feeds)
|
||||
|
||||
# Generate gradients of sum of outputs w.r.t. inputs
|
||||
static_gradients = gradients_impl.gradients(
|
||||
outputs_static + [state_static], [concat_inputs])
|
||||
static_individual_grad_values = sess.run(static_individual_gradients,
|
||||
feed_dict=feeds)
|
||||
|
||||
# Generate gradients of individual outputs w.r.t. inputs
|
||||
static_individual_gradients = nest.flatten([
|
||||
gradients_impl.gradients(y, [concat_inputs])
|
||||
for y in [outputs_static[0], outputs_static[-1], state_static]
|
||||
])
|
||||
|
||||
# Generate gradients of individual variables w.r.t. inputs
|
||||
trainable_variables = ops_lib.get_collection(
|
||||
ops_lib.GraphKeys.TRAINABLE_VARIABLES)
|
||||
assert len(trainable_variables) > 1, ("Count of trainable variables: %d" %
|
||||
len(trainable_variables))
|
||||
# pylint: disable=bad-builtin
|
||||
static_individual_variable_gradients = nest.flatten([
|
||||
gradients_impl.gradients(y, trainable_variables)
|
||||
for y in [outputs_static[0], outputs_static[-1], state_static]
|
||||
])
|
||||
|
||||
# Test forward pass
|
||||
values_static = sess.run(outputs_static, feed_dict=feeds)
|
||||
(state_value_static,) = sess.run((state_static,), feed_dict=feeds)
|
||||
|
||||
# Test gradients to inputs and variables w.r.t. outputs & final state
|
||||
static_grad_values = sess.run(static_gradients, feed_dict=feeds)
|
||||
|
||||
static_individual_grad_values = sess.run(static_individual_gradients,
|
||||
feed_dict=feeds)
|
||||
|
||||
static_individual_var_grad_values = sess.run(
|
||||
static_individual_variable_gradients, feed_dict=feeds)
|
||||
static_individual_var_grad_values = sess.run(
|
||||
static_individual_variable_gradients, feed_dict=feeds)
|
||||
|
||||
########## Step 2: Run dynamic graph and generate readouts
|
||||
with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess:
|
||||
concat_inputs = array_ops.placeholder(
|
||||
dtypes.float32, shape=(time_steps, batch_size, input_size))
|
||||
inputs = array_ops.unstack(concat_inputs)
|
||||
if in_graph_mode:
|
||||
concat_inputs = array_ops.placeholder(
|
||||
dtypes.float32, shape=(time_steps, batch_size, input_size))
|
||||
else:
|
||||
concat_inputs = constant_op.constant(input_values)
|
||||
initializer = init_ops.random_uniform_initializer(
|
||||
-0.01, 0.01, seed=self._seed)
|
||||
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=True,
|
||||
initializer=initializer,
|
||||
num_proj=num_proj,
|
||||
state_is_tuple=False)
|
||||
# TODO(akshayka): Remove this special case once b/68017812 is
|
||||
# fixed.
|
||||
if in_graph_mode:
|
||||
cell = rnn_cell.LSTMCell(
|
||||
num_units,
|
||||
use_peepholes=True,
|
||||
initializer=initializer,
|
||||
num_proj=num_proj,
|
||||
state_is_tuple=False)
|
||||
|
||||
with variable_scope.variable_scope("dynamic_scope"):
|
||||
outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
|
||||
@ -1104,72 +1171,83 @@ class LSTMTest(test.TestCase):
|
||||
dtype=dtypes.float32)
|
||||
split_outputs_dynamic = array_ops.unstack(outputs_dynamic, time_steps)
|
||||
|
||||
feeds = {concat_inputs: input_values}
|
||||
if in_graph_mode:
|
||||
feeds = {concat_inputs: input_values}
|
||||
|
||||
# Initialize
|
||||
variables_lib.global_variables_initializer().run(feed_dict=feeds)
|
||||
# Initialize
|
||||
variables_lib.global_variables_initializer().run(feed_dict=feeds)
|
||||
|
||||
# Generate gradients of sum of outputs w.r.t. inputs
|
||||
dynamic_gradients = gradients_impl.gradients(
|
||||
split_outputs_dynamic + [state_dynamic], [concat_inputs])
|
||||
# Generate gradients of sum of outputs w.r.t. inputs
|
||||
dynamic_gradients = gradients_impl.gradients(
|
||||
split_outputs_dynamic + [state_dynamic], [concat_inputs])
|
||||
|
||||
# Generate gradients of several individual outputs w.r.t. inputs
|
||||
dynamic_individual_gradients = nest.flatten([
|
||||
gradients_impl.gradients(y, [concat_inputs])
|
||||
for y in
|
||||
[split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic]
|
||||
])
|
||||
# Generate gradients of several individual outputs w.r.t. inputs
|
||||
dynamic_individual_gradients = nest.flatten([
|
||||
gradients_impl.gradients(y, [concat_inputs])
|
||||
for y in
|
||||
[split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic]
|
||||
])
|
||||
|
||||
# Generate gradients of individual variables w.r.t. inputs
|
||||
trainable_variables = ops_lib.get_collection(
|
||||
ops_lib.GraphKeys.TRAINABLE_VARIABLES)
|
||||
assert len(trainable_variables) > 1, ("Count of trainable variables: %d" %
|
||||
len(trainable_variables))
|
||||
dynamic_individual_variable_gradients = nest.flatten([
|
||||
gradients_impl.gradients(y, trainable_variables)
|
||||
for y in
|
||||
[split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic]
|
||||
])
|
||||
# Generate gradients of individual variables w.r.t. inputs
|
||||
trainable_variables = ops_lib.get_collection(
|
||||
ops_lib.GraphKeys.TRAINABLE_VARIABLES)
|
||||
assert len(trainable_variables) > 1, (
|
||||
"Count of trainable variables: %d" % len(trainable_variables))
|
||||
dynamic_individual_variable_gradients = nest.flatten([
|
||||
gradients_impl.gradients(y, trainable_variables)
|
||||
for y in
|
||||
[split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic]
|
||||
])
|
||||
|
||||
# Test forward pass
|
||||
values_dynamic = sess.run(split_outputs_dynamic, feed_dict=feeds)
|
||||
(state_value_dynamic,) = sess.run((state_dynamic,), feed_dict=feeds)
|
||||
# Test forward pass
|
||||
values_dynamic = sess.run(split_outputs_dynamic, feed_dict=feeds)
|
||||
(state_value_dynamic,) = sess.run((state_dynamic,), feed_dict=feeds)
|
||||
|
||||
# Test gradients to inputs and variables w.r.t. outputs & final state
|
||||
dynamic_grad_values = sess.run(dynamic_gradients, feed_dict=feeds)
|
||||
# Test gradients to inputs and variables w.r.t. outputs & final state
|
||||
dynamic_grad_values = sess.run(dynamic_gradients, feed_dict=feeds)
|
||||
|
||||
dynamic_individual_grad_values = sess.run(dynamic_individual_gradients,
|
||||
feed_dict=feeds)
|
||||
dynamic_individual_grad_values = sess.run(dynamic_individual_gradients,
|
||||
feed_dict=feeds)
|
||||
|
||||
dynamic_individual_var_grad_values = sess.run(
|
||||
dynamic_individual_variable_gradients, feed_dict=feeds)
|
||||
dynamic_individual_var_grad_values = sess.run(
|
||||
dynamic_individual_variable_gradients, feed_dict=feeds)
|
||||
|
||||
######### Step 3: Comparisons
|
||||
if not in_graph_mode:
|
||||
values_static = outputs_static
|
||||
values_dynamic = split_outputs_dynamic
|
||||
state_value_static = state_static
|
||||
state_value_dynamic = state_dynamic
|
||||
|
||||
self.assertEqual(len(values_static), len(values_dynamic))
|
||||
for (value_static, value_dynamic) in zip(values_static, values_dynamic):
|
||||
self.assertAllEqual(value_static, value_dynamic)
|
||||
self.assertAllEqual(state_value_static, state_value_dynamic)
|
||||
|
||||
self.assertAllEqual(static_grad_values, dynamic_grad_values)
|
||||
if in_graph_mode:
|
||||
|
||||
self.assertEqual(
|
||||
len(static_individual_grad_values), len(dynamic_individual_grad_values))
|
||||
self.assertEqual(
|
||||
len(static_individual_var_grad_values),
|
||||
len(dynamic_individual_var_grad_values))
|
||||
self.assertAllEqual(static_grad_values, dynamic_grad_values)
|
||||
|
||||
for i, (a, b) in enumerate(
|
||||
zip(static_individual_grad_values, dynamic_individual_grad_values)):
|
||||
tf_logging.info("Comparing individual gradients iteration %d" % i)
|
||||
self.assertAllEqual(a, b)
|
||||
self.assertEqual(
|
||||
len(static_individual_grad_values),
|
||||
len(dynamic_individual_grad_values))
|
||||
self.assertEqual(
|
||||
len(static_individual_var_grad_values),
|
||||
len(dynamic_individual_var_grad_values))
|
||||
|
||||
for i, (a, b) in enumerate(
|
||||
zip(static_individual_var_grad_values,
|
||||
dynamic_individual_var_grad_values)):
|
||||
tf_logging.info("Comparing individual variable gradients iteration %d" %
|
||||
i)
|
||||
self.assertAllEqual(a, b)
|
||||
for i, (a, b) in enumerate(
|
||||
zip(static_individual_grad_values, dynamic_individual_grad_values)):
|
||||
tf_logging.info("Comparing individual gradients iteration %d" % i)
|
||||
self.assertAllEqual(a, b)
|
||||
|
||||
for i, (a, b) in enumerate(
|
||||
zip(static_individual_var_grad_values,
|
||||
dynamic_individual_var_grad_values)):
|
||||
tf_logging.info("Comparing individual variable gradients iteration %d" %
|
||||
i)
|
||||
self.assertAllEqual(a, b)
|
||||
|
||||
@test_util.run_in_graph_and_eager_modes()
|
||||
def testDynamicEquivalentToStaticRNN(self):
|
||||
self._testDynamicEquivalentToStaticRNN(
|
||||
use_gpu=False, use_sequence_length=False)
|
||||
|
@ -112,7 +112,7 @@ struct GatherTree<CPUDevice, int32> {
|
||||
const int32 max_time = parent_ids.dimension(0);
|
||||
const int32 batch_size = parent_ids.dimension(1);
|
||||
const int32 beam_width = parent_ids.dimension(2);
|
||||
beams.setConstant(-1);
|
||||
beams.setConstant(end_token);
|
||||
|
||||
auto DoWork = [&, ctx, end_token](int start_batch_beam,
|
||||
int limit_batch_beam) {
|
||||
@ -138,10 +138,13 @@ struct GatherTree<CPUDevice, int32> {
|
||||
beams(level, batch, beam) = step_ids(level, batch, parent);
|
||||
parent = parent_ids(level, batch, parent);
|
||||
}
|
||||
// Not necessary when using a BeamSearchDecoder, but necessary
|
||||
// when a user feeds in possibly broken trajectory (i.e., non-eos
|
||||
// entries in a beam following eos entries).
|
||||
bool finished = false;
|
||||
for (int32 time = 0; time < max_seq_len_b; ++time) {
|
||||
if (finished) {
|
||||
beams(time, batch, beam) = -1;
|
||||
beams(time, batch, beam) = end_token;
|
||||
} else if (beams(time, batch, beam) == end_token) {
|
||||
finished = true;
|
||||
}
|
||||
|
@ -46,24 +46,31 @@ __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
|
||||
const int32 initial_beam_ix = GET_IX(max_seq_len_b - 1, beam);
|
||||
beams[initial_beam_ix] = ldg(step_ids + initial_beam_ix);
|
||||
int32 parent = ldg(parent_ids + initial_beam_ix);
|
||||
bool found_bad = false;
|
||||
for (int32 level = max_seq_len_b - 2; level >= 0; --level) {
|
||||
const int32 level_beam_ix = GET_IX(level, beam);
|
||||
const int32 level_parent_ix = GET_IX(level, parent);
|
||||
if (parent < 0 || parent > beam_width) {
|
||||
beams[level_beam_ix] = -1;
|
||||
parent = -1;
|
||||
found_bad = true;
|
||||
} else {
|
||||
beams[level_beam_ix] = ldg(step_ids + level_parent_ix);
|
||||
parent = ldg(parent_ids + level_parent_ix);
|
||||
}
|
||||
}
|
||||
bool finished = false;
|
||||
for (int32 time = 0; time < max_seq_len_b; ++time) {
|
||||
const int32 level_beam_ix = GET_IX(time, beam);
|
||||
if (finished) {
|
||||
beams[level_beam_ix] = -1;
|
||||
} else if (beams[level_beam_ix] == end_token) {
|
||||
finished = true;
|
||||
// Not necessary when using a BeamSearchDecoder, but necessary
|
||||
// when a user feeds in possibly broken trajectory (i.e., non-eos
|
||||
// entries in a beam following eos entries).
|
||||
if (!found_bad) {
|
||||
bool finished = false;
|
||||
for (int32 time = 0; time < max_seq_len_b; ++time) {
|
||||
const int32 level_beam_ix = GET_IX(time, beam);
|
||||
if (finished) {
|
||||
beams[level_beam_ix] = end_token;
|
||||
} else if (beams[level_beam_ix] == end_token) {
|
||||
finished = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
#undef GET_IX
|
||||
@ -80,8 +87,8 @@ struct GatherTree<GPUDevice, T> {
|
||||
const int32 max_time = parent_ids.dimension(0);
|
||||
const int32 batch_size = parent_ids.dimension(1);
|
||||
const int32 beam_width = parent_ids.dimension(2);
|
||||
// First kernel launch to zero things out
|
||||
beams.device(d) = beams.constant(T(-1));
|
||||
// First kernel launch to "zero" things out
|
||||
beams.device(d) = beams.constant(end_token);
|
||||
|
||||
CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d);
|
||||
// clang-format off
|
||||
|
@ -53,11 +53,14 @@ REGISTER_OP("GatherTree")
|
||||
.Doc(R"doc(
|
||||
Calculates the full beams from the per-step ids and parent beam ids.
|
||||
|
||||
This op implements the following mathematical equations:
|
||||
On CPU, if an out of bound parent id is found, an error is returned.
|
||||
On GPU, if an out of bound parent id is found, a -1 is stored in the
|
||||
corresponding output value and the execution for that beam returns early.
|
||||
|
||||
```python
|
||||
TODO(ebrevdo): fill in
|
||||
```
|
||||
For a given beam, past the time step containing the first decoded `end_token`
|
||||
all values are filled in with `end_token`.
|
||||
|
||||
TODO(ebrevdo): fill in the remainder of this docstring.
|
||||
|
||||
step_ids: `[max_time, batch_size, beam_width]`.
|
||||
parent_ids: `[max_time, batch_size, beam_width]`.
|
||||
|
@ -36,24 +36,26 @@ class GatherTreeTest(test.TestCase):
|
||||
|
||||
def testGatherTreeOne(self):
|
||||
# (max_time = 4, batch_size = 1, beams = 3)
|
||||
end_token = 10
|
||||
step_ids = _transpose_batch_time(
|
||||
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
parent_ids = _transpose_batch_time(
|
||||
[[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]])
|
||||
max_sequence_lengths = [3]
|
||||
expected_result = _transpose_batch_time(
|
||||
[[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
expected_result = _transpose_batch_time([[[2, 2, 2], [6, 5, 6], [7, 8, 9],
|
||||
[10, 10, 10]]])
|
||||
beams = beam_search_ops.gather_tree(
|
||||
step_ids=step_ids,
|
||||
parent_ids=parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=10)
|
||||
end_token=end_token)
|
||||
with self.test_session(use_gpu=True):
|
||||
self.assertAllEqual(expected_result, beams.eval())
|
||||
|
||||
def testBadParentValuesOnCPU(self):
|
||||
# (batch_size = 1, max_time = 4, beams = 3)
|
||||
# bad parent in beam 1 time 1
|
||||
end_token = 10
|
||||
step_ids = _transpose_batch_time(
|
||||
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
parent_ids = _transpose_batch_time(
|
||||
@ -64,7 +66,7 @@ class GatherTreeTest(test.TestCase):
|
||||
step_ids=step_ids,
|
||||
parent_ids=parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=10)
|
||||
end_token=end_token)
|
||||
with self.test_session():
|
||||
with self.assertRaisesOpError(
|
||||
r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"):
|
||||
@ -77,19 +79,20 @@ class GatherTreeTest(test.TestCase):
|
||||
return
|
||||
# (max_time = 4, batch_size = 1, beams = 3)
|
||||
# bad parent in beam 1 time 1; appears as a negative index at time 0
|
||||
end_token = 10
|
||||
step_ids = _transpose_batch_time(
|
||||
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
parent_ids = _transpose_batch_time(
|
||||
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
|
||||
max_sequence_lengths = [3]
|
||||
expected_result = _transpose_batch_time(
|
||||
[[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
|
||||
expected_result = _transpose_batch_time([[[2, -1, 2], [6, 5, 6], [7, 8, 9],
|
||||
[10, 10, 10]]])
|
||||
with ops.device("/device:GPU:0"):
|
||||
beams = beam_search_ops.gather_tree(
|
||||
step_ids=step_ids,
|
||||
parent_ids=parent_ids,
|
||||
max_sequence_lengths=max_sequence_lengths,
|
||||
end_token=10)
|
||||
end_token=end_token)
|
||||
with self.test_session(use_gpu=True):
|
||||
self.assertAllEqual(expected_result, beams.eval())
|
||||
|
||||
@ -115,24 +118,24 @@ class GatherTreeTest(test.TestCase):
|
||||
self.assertEqual((max_time, batch_size, beam_width), beams.shape)
|
||||
beams_value = beams.eval()
|
||||
for b in range(batch_size):
|
||||
# Past max_sequence_lengths[b], we emit all -1s.
|
||||
# Past max_sequence_lengths[b], we emit all end tokens.
|
||||
b_value = beams_value[max_sequence_lengths[b]:, b, :]
|
||||
self.assertAllClose(b_value, -1. * np.ones_like(b_value))
|
||||
self.assertAllClose(b_value, end_token * np.ones_like(b_value))
|
||||
for batch, beam in itertools.product(
|
||||
range(batch_size), range(beam_width)):
|
||||
v = np.squeeze(beams_value[:, batch, beam])
|
||||
if end_token in v:
|
||||
found_bad = np.where(v == -1)[0]
|
||||
self.assertEqual(0, len(found_bad))
|
||||
found = np.where(v == end_token)[0]
|
||||
# Should be up to 1 instance of end_token per beam.
|
||||
self.assertEqual(len(found), 1)
|
||||
found = found[0]
|
||||
found = found[0] # First occurrence of end_token.
|
||||
# If an end_token is found, everything before it should be a
|
||||
# valid id and everything after it should be -1.
|
||||
if found > 0:
|
||||
self.assertAllEqual(
|
||||
v[:found - 1] >= 0, np.ones_like(v[:found - 1], dtype=bool))
|
||||
self.assertAllClose(
|
||||
v[found + 1:], -1 * np.ones_like(v[found + 1:]))
|
||||
self.assertAllClose(v[found + 1:],
|
||||
end_token * np.ones_like(v[found + 1:]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -31,6 +31,7 @@ from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.layers import utils
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
@ -47,7 +48,6 @@ _dtypes = input_py._dtypes
|
||||
_store_sparse_tensors = input_py._store_sparse_tensors
|
||||
_validate_keep_input = input_py._validate_keep_input
|
||||
_shapes = input_py._shapes
|
||||
_smart_cond = input_py._smart_cond
|
||||
_which_queue = input_py._which_queue
|
||||
|
||||
# pylint: enable=protected-access
|
||||
@ -239,7 +239,7 @@ def bucket(tensors,
|
||||
]
|
||||
return control_flow_ops.group(*enqueues, name="group_enqueues")
|
||||
|
||||
maybe_enqueue = _smart_cond(
|
||||
maybe_enqueue = utils.smart_cond(
|
||||
keep_input,
|
||||
enqueue_which,
|
||||
control_flow_ops.no_op)
|
||||
|
@ -1411,7 +1411,7 @@ cc_library(
|
||||
hdrs = LIB_INTERNAL_PUBLIC_HEADERS,
|
||||
copts = tf_copts(),
|
||||
defines = tf_additional_lib_defines() + [
|
||||
"SNAPPY",
|
||||
"TF_USE_SNAPPY",
|
||||
] + tf_additional_verbs_lib_defines() +
|
||||
tf_additional_mpi_lib_defines() +
|
||||
tf_additional_gdr_lib_defines(),
|
||||
|
@ -51,7 +51,8 @@ message ApiDef {
|
||||
// endpoints are deprecated).
|
||||
message Endpoint {
|
||||
// Name should be either like "CamelCaseName" or
|
||||
// "Package.CamelCaseName".
|
||||
// "Package.CamelCaseName". Client-language-specific ApiDefs may
|
||||
// use a snake_case convention instead of CamelCase.
|
||||
string name = 1;
|
||||
|
||||
// First GraphDef version at which the op is disallowed.
|
||||
@ -74,7 +75,7 @@ message ApiDef {
|
||||
}
|
||||
repeated Arg in_arg = 4;
|
||||
repeated Arg out_arg = 5;
|
||||
// List of post-rename in_arg names to specify new argument order.
|
||||
// List of original in_arg names to specify new argument order.
|
||||
// Length of arg_order should be either empty to keep current order
|
||||
// or match size of in_arg.
|
||||
repeated string arg_order = 11;
|
||||
|
@ -412,6 +412,8 @@ void InitApiDefFromOpDef(const OpDef& op_def, ApiDef* api_def) {
|
||||
api_in_arg->set_name(op_in_arg.name());
|
||||
api_in_arg->set_rename_to(op_in_arg.name());
|
||||
api_in_arg->set_description(op_in_arg.description());
|
||||
|
||||
*api_def->add_arg_order() = op_in_arg.name();
|
||||
}
|
||||
for (const auto& op_out_arg : op_def.output_arg()) {
|
||||
auto* api_out_arg = api_def->add_out_arg();
|
||||
@ -503,6 +505,22 @@ Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) {
|
||||
}
|
||||
// Merge arg order
|
||||
if (new_api_def.arg_order_size() > 0) {
|
||||
// Validate that new arg_order is correct.
|
||||
if (new_api_def.arg_order_size() != base_api_def->arg_order_size()) {
|
||||
return errors::FailedPrecondition(
|
||||
"Invalid number of arguments ", new_api_def.arg_order_size(), " for ",
|
||||
base_api_def->graph_op_name(),
|
||||
". Expected: ", base_api_def->arg_order_size());
|
||||
}
|
||||
if (!std::is_permutation(new_api_def.arg_order().begin(),
|
||||
new_api_def.arg_order().end(),
|
||||
base_api_def->arg_order().begin())) {
|
||||
return errors::FailedPrecondition(
|
||||
"Invalid arg_order: ", str_util::Join(new_api_def.arg_order(), ", "),
|
||||
" for ", base_api_def->graph_op_name(),
|
||||
". All elements in arg_order override must match base arg_order: ",
|
||||
str_util::Join(base_api_def->arg_order(), ", "));
|
||||
}
|
||||
base_api_def->clear_arg_order();
|
||||
std::copy(
|
||||
new_api_def.arg_order().begin(), new_api_def.arg_order().end(),
|
||||
|
@ -207,6 +207,8 @@ attr {
|
||||
name: "attr_a"
|
||||
rename_to: "attr_a"
|
||||
}
|
||||
arg_order: "arg_a"
|
||||
arg_order: "arg_b"
|
||||
)";
|
||||
OpList op_list;
|
||||
protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT
|
||||
@ -331,8 +333,8 @@ op {
|
||||
name: "arg_c"
|
||||
rename_to: "arg_cc"
|
||||
}
|
||||
arg_order: "arg_aa"
|
||||
arg_order: "arg_b"
|
||||
arg_order: "arg_a"
|
||||
}
|
||||
)";
|
||||
OpList op_list;
|
||||
@ -351,8 +353,8 @@ op {
|
||||
EXPECT_EQ("arg_cc", api_def->out_arg(0).rename_to());
|
||||
|
||||
ASSERT_EQ(2, api_def->arg_order_size());
|
||||
EXPECT_EQ("arg_aa", api_def->arg_order(0));
|
||||
EXPECT_EQ("arg_b", api_def->arg_order(1));
|
||||
EXPECT_EQ("arg_b", api_def->arg_order(0));
|
||||
EXPECT_EQ("arg_a", api_def->arg_order(1));
|
||||
}
|
||||
|
||||
TEST(OpGenLibTest, ApiDefOverrideDescriptions) {
|
||||
@ -411,5 +413,47 @@ op {
|
||||
auto status = api_map.LoadApiDef(api_def1);
|
||||
ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
|
||||
}
|
||||
|
||||
TEST(OpGenLibTest, ApiDefInvalidArgOrder) {
|
||||
const string api_def1 = R"(
|
||||
op {
|
||||
graph_op_name: "testop"
|
||||
arg_order: "arg_a"
|
||||
arg_order: "unexpected_arg"
|
||||
}
|
||||
)";
|
||||
|
||||
const string api_def2 = R"(
|
||||
op {
|
||||
graph_op_name: "testop"
|
||||
arg_order: "arg_a"
|
||||
}
|
||||
)";
|
||||
|
||||
const string api_def3 = R"(
|
||||
op {
|
||||
graph_op_name: "testop"
|
||||
arg_order: "arg_a"
|
||||
arg_order: "arg_a"
|
||||
}
|
||||
)";
|
||||
|
||||
OpList op_list;
|
||||
protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT
|
||||
ApiDefMap api_map(op_list);
|
||||
TF_CHECK_OK(api_map.LoadApiDef(kTestApiDef));
|
||||
|
||||
// Loading with incorrect arg name in arg_order should fail.
|
||||
auto status = api_map.LoadApiDef(api_def1);
|
||||
ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
|
||||
|
||||
// Loading with incorrect number of args in arg_order should fail.
|
||||
status = api_map.LoadApiDef(api_def2);
|
||||
ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
|
||||
|
||||
// Loading with the same argument twice in arg_order should fail.
|
||||
status = api_map.LoadApiDef(api_def3);
|
||||
ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -1068,10 +1068,16 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
|
||||
refiner->set_graph_def_version(
|
||||
std::min(refiner->graph_def_version(), gdef.versions().producer()));
|
||||
|
||||
return GraphConstructor::Construct(
|
||||
opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner,
|
||||
&results->return_tensors, &results->return_nodes,
|
||||
&results->unused_input_map_keys);
|
||||
if (results == nullptr) {
|
||||
return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
|
||||
&gdef.library(), g, refiner, nullptr,
|
||||
nullptr, nullptr);
|
||||
} else {
|
||||
return GraphConstructor::Construct(
|
||||
opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner,
|
||||
&results->return_tensors, &results->return_nodes,
|
||||
&results->unused_input_map_keys);
|
||||
}
|
||||
}
|
||||
|
||||
void CopyGraph(const Graph& src, Graph* dest) {
|
||||
|
@ -450,12 +450,16 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
||||
}
|
||||
|
||||
// Optimize the graph (function inlining, l1 optimizations, etc).
|
||||
VLOG(1) << "Number of nodes in graph before OptimizeGraph: "
|
||||
<< new_item->graph.node_size();
|
||||
Status optimize_status =
|
||||
OptimizeGraph(new_item->graph, &new_item->graph, cfg);
|
||||
if (!optimize_status.ok()) {
|
||||
LOG(ERROR) << "Graph preprocessing failed: " << optimize_status;
|
||||
return nullptr;
|
||||
}
|
||||
VLOG(1) << "Number of nodes in graph after OptimizeGraph: "
|
||||
<< new_item->graph.node_size();
|
||||
|
||||
if (cfg.prune_graph) {
|
||||
VLOG(1) << "Pruning graph...";
|
||||
@ -464,7 +468,8 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
|
||||
LOG(ERROR) << "Pruning failed: " << status.error_message();
|
||||
return nullptr;
|
||||
}
|
||||
VLOG(1) << "Pruning ran succesfully.";
|
||||
VLOG(1) << "Number of nodes in graph after pruning: "
|
||||
<< new_item->graph.node_size();
|
||||
}
|
||||
|
||||
// Validate feed, fetch and init nodes
|
||||
|
@ -18,6 +18,9 @@ limitations under the License.
|
||||
#ifdef INTEL_MKL
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#define MKL_Complex8 tensorflow::complex64
|
||||
#define MKL_Complex16 tensorflow::complex128
|
||||
#include "mkl_trans.h"
|
||||
#include "tensorflow/core/kernels/transpose_functor.h"
|
||||
#include "tensorflow/core/kernels/transpose_op.h"
|
||||
@ -41,7 +44,7 @@ namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
void MKLTranspose2D(const char trans, const Tensor& in, Tensor* out) {}
|
||||
Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out);
|
||||
|
||||
// Documentation here: https://software.intel.com/en-us/node/520863
|
||||
// Parameters: (ordering:row-major, operation:transpose, num_rows, num_cols,
|
||||
@ -54,70 +57,73 @@ void MKLTranspose2D(const char trans, const Tensor& in, Tensor* out) {}
|
||||
mkl_##PREFIX##omatcopy('R', trans, in.dim_size(0), in.dim_size(1), 1, \
|
||||
in.flat<T>().data(), in.dim_size(1), \
|
||||
out->flat<T>().data(), in.dim_size(0)); \
|
||||
return Status::OK();
|
||||
return Status::OK(); \
|
||||
}
|
||||
|
||||
INSTANTIATE(float, s)
|
||||
INSTANTIATE(double, d)
|
||||
INSTANTIATE(complex64, c)
|
||||
INSTANTIATE(complex128, z)
|
||||
INSTANTIATE(float, s)
|
||||
INSTANTIATE(double, d)
|
||||
INSTANTIATE(complex64, c)
|
||||
INSTANTIATE(complex128, z)
|
||||
#undef INSTANTIATE
|
||||
|
||||
static const char kMKLTranspose = 'T';
|
||||
static const char kMKLConjugateTranspose = 'C';
|
||||
static const char kMKLTranspose = 'T';
|
||||
static const char kMKLConjugateTranspose = 'C';
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace
|
||||
|
||||
Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
|
||||
gtl::ArraySlice<int32> perm,
|
||||
Tensor* out) {
|
||||
if (in.dims() == 2) {
|
||||
switch (in.dtype()) {
|
||||
case DT_FLOAT:
|
||||
return MKLTranspose2D<float>(kMKLTranspose, in, out);
|
||||
case DT_DOUBLE:
|
||||
return MKLTranspose2D<double>(kMKLTranspose, in, out);
|
||||
case DT_COMPLEX64:
|
||||
return MKLTranspose2D<complex64>(kMKLTranspose, in, out);
|
||||
case DT_COMPLEX128:
|
||||
return MKLTranspose2D<complex128>(kMKLTranspose, in, out);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
|
||||
gtl::ArraySlice<int32> perm,
|
||||
Tensor* out) {
|
||||
if (in.dims() == 2) {
|
||||
if (perm[0] == 0 && perm[1] == 1) {
|
||||
return Status::OK();
|
||||
}
|
||||
// Fallback to eigen if transpose parameters not supported by MKL
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
return ::tensorflow::DoTranspose(ctx->eigen_device<CPUDevice>(), in, perm,
|
||||
out);
|
||||
}
|
||||
|
||||
Status MklConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
|
||||
const Tensor& in,
|
||||
gtl::ArraySlice<int32> perm,
|
||||
Tensor* out) {
|
||||
if (in.dims() == 2) {
|
||||
// TODO(rmlarsen): By setting lda and ldb, we could use the MKL kernels
|
||||
// for any transpose that can be reduced to swapping the last two
|
||||
// dimensions in a rank-3 tensor. We can even run each outer dimension in
|
||||
// a separate thread.
|
||||
switch (in.dtype()) {
|
||||
case DT_FLOAT:
|
||||
return MKLTranspose2D<float>(kMKLTranspose, in, out);
|
||||
case DT_DOUBLE:
|
||||
return MKLTranspose2D<double>(kMKLTranspose, in, out);
|
||||
case DT_COMPLEX64:
|
||||
return MKLTranspose2D<complex64>(kMKLConjugateTranspose, in, out);
|
||||
case DT_COMPLEX128:
|
||||
return MKLTranspose2D<complex128>(kMKLConjugateTranspose, in, out);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
switch (in.dtype()) {
|
||||
case DT_FLOAT:
|
||||
return MKLTranspose2D<float>(kMKLTranspose, in, out);
|
||||
case DT_DOUBLE:
|
||||
return MKLTranspose2D<double>(kMKLTranspose, in, out);
|
||||
case DT_COMPLEX64:
|
||||
return MKLTranspose2D<complex64>(kMKLTranspose, in, out);
|
||||
case DT_COMPLEX128:
|
||||
return MKLTranspose2D<complex128>(kMKLTranspose, in, out);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
// Fallback to eigen if transpose parameters not supported by MKL
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<CPUDevice>(),
|
||||
in, perm, out);
|
||||
}
|
||||
// Fallback to eigen if transpose parameters not supported by MKL
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
return ::tensorflow::DoTranspose(ctx->eigen_device<CPUDevice>(), in, perm,
|
||||
out);
|
||||
}
|
||||
|
||||
Status MklConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
|
||||
const Tensor& in,
|
||||
gtl::ArraySlice<int32> perm,
|
||||
Tensor* out) {
|
||||
if (in.dims() == 2 && perm[0] == 1 && perm[1] == 0) {
|
||||
// TODO(rmlarsen): By setting lda and ldb, we could use the MKL kernels
|
||||
// for any transpose that can be reduced to swapping the last two
|
||||
// dimensions in a rank-3 tensor. We can even run each outer dimension in
|
||||
// a separate thread.
|
||||
switch (in.dtype()) {
|
||||
case DT_FLOAT:
|
||||
return MKLTranspose2D<float>(kMKLTranspose, in, out);
|
||||
case DT_DOUBLE:
|
||||
return MKLTranspose2D<double>(kMKLTranspose, in, out);
|
||||
case DT_COMPLEX64:
|
||||
return MKLTranspose2D<complex64>(kMKLConjugateTranspose, in, out);
|
||||
case DT_COMPLEX128:
|
||||
return MKLTranspose2D<complex128>(kMKLConjugateTranspose, in, out);
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Fallback to eigen if transpose parameters not supported by MKL
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<CPUDevice>(), in,
|
||||
perm, out);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -201,17 +201,26 @@ Status DoTransposeImpl(const Device& d, const Tensor& in,
|
||||
|
||||
case DT_COMPLEX64:
|
||||
if (conjugate) {
|
||||
Transpose<Device, complex64, true>::run(d, in, perm, out);
|
||||
#if defined(__ANDROID__) and !defined(__clang__)
|
||||
// Workaround for GCC compiler bug in Android toolchain.
|
||||
return errors::Unimplemented(
|
||||
"Conjugate transpose of complex64 not supported for GCC on "
|
||||
"Android.");
|
||||
#else
|
||||
Transpose<Device, complex64, /*conjugate=*/true>::run(d, in, perm, out);
|
||||
#endif
|
||||
} else {
|
||||
Transpose<Device, complex64, false>::run(d, in, perm, out);
|
||||
Transpose<Device, uint64>::run(d, in, perm, out);
|
||||
}
|
||||
break;
|
||||
|
||||
case DT_COMPLEX128:
|
||||
if (conjugate) {
|
||||
Transpose<Device, complex128, true>::run(d, in, perm, out);
|
||||
Transpose<Device, complex128, /*conjugate=*/true>::run(d, in, perm,
|
||||
out);
|
||||
} else {
|
||||
Transpose<Device, complex128, false>::run(d, in, perm, out);
|
||||
Transpose<Device, complex128, /*conjugate=*/false>::run(d, in, perm,
|
||||
out);
|
||||
}
|
||||
break;
|
||||
|
||||
|
@ -467,7 +467,7 @@ def tf_additional_core_deps():
|
||||
"//conditions:default": [],
|
||||
}) + select({
|
||||
"//tensorflow:with_s3_support": [
|
||||
"//tensorflow/contrib/s3:s3_file_system",
|
||||
"//tensorflow/core/platform/s3:s3_file_system",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
})
|
||||
|
@ -29,7 +29,7 @@ limitations under the License.
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#include <unistd.h>
|
||||
#ifdef SNAPPY
|
||||
#ifdef TF_USE_SNAPPY
|
||||
#include "snappy.h"
|
||||
#endif
|
||||
#if (defined(__APPLE__) && defined(__MACH__)) || defined(__FreeBSD__)
|
||||
@ -126,7 +126,7 @@ void AdjustFilenameForLogging(string* filename) {
|
||||
}
|
||||
|
||||
bool Snappy_Compress(const char* input, size_t length, string* output) {
|
||||
#ifdef SNAPPY
|
||||
#ifdef TF_USE_SNAPPY
|
||||
output->resize(snappy::MaxCompressedLength(length));
|
||||
size_t outlen;
|
||||
snappy::RawCompress(input, length, &(*output)[0], &outlen);
|
||||
@ -139,7 +139,7 @@ bool Snappy_Compress(const char* input, size_t length, string* output) {
|
||||
|
||||
bool Snappy_GetUncompressedLength(const char* input, size_t length,
|
||||
size_t* result) {
|
||||
#ifdef SNAPPY
|
||||
#ifdef TF_USE_SNAPPY
|
||||
return snappy::GetUncompressedLength(input, length, result);
|
||||
#else
|
||||
return false;
|
||||
@ -147,7 +147,7 @@ bool Snappy_GetUncompressedLength(const char* input, size_t length,
|
||||
}
|
||||
|
||||
bool Snappy_Uncompress(const char* input, size_t length, char* output) {
|
||||
#ifdef SNAPPY
|
||||
#ifdef TF_USE_SNAPPY
|
||||
return snappy::RawUncompress(input, length, output);
|
||||
#else
|
||||
return false;
|
||||
|
@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/contrib/s3/s3_crypto.h"
|
||||
#include "tensorflow/core/platform/s3/s3_crypto.h"
|
||||
#include <openssl/hmac.h>
|
||||
#include <openssl/sha.h>
|
||||
|
@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/contrib/s3/s3_file_system.h"
|
||||
#include "tensorflow/contrib/s3/s3_crypto.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/s3/s3_file_system.h"
|
||||
#include "tensorflow/core/platform/s3/s3_crypto.h"
|
||||
|
||||
#include <aws/core/Aws.h>
|
||||
#include <aws/core/utils/FileSystemUtils.h>
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/contrib/s3/s3_file_system.h"
|
||||
#include "tensorflow/core/platform/s3/s3_file_system.h"
|
||||
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/gtl/stl_util.h"
|
@ -20,7 +20,7 @@ limitations under the License.
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
#ifdef SNAPPY
|
||||
#ifdef TF_USE_SNAPPY
|
||||
#include "snappy.h"
|
||||
#endif
|
||||
|
||||
@ -118,7 +118,7 @@ void AdjustFilenameForLogging(string* filename) {
|
||||
}
|
||||
|
||||
bool Snappy_Compress(const char* input, size_t length, string* output) {
|
||||
#ifdef SNAPPY
|
||||
#ifdef TF_USE_SNAPPY
|
||||
output->resize(snappy::MaxCompressedLength(length));
|
||||
size_t outlen;
|
||||
snappy::RawCompress(input, length, &(*output)[0], &outlen);
|
||||
@ -131,7 +131,7 @@ bool Snappy_Compress(const char* input, size_t length, string* output) {
|
||||
|
||||
bool Snappy_GetUncompressedLength(const char* input, size_t length,
|
||||
size_t* result) {
|
||||
#ifdef SNAPPY
|
||||
#ifdef TF_USE_SNAPPY
|
||||
return snappy::GetUncompressedLength(input, length, result);
|
||||
#else
|
||||
return false;
|
||||
@ -139,7 +139,7 @@ bool Snappy_GetUncompressedLength(const char* input, size_t length,
|
||||
}
|
||||
|
||||
bool Snappy_Uncompress(const char* input, size_t length, char* output) {
|
||||
#ifdef SNAPPY
|
||||
#ifdef TF_USE_SNAPPY
|
||||
return snappy::RawUncompress(input, length, output);
|
||||
#else
|
||||
return false;
|
||||
|
@ -17,47 +17,94 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
from sklearn import datasets
|
||||
from sklearn import metrics
|
||||
from sklearn import model_selection
|
||||
import os
|
||||
import urllib
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
# Data sets
|
||||
IRIS_TRAINING = 'iris_training.csv'
|
||||
IRIS_TRAINING_URL = 'http://download.tensorflow.org/data/iris_training.csv'
|
||||
|
||||
X_FEATURE = 'x' # Name of the input feature.
|
||||
IRIS_TEST = 'iris_test.csv'
|
||||
IRIS_TEST_URL = 'http://download.tensorflow.org/data/iris_test.csv'
|
||||
|
||||
FEATURE_KEYS = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
|
||||
|
||||
|
||||
def maybe_download_iris_data(file_name, download_url):
|
||||
"""Downloads the file and returns the number of data."""
|
||||
if not os.path.exists(file_name):
|
||||
raw = urllib.urlopen(download_url).read()
|
||||
with open(file_name, 'w') as f:
|
||||
f.write(raw)
|
||||
|
||||
# The first line is a comma-separated string. The first one is the number of
|
||||
# total data in the file.
|
||||
with open(file_name, 'r') as f:
|
||||
first_line = f.readline()
|
||||
num_elements = first_line.split(',')[0]
|
||||
return int(num_elements)
|
||||
|
||||
|
||||
def input_fn(file_name, num_data, batch_size, is_training):
|
||||
"""Creates an input_fn required by Estimator train/evaluate."""
|
||||
# If the data sets aren't stored locally, download them.
|
||||
|
||||
def _parse_csv(rows_string_tensor):
|
||||
"""Takes the string input tensor and returns tuple of (features, labels)."""
|
||||
# Last dim is the label.
|
||||
num_features = len(FEATURE_KEYS)
|
||||
num_columns = num_features + 1
|
||||
columns = tf.decode_csv(rows_string_tensor,
|
||||
record_defaults=[[]] * num_columns)
|
||||
features = dict(zip(FEATURE_KEYS, columns[:num_features]))
|
||||
labels = tf.cast(columns[num_features], tf.int32)
|
||||
return features, labels
|
||||
|
||||
def _input_fn():
|
||||
"""The input_fn."""
|
||||
dataset = tf.data.TextLineDataset([file_name])
|
||||
# Skip the first line (which does not have data).
|
||||
dataset = dataset.skip(1)
|
||||
dataset = dataset.map(_parse_csv)
|
||||
|
||||
if is_training:
|
||||
# For this small dataset, which can fit into memory, to achieve true
|
||||
# randomness, the shuffle buffer size is set as the total number of
|
||||
# elements in the dataset.
|
||||
dataset = dataset.shuffle(num_data)
|
||||
dataset = dataset.repeat()
|
||||
|
||||
dataset = dataset.batch(batch_size)
|
||||
iterator = dataset.make_one_shot_iterator()
|
||||
features, labels = iterator.get_next()
|
||||
return features, labels
|
||||
|
||||
return _input_fn
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
# Load dataset.
|
||||
iris = datasets.load_iris()
|
||||
x_train, x_test, y_train, y_test = model_selection.train_test_split(
|
||||
iris.data, iris.target, test_size=0.2, random_state=42)
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
|
||||
num_training_data = maybe_download_iris_data(
|
||||
IRIS_TRAINING, IRIS_TRAINING_URL)
|
||||
num_test_data = maybe_download_iris_data(IRIS_TEST, IRIS_TEST_URL)
|
||||
|
||||
# Build 3 layer DNN with 10, 20, 10 units respectively.
|
||||
feature_columns = [
|
||||
tf.feature_column.numeric_column(
|
||||
X_FEATURE, shape=np.array(x_train).shape[1:])]
|
||||
tf.feature_column.numeric_column(key, shape=1) for key in FEATURE_KEYS]
|
||||
classifier = tf.estimator.DNNClassifier(
|
||||
feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3)
|
||||
|
||||
# Train.
|
||||
train_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
x={X_FEATURE: x_train}, y=y_train, num_epochs=None, shuffle=True)
|
||||
classifier.train(input_fn=train_input_fn, steps=200)
|
||||
train_input_fn = input_fn(IRIS_TRAINING, num_training_data, batch_size=32,
|
||||
is_training=True)
|
||||
classifier.train(input_fn=train_input_fn, steps=400)
|
||||
|
||||
# Predict.
|
||||
test_input_fn = tf.estimator.inputs.numpy_input_fn(
|
||||
x={X_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False)
|
||||
predictions = classifier.predict(input_fn=test_input_fn)
|
||||
y_predicted = np.array(list(p['class_ids'] for p in predictions))
|
||||
y_predicted = y_predicted.reshape(np.array(y_test).shape)
|
||||
|
||||
# Score with sklearn.
|
||||
score = metrics.accuracy_score(y_test, y_predicted)
|
||||
print('Accuracy (sklearn): {0:f}'.format(score))
|
||||
|
||||
# Score with tensorflow.
|
||||
# Eval.
|
||||
test_input_fn = input_fn(IRIS_TEST, num_test_data, batch_size=32,
|
||||
is_training=False)
|
||||
scores = classifier.evaluate(input_fn=test_input_fn)
|
||||
print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy']))
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user