Merge pull request from vrv/branch_172924803

Branch 172924803
This commit is contained in:
Vijay Vasudevan 2017-10-20 16:06:41 -07:00 committed by GitHub
commit b3d5ec90bc
128 changed files with 6825 additions and 2227 deletions
WORKSPACE
tensorflow
BUILD
c/eager
compiler/xla
contrib
core
examples/learn

View File

@ -5,7 +5,7 @@ http_archive(
sha256 = "110fe68753413777944b473c25eed6368c4a0487cee23a7bac1b13cc49d3e257",
strip_prefix = "rules_closure-4af89ef1db659eb41f110df189b67d4cf14073e1",
urls = [
"http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz",
"https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz",
"https://github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", # 2017-08-28
],
)

View File

@ -348,6 +348,7 @@ filegroup(
"//tensorflow/compiler/xla/service/llvm_ir:all_files",
"//tensorflow/compiler/xla/tests:all_files",
"//tensorflow/compiler/xla/tools:all_files",
"//tensorflow/compiler/xla/tools/parser:all_files",
"//tensorflow/contrib:all_files",
"//tensorflow/contrib/all_reduce:all_files",
"//tensorflow/contrib/android:all_files",
@ -421,7 +422,6 @@ filegroup(
"//tensorflow/contrib/remote_fused_graph/pylib:all_files",
"//tensorflow/contrib/resampler:all_files",
"//tensorflow/contrib/rnn:all_files",
"//tensorflow/contrib/s3:all_files",
"//tensorflow/contrib/saved_model:all_files",
"//tensorflow/contrib/saved_model/cc/saved_model:all_files",
"//tensorflow/contrib/seq2seq:all_files",
@ -475,6 +475,7 @@ filegroup(
"//tensorflow/core/platform/cloud:all_files",
"//tensorflow/core/platform/default/build_config:all_files",
"//tensorflow/core/platform/hadoop:all_files",
"//tensorflow/core/platform/s3:all_files",
"//tensorflow/core/profiler:all_files",
"//tensorflow/core/profiler/internal:all_files",
"//tensorflow/core/profiler/internal/advisor:all_files",

View File

@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0
load(
"//tensorflow:tensorflow.bzl",
"tf_cuda_cc_test",
"tf_cc_test",
"tf_copts",
"tf_cuda_library",
@ -50,7 +51,7 @@ tf_cuda_library(
],
)
tf_cc_test(
tf_cuda_cc_test(
name = "c_api_test",
srcs = ["c_api_test.cc"],
deps = [

View File

@ -54,9 +54,23 @@ string DeviceName(tensorflow::Device* d) {
extern "C" {
TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) {
TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
size_t proto_len, TF_Status* status) {
TF_SetConfig(&options->session_options, proto, proto_len, status);
}
void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
options->policy = policy;
}
void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
TF_Graph* graph = TF_NewGraph();
TF_Session* session = TF_NewSession(graph, opts, status);
TF_Session* session = TF_NewSession(graph, &opts->session_options, status);
if (status->status.ok()) {
if (session->device_mgr == nullptr || session->devices.empty()) {
status->status = tensorflow::errors::InvalidArgument(
@ -71,9 +85,10 @@ TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) {
}
TFE_Context* ret = new TFE_Context(session);
ret->policy = opts->policy;
ret->pflr.reset(new tensorflow::ProcessFunctionLibraryRuntime(
ret->session->device_mgr, opts->options.env, TF_GRAPH_DEF_VERSION,
&ret->func_lib_def, {}));
ret->session->device_mgr, opts->session_options.options.env,
TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {}));
ret->rendezvous =
new tensorflow::IntraProcessRendezvous(ret->session->device_mgr);
@ -408,8 +423,10 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
namespace {
tensorflow::Status ValidateInputTypeAndPlacement(
tensorflow::Device* host_device, tensorflow::Device* op_device, TFE_Op* op,
const tensorflow::OpKernel* kernel) {
TFE_Context* ctx, tensorflow::Device* host_device,
tensorflow::Device* op_device, TFE_Op* op,
const tensorflow::OpKernel* kernel,
std::vector<TFE_TensorHandle*>* copied_tensors) {
const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types();
if (memtypes.size() != op->inputs.size()) {
return tensorflow::errors::InvalidArgument(
@ -421,11 +438,42 @@ tensorflow::Status ValidateInputTypeAndPlacement(
const tensorflow::Device* actual_device =
op->input_devices[i] == nullptr ? host_device : op->input_devices[i];
if (expected_device != actual_device) {
return tensorflow::errors::InvalidArgument(
"cannot compute ", op->name, " as input #", i,
" was expected to be on ", expected_device->name(),
" but is actually on ", actual_device->name(),
" (operation running on ", op_device->name(), ")");
switch (ctx->policy) {
case TFE_DEVICE_PLACEMENT_EXPLICIT:
return tensorflow::errors::InvalidArgument(
"cannot compute ", op->name, " as input #", i,
" was expected to be on ", expected_device->name(),
" but is actually on ", actual_device->name(),
" (operation running on ", op_device->name(), ")");
case TFE_DEVICE_PLACEMENT_WARN:
LOG(WARNING) << "before computing " << op->name << " input #" << i
<< " was expected to be on " << expected_device->name()
<< " but is actually on " << actual_device->name()
<< " (operation running on " << op_device->name()
<< "). This triggers a copy which can be a performance "
"bottleneck.";
break;
case TFE_DEVICE_PLACEMENT_SILENT: // Do nothing.
break;
}
// We are only here if the policy is warn or silent copies, so we should
// trigger a copy.
TFE_TensorHandle original{op->inputs[i], op->input_devices[i]};
TF_Status* s = TF_NewStatus();
TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice(
&original, ctx, expected_device->name().c_str(), s);
if (!s->status.ok()) {
tensorflow::Status status = s->status;
delete s;
return tensorflow::errors::Internal(
"Failed copying input tensor from ", actual_device->name(), " to ",
expected_device->name(), " in order to run ", op->name, ": ",
status.error_message());
}
op->inputs[i] = copied_tensor->t;
copied_tensors->push_back(copied_tensor);
op->input_devices[i] = copied_tensor->d;
delete s;
}
if (op->inputs[i].dtype() != kernel->input_type(i)) {
return tensorflow::errors::InvalidArgument(
@ -468,10 +516,14 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
}
tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
}
status->status = ValidateInputTypeAndPlacement(ctx->devices()[0], device, op,
kernel->kernel());
std::vector<TFE_TensorHandle*> copied_tensors;
status->status = ValidateInputTypeAndPlacement(
ctx, ctx->devices()[0], device, op, kernel->kernel(), &copied_tensors);
output_memory_types = &kernel->kernel()->output_memory_types();
if (!status->status.ok()) {
for (auto* t : copied_tensors) {
TFE_DeleteTensorHandle(t);
}
return;
}
// WARNING: kernel->Run utilizes the FunctionLibraryRuntime
@ -483,6 +535,9 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
// sense for FunctionLibraryRuntime to ensure thread-safe access to
// FunctionLibraryDefinition?).
status->status = kernel->Run(&op->inputs, &outputs);
for (auto* t : copied_tensors) {
TFE_DeleteTensorHandle(t);
}
if (!status->status.ok()) return;
*num_retvals = std::min<int>(*num_retvals, outputs.size());
for (int i = 0; i < *num_retvals; ++i) {

View File

@ -43,14 +43,46 @@ limitations under the License.
extern "C" {
#endif
typedef struct TFE_ContextOptions TFE_ContextOptions;
// Return a new options object.
TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions();
// Set the config in TF_ContextOptions.options.
// config should be a serialized tensorflow.ConfigProto proto.
// If config was not parsed successfully as a ConfigProto, record the
// error information in *status.
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetConfig(
TFE_ContextOptions* options, const void* proto, size_t proto_len,
TF_Status* status);
// Controls how to act when we try to run an operation on a given device but
// some input tensors are not on that device.
typedef enum TFE_ContextDevicePlacementPolicy {
// The default: running operations with input tensors on the wrong device will
// fail.
TFE_DEVICE_PLACEMENT_EXPLICIT = 0,
// Copy the tensor to the right device but log a warning.
TFE_DEVICE_PLACEMENT_WARN = 1,
// Silently copy the tensor, which has a performance cost since the
// operation will be blocked till the copy completes.
TFE_DEVICE_PLACEMENT_SILENT = 2,
} TFE_ContextDevicePlacementPolicy;
TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy(
TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy);
// Destroy an options object.
TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*);
// "Context" under which operations/functions are executed. It encapsulates
// things like the available devices, resource manager etc.
//
// TODO(ashankar): Merge with TF_Session?
typedef struct TFE_Context TFE_Context;
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext(const TF_SessionOptions* opts,
TF_Status* status);
TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext(
const TFE_ContextOptions* opts, TF_Status* status);
TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status);
TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
TF_Status* status);

View File

@ -35,9 +35,16 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
struct TFE_ContextOptions {
TF_SessionOptions session_options;
TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_EXPLICIT};
};
struct TFE_Context {
explicit TFE_Context(TF_Session* s) : session(s) {}
TFE_ContextDevicePlacementPolicy policy;
// TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph.
TF_Session* session;
tensorflow::Rendezvous* rendezvous;

View File

@ -62,10 +62,10 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) {
void BM_InitOp(int iters) {
tensorflow::testing::StopTiming();
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
tensorflow::testing::StartTiming();
@ -84,10 +84,10 @@ BENCHMARK(BM_InitOp);
void BM_Execute(int iters) {
tensorflow::testing::StopTiming();
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul = MatMulOp(ctx, m, m);
@ -109,9 +109,9 @@ BENCHMARK(BM_Execute);
TEST(CAPI, Context) {
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
TF_DeleteSessionOptions(opts);
TFE_DeleteContextOptions(opts);
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@ -150,9 +150,9 @@ TEST(CAPI, TensorHandle) {
TEST(CAPI, TensorHandleCopyBetweenDevices) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TF_DeleteSessionOptions(opts);
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
@ -216,12 +216,58 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) {
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
}
TEST(CAPI, TensorHandleSilentCopy) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT);
TFE_Context* ctx = TFE_NewContext(opts, status.get());
TFE_DeleteContextOptions(opts);
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TFE_TensorHandle* hcpu = TestMatrixTensorHandle();
TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
const int num_devices = TF_DeviceListCount(devices);
// Disable the test if no GPU is present.
if (num_devices > 1) {
const int device_to_use = 1;
const string name(TF_DeviceListName(devices, device_to_use, status.get()));
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* hgpu =
TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu);
TFE_OpSetDevice(matmul, name.c_str(), status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_TensorHandle* retvals[1];
int num_retvals = 1;
TFE_Execute(matmul, &retvals[0], &num_retvals, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteTensorHandle(hgpu);
}
TF_DeleteDeviceList(devices);
TF_DeleteTensor(t);
TFE_DeleteTensorHandle(hcpu);
TFE_DeleteContext(ctx, status.get());
EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
}
TEST(CAPI, Execute) {
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul = MatMulOp(ctx, m, m);
@ -285,10 +331,10 @@ string MatMulFunction() {
TEST(CAPI, FunctionDefAndExecute) {
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts);
TFE_DeleteContextOptions(opts);
string function_def = MatMulFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
@ -326,10 +372,10 @@ TEST(CAPI, FunctionDefAndExecute) {
void BM_ExecuteFunction(int iters) {
tensorflow::testing::StopTiming();
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts);
TFE_DeleteContextOptions(opts);
string function_def = MatMulFunction();
TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(),
@ -406,10 +452,10 @@ TEST(CAPI, Variables) {
// Variables use resource handles, so this is really a test for resource
// tensor handling.
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
@ -446,10 +492,10 @@ TEST(CAPI, Variables) {
void BM_ReadVariable(int iters) {
tensorflow::testing::StopTiming();
TF_Status* status = TF_NewStatus();
TF_SessionOptions* opts = TF_NewSessionOptions();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_Context* ctx = TFE_NewContext(opts, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteSessionOptions(opts);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);

View File

@ -138,6 +138,11 @@ class ComputationBuilder {
ComputationDataHandle ConstantR2(
std::initializer_list<std::initializer_list<NativeT>> values);
template <typename NativeT>
ComputationDataHandle ConstantFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout);
template <typename NativeT>
ComputationDataHandle ConstantFromArray(const Array<NativeT>& values);
template <typename NativeT>
ComputationDataHandle ConstantR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout);
template <typename NativeT>
@ -909,49 +914,55 @@ ComputationDataHandle ComputationBuilder::ConstantR2(
[&values](Literal* literal) { literal->PopulateR2(values); });
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout) {
return ConstantOp([&values, &layout](Literal* literal) {
literal->PopulateFromArrayWithLayout(values, layout);
});
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantFromArray(
const Array<NativeT>& values) {
return ConstantOp(
[&values](Literal* literal) { literal->PopulateFromArray(values); });
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) {
return ConstantOp([&values, &layout](Literal* literal) {
literal->PopulateR2FromArray2DWithLayout(values, layout);
});
return ConstantFromArrayWithLayout(values, layout);
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D(
const Array2D<NativeT>& values) {
return ConstantOp(
[&values](Literal* literal) { literal->PopulateR2FromArray2D(values); });
return ConstantFromArray(values);
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) {
return ConstantOp([&values, &layout](Literal* literal) {
literal->PopulateR3FromArray3DWithLayout(values, layout);
});
return ConstantFromArrayWithLayout(values, layout);
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D(
const Array3D<NativeT>& values) {
return ConstantOp(
[&values](Literal* literal) { literal->PopulateR3FromArray3D(values); });
return ConstantFromArray(values);
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout(
const Array4D<NativeT>& values, const Layout& layout) {
return ConstantOp([&values, &layout](Literal* literal) {
literal->PopulateR4FromArray4DWithLayout(values, layout);
});
return ConstantFromArrayWithLayout(values, layout);
}
template <typename NativeT>
ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D(
const Array4D<NativeT>& values) {
return ConstantOp(
[&values](Literal* literal) { literal->PopulateR4FromArray4D(values); });
return ConstantFromArray(values);
}
} // namespace xla

View File

@ -83,6 +83,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) {
return CreateDefaultLayoutForRank(shape.dimensions_size());
}
/* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64 rank) {
return CreateDefaultLayoutForRank(rank);
}
/* static */ Layout LayoutUtil::GetDefaultLayoutForR2() {
return CreateDefaultLayoutForRank(2);
}

View File

@ -40,6 +40,7 @@ class LayoutUtil {
static Layout GetDefaultLayoutForShape(const Shape& shape);
// Helper functions that create default layouts for various ranks.
static Layout GetDefaultLayoutForRank(int64 rank);
static Layout GetDefaultLayoutForR2();
static Layout GetDefaultLayoutForR3();
static Layout GetDefaultLayoutForR4();

View File

@ -206,9 +206,9 @@ void AllocateFlags() {
flag_values->xla_gpu_disable_multi_streaming(),
"If true, multi-streaming in the GPU backend is disabled."),
tensorflow::Flag(
"xla_dump_debug_json_to",
flag_values->mutable_xla_dump_debug_json_to(),
"Dump compilation artifacts as JSON into this directory."),
"xla_dump_hlo_proto_to",
flag_values->mutable_xla_dump_hlo_proto_to(),
"Dump compilation artifacts as proto binary into this directory."),
tensorflow::Flag(
"xla_test_all_output_layouts",
bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts),

View File

@ -334,6 +334,11 @@ class Literal {
// WithLayout use the default XLA layout for the literal's linear
// representation in memory.
template <typename NativeT>
static std::unique_ptr<Literal> CreateFromArray(const Array<NativeT>& values);
template <typename NativeT>
static std::unique_ptr<Literal> CreateFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout);
template <typename NativeT>
static std::unique_ptr<Literal> CreateR2FromArray2D(
const Array2D<NativeT>& values);
template <typename NativeT>
@ -481,6 +486,11 @@ class Literal {
std::initializer_list<std::initializer_list<NativeT>> values,
const Layout& layout);
template <typename NativeT>
void PopulateFromArray(const Array<NativeT>& values);
template <typename NativeT>
void PopulateFromArrayWithLayout(const Array<NativeT>& values,
const Layout& layout);
template <typename NativeT>
void PopulateR2FromArray2D(const Array2D<NativeT>& values);
template <typename NativeT>
void PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
@ -815,34 +825,43 @@ template <typename NativeT>
return CreateR4WithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateFromArrayWithLayout(
const Array<NativeT>& values, const Layout& layout) {
auto literal = MakeUnique<Literal>();
literal->PopulateFromArrayWithLayout(values, layout);
return literal;
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateFromArray(
const Array<NativeT>& values) {
return CreateFromArrayWithLayout(
values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2DWithLayout(
const Array2D<NativeT>& values, const Layout& layout) {
auto literal = MakeUnique<Literal>();
literal->PopulateR2FromArray2DWithLayout(values, layout);
return literal;
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR2FromArray2D(
const Array2D<NativeT>& values) {
return CreateR2FromArray2DWithLayout(values,
LayoutUtil::GetDefaultLayoutForR2());
return CreateFromArray(values);
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3DWithLayout(
const Array3D<NativeT>& values, const Layout& layout) {
auto literal = MakeUnique<Literal>();
literal->PopulateR3FromArray3DWithLayout(values, layout);
return literal;
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR3FromArray3D(
const Array3D<NativeT>& values) {
return CreateR3FromArray3DWithLayout(values,
LayoutUtil::GetDefaultLayoutForR3());
return CreateFromArray(values);
}
template <typename NativeT>
@ -901,16 +920,13 @@ template <typename NativeT>
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4D(
const Array4D<NativeT>& values) {
return CreateR4FromArray4DWithLayout(values,
LayoutUtil::GetDefaultLayoutForR4());
return CreateFromArray(values);
}
template <typename NativeT>
/* static */ std::unique_ptr<Literal> Literal::CreateR4FromArray4DWithLayout(
const Array4D<NativeT>& values, const Layout& layout) {
auto literal = MakeUnique<Literal>();
literal->PopulateR4FromArray4DWithLayout(values, layout);
return literal;
return CreateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
@ -1069,83 +1085,54 @@ void Literal::PopulateR2(
PopulateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
}
template <typename NativeT>
void Literal::PopulateFromArrayWithLayout(const Array<NativeT>& values,
const Layout& layout) {
*mutable_shape() = ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(), values.dimensions(),
AsInt64Slice(layout.minor_to_major()));
Reserve(values.num_elements());
values.Each([this](tensorflow::gtl::ArraySlice<int64> indices,
NativeT value) { this->Set(indices, value); });
}
template <typename NativeT>
void Literal::PopulateFromArray(const Array<NativeT>& values) {
PopulateFromArrayWithLayout(
values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions()));
}
template <typename NativeT>
void Literal::PopulateR2FromArray2DWithLayout(const Array2D<NativeT>& values,
const Layout& layout) {
*mutable_shape() = ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(),
{values.height(), values.width()}, AsInt64Slice(layout.minor_to_major()));
const int64 dim1_size = values.width();
const int64 dim0_size = values.height();
CHECK_EQ(dim0_size, shape().dimensions(0));
CHECK_EQ(dim1_size, shape().dimensions(1));
Reserve(dim1_size * dim0_size);
for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) {
for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) {
Set({dim0, dim1}, values(dim0, dim1));
}
}
PopulateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
void Literal::PopulateR2FromArray2D(const Array2D<NativeT>& values) {
PopulateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2());
PopulateFromArray(values);
}
template <typename NativeT>
void Literal::PopulateR3FromArray3DWithLayout(const Array3D<NativeT>& values,
const Layout& layout) {
*mutable_shape() = ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(),
{values.n1(), values.n2(), values.n3()},
AsInt64Slice(layout.minor_to_major()));
CHECK_EQ(values.n1(), shape().dimensions(0));
CHECK_EQ(values.n2(), shape().dimensions(1));
CHECK_EQ(values.n3(), shape().dimensions(2));
Reserve(values.n1() * values.n2() * values.n3());
for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) {
for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) {
for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) {
Set({dim0, dim1, dim2}, values(dim0, dim1, dim2));
}
}
}
PopulateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
void Literal::PopulateR3FromArray3D(const Array3D<NativeT>& values) {
PopulateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3());
PopulateFromArray(values);
}
template <typename NativeT>
void Literal::PopulateR4FromArray4DWithLayout(const Array4D<NativeT>& values,
const Layout& layout) {
*mutable_shape() = ShapeUtil::MakeShapeWithLayout(
primitive_util::NativeToPrimitiveType<NativeT>(),
{values.planes(), values.depth(), values.height(), values.width()},
AsInt64Slice(layout.minor_to_major()));
CHECK_EQ(values.n1(), shape().dimensions(0));
CHECK_EQ(values.n2(), shape().dimensions(1));
CHECK_EQ(values.n3(), shape().dimensions(2));
CHECK_EQ(values.n4(), shape().dimensions(3));
Reserve(values.n1() * values.n2() * values.n3() * values.n4());
for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) {
for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) {
for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) {
for (int64 dim3 = 0; dim3 < values.n4(); ++dim3) {
Set({dim0, dim1, dim2, dim3}, values(dim0, dim1, dim2, dim3));
}
}
}
}
PopulateFromArrayWithLayout(values, layout);
}
template <typename NativeT>
void Literal::PopulateR4FromArray4D(const Array4D<NativeT>& values) {
PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4());
PopulateFromArray(values);
}
template <typename NativeT, typename FnType>

View File

@ -37,20 +37,6 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
return (serialized1 == serialized2);
}
StatusOr<string> ToJson(const tensorflow::protobuf::Message& message) {
string json_output;
tensorflow::protobuf::util::JsonPrintOptions json_options;
json_options.add_whitespace = true;
json_options.always_print_primitive_fields = true;
auto status = tensorflow::protobuf::util::MessageToJsonString(
message, &json_output, json_options);
if (!status.ok()) {
return InternalError("MessageToJsonString failed: %s",
status.error_message().data());
}
return json_output;
}
namespace {
string SanitizeFilename(const string& file_name) {
@ -65,17 +51,6 @@ string SanitizeFilename(const string& file_name) {
} // namespace
Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name) {
TF_ASSIGN_OR_RETURN(const string json_output, ToJson(message));
tensorflow::Env* env = tensorflow::Env::Default();
TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory));
string safe_file_name = SanitizeFileName(file_name) + ".json";
const string path = tensorflow::io::JoinPath(directory, safe_file_name);
return tensorflow::WriteStringToFile(env, path, json_output);
}
Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name) {
tensorflow::Env* env = tensorflow::Env::Default();

View File

@ -32,17 +32,12 @@ namespace protobuf_util {
extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1,
const tensorflow::protobuf::Message& m2);
// Returns 'message' as a JSON string.
StatusOr<string> ToJson(const tensorflow::protobuf::Message& message);
// Writes the given message in binary proto or JSON format to the path formed by
// joining 'directory/file_name.pb' (or file_name.json). The 'directory' is
// recursively created if it doesn't already exist, and the 'file_name' is
// sanitized by replacing illegal characters with underscore '_'.
// Writes the given message in binary proto to the path formed by joining
// 'directory/file_name.pb'. The 'directory' is recursively created if it
// doesn't already exist, and the 'file_name' is sanitized by replacing
// illegal characters with underscore '_'.
Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name);
Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message,
const string& directory, const string& file_name);
} // namespace protobuf_util
} // namespace xla

View File

@ -2064,6 +2064,29 @@ tf_cc_test(
],
)
cc_library(
name = "hlo_runner",
srcs = ["hlo_runner.cc"],
hdrs = ["hlo_runner.h"],
deps = [
":executable",
":hlo",
":transfer_manager",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//third_party/eigen3",
],
)
# -----------------------------------------------------------------------------
filegroup(

View File

@ -475,8 +475,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
// ownership is std::moved.
const bool embed_ir_in_executable =
module->config().debug_options().xla_embed_ir_in_executable();
const string dump_debug_json_to =
module->config().debug_options().xla_dump_debug_json_to();
const string xla_dump_hlo_proto_to =
module->config().debug_options().xla_dump_hlo_proto_to();
if (options::CpuParallelBackendRequested(module->config())) {
VLOG(1) << "Using parallel cpu backend";
@ -496,10 +496,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
if (!dump_debug_json_to.empty()) {
if (!xla_dump_hlo_proto_to.empty()) {
HloProto proto = MakeHloProto(*module, *assignment);
TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory(
proto, dump_debug_json_to, module->name()));
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
proto, xla_dump_hlo_proto_to, module->name()));
}
// If we are using the parallel CPU backend, we need to create map from
@ -603,12 +603,11 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::Compile(
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
if (!dump_debug_json_to.empty()) {
if (!xla_dump_hlo_proto_to.empty()) {
HloProto proto = MakeHloProto(*module, *assignment);
TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory(
proto, dump_debug_json_to, module->name()));
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
proto, xla_dump_hlo_proto_to, module->name()));
}
// Each computation is a single function. Emit all embedded computations
// before the entry computation. The order of computations returned from
// GetEmbeddedComputations guarantees that a called computation occurs
@ -775,12 +774,12 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
const string dump_debug_json_to =
module->config().debug_options().xla_dump_debug_json_to();
if (!dump_debug_json_to.empty()) {
const string xla_dump_hlo_proto_to =
module->config().debug_options().xla_dump_hlo_proto_to();
if (!xla_dump_hlo_proto_to.empty()) {
HloProto proto = MakeHloProto(*module, *assignment);
TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory(
proto, dump_debug_json_to, module->name()));
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
proto, xla_dump_hlo_proto_to, module->name()));
}
IrEmitter ir_emitter(*module, *assignment, &llvm_module,

View File

@ -136,6 +136,8 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
instruction->opcode() == HloOpcode::kCall ||
instruction->opcode() == HloOpcode::kCustomCall ||
instruction->opcode() == HloOpcode::kSelectAndScatter ||
instruction->opcode() == HloOpcode::kGetTupleElement ||
instruction->opcode() == HloOpcode::kBitcast ||
(instruction->opcode() == HloOpcode::kConvolution &&
PotentiallyImplementedAsEigenConvolution(*instruction)) ||
PotentiallyImplementedAsEigenDot(*instruction) ||

View File

@ -318,12 +318,12 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::Compile(
// print one ourselves.
XLA_VLOG_LINES(2, buffer_assignment->ToString());
const string dump_debug_json_to =
module->config().debug_options().xla_dump_debug_json_to();
if (!dump_debug_json_to.empty()) {
const string xla_dump_hlo_proto_to =
module->config().debug_options().xla_dump_hlo_proto_to();
if (!xla_dump_hlo_proto_to.empty()) {
HloProto proto = MakeHloProto(*module, *buffer_assignment);
TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory(
proto, dump_debug_json_to, module->name()));
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
proto, xla_dump_hlo_proto_to, module->name()));
}
IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(),

View File

@ -373,8 +373,8 @@ string HloComputation::ToString(int nested_level) const {
for (int i = 0; i < nested_level; i++) {
s << " ";
}
s << name() << " " << ShapeUtil::HumanString(ComputeProgramShape())
<< " { \n";
s << "%" << name() << " " << ShapeUtil::HumanString(ComputeProgramShape())
<< " {\n";
for (const HloInstruction* instruction : MakeInstructionPostOrder()) {
for (int i = 0; i < nested_level; i++) {
s << " ";

View File

@ -0,0 +1,199 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_runner.h"
#include <set>
#include <string>
#include <utility>
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace se = ::perftools::gputools;
namespace xla {
/*static*/ StatusOr<std::unique_ptr<HloModule>>
HloRunner::ReadModuleFromHloProtoFile(const char* filename,
const DebugOptions& debug_options) {
HloProto proto;
TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
filename, &proto));
HloModuleConfig config;
config.set_debug_options(debug_options);
TF_ASSIGN_OR_RETURN(auto module, HloModule::CreateFromProto(
proto.hlo_module(),
VersionedComputationHandle(), config));
return std::move(module);
}
// Define this in .cc file to avoid having to include eigen or forward declare
// these types in the header.
struct HloRunner::EigenThreadPoolWrapper {
std::unique_ptr<EigenThreadPoolWrapper> pool;
std::unique_ptr<Eigen::ThreadPoolDevice> device;
};
HloRunner::HloRunner() {}
HloRunner::HloRunner(se::Platform* platform) {
BackendOptions backend_options;
backend_options.set_platform(platform);
backend_ = Backend::CreateBackend(backend_options).ConsumeValueOrDie();
VLOG(1) << "Created HloRunner for platform: " << platform->Name();
}
HloRunner::~HloRunner() {
// Deallocate all the memory allocated during the tests.
for (auto& allocation : allocations_) {
backend().default_stream_executor()->Deallocate(&allocation);
}
}
StatusOr<se::DeviceMemoryBase> HloRunner::Execute(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments,
Shape* result_shape) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
backend().compiler()->Compile(std::move(module),
backend().default_stream_executor()));
se::Stream stream(backend().default_stream_executor());
stream.Init();
ExecutableRunOptions run_options;
run_options.set_stream(&stream);
run_options.set_allocator(backend().memory_allocator());
run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool());
run_options.set_intra_op_thread_pool(
backend().eigen_intra_op_thread_pool_device());
HloExecutionProfile hlo_execution_profile;
ServiceExecutableRunOptions service_run_options(
run_options, backend().StreamBorrower(),
backend().inter_op_thread_pool());
TF_ASSIGN_OR_RETURN(
se::DeviceMemoryBase result,
executable->ExecuteOnStream(&service_run_options, arguments,
&hlo_execution_profile));
TF_RET_CHECK(stream.BlockHostUntilDone());
allocations_.push_back(result);
*result_shape = executable->result_shape();
if (ShapeUtil::IsTuple(*result_shape)) {
// We must record element buffers of tuples as well to avoid leaks.
DCHECK(!ShapeUtil::IsNestedTuple(*result_shape));
TF_ASSIGN_OR_RETURN(
std::vector<se::DeviceMemoryBase> element_buffers,
backend().transfer_manager()->ShallowCopyTupleFromDevice(
backend().default_stream_executor(), result, *result_shape));
// A tuple may contain the same buffer in more than one element. Keep track
// of the buffers already added to avoid duplicates in allocations_.
std::set<void*> added_opaques;
for (auto element_buffer : element_buffers) {
if (added_opaques.count(element_buffer.opaque()) == 0) {
CHECK(element_buffer.opaque() != nullptr);
added_opaques.insert(element_buffer.opaque());
allocations_.push_back(element_buffer);
}
}
}
return result;
}
se::DeviceMemoryBase HloRunner::TransferToDevice(const Literal& literal) {
// Allocate memory on the device using the stream executor.
int64 allocation_size =
backend().transfer_manager()->GetByteSizeRequirement(literal.shape());
se::DeviceMemoryBase allocation =
backend().default_stream_executor()->AllocateArray<uint8>(
allocation_size);
allocations_.push_back(allocation);
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToDevice(
backend().default_stream_executor(), literal, &allocation));
return allocation;
}
std::unique_ptr<Literal> HloRunner::TransferFromDevice(
const Shape& shape, se::DeviceMemoryBase device_base) {
auto literal = MakeUnique<Literal>();
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromDevice(
backend().default_stream_executor(), device_base, shape, shape,
literal.get()));
return literal;
}
std::unique_ptr<Literal> HloRunner::ExecuteAndTransfer(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
Shape result_shape;
se::DeviceMemoryBase device_base =
Execute(std::move(module), arguments, &result_shape).ValueOrDie();
return TransferFromDevice(result_shape, device_base);
}
template <>
std::unique_ptr<Literal> HloRunner::Execute(
std::unique_ptr<HloModule> module,
const tensorflow::gtl::ArraySlice<std::unique_ptr<Literal>>& literals) {
std::vector<se::DeviceMemoryBase> arguments;
for (const auto& literal : literals) {
arguments.push_back(TransferToDevice(*literal));
}
return ExecuteAndTransfer(std::move(module), arguments);
}
template <>
std::unique_ptr<Literal> HloRunner::Execute(
std::unique_ptr<HloModule> module,
const tensorflow::gtl::ArraySlice<Literal*>& literals) {
std::vector<se::DeviceMemoryBase> arguments;
for (const auto& literal : literals) {
arguments.push_back(TransferToDevice(*literal));
}
return ExecuteAndTransfer(std::move(module), arguments);
}
Backend& HloRunner::backend() {
if (!backend_) {
backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie();
VLOG(1) << "executing on platform " << backend().platform()->Name();
}
return *backend_;
}
} // namespace xla

View File

@ -0,0 +1,100 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_
#include <memory>
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
namespace xla {
// A base class for running an HloModule. This executes the given HloModule on a
// certain backend directly without using the client interface. HloModule can be
// explicitly built, or loaded from a serialization file (e.g., hlo proto file).
class HloRunner {
public:
HloRunner();
HloRunner(::perftools::gputools::Platform* platform);
~HloRunner();
// Reads the binary proto file in xla.HloProto format, creates and returns the
// HloModule.
static StatusOr<std::unique_ptr<HloModule>> ReadModuleFromHloProtoFile(
const char* filename, const DebugOptions& debug_options);
// Executes the given module with given literals as input and returns the
// result as a Literal. The LiteralPtr type accepts Literal* or
// std::unique_ptr<Literal>.
template <typename LiteralPtr>
std::unique_ptr<Literal> Execute(
std::unique_ptr<HloModule> module,
const tensorflow::gtl::ArraySlice<LiteralPtr>& literals);
// Executes the given module and returns a global data handle.
StatusOr<perftools::gputools::DeviceMemoryBase> Execute(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments,
Shape* result_shape);
// Transfers the given literal to the device and returns the data handle.
perftools::gputools::DeviceMemoryBase TransferToDevice(
const Literal& literal);
// Transfers the array referred to by the given handle from the device and
// returns as a Literal.
std::unique_ptr<Literal> TransferFromDevice(
const Shape& shape, perftools::gputools::DeviceMemoryBase device_base);
// Executes the given module and return the result as a Literal.
std::unique_ptr<Literal> ExecuteAndTransfer(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments);
// If backend is not created in the constructor, creates and returns the
// default backend. If creation fails, crashes the program.
//
// This creates the backend lazily so it's possible to instantiate an
// HloRunner in a program without any backends linked in.
Backend& backend();
private:
struct EigenThreadPoolWrapper;
std::vector<perftools::gputools::DeviceMemoryBase> allocations_;
std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_;
std::unique_ptr<Backend> backend_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_

View File

@ -58,14 +58,32 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution(
return {};
}
// We only support folding the RHS.
const int64 kRhsOperandIndex = 1;
auto& operand = *convolution.operand(kRhsOperandIndex);
if (operand.opcode() == HloOpcode::kTranspose && operand.user_count() == 1) {
return transposable_conv_operands(convolution, {kRhsOperandIndex});
const ConvolutionDimensionNumbers& dnums =
convolution.convolution_dimension_numbers();
TransposeFolding::OperandIndices operand_set;
for (int64 i = 0; i < convolution.operand_count(); ++i) {
auto& operand = *convolution.operand(i);
if (operand.opcode() == HloOpcode::kTranspose &&
operand.user_count() == 1) {
const auto& transpose_dimensions = operand.dimensions();
// We can transpose the LHS so long as it doesn't move around spatial
// dimensions because ConvolutionDimensionNumbers doesn't have different
// fields for input and output spatial dimensions.
if (i == 0 &&
std::any_of(dnums.spatial_dimensions().begin(),
dnums.spatial_dimensions().end(),
[&](const int64 spatial_dimension) {
return transpose_dimensions[spatial_dimension] !=
spatial_dimension;
})) {
continue;
}
operand_set.push_back(i);
}
}
return {};
return transposable_conv_operands(convolution, operand_set);
}
using InstructionOperandsPair =
@ -98,40 +116,61 @@ bool FoldTransposeIntoDot(InstructionOperandsPair pair) {
// Returns whether the module is changed.
bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
auto& convolution = *pair.first;
// We only support fusing the RHS transpose into convolution.
//
// ConvolutionDimensionNumbers doesn't make enough of a distinction between
// the output and the activations.
//
// TODO(b/37125184): Support transposing the LHS too.
if (pair.second.size() != 1 || pair.second.front() != 1) {
return false;
}
auto& operand_indices = pair.second;
const ConvolutionDimensionNumbers& dnums =
convolution.convolution_dimension_numbers();
HloInstruction& transpose = *convolution.mutable_operand(1);
CHECK_EQ(transpose.opcode(), HloOpcode::kTranspose);
const auto& transpose_dimensions = transpose.dimensions();
HloInstruction& transpose_operand = *transpose.mutable_operand(0);
// Everything remains the same except for the kernel dimension numbers. We
// need to apply the transpose permutation to the original shape to figure out
// what the new logical dimensions are.
ConvolutionDimensionNumbers new_dnums = dnums;
new_dnums.set_kernel_input_feature_dimension(
transpose_dimensions[dnums.kernel_input_feature_dimension()]);
new_dnums.set_kernel_output_feature_dimension(
transpose_dimensions[dnums.kernel_output_feature_dimension()]);
for (auto& kernel_spatial_dimension :
*new_dnums.mutable_kernel_spatial_dimensions()) {
kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension];
HloInstruction* new_lhs;
const int64 kLhsIdx = 0;
if (std::find(operand_indices.begin(), operand_indices.end(), kLhsIdx) !=
operand_indices.end()) {
HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx);
const auto& transpose_dimensions = transpose.dimensions();
HloInstruction& transpose_operand = *transpose.mutable_operand(0);
// Everything remains the same except for the input/output dimension
// numbers. We need to apply the transpose permutation to the original shape
// to figure out what the new logical dimensions are.
new_dnums.set_input_batch_dimension(
transpose_dimensions[dnums.input_batch_dimension()]);
new_dnums.set_input_feature_dimension(
transpose_dimensions[dnums.input_feature_dimension()]);
for (const auto& spatial_dimension : dnums.spatial_dimensions()) {
CHECK_EQ(spatial_dimension, transpose_dimensions[spatial_dimension]);
}
new_lhs = &transpose_operand;
} else {
new_lhs = convolution.mutable_operand(kLhsIdx);
}
HloInstruction* new_rhs;
const int64 kRhsIdx = 1;
if (std::find(operand_indices.begin(), operand_indices.end(), kRhsIdx) !=
operand_indices.end()) {
HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx);
const auto& transpose_dimensions = transpose.dimensions();
HloInstruction& transpose_operand = *transpose.mutable_operand(0);
// Everything remains the same except for the kernel dimension numbers. We
// need to apply the transpose permutation to the original shape to figure
// out what the new logical dimensions are.
new_dnums.set_kernel_input_feature_dimension(
transpose_dimensions[dnums.kernel_input_feature_dimension()]);
new_dnums.set_kernel_output_feature_dimension(
transpose_dimensions[dnums.kernel_output_feature_dimension()]);
for (auto& kernel_spatial_dimension :
*new_dnums.mutable_kernel_spatial_dimensions()) {
kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension];
}
new_rhs = &transpose_operand;
} else {
new_rhs = convolution.mutable_operand(kRhsIdx);
}
auto new_conv = HloInstruction::CreateConvolve(
convolution.shape(), convolution.mutable_operand(0), &transpose_operand,
convolution.window(), new_dnums);
convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums);
TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction(
&convolution, std::move(new_conv)));

View File

@ -313,8 +313,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) {
new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(1));
}
// Test that a transpose of the activations does not get folded into
// convolution.
// Test that a transpose of the activations gets folded into convolution.
TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
auto builder = HloComputation::Builder("entry_computation");
HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
@ -348,18 +347,25 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
module.AddEntryComputation(builder.Build(conv));
FoldTranspose(&module);
// Instructions after folding: transpose_x, y, and the convolution.
// Instructions after folding: x, y, and the convolution.
std::unordered_set<HloInstruction*> instruction_set(
entry_computation->instructions().begin(),
entry_computation->instructions().end());
CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
CHECK_EQ(1, instruction_set.erase(transpose_x))
<< "transpose_x is not in entry_computation.";
CHECK_EQ(1, instruction_set.erase(conv))
<< "transpose_x is not in entry_computation.";
CHECK_EQ(0, instruction_set.size())
<< "entry_computation should contain exactly 4 instructions.";
EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
EXPECT_EQ(1, instruction_set.size())
<< "entry_computation should contain exactly 3 instructions.";
HloInstruction* new_conv = *instruction_set.begin();
EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
EXPECT_EQ(dnums.input_feature_dimension(),
new_conv->convolution_dimension_numbers().input_batch_dimension());
EXPECT_EQ(
dnums.input_batch_dimension(),
new_conv->convolution_dimension_numbers().input_feature_dimension());
EXPECT_EQ(dnums.spatial_dimensions(0),
new_conv->convolution_dimension_numbers().spatial_dimensions(0));
EXPECT_EQ(dnums.spatial_dimensions(1),
new_conv->convolution_dimension_numbers().spatial_dimensions(1));
}
} // namespace

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <stack>
#include <unordered_map>
#include <utility>
#include <vector>
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
@ -1843,10 +1844,17 @@ UserComputation::GetEmbeddedComputations(
XLA_VLOG_LINES(3, session_computation_.DebugString());
std::vector<VersionedComputationHandle> computations;
std::vector<int64> sorted_handles;
for (const auto& handle_request : session_computation_.requests()) {
int64 handle_value = handle_request.first;
sorted_handles.push_back(handle_request.first);
}
std::sort(sorted_handles.begin(), sorted_handles.end());
for (int64 handle : sorted_handles) {
const auto& handle_request = session_computation_.requests().find(handle);
CHECK(handle_request != session_computation_.requests().end());
int64 handle_value = handle_request->first;
if (handle_value <= version) {
const OperationRequest& request = handle_request.second;
const OperationRequest& request = handle_request->second;
switch (request.request().op_case()) {
case OpRequest::kCallRequest: {
CHECK_EQ(1, request.embedded_computation_versions_size());

View File

@ -102,6 +102,32 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
return true;
}
// Constructs and returns the new shape with the given minor_to_major order in
// its Layout.
StatusOr<Shape> MakeShapeWithLayoutInternal(
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> minor_to_major) {
if (dimensions.size() != minor_to_major.size()) {
return InvalidArgument("Dimensions size is %ld, but layout size is %ld.",
dimensions.size(), minor_to_major.size());
}
if (element_type == OPAQUE || element_type == TUPLE) {
return InvalidArgument("Unsupported element type: %s",
PrimitiveType_Name(element_type).c_str());
}
Shape shape = ShapeUtil::MakeShape(element_type, dimensions);
auto min2maj = shape.mutable_layout()->mutable_minor_to_major();
min2maj->Clear();
for (int64 value : minor_to_major) {
min2maj->Add(value);
}
if (!shape.has_layout()) {
return InvalidArgument("Shape has no layout.");
}
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape));
return shape;
}
} // namespace
/* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) {
@ -152,16 +178,8 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) {
/* static */ Shape ShapeUtil::MakeShapeWithLayout(
PrimitiveType element_type, tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> minor_to_major) {
CHECK_EQ(dimensions.size(), minor_to_major.size());
Shape shape = MakeShape(element_type, dimensions);
auto min2maj = shape.mutable_layout()->mutable_minor_to_major();
min2maj->Clear();
for (int64 value : minor_to_major) {
min2maj->Add(value);
}
DCHECK(shape.has_layout());
TF_DCHECK_OK(ValidateShape(shape));
return shape;
return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major)
.ValueOrDie();
}
/* static */ Shape ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(
@ -499,11 +517,10 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
// Extract the layout minor-to-major and set it.
TF_ASSIGN_OR_RETURN(std::vector<int64> min2maj,
comma_list_to_int64s(layout_string));
TF_RET_CHECK(dimensions.size() == min2maj.size());
result =
ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, min2maj);
TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal(
primitive_type, dimensions, min2maj));
}
TF_DCHECK_OK(ShapeUtil::ValidateShape(result));
TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result));
return std::move(result);
}

View File

@ -102,28 +102,18 @@ cc_library(
deps = [
":literal_test_util",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/service",
"//tensorflow/compiler/xla/service:backend",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_execution_profile",
"//tensorflow/compiler/xla/service:hlo_graph_dumper",
"//tensorflow/compiler/xla/service:transfer_manager",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/compiler/xla/service:hlo_runner",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core:test",
"//third_party/eigen3",
],
)

View File

@ -19,24 +19,9 @@ limitations under the License.
#include <string>
#include <utility>
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/transfer_manager.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@ -45,22 +30,6 @@ namespace se = ::perftools::gputools;
namespace xla {
// Define this in .cc file to avoid having to include eigen or forward declare
// these types in the header.
struct HloTestBase::EigenThreadPoolWrapper {
std::unique_ptr<EigenThreadPoolWrapper> pool;
std::unique_ptr<Eigen::ThreadPoolDevice> device;
};
HloTestBase::HloTestBase() {}
HloTestBase::~HloTestBase() {
// Deallocate all the memory allocated during the tests.
for (auto& allocation : allocations_) {
backend().default_stream_executor()->Deallocate(&allocation);
}
}
/* static */
std::unique_ptr<HloModule> HloTestBase::CreateNewModule() {
HloModuleConfig config;
@ -80,98 +49,25 @@ StatusOr<perftools::gputools::DeviceMemoryBase> HloTestBase::Execute(
tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments,
Shape* result_shape) {
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
backend().compiler()->Compile(std::move(module),
backend().default_stream_executor()));
se::Stream stream(backend().default_stream_executor());
stream.Init();
ExecutableRunOptions run_options;
run_options.set_stream(&stream);
run_options.set_allocator(backend().memory_allocator());
run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool());
run_options.set_intra_op_thread_pool(
backend().eigen_intra_op_thread_pool_device());
HloExecutionProfile hlo_execution_profile;
ServiceExecutableRunOptions service_run_options(
run_options, backend().StreamBorrower(),
backend().inter_op_thread_pool());
TF_ASSIGN_OR_RETURN(
se::DeviceMemoryBase result,
executable->ExecuteOnStream(&service_run_options, arguments,
&hlo_execution_profile));
TF_RET_CHECK(stream.BlockHostUntilDone());
allocations_.push_back(result);
*result_shape = executable->result_shape();
if (ShapeUtil::IsTuple(*result_shape)) {
// We must record element buffers of tuples as well to avoid leaks.
DCHECK(!ShapeUtil::IsNestedTuple(*result_shape));
TF_ASSIGN_OR_RETURN(
std::vector<se::DeviceMemoryBase> element_buffers,
backend().transfer_manager()->ShallowCopyTupleFromDevice(
backend().default_stream_executor(), result, *result_shape));
// A tuple may contain the same buffer in more than one element. Keep track
// of the buffers already added to avoid duplicates in allocations_.
std::set<void*> added_opaques;
for (auto element_buffer : element_buffers) {
if (added_opaques.count(element_buffer.opaque()) == 0) {
CHECK(element_buffer.opaque() != nullptr);
added_opaques.insert(element_buffer.opaque());
allocations_.push_back(element_buffer);
}
}
}
return result;
return runner_.Execute(std::move(module), arguments, result_shape);
}
se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) {
// Allocate memory on the device using the stream executor.
int64 allocation_size =
backend().transfer_manager()->GetByteSizeRequirement(literal.shape());
se::DeviceMemoryBase allocation =
backend().default_stream_executor()->AllocateArray<uint8>(
allocation_size);
allocations_.push_back(allocation);
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToDevice(
backend().default_stream_executor(), literal, &allocation));
return allocation;
return runner_.TransferToDevice(literal);
}
std::unique_ptr<Literal> HloTestBase::TransferFromDevice(
const Shape& shape, se::DeviceMemoryBase device_base) {
auto literal = MakeUnique<Literal>();
TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromDevice(
backend().default_stream_executor(), device_base, shape, shape,
literal.get()));
return literal;
return runner_.TransferFromDevice(shape, device_base);
}
std::unique_ptr<Literal> HloTestBase::ExecuteAndTransfer(
std::unique_ptr<HloModule> module,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> arguments) {
Shape result_shape;
se::DeviceMemoryBase device_base =
Execute(std::move(module), arguments, &result_shape).ValueOrDie();
return TransferFromDevice(result_shape, device_base);
return runner_.ExecuteAndTransfer(std::move(module), arguments);
}
Backend& HloTestBase::backend() {
if (!backend_) {
backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie();
VLOG(1) << "executing on platform " << backend().platform()->Name();
}
return *backend_;
}
Backend& HloTestBase::backend() { return runner_.backend(); }
/* static */
string HloTestBase::TestName() {

View File

@ -21,12 +21,12 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/service/compiler.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_runner.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@ -39,10 +39,9 @@ namespace xla {
// building a graph of HLO instructions to run.
class HloTestBase : public ::testing::Test {
protected:
struct EigenThreadPoolWrapper;
HloTestBase();
HloTestBase() {}
~HloTestBase() override;
~HloTestBase() override {}
// Creates a new HLO module for a test. The module created will have
// TestName() for its name; it will also automatically populate its debug
@ -102,23 +101,12 @@ class HloTestBase : public ::testing::Test {
static string TestName();
// Creates (if necessary) and returns the default backend. If creation fails,
// crashes the program.
//
// This creates the backend lazily so it's possible to instantiate an
// HloTestBase in a program without any backends linked in.
// Returns the backend owned by the HloRunner.
Backend& backend();
// This vector contains handles of all the device memory allocations performed
// by the test. These are deallocated on destruction of the test object.
std::vector<perftools::gputools::DeviceMemoryBase> allocations_;
HloRunner runner_;
ErrorSpec error_spec_{0.0001};
std::unique_ptr<EigenThreadPoolWrapper> thread_pool_wrapper_;
private:
std::unique_ptr<Backend> backend_; // Lazily populated. Access via backend().
};
} // namespace xla

View File

@ -210,6 +210,18 @@ tf_cc_binary(
],
)
tf_cc_binary(
name = "hlo_proto_to_json",
srcs = ["hlo_proto_to_json.cc"],
deps = [
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
# -----------------------------------------------------------------------------
filegroup(

View File

@ -0,0 +1,91 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Usage:
// hlo_proto_to_json --input_file=some_binary_proto
// --output_file=path_to_dump_output
//
// Reads one serilized Hlo module, convert it into JSON format and dump into
// some output directory. some_binaray_proto is obtained by serializing Hlo
// module to disk using --xla_dump_hlo_proto_to debug optoin.
#include <stdio.h>
#include <string>
#include <vector>
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/command_line_flags.h"
using tensorflow::Env;
using xla::string;
namespace xla {
namespace tools {
StatusOr<string> ToJson(const tensorflow::protobuf::Message& message) {
string json_output;
tensorflow::protobuf::util::JsonPrintOptions json_options;
json_options.add_whitespace = true;
json_options.always_print_primitive_fields = true;
auto status = tensorflow::protobuf::util::MessageToJsonString(
message, &json_output, json_options);
if (!status.ok()) {
return InternalError("MessageToJsonString failed: %s",
status.error_message().data());
}
return json_output;
}
void RealMain(const string& input, const string& output) {
HloProto hlo_proto;
TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), input,
&hlo_proto))
<< "Can't open, read, or parse input file " << input;
auto statusor = ToJson(hlo_proto);
QCHECK(statusor.ok()) << "Error converting " << input << " to JSON."
<< statusor.status();
TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(), output,
statusor.ValueOrDie()));
}
} // namespace tools
} // namespace xla
int main(int argc, char** argv) {
string input_file, output_file;
const std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("input_file", &input_file, "file to convert."),
tensorflow::Flag("output_file", &output_file, "converted file"),
};
const string usage = tensorflow::Flags::Usage(argv[0], flag_list);
bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list);
tensorflow::port::InitMain(usage.c_str(), &argc, &argv);
QCHECK(parse_ok && argc == 1) << "\n" << usage;
QCHECK(!input_file.empty()) << "--input_file is required";
QCHECK(!output_file.empty()) << "--output_file is required";
xla::tools::RealMain(input_file, output_file);
return 0;
}

View File

@ -0,0 +1,84 @@
# Build file for the Hlo parser.
licenses(["notice"]) # Apache 2.0
package(
default_visibility = [":friends"],
)
package_group(
name = "friends",
includes = [
"//tensorflow/compiler/xla:friends",
],
)
# Filegroup used to collect source files for dependency checking.
filegroup(
name = "c_srcs",
data = glob([
"**/*.cc",
"**/*.h",
]),
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
cc_library(
name = "hlo_lexer",
srcs = ["hlo_lexer.cc"],
hdrs = [
"hlo_lexer.h",
"hlo_token.h",
],
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
],
)
cc_library(
name = "hlo_parser",
srcs = ["hlo_parser.cc"],
hdrs = ["hlo_parser.h"],
deps = [
":hlo_lexer",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_cc_test(
name = "hlo_parser_test",
size = "small",
srcs = ["hlo_parser_test.cc"],
deps = [
":hlo_parser",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
# -----------------------------------------------------------------------------
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)

View File

@ -0,0 +1,69 @@
# HloModule string syntax
TODO: Support subcomputations (for fusion, reduce, while, ...).
TODO: Support ops that require extra attributes, e.g. dimensions, strides.
```yacc
hlo_module
: 'HloModule' name computation
;
computation
: 'ENTRY' name param_list '->' shape instruction_list
;
instruction_list
: '{' instruction_list1 '}'
;
instruction_list1
: instruction
| instruction_list1 instruction
;
instruction
: name '=' shape opcode operands
;
operands
: '(' operands1 ')'
;
operands1
: /*empty*/
| operand
| operands1 ',' operand
;
operand
: shape name
;
param_list
: '(' param_list1 ')'
;
param_list1
: /*empty*/
| param
| param_list1 ',' param
;
param
: name shape
;
shape
: shape_val_
| '(' tuple_elements ')'
;
tuple_elements
: /*empty*/
| shape (',' shape)*
;
name
: identifier ':'
| '%' identifier
;
identifier
: [a-zA-Z_][a-zA-Z0-9_.-]*
;
```

View File

@ -0,0 +1,270 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h"
#include <unordered_map>
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/regexp.h"
namespace xla {
namespace tools {
using tensorflow::StringPiece;
namespace {
constexpr int kEOF = -1;
constexpr int kError = -2;
// [a-zA-Z0-9_.-]
bool IsIdentifierChar(char c) {
return isalnum(static_cast<unsigned char>(c)) || c == '-' || c == '.' ||
c == '_';
}
} // namespace
int HloLexer::GetNextChar() {
int current_char = PeekCurrentChar();
if (current_char != kEOF && current_char != kError) {
current_ptr_++;
}
return current_char;
}
int HloLexer::PeekCurrentChar() const {
if (current_ptr_ == buf_.end()) {
return kEOF;
}
char current_char = *current_ptr_;
if (current_char == 0) {
// '\0' should not appear in the middle of the string.
return kError;
}
return static_cast<unsigned char>(current_char);
}
bool HloLexer::CanDereference(const char* ptr) const {
return ptr < buf_.end() && ptr >= buf_.begin();
}
StringPiece HloLexer::StringPieceFromPointers(const char* begin,
const char* end) const {
CHECK(begin <= end);
CHECK(begin == buf_.end() || CanDereference(begin));
CHECK(end == buf_.end() || CanDereference(end));
return StringPiece(begin, end - begin);
}
tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers(
const char* begin, const char* end) const {
CHECK(begin <= end);
CHECK(begin == buf_.end() || CanDereference(begin));
CHECK(end == buf_.end() || CanDereference(end));
return tensorflow::RegexpStringPiece(begin, end - begin);
}
TokKind HloLexer::LexToken() {
while (true) {
token_start_ = current_ptr_;
int current_char = GetNextChar();
switch (current_char) {
default:
// [a-zA-Z_]
if (isalpha(static_cast<unsigned char>(current_char)) ||
current_char == '_') {
return LexIdentifier();
}
return TokKind::kError;
case kEOF:
// Hit the end of the input buffer.
return TokKind::kEof;
case kError:
// Hit an invalid character in the input buffer.
return TokKind::kError;
case ' ':
case '\t':
case '\n':
case '\r':
// Ignore whitespace.
continue;
case '0':
case '1':
case '2':
case '3':
case '4':
case '5':
case '6':
case '7':
case '8':
case '9':
case '-':
if (current_char == '-' && PeekCurrentChar() == '>') {
current_ptr_++;
return TokKind::kArrow;
}
return LexDigitOrNegative();
case '=':
return TokKind::kEqual;
case ',':
return TokKind::kComma;
case '%':
return LexPercent();
case ':':
return TokKind::kColon;
case '[':
return TokKind::kLsquare;
case ']':
return TokKind::kRsquare;
case '{':
return TokKind::kLbrace;
case '}':
return TokKind::kRbrace;
case '(':
return TokKind::kLparen;
case ')':
return TokKind::kRparen;
}
}
}
// Lex a shape, name, keyword, or opcode.
// shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})?
// name ::= [a-zA-Z_][a-zA-Z0-9_.-]*:
// keyword ::= HloModule, ENTRY, ...
// opcode ::= add, greater-than, ...
TokKind HloLexer::LexIdentifier() {
{
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
// 'consumable' will be advanced iff its prefix matches the pattern.
static LazyRE2 shape_pattern = {
R"(^(\w*\d*)\[([\d,]*)\](?:\s*{([\d,]*)})?)"};
if (RE2::Consume(&consumable, *shape_pattern)) {
auto status_or_shape = ShapeUtil::ParseShapeString(
StringPieceFromPointers(token_start_, consumable.begin()));
if (status_or_shape.ok()) {
// This is a shape string.
shape_val_ = status_or_shape.ValueOrDie();
current_ptr_ = consumable.begin();
return TokKind::kShape;
}
}
}
while (IsIdentifierChar(PeekCurrentChar())) {
current_ptr_++;
}
// If followed by ':', it's a name.
if (PeekCurrentChar() == ':') {
str_val_.assign(token_start_, current_ptr_);
current_ptr_++; // skip ':'
return TokKind::kName;
}
StringPiece identifier = StringPieceFromPointers(token_start_, current_ptr_);
// See if this is a keyword.
#define KEYWORD(STR) \
do { \
if (identifier == #STR) { \
return TokKind::kw_##STR; \
} \
} while (false)
KEYWORD(true);
KEYWORD(false);
KEYWORD(HloModule);
KEYWORD(ENTRY);
#undef KEYWORD
// See if this is an opcode.
auto opcode = StringToHloOpcode(identifier.ToString());
if (opcode.ok()) {
opcode_val_ = opcode.ValueOrDie();
return TokKind::kOpcode;
}
current_ptr_ = token_start_ + 1;
return TokKind::kError;
}
// Lex names after a % character.
// name ::= [a-zA-Z_][a-zA-Z0-9_.-]*
TokKind HloLexer::LexPercent() {
const char* name_start = current_ptr_;
if (isalpha(static_cast<unsigned char>(PeekCurrentChar())) ||
PeekCurrentChar() == '_') {
current_ptr_++;
while (IsIdentifierChar(PeekCurrentChar())) {
current_ptr_++;
}
str_val_.assign(name_start, current_ptr_);
return TokKind::kName;
}
return TokKind::kError;
}
// Lex integer and floating-point values.
// int [-]?[0-9]+
// fp with exp [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
// fp without exp [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
TokKind HloLexer::LexDigitOrNegative() {
auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end());
static LazyRE2 float_pattern = {
R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|(\d+[.]\d*|\d*[.]\d+))"};
if (RE2::Consume(&consumable, *float_pattern)) {
current_ptr_ = consumable.begin();
tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(),
&decimal_val_);
return TokKind::kDecimal;
}
static LazyRE2 int_pattern = {R"([-]?\d+)"};
if (RE2::Consume(&consumable, *int_pattern)) {
current_ptr_ = consumable.begin();
tensorflow::strings::safe_strto64(
StringPieceFromPointers(token_start_, current_ptr_), &int64_val_);
return TokKind::kInt;
}
return TokKind::kError;
}
StringPiece HloLexer::GetCurrentLine() const {
const char* start = token_start_;
const char* end = current_ptr_;
if (!CanDereference(start) || !CanDereference(end)) {
return "LINE OUT OF RANGE";
}
while (start > buf_.begin() && *start != '\n') {
start--;
}
while (end < buf_.end() && *end != '\n') {
end++;
}
return StringPieceFromPointers(start, end);
}
} // namespace tools
} // namespace xla

View File

@ -0,0 +1,108 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_
#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_
#include <string>
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_token.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
namespace tools {
// Lexer for the HloModule::ToString() format text.
class HloLexer {
public:
explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) {
current_ptr_ = buf_.begin();
}
TokKind Lex() { return current_kind_ = LexToken(); }
TokKind GetKind() const { return current_kind_; }
string GetStrVal() const {
CHECK(GetKind() == TokKind::kName);
return str_val_;
}
Shape GetShapeVal() const {
CHECK(GetKind() == TokKind::kShape);
return shape_val_;
}
HloOpcode GetOpcodeVal() const {
CHECK(GetKind() == TokKind::kOpcode);
return opcode_val_;
}
int64 GetInt64Val() const {
CHECK(GetKind() == TokKind::kInt);
return int64_val_;
}
double GetDecimalVal() const {
CHECK(GetKind() == TokKind::kDecimal);
return decimal_val_;
}
// Returns the line of text that is currently being lexed.
tensorflow::StringPiece GetCurrentLine() const;
private:
// Returns the current character. If it's neither the end of input buffer nor
// an invalid character, moves the pointer forward.
int GetNextChar();
// Returns the current character.
int PeekCurrentChar() const;
// Creates StringPiece with the given begin and end. Exits if the begin > end,
// or it's out of the range of the current buffer.
tensorflow::StringPiece StringPieceFromPointers(const char* begin,
const char* end) const;
tensorflow::RegexpStringPiece RegexpStringPieceFromPointers(
const char* begin, const char* end) const;
// Returns true if the given ptr is dereferenceable within the range of the
// current buffer.
bool CanDereference(const char* ptr) const;
TokKind LexToken();
TokKind LexIdentifier();
TokKind LexPercent();
TokKind LexShape();
TokKind LexConstant();
TokKind LexDigitOrNegative();
const tensorflow::StringPiece buf_;
const char* current_ptr_;
// Information about the current token.
const char* token_start_;
TokKind current_kind_;
string str_val_;
Shape shape_val_;
HloOpcode opcode_val_;
int64 int64_val_;
double decimal_val_;
};
} // namespace tools
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_

View File

@ -0,0 +1,502 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace tools {
namespace {
using tensorflow::StringPiece;
using tensorflow::strings::StrCat;
// Parser for the HloModule::ToString() format text.
class HloParser {
public:
explicit HloParser(StringPiece str) : lexer_(str) {}
// Runs the parser. Returns false if an error occurred.
bool Run();
// Returns the parsed HloModule.
std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); }
// Returns the error information.
string GetError() const { return tensorflow::str_util::Join(error_, "\n"); }
private:
// ParseXXX returns false if an error occurred.
bool ParseHloModule();
bool ParseComputation();
bool ParseInstructionList(HloComputation::Builder* builder);
bool ParseInstruction(HloComputation::Builder* builder);
bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
bool ParseOperands(std::vector<HloInstruction*>* operands,
const int expected_size);
bool ParseParamList();
bool ParseName(string* result);
bool ParseShape(Shape* result);
bool ParseOpcode(HloOpcode* result);
bool ParseInt64(int64* result);
bool ParseDecimal(double* result);
bool ParseBool(bool* result);
bool ParseToken(TokKind kind, const string& msg);
// Logs the current parsing line and the given message. Always returns false.
bool TokenError(StringPiece msg);
// If the current token is 'kind', eats it (i.e. lexes the next token) and
// returns true.
bool EatIfPresent(TokKind kind);
// Adds the instruction to the pool. Returns false and emits an error if the
// instruction already exists.
bool AddInstruction(const string& name, HloInstruction* instruction);
// The map from the instruction name to the instruction. This does not own the
// instructions.
std::unordered_map<string, HloInstruction*> instruction_pool_;
HloLexer lexer_;
std::unique_ptr<HloModule> module_;
std::vector<string> error_;
};
bool HloParser::TokenError(StringPiece msg) {
error_.push_back(
StrCat("was parsing \"", lexer_.GetCurrentLine(), "\"; ", msg));
return false;
}
bool HloParser::Run() {
lexer_.Lex();
return ParseHloModule();
}
// ::= 'HloModule' name computation
bool HloParser::ParseHloModule() {
if (lexer_.GetKind() != TokKind::kw_HloModule) {
return TokenError("expects HloModule");
}
// Eat 'HloModule'
lexer_.Lex();
string name;
if (!ParseName(&name)) {
return false;
}
module_ = MakeUnique<HloModule>(name);
return ParseComputation();
}
// computation ::= 'ENTRY' name param_list '->' shape instruction_list
bool HloParser::ParseComputation() {
string name;
if (!ParseToken(TokKind::kw_ENTRY, "expects 'ENTRY'") || !ParseName(&name)) {
return false;
}
auto builder = MakeUnique<HloComputation::Builder>(name);
Shape shape;
if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'") ||
!ParseShape(&shape) || !ParseInstructionList(builder.get())) {
return false;
}
module_->AddEntryComputation(builder->Build());
return true;
}
// instruction_list ::= '{' instruction_list1 '}'
// instruction_list1 ::= (instruction)+
bool HloParser::ParseInstructionList(HloComputation::Builder* builder) {
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of instruction list.")) {
return false;
}
do {
if (!ParseInstruction(builder)) {
return false;
}
} while (lexer_.GetKind() != TokKind::kRbrace);
return ParseToken(TokKind::kRbrace,
"expects '}' at the end of instruction list.");
}
// instruction ::= name '=' shape opcode operands
bool HloParser::ParseInstruction(HloComputation::Builder* builder) {
string name;
Shape shape;
HloOpcode opcode;
std::vector<HloInstruction*> operands;
if (!ParseName(&name) ||
!ParseToken(TokKind::kEqual, "expects '=' in instruction") ||
!ParseShape(&shape) || !ParseOpcode(&opcode)) {
return false;
}
switch (opcode) {
case HloOpcode::kParameter: {
int64 parameter_number;
return ParseToken(TokKind::kLparen,
"expects '(' before parameter number") &&
ParseInt64(&parameter_number) &&
ParseToken(TokKind::kRparen,
"expects ')' after parameter number") &&
AddInstruction(
name, builder->AddInstruction(HloInstruction::CreateParameter(
parameter_number, shape, name)));
}
case HloOpcode::kConstant: {
std::unique_ptr<Literal> literal;
return ParseToken(TokKind::kLparen,
"expects '(' before parameter number") &&
ParseLiteral(&literal, shape) &&
ParseToken(TokKind::kRparen,
"expects ')' after parameter number") &&
AddInstruction(
name, builder->AddInstruction(
HloInstruction::CreateConstant(std::move(literal))));
}
// Unary ops.
case HloOpcode::kAbs:
case HloOpcode::kRoundNearestAfz:
case HloOpcode::kBitcast:
case HloOpcode::kCeil:
case HloOpcode::kCopy:
case HloOpcode::kCos:
case HloOpcode::kExp:
case HloOpcode::kIsFinite:
case HloOpcode::kFloor:
case HloOpcode::kLog:
case HloOpcode::kNot:
case HloOpcode::kNegate:
case HloOpcode::kSign:
case HloOpcode::kSin:
case HloOpcode::kSort:
case HloOpcode::kTanh: {
return ParseOperands(&operands, /*expected_size=*/1) &&
AddInstruction(name,
builder->AddInstruction(HloInstruction::CreateUnary(
shape, opcode, operands[0])));
}
// Binary ops.
case HloOpcode::kAdd:
case HloOpcode::kDivide:
case HloOpcode::kMultiply:
case HloOpcode::kSubtract:
case HloOpcode::kEq:
case HloOpcode::kGe:
case HloOpcode::kGt:
case HloOpcode::kLe:
case HloOpcode::kLt:
case HloOpcode::kNe:
case HloOpcode::kDot:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kPower:
case HloOpcode::kRemainder:
case HloOpcode::kAnd:
case HloOpcode::kOr:
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical: {
return ParseOperands(&operands, /*expected_size=*/2) &&
AddInstruction(
name, builder->AddInstruction(HloInstruction::CreateBinary(
shape, opcode, operands[0], operands[1])));
}
// Ternary ops.
case HloOpcode::kClamp:
case HloOpcode::kSelect: {
return ParseOperands(&operands, /*expected_size=*/3) &&
AddInstruction(
name,
builder->AddInstruction(HloInstruction::CreateTernary(
shape, opcode, operands[0], operands[1], operands[2])));
}
// Other supported ops.
case HloOpcode::kConvert: {
return ParseOperands(&operands, /*expected_size=*/1) &&
AddInstruction(
name, builder->AddInstruction(
HloInstruction::CreateConvert(shape, operands[0])));
}
case HloOpcode::kCrossReplicaSum: {
return ParseOperands(&operands, /*expected_size=*/1) &&
AddInstruction(name, builder->AddInstruction(
HloInstruction::CreateCrossReplicaSum(
shape, operands[0])));
}
case HloOpcode::kReshape: {
return ParseOperands(&operands, /*expected_size=*/1) &&
AddInstruction(
name, builder->AddInstruction(
HloInstruction::CreateReshape(shape, operands[0])));
}
case HloOpcode::kBroadcast:
case HloOpcode::kCall:
case HloOpcode::kCustomCall:
case HloOpcode::kConcatenate:
case HloOpcode::kReducePrecision:
case HloOpcode::kConvolution:
case HloOpcode::kGetTupleElement:
case HloOpcode::kMap:
case HloOpcode::kPad:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kReverse:
case HloOpcode::kRng:
case HloOpcode::kSlice:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
case HloOpcode::kTranspose:
case HloOpcode::kTuple:
case HloOpcode::kWhile:
case HloOpcode::kFusion:
case HloOpcode::kBatchNormTraining:
case HloOpcode::kBatchNormInference:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kBatchNormGrad:
case HloOpcode::kRecv:
case HloOpcode::kSend:
case HloOpcode::kUpdate:
case HloOpcode::kIndex:
case HloOpcode::kTrace:
return TokenError(StrCat("parsing not yet implemented for op: ",
HloOpcodeString(opcode)));
}
}
bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
const Shape& shape) {
switch (shape.element_type()) {
case PRED:
bool b;
if (!ParseBool(&b)) {
return false;
}
*literal = Literal::CreateR0<bool>(b);
return true;
case S32:
int64 i;
if (!ParseInt64(&i)) {
return false;
}
*literal = Literal::CreateR0<int32>(i);
return true;
case F32:
double d;
if (!ParseDecimal(&d)) {
return false;
}
*literal = Literal::CreateR0<float>(d);
return true;
default:
return TokenError(StrCat("unsupported constant in shape: ",
ShapeUtil::HumanString(shape)));
}
}
// operands ::= '(' operands1 ')'
// operands1
// ::= /*empty*/
// ::= operand (, operand)*
// operand ::= shape name
bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands,
const int expected_size) {
if (!ParseToken(TokKind::kLparen,
"expects '(' at the beginning of operands")) {
return false;
}
if (lexer_.GetKind() == TokKind::kRparen) {
// empty
} else {
do {
Shape shape;
string name;
if (!ParseShape(&shape) || !ParseName(&name)) {
return false;
}
HloInstruction* instruction =
tensorflow::gtl::FindPtrOrNull(instruction_pool_, name);
if (!instruction) {
return TokenError(StrCat("instruction does not exist: ", name));
}
operands->push_back(instruction);
} while (EatIfPresent(TokKind::kComma));
}
if (expected_size != operands->size()) {
return TokenError(StrCat("expects ", expected_size, " operands, but has ",
operands->size(), " operands"));
}
return ParseToken(TokKind::kRparen, "expects ')' at the end of operands");
}
// param_list ::= '(' param_list1 ')'
// param_list1
// ::= /*empty*/
// ::= param (',' param)*
// param ::= name shape
bool HloParser::ParseParamList() {
if (!ParseToken(TokKind::kLparen,
"expects '(' at the beginning of param list")) {
return false;
}
if (lexer_.GetKind() == TokKind::kRparen) {
// empty
} else {
do {
Shape shape;
if (!ParseToken(TokKind::kName, "expects name in parameter") ||
!ParseShape(&shape)) {
return false;
}
} while (EatIfPresent(TokKind::kComma));
}
return ParseToken(TokKind::kRparen, "expects ')' at the end of param list");
}
// shape ::= shape_val_
// shape ::= '(' tuple_elements ')'
// tuple_elements
// ::= /*empty*/
// ::= shape (',' shape)*
bool HloParser::ParseShape(Shape* result) {
if (EatIfPresent(TokKind::kLparen)) { // Tuple
std::vector<Shape> shapes;
if (lexer_.GetKind() == TokKind::kRparen) {
/*empty*/
} else {
// shape (',' shape)*
do {
shapes.emplace_back();
if (!ParseShape(&shapes.back())) {
return false;
}
} while (EatIfPresent(TokKind::kComma));
}
*result = ShapeUtil::MakeTupleShape(shapes);
return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple.");
}
if (lexer_.GetKind() != TokKind::kShape) {
return TokenError("expects shape");
}
*result = lexer_.GetShapeVal();
lexer_.Lex();
return true;
}
bool HloParser::ParseName(string* result) {
VLOG(1) << "ParseName";
if (lexer_.GetKind() != TokKind::kName) {
return TokenError("expects name");
}
*result = lexer_.GetStrVal();
lexer_.Lex();
return true;
}
bool HloParser::ParseOpcode(HloOpcode* result) {
VLOG(1) << "ParseOpcode";
if (lexer_.GetKind() != TokKind::kOpcode) {
return TokenError("expects opcode");
}
*result = lexer_.GetOpcodeVal();
lexer_.Lex();
return true;
}
bool HloParser::ParseInt64(int64* result) {
VLOG(1) << "ParseInt64";
if (lexer_.GetKind() != TokKind::kInt) {
return TokenError("expects integer");
}
*result = lexer_.GetInt64Val();
lexer_.Lex();
return true;
}
bool HloParser::ParseDecimal(double* result) {
switch (lexer_.GetKind()) {
case TokKind::kDecimal:
*result = lexer_.GetDecimalVal();
break;
case TokKind::kInt:
*result = static_cast<double>(lexer_.GetInt64Val());
break;
default:
return TokenError("expects decimal or integer");
}
lexer_.Lex();
return true;
}
bool HloParser::ParseBool(bool* result) {
if (lexer_.GetKind() != TokKind::kw_true &&
lexer_.GetKind() != TokKind::kw_false) {
return TokenError("expects true or false");
}
*result = lexer_.GetKind() == TokKind::kw_true;
lexer_.Lex();
return true;
}
bool HloParser::ParseToken(TokKind kind, const string& msg) {
if (lexer_.GetKind() != kind) {
return TokenError(msg);
}
lexer_.Lex();
return true;
}
bool HloParser::EatIfPresent(TokKind kind) {
if (lexer_.GetKind() != kind) {
return false;
}
lexer_.Lex();
return true;
}
bool HloParser::AddInstruction(const string& name,
HloInstruction* instruction) {
auto result = instruction_pool_.insert({name, instruction});
if (!result.second) {
return TokenError(StrCat("instruction already exists: ", name));
}
return true;
}
} // namespace
StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str) {
HloParser parser(str);
if (!parser.Run()) {
return InvalidArgument("Syntax error: %s", parser.GetError().c_str());
}
return parser.ConsumeHloModule();
}
} // namespace tools
} // namespace xla

View File

@ -0,0 +1,37 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_
#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace tools {
// The api of the hlo parser. Given a string in the HloModule::ToString()
// format, returns the parsed HloModule.
StatusOr<std::unique_ptr<HloModule>> Parse(tensorflow::StringPiece str);
} // namespace tools
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_

View File

@ -0,0 +1,240 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
#include <string>
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
namespace tools {
namespace {
struct TestData {
string test_name;
string module_string;
};
string TestDataToString(const ::testing::TestParamInfo<TestData>& data) {
return data.param.test_name;
}
std::vector<TestData> CreateTestCases() {
// clang-format off
return std::vector<TestData>({
// ax + y
{
"AxpyParam",
R"(HloModule axpy_module:
ENTRY %axpy.v5 (alpha: f32[2,4], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
%alpha = f32[2,4]{1,0} parameter(0)
%x = f32[2,4]{1,0} parameter(1)
%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %alpha, f32[2,4]{1,0} %x)
%y = f32[2,4]{1,0} parameter(2)
%add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
}
)"
},
// pred constant
{
"ConstantPred",
R"(HloModule constant_pred_module:
ENTRY %constant_pred () -> pred[] {
%constant = pred[] constant(true)
}
)"
},
// s32 constant
{
"ConstantS32",
R"(HloModule constant_s32_module:
ENTRY %constant_s32 () -> s32[] {
%constant = s32[] constant(-42)
}
)"
},
// f32 constant, but the value is not a decimal
{
"ConstantF32", R"(HloModule ConstantF32_module:
ENTRY %ConstantF32.v4 () -> f32[] {
%constant = f32[] constant(42)
}
)"
},
// constant + constant
{
"AddConstants",
R"(HloModule add_constants_module:
ENTRY %add_constants () -> f32[] {
%constant = f32[] constant(3.14)
%add = f32[] add(f32[] %constant, f32[] %constant)
}
)"
},
// v1 > v2 ? v1 : v2
{
"SelectR1F32",
R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module:
ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
%v1 = f32[4]{0} parameter(0)
%v2 = f32[4]{0} parameter(1)
%greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2)
%select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2)
}
)"
}
});
// clang-format on
}
class HloParserTest : public ::testing::Test,
public ::testing::WithParamInterface<TestData> {
protected:
void ExpectSuccess() {
const string& original = GetParam().module_string;
auto result = Parse(original);
TF_EXPECT_OK(result.status());
EXPECT_EQ(original, result.ValueOrDie()->ToString());
}
};
TEST_P(HloParserTest, Run) { ExpectSuccess(); }
INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest,
::testing::ValuesIn(CreateTestCases()),
TestDataToString);
TEST_F(HloParserTest, Empty) {
const string original = "";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
TEST_F(HloParserTest, Garbage) {
const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
TEST_F(HloParserTest, WrongOpcode) {
const string original = R"(HloModule wrong_opcode:
ENTRY %blabla (x: f32[], y: f32[]) -> f32[] {
%x = f32[]{} parameter(0)
%y = f32[]{} parameter(1)
%le = pred[]{} le(f32[]{} %x, f32[]{} %y)
}
)";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
TEST_F(HloParserTest, WrongShape) {
const string original = R"(HloModule wrong_opcode:
ENTRY %blabla (x: g32[]) -> g32[] {
%x = g32[]{} parameter(0)
}
)";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
TEST_F(HloParserTest, WrongOperandsSize) {
const string original = R"(HloModule wrong_opcode:
ENTRY %blabla (x: f32[]) -> pred[] {
%x = f32[]{} parameter(0)
%eq = pred[]{} equal-to(f32[]{} %x)
}
)";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
TEST_F(HloParserTest, OperandNotFound) {
const string original = R"(HloModule operand_not_found:
ENTRY %blabla (x: f32[]) -> pred[] {
%x = f32[]{} parameter(0)
%eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y)
}
)";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
TEST_F(HloParserTest, MoreConstants) {
const string original = R"(HloModule SelectScalarS32True_module:
ENTRY %SelectScalarS32True.v4 () -> s32[] {
%constant.2 = pred[] constant(true)
%constant.1 = s32[] constant(-42)
%constant = s32[] constant(42)
%select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
}
)";
auto result = Parse(original);
TF_EXPECT_OK(result.status());
// Constant instructions have no name. The string will be parsed successfully
// but the constant names will not be exactly the same.
}
TEST_F(HloParserTest, ConstantWithExp) {
const string original = R"(HloModule ConstantWithExp_module:
ENTRY %ConstantWithExp.v4 () -> f32[] {
%constant.1 = f32[] constant(3e+2)
}
)";
auto result = Parse(original);
TF_EXPECT_OK(result.status());
// The string will be parsed successfully but the output strings are not
// exactly the same, because "3e2" is parsed into value 300 and will be
// printed as "300".
}
TEST_F(HloParserTest, Tuple) {
const string original = R"(HloModule EmptyTupleCreate_module:
ENTRY %EmptyTupleCreate.v1 () -> () {
%tuple = () tuple()
}
)";
auto result = Parse(original);
EXPECT_NE(tensorflow::Status::OK(), result.status());
}
} // namespace
} // namespace tools
} // namespace xla

View File

@ -0,0 +1,58 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_
#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_
namespace xla {
namespace tools {
// Defines different kinds of tokens in a hlo module string.
enum class TokKind {
// Markers
kEof,
kError,
// Tokens with no info.
kEqual, // =
kComma, // ,
kColon, // :
kLsquare,
kRsquare, // [ ]
kLbrace,
kRbrace, // { }
kLparen,
kRparen, // ( )
kArrow, // ->
// Keywords
kw_HloModule,
kw_ENTRY,
kw_true,
kw_false,
// Typed tokens.
kName, // %foo
kShape, // f32[2,3]{1,0}
kOpcode, // add
kInt, // 42
kDecimal, // 4.2
};
} // namespace tools
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_

View File

@ -82,8 +82,8 @@ message DebugOptions {
// Dump all HLO modules as text into the provided directory path.
string xla_generate_hlo_text_to = 7;
// Dump compilation artifacts as JSON into this directory.
string xla_dump_debug_json_to = 8;
// Dump compilation artifacts in binary proto into this directory.
string xla_dump_hlo_proto_to = 8;
// Instrument the computation to collect per-HLO cycle counts.
bool xla_hlo_profile = 9;

View File

@ -69,6 +69,28 @@ tf_cc_test(
],
)
cc_library(
name = "adaptive_shared_batch_scheduler",
hdrs = ["adaptive_shared_batch_scheduler.h"],
deps = [
":batch_scheduler",
"//tensorflow/contrib/batching/util:periodic_function_dynamic",
"//tensorflow/core:lib",
],
)
tf_cc_test(
name = "adaptive_shared_batch_scheduler_test",
srcs = ["adaptive_shared_batch_scheduler_test.cc"],
deps = [
":adaptive_shared_batch_scheduler",
"//tensorflow/contrib/batching/test_util:fake_clock_env",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
cc_library(
name = "basic_batch_scheduler",
hdrs = ["basic_batch_scheduler.h"],

View File

@ -0,0 +1,463 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
#include <functional>
#include <memory>
#include <queue>
#include <unordered_map>
#include <vector>
#include "tensorflow/contrib/batching/batch_scheduler.h"
#include "tensorflow/contrib/batching/util/periodic_function.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace serving {
namespace internal {
template <typename TaskType>
class ASBSBatch;
template <typename TaskType>
class ASBSQueue;
} // namespace internal
// Shared batch scheduler designed to minimize latency. The scheduler keeps
// track of a number of queues (one per model or model version) which are
// continuously enqueuing requests. The scheduler groups the requests into
// batches which it periodically sends off for processing (see
// shared_batch_scheduler.h for more details). The AdaptiveSharedBatchScheduler
// prioritizes batches by age (i.e. the batch's oldest request) irrespective of
// queue. The scheduler will process the oldest batch at an adjustable rate,
// regardless of batch size. The user can provide feedback to help set this rate
// to achieve some goal (i.e. minimize overall latency, limit cpu usage, etc).
//
// The rate (or rather, the corresponding period) is adjusted each time a batch
// is processed, using an exponentially weighted moving average to smooth
// potentially noisy feedback:
// ewma_feedback = ((N - 1) * ewma_feedback + feedback()) / N
// period *= (1 + K * emwa_feedback)
//
// Some potential use cases:
// Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing
// involves serial processing by a device, from a latency perspective it is
// desirable to keep the device evenly loaded, avoiding the need to wait for
// the device to process prior batches.
// feedback = num_pending_on_device() - desired_pending.
// CPU utilization - If the batch processing is cpu dominated, you can reap
// latency gains when underutilized by increasing the processing rate, but
// back the rate off when the load increases to avoid overload.
// feedback = cpu_rate() - desired_cpu_rate.
template <typename TaskType>
class AdaptiveSharedBatchScheduler
: public std::enable_shared_from_this<
AdaptiveSharedBatchScheduler<TaskType>> {
public:
struct Options {
// The name to use for the pool of batch threads.
string thread_pool_name = {"batch_threads"};
// Number of batch processing threads; equivalently the maximum number of
// concurrently running batches.
int64 num_batch_threads = port::NumSchedulableCPUs();
// The environment to use (typically only overridden by test code).
Env* env = Env::Default();
// Initial batch scheduling period in microseconds. Will be altered for
// non-zero rate_feedback.
double initial_scheduling_period_micros = 500;
// Minimum batch scheduling period in microseconds. Recommend setting this
// value greater than 0, otherwise it may take a while to recover from a
// sustained time of negative scheduling_period_feedback (which may occur
// under low load).
double min_scheduling_period_micros = 100;
// Maximum batch scheduling period in microseconds.
double max_scheduling_period_micros = 10000;
// Feedback function used to modify the scheduling period each time a batch
// is scheduled. Should return values roughly O(1), with positive values
// resulting in an increased period.
std::function<double()> scheduling_period_feedback = [] { return 0.; };
// To handle potentially noisy scheduling_period_feedback, the period is
// adjusted using an exponentially weighted moving average over the previous
// feedback_smoothing_batches batches. Must be greater than 0.
int64 feedback_smoothing_batches = 10;
};
// Ownership is shared between the caller of Create() and any queues created
// via AddQueue().
static Status Create(
const Options& options,
std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler);
struct QueueOptions {
// Maximum size of each batch.
int max_batch_size = 1000;
// Maximum number of enqueued (i.e. non-scheduled) batches.
int max_enqueued_batches = 10;
};
using BatchProcessor = std::function<void(std::unique_ptr<Batch<TaskType>>)>;
// Adds queue (and its callback) to be managed by this scheduler.
Status AddQueue(const QueueOptions& options,
BatchProcessor process_batch_callback,
std::unique_ptr<BatchScheduler<TaskType>>* queue);
private:
// access to AddBatch, RemoveQueue, GetEnv.
friend class internal::ASBSQueue<TaskType>;
explicit AdaptiveSharedBatchScheduler(const Options& options);
// Batch scheduling function which runs every scheduling_period_ microseconds.
void ProcessOneBatch();
// Notifies scheduler of non-empty batch which is eligible for processing.
void AddBatch(internal::ASBSBatch<TaskType>*);
// Removes queue from scheduler.
void RemoveQueue(const internal::ASBSQueue<TaskType>* queue);
Env* GetEnv() const { return options_.env; }
const Options options_;
struct BatchCompare {
bool operator()(const internal::ASBSBatch<TaskType>* a,
const internal::ASBSBatch<TaskType>* b);
};
// Collection of batches added by AddBatch, ordered by age. Owned by scheduler
// until they are released for processing.
std::priority_queue<const internal::ASBSBatch<TaskType>*,
std::vector<internal::ASBSBatch<TaskType>*>, BatchCompare>
batches_ GUARDED_BY(mu_);
// Unowned queues and callbacks added by AddQueue.
std::unordered_map<const internal::ASBSQueue<TaskType>*, BatchProcessor>
queues_and_callbacks_ GUARDED_BY(mu_);
mutex mu_;
// Responsible for running ProcessOneBatch. PeriodicFunction was used in order
// to check for deletion so that the thread can be shut down.
std::unique_ptr<PeriodicFunction> scheduling_thread_;
// Responsible for running the batch processing callbacks.
std::unique_ptr<thread::ThreadPool> batch_thread_pool_;
// Time interval in microseconds between successive ProcessOneBatch calls.
double scheduling_period_;
// Exponentially weighted moving average of
// options_.scheduling_period_feedback() evaluated in each ProcessOneBatch
// call.
double ewma_feedback_ = 0;
TF_DISALLOW_COPY_AND_ASSIGN(AdaptiveSharedBatchScheduler);
};
//////////////////////////////////////////////////////////
// Implementation details follow. API users need not read.
namespace internal {
// Consolidates tasks into batches, passing them off to the
// AdaptiveSharedBatchScheduler for processing.
template <typename TaskType>
class ASBSQueue : public BatchScheduler<TaskType> {
public:
using QueueOptions =
typename AdaptiveSharedBatchScheduler<TaskType>::QueueOptions;
ASBSQueue(std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
const QueueOptions& options);
~ASBSQueue() override;
// Adds task to current batch. Fails if the task size is larger than the batch
// size or if the current batch is full and this queue's number of outstanding
// batches is at its maximum.
Status Schedule(std::unique_ptr<TaskType>* task) override;
// Number of tasks waiting to be scheduled.
size_t NumEnqueuedTasks() const override;
// Number of size 1 tasks which could currently be scheduled without failing.
size_t SchedulingCapacity() const override;
// Notifies queue that a batch is about to be scheduled; the queue should not
// place any more tasks in this batch.
void ReleaseBatch(const ASBSBatch<TaskType>* batch);
private:
std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler_;
const QueueOptions options_;
// Owned by scheduler_.
ASBSBatch<TaskType>* current_batch_ GUARDED_BY(mu_) = nullptr;
int64 num_enqueued_batches_ GUARDED_BY(mu_) = 0;
int64 num_enqueued_tasks_ GUARDED_BY(mu_) = 0;
mutable mutex mu_;
TF_DISALLOW_COPY_AND_ASSIGN(ASBSQueue);
};
// Batch which remembers when and by whom it was created.
template <typename TaskType>
class ASBSBatch : public Batch<TaskType> {
public:
ASBSBatch(ASBSQueue<TaskType>* queue, int64 creation_time_micros)
: queue_(queue), creation_time_micros_(creation_time_micros) {}
~ASBSBatch() override {}
ASBSQueue<TaskType>* queue() const { return queue_; }
int64 creation_time_micros() const { return creation_time_micros_; }
private:
ASBSQueue<TaskType>* queue_;
const int64 creation_time_micros_;
TF_DISALLOW_COPY_AND_ASSIGN(ASBSBatch);
};
} // namespace internal
// ---------------- AdaptiveSharedBatchScheduler ----------------
template <typename TaskType>
Status AdaptiveSharedBatchScheduler<TaskType>::Create(
const Options& options,
std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>>* scheduler) {
if (options.num_batch_threads < 1) {
return errors::InvalidArgument("num_batch_threads must be positive; was ",
options.num_batch_threads);
}
if (options.min_scheduling_period_micros < 0) {
return errors::InvalidArgument(
"min_scheduling_period_micros must be >= 0; was ",
options.min_scheduling_period_micros);
}
if (options.min_scheduling_period_micros >
options.initial_scheduling_period_micros) {
return errors::InvalidArgument(
"initial_scheduling_period_micros (",
options.initial_scheduling_period_micros,
") must be >= min_scheduling_period_micros (",
options.min_scheduling_period_micros, ")");
}
if (options.initial_scheduling_period_micros >
options.max_scheduling_period_micros) {
return errors::InvalidArgument(
"initial_scheduling_period_micros (",
options.initial_scheduling_period_micros,
") must be <= max_scheduling_period_micros (",
options.max_scheduling_period_micros, ")");
}
if (options.feedback_smoothing_batches < 1) {
return errors::InvalidArgument(
"feedback_smoothing_batches must be positive; was ",
options.feedback_smoothing_batches);
}
scheduler->reset(new AdaptiveSharedBatchScheduler<TaskType>(options));
return Status::OK();
}
template <typename TaskType>
AdaptiveSharedBatchScheduler<TaskType>::AdaptiveSharedBatchScheduler(
const Options& options)
: options_(options),
scheduling_period_(options.initial_scheduling_period_micros) {
PeriodicFunction::Options opts;
opts.thread_name_prefix = "scheduling_thread";
opts.env = GetEnv();
scheduling_thread_.reset(
new PeriodicFunction([this] { ProcessOneBatch(); }, 0, opts));
batch_thread_pool_.reset(new thread::ThreadPool(
GetEnv(), options.thread_pool_name, options.num_batch_threads));
}
template <typename TaskType>
Status AdaptiveSharedBatchScheduler<TaskType>::AddQueue(
const QueueOptions& options, BatchProcessor process_batch_callback,
std::unique_ptr<BatchScheduler<TaskType>>* queue) {
if (options.max_batch_size <= 0) {
return errors::InvalidArgument("max_batch_size must be positive; was ",
options.max_batch_size);
}
if (options.max_enqueued_batches <= 0) {
return errors::InvalidArgument(
"max_enqueued_batches must be positive; was ",
options.max_enqueued_batches);
}
internal::ASBSQueue<TaskType>* asbs_queue_raw;
queue->reset(asbs_queue_raw = new internal::ASBSQueue<TaskType>(
this->shared_from_this(), options));
mutex_lock l(mu_);
queues_and_callbacks_[asbs_queue_raw] = process_batch_callback;
return Status::OK();
}
template <typename TaskType>
void AdaptiveSharedBatchScheduler<TaskType>::AddBatch(
internal::ASBSBatch<TaskType>* batch) {
mutex_lock l(mu_);
batches_.push(batch);
}
template <typename TaskType>
void AdaptiveSharedBatchScheduler<TaskType>::RemoveQueue(
const internal::ASBSQueue<TaskType>* queue) {
mutex_lock l(mu_);
queues_and_callbacks_.erase(queue);
}
template <typename TaskType>
void AdaptiveSharedBatchScheduler<TaskType>::ProcessOneBatch() {
static const double kFeedbackMultiplier = .001;
internal::ASBSBatch<TaskType>* batch = nullptr;
BatchProcessor callback;
const int64 start_time_micros = GetEnv()->NowMicros();
{
mutex_lock l(mu_);
if (!batches_.empty()) {
batch = batches_.top();
batches_.pop();
callback = queues_and_callbacks_[batch->queue()];
}
}
if (batch != nullptr) {
double feedback = options_.scheduling_period_feedback();
const int64 N = options_.feedback_smoothing_batches;
ewma_feedback_ = ((N - 1) * ewma_feedback_ + feedback) / N;
scheduling_period_ *= (1 + kFeedbackMultiplier * ewma_feedback_);
if (scheduling_period_ < options_.min_scheduling_period_micros) {
scheduling_period_ = options_.min_scheduling_period_micros;
} else if (scheduling_period_ > options_.max_scheduling_period_micros) {
scheduling_period_ = options_.max_scheduling_period_micros;
}
// Queue may destroy itself after ReleaseBatch is called.
batch->queue()->ReleaseBatch(batch);
batch_thread_pool_->Schedule([callback, batch] {
callback(std::unique_ptr<Batch<TaskType>>(batch));
});
}
const int64 sleep_time =
scheduling_period_ - (GetEnv()->NowMicros() - start_time_micros);
if (sleep_time > 0) {
GetEnv()->SleepForMicroseconds(sleep_time);
}
}
template <typename TaskType>
bool AdaptiveSharedBatchScheduler<TaskType>::BatchCompare::operator()(
const internal::ASBSBatch<TaskType>* a,
const internal::ASBSBatch<TaskType>* b) {
return a->creation_time_micros() > b->creation_time_micros();
}
// ---------------- ASBSQueue ----------------
namespace internal {
template <typename TaskType>
ASBSQueue<TaskType>::ASBSQueue(
std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler,
const QueueOptions& options)
: scheduler_(scheduler), options_(options) {}
template <typename TaskType>
ASBSQueue<TaskType>::~ASBSQueue() {
// Wait until last batch has been scheduled.
const int kSleepMicros = 1000;
for (;;) {
{
mutex_lock l(mu_);
if (num_enqueued_batches_ == 0) {
break;
}
}
scheduler_->GetEnv()->SleepForMicroseconds(kSleepMicros);
}
scheduler_->RemoveQueue(this);
}
template <typename TaskType>
Status ASBSQueue<TaskType>::Schedule(std::unique_ptr<TaskType>* task) {
bool added_new_batch = false;
size_t size = (*task)->size();
if (size > options_.max_batch_size) {
return errors::InvalidArgument("Task size ", size,
" is larger than maximum batch size ",
options_.max_batch_size);
}
{
mutex_lock l(mu_);
// Current batch is full, create another if allowed.
if (current_batch_ &&
current_batch_->size() + size > options_.max_batch_size) {
if (num_enqueued_batches_ >= options_.max_enqueued_batches) {
return errors::Unavailable("The batch scheduling queue is full");
}
current_batch_->Close();
current_batch_ = nullptr;
}
if (!current_batch_) {
added_new_batch = true;
num_enqueued_batches_++;
current_batch_ =
new ASBSBatch<TaskType>(this, scheduler_->GetEnv()->NowMicros());
}
current_batch_->AddTask(std::move(*task));
num_enqueued_tasks_++;
}
if (added_new_batch) scheduler_->AddBatch(current_batch_);
return Status::OK();
}
template <typename TaskType>
void ASBSQueue<TaskType>::ReleaseBatch(const ASBSBatch<TaskType>* batch) {
mutex_lock l(mu_);
num_enqueued_batches_--;
num_enqueued_tasks_ -= batch->num_tasks();
if (batch == current_batch_) {
current_batch_->Close();
current_batch_ = nullptr;
}
}
template <typename TaskType>
size_t ASBSQueue<TaskType>::NumEnqueuedTasks() const {
mutex_lock l(mu_);
return num_enqueued_tasks_;
}
template <typename TaskType>
size_t ASBSQueue<TaskType>::SchedulingCapacity() const {
mutex_lock l(mu_);
const int current_batch_capacity =
current_batch_ ? options_.max_batch_size - current_batch_->size() : 0;
const int spare_batches =
options_.max_enqueued_batches - num_enqueued_batches_;
return spare_batches * options_.max_batch_size + current_batch_capacity;
}
} // namespace internal
} // namespace serving
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_

View File

@ -0,0 +1,438 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h"
#include "tensorflow/contrib/batching/test_util/fake_clock_env.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace serving {
namespace anonymous {
class FakeTask : public BatchTask {
public:
explicit FakeTask(size_t size) : size_(size) {}
~FakeTask() override = default;
size_t size() const override { return size_; }
private:
const size_t size_;
TF_DISALLOW_COPY_AND_ASSIGN(FakeTask);
};
// Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' on
// that task. Returns the resulting status.
Status ScheduleTask(size_t task_size, BatchScheduler<FakeTask>* scheduler) {
std::unique_ptr<FakeTask> task(new FakeTask(task_size));
Status status = scheduler->Schedule(&task);
// Schedule() should have consumed 'task' iff it returned Status::OK.
CHECK_EQ(status.ok(), task == nullptr);
return status;
}
// Creates a thread that waits on 'start' and then advances the fake clock in
// 'env' in a loop until 'stop' is notified. Useful for allowing objects that
// use the clock to be destroyed.
std::unique_ptr<Thread> CreateFakeClockAdvancerThread(
test_util::FakeClockEnv* env, Notification* start, Notification* stop) {
return std::unique_ptr<Thread>(Env::Default()->StartThread(
{}, "FakeClockAdvancerThread", [env, start, stop] {
start->WaitForNotification();
while (!stop->HasBeenNotified()) {
env->AdvanceByMicroseconds(10);
Env::Default()->SleepForMicroseconds(10);
}
}));
}
TEST(AdaptiveSharedBatchSchedulerTest, Basic) {
for (const bool delete_scheduler_early : {false, true}) {
for (const bool delete_queue_1_early : {false, true}) {
int queue_0_tasks = 0;
auto queue_0_callback =
[&queue_0_tasks](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
for (int i = 0; i < batch->num_tasks(); i++) {
queue_0_tasks += batch->task(i).size();
}
};
int queue_1_tasks = 0;
auto queue_1_callback =
[&queue_1_tasks](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
for (int i = 0; i < batch->num_tasks(); i++) {
queue_1_tasks += batch->task(i).size();
}
};
{
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
AdaptiveSharedBatchScheduler<FakeTask>::Create({}, &scheduler));
// Create two queues.
std::unique_ptr<BatchScheduler<FakeTask>> queue_0;
TF_ASSERT_OK(scheduler->AddQueue({}, queue_0_callback, &queue_0));
std::unique_ptr<BatchScheduler<FakeTask>> queue_1;
TF_ASSERT_OK(scheduler->AddQueue({}, queue_1_callback, &queue_1));
if (delete_scheduler_early) {
// Delete our copy of the scheduler. The queues should keep it alive
// under the covers.
scheduler = nullptr;
}
// Submit tasks to the two queues, and (optionally) remove the queues.
TF_ASSERT_OK(ScheduleTask(1, queue_0.get()));
TF_ASSERT_OK(ScheduleTask(2, queue_1.get()));
TF_ASSERT_OK(ScheduleTask(3, queue_0.get()));
TF_ASSERT_OK(ScheduleTask(4, queue_1.get()));
if (delete_queue_1_early) {
queue_1 = nullptr;
}
TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
}
EXPECT_EQ(queue_0_tasks, 9);
EXPECT_EQ(queue_1_tasks, 6);
}
}
}
TEST(AdaptiveSharedBatchSchedulerTest, BadOptions) {
using Scheduler = AdaptiveSharedBatchScheduler<FakeTask>;
std::shared_ptr<Scheduler> scheduler;
Scheduler::Options options;
options.num_batch_threads = 0;
EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
options = Scheduler::Options();
options.min_scheduling_period_micros = 50;
options.max_scheduling_period_micros = 100;
options.initial_scheduling_period_micros = 1;
EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
options = Scheduler::Options();
options.min_scheduling_period_micros = 50;
options.max_scheduling_period_micros = 100;
options.initial_scheduling_period_micros = 1000;
EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
options = Scheduler::Options();
options.min_scheduling_period_micros = 100;
options.max_scheduling_period_micros = 50;
options.initial_scheduling_period_micros = 75;
EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
options = Scheduler::Options();
options.feedback_smoothing_batches = 0;
EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok());
}
TEST(AdaptiveSharedBatchSchedulerTest, ObeysQueueOptions) {
test_util::FakeClockEnv env(Env::Default());
Notification start_teardown, stop_teardown;
std::unique_ptr<Thread> teardown_thread =
CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
{
AdaptiveSharedBatchScheduler<FakeTask>::Options options;
options.initial_scheduling_period_micros = 1000;
options.env = &env;
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
std::unique_ptr<BatchScheduler<FakeTask>> queue_0;
std::unique_ptr<BatchScheduler<FakeTask>> queue_1;
int queue_0_tasks = 0;
int queue_1_tasks = 0;
auto queue_0_callback = [&queue_0_tasks,
&env](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
for (int i = 0; i < batch->num_tasks(); i++) {
queue_0_tasks += batch->task(i).size();
}
env.SleepForMicroseconds(1);
};
auto queue_1_callback = [&queue_1_tasks,
&env](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
for (int i = 0; i < batch->num_tasks(); i++) {
queue_1_tasks += batch->task(i).size();
}
env.SleepForMicroseconds(1);
};
AdaptiveSharedBatchScheduler<FakeTask>::QueueOptions queue_options;
queue_options.max_batch_size = 10;
queue_options.max_enqueued_batches = 0;
// Queue must have max_enqueued_batchs > 1.
EXPECT_FALSE(
scheduler->AddQueue(queue_options, queue_0_callback, &queue_0).ok());
queue_options.max_enqueued_batches = 2;
TF_ASSERT_OK(
scheduler->AddQueue(queue_options, queue_0_callback, &queue_0));
queue_options.max_batch_size = 0;
// Queue must have max_batch_size > 0.
EXPECT_FALSE(
scheduler->AddQueue(queue_options, queue_1_callback, &queue_1).ok());
queue_options.max_batch_size = 2;
queue_options.max_enqueued_batches = 1;
TF_ASSERT_OK(
scheduler->AddQueue(queue_options, queue_1_callback, &queue_1));
// Wait for scheduling_thread to sleep.
env.BlockUntilThreadsAsleep(1);
// Task larger than max_batch_size shouldn't schedule.
EXPECT_FALSE(ScheduleTask(15, queue_0.get()).ok());
TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
env.AdvanceByMicroseconds(1);
// Task larger than max_batch_size shouldn't schedule.
EXPECT_FALSE(ScheduleTask(3, queue_1.get()).ok());
TF_ASSERT_OK(ScheduleTask(1, queue_1.get()));
TF_ASSERT_OK(ScheduleTask(1, queue_1.get()));
env.AdvanceByMicroseconds(1);
// Exceeds max_enqueued_batches, shouldn't schedule.
EXPECT_FALSE(ScheduleTask(1, queue_1.get()).ok());
TF_ASSERT_OK(ScheduleTask(5, queue_0.get()));
// Exceeds max_enqueued_batches, shouldn't schedule.
EXPECT_FALSE(ScheduleTask(6, queue_0.get()).ok());
TF_ASSERT_OK(ScheduleTask(4, queue_0.get()));
// Batches should be processed in order from oldest to newest.
env.AdvanceByMicroseconds(1000);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(queue_0_tasks, 10);
EXPECT_EQ(queue_1_tasks, 0);
env.AdvanceByMicroseconds(1000);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(queue_0_tasks, 10);
EXPECT_EQ(queue_1_tasks, 2);
env.AdvanceByMicroseconds(1000);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(queue_0_tasks, 19);
EXPECT_EQ(queue_1_tasks, 2);
start_teardown.Notify();
}
stop_teardown.Notify();
}
TEST(AdaptiveSharedBatchSchedulerTest, RateFeedback) {
test_util::FakeClockEnv env(Env::Default());
Notification start_teardown, stop_teardown;
std::unique_ptr<Thread> teardown_thread =
CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
{
double feedback = 0;
AdaptiveSharedBatchScheduler<FakeTask>::Options options;
options.initial_scheduling_period_micros = 1000;
options.min_scheduling_period_micros = 200;
options.max_scheduling_period_micros = 2000;
options.env = &env;
options.scheduling_period_feedback = [&feedback] { return feedback; };
options.feedback_smoothing_batches = 1;
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
std::unique_ptr<BatchScheduler<FakeTask>> queue;
int scheduled_items = 0;
auto queue_callback = [&scheduled_items,
&env](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
scheduled_items = 0;
for (int i = 0; i < batch->num_tasks(); i++) {
scheduled_items += batch->task(i).size();
}
env.SleepForMicroseconds(1);
};
TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
// Wait for scheduling_thread to sleep.
env.BlockUntilThreadsAsleep(1);
// Enqueue 6 batches.
for (int i = 0; i < 6; i++) {
TF_ASSERT_OK(ScheduleTask(900 + i, queue.get()));
env.AdvanceByMicroseconds(1);
}
feedback = -500;
env.AdvanceByMicroseconds(994);
env.BlockUntilThreadsAsleep(2); // scheduling period = 500 usec.
EXPECT_EQ(scheduled_items, 900);
env.AdvanceByMicroseconds(500);
env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec.
EXPECT_EQ(scheduled_items, 901);
feedback = 0;
env.AdvanceByMicroseconds(250);
env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec.
EXPECT_EQ(scheduled_items, 902);
feedback = 10000; // large feedback should hit max_scheduling_period.
env.AdvanceByMicroseconds(250);
env.BlockUntilThreadsAsleep(2); // scheduling period = 2000 usec.
EXPECT_EQ(scheduled_items, 903);
feedback = -10000; // large feedback should hit min_scheduling_period.
env.AdvanceByMicroseconds(1999);
// No callback scheduled, only scheduling thread sleeping.
env.BlockUntilThreadsAsleep(1);
EXPECT_EQ(scheduled_items, 903);
env.AdvanceByMicroseconds(1);
env.BlockUntilThreadsAsleep(2); // scheduling period = 200 usec.
EXPECT_EQ(scheduled_items, 904);
env.AdvanceByMicroseconds(200);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(scheduled_items, 905);
start_teardown.Notify();
}
stop_teardown.Notify();
}
TEST(AdaptiveSharedBatchSchedulerTest, FeedbackSmoothing) {
test_util::FakeClockEnv env(Env::Default());
Notification start_teardown, stop_teardown;
std::unique_ptr<Thread> teardown_thread =
CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
{
double feedback = 0;
AdaptiveSharedBatchScheduler<FakeTask>::Options options;
options.initial_scheduling_period_micros = 1000;
options.env = &env;
options.scheduling_period_feedback = [&feedback] { return feedback; };
options.feedback_smoothing_batches = 3;
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
std::unique_ptr<BatchScheduler<FakeTask>> queue;
int scheduled_items = 0;
auto queue_callback = [&scheduled_items,
&env](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
scheduled_items = 0;
for (int i = 0; i < batch->num_tasks(); i++) {
scheduled_items += batch->task(i).size();
}
env.SleepForMicroseconds(1);
};
TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue));
// Wait for scheduling_thread to sleep.
env.BlockUntilThreadsAsleep(1);
// Enqueue 4 batches.
for (int i = 0; i < 4; i++) {
TF_ASSERT_OK(ScheduleTask(900 + i, queue.get()));
env.AdvanceByMicroseconds(1);
}
feedback = -300;
env.AdvanceByMicroseconds(996);
env.BlockUntilThreadsAsleep(2);
// ewma_feedback = 100, scheduling_period = 900.
EXPECT_EQ(scheduled_items, 900);
env.AdvanceByMicroseconds(899);
// No callback scheduled, only scheduling thread sleeping.
env.BlockUntilThreadsAsleep(1);
EXPECT_EQ(scheduled_items, 900);
env.AdvanceByMicroseconds(1);
env.BlockUntilThreadsAsleep(2);
// ewma_feedback = 167, scheduling_period = 750.
EXPECT_EQ(scheduled_items, 901);
env.AdvanceByMicroseconds(749);
// No callback scheduled, only scheduling thread sleeping.
env.BlockUntilThreadsAsleep(1);
EXPECT_EQ(scheduled_items, 901);
feedback = 1000 / 3.;
env.AdvanceByMicroseconds(1);
env.BlockUntilThreadsAsleep(2);
// emwa_feedback = 0, scheduling_period = 750.
EXPECT_EQ(scheduled_items, 902);
env.AdvanceByMicroseconds(749);
// No callback scheduled, only scheduling thread sleeping.
env.BlockUntilThreadsAsleep(1);
EXPECT_EQ(scheduled_items, 902);
env.AdvanceByMicroseconds(1);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(scheduled_items, 903);
start_teardown.Notify();
}
stop_teardown.Notify();
}
TEST(AdaptiveSharedBatchSchedulerTest, QueueCapacityInfo) {
test_util::FakeClockEnv env(Env::Default());
Notification start_teardown, stop_teardown;
std::unique_ptr<Thread> teardown_thread =
CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown);
{
AdaptiveSharedBatchScheduler<FakeTask>::Options options;
options.initial_scheduling_period_micros = 1000;
options.env = &env;
std::shared_ptr<AdaptiveSharedBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
AdaptiveSharedBatchScheduler<FakeTask>::Create(options, &scheduler));
std::unique_ptr<BatchScheduler<FakeTask>> queue;
int scheduled_items = 0;
auto queue_callback = [&scheduled_items,
&env](std::unique_ptr<Batch<FakeTask>> batch) {
ASSERT_TRUE(batch->IsClosed());
EXPECT_GT(batch->num_tasks(), 0);
scheduled_items = 0;
for (int i = 0; i < batch->num_tasks(); i++) {
scheduled_items += batch->task(i).size();
}
env.SleepForMicroseconds(1);
};
AdaptiveSharedBatchScheduler<FakeTask>::QueueOptions queue_options;
queue_options.max_batch_size = 10;
queue_options.max_enqueued_batches = 10;
TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue));
// Wait for scheduling_thread to sleep.
env.BlockUntilThreadsAsleep(1);
// Enqueue 3 tasks.
EXPECT_EQ(queue->NumEnqueuedTasks(), 0);
EXPECT_EQ(queue->SchedulingCapacity(), 100);
TF_ASSERT_OK(ScheduleTask(5, queue.get()));
EXPECT_EQ(queue->NumEnqueuedTasks(), 1);
EXPECT_EQ(queue->SchedulingCapacity(), 95);
env.AdvanceByMicroseconds(1);
TF_ASSERT_OK(ScheduleTask(6, queue.get()));
EXPECT_EQ(queue->NumEnqueuedTasks(), 2);
EXPECT_EQ(queue->SchedulingCapacity(), 84);
env.AdvanceByMicroseconds(1);
TF_ASSERT_OK(ScheduleTask(1, queue.get()));
EXPECT_EQ(queue->NumEnqueuedTasks(), 3);
EXPECT_EQ(queue->SchedulingCapacity(), 83);
env.AdvanceByMicroseconds(998);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(scheduled_items, 5);
env.AdvanceByMicroseconds(1000);
env.BlockUntilThreadsAsleep(2);
EXPECT_EQ(scheduled_items, 7);
start_teardown.Notify();
}
stop_teardown.Notify();
}
} // namespace anonymous
} // namespace serving
} // namespace tensorflow

View File

@ -78,7 +78,7 @@ template <typename TaskType>
class Batch {
public:
Batch() = default;
~Batch(); // Blocks until the batch is closed.
virtual ~Batch(); // Blocks until the batch is closed.
// Appends 'task' to the batch. After calling AddTask(), the newly-added task
// can be accessed via task(num_tasks()-1) or mutable_task(num_tasks()-1).

View File

@ -14,7 +14,7 @@
# ==============================================================================
include (ExternalProject)
set(cub_URL https://github.com/NVlabs/cub/archive/1.7.4.zip)
set(cub_URL https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip)
set(cub_HASH SHA256=20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31)
set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub)
set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub)

View File

@ -15,7 +15,7 @@
include (ExternalProject)
set(gif_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/gif_archive/giflib-5.1.4/)
set(gif_URL http://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz)
set(gif_URL https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz)
set(gif_HASH SHA256=34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1)
set(gif_INSTALL ${CMAKE_BINARY_DIR}/gif/install)
set(gif_BUILD ${CMAKE_BINARY_DIR}/gif/src/gif)

View File

@ -15,7 +15,7 @@
include (ExternalProject)
set(jpeg_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/jpeg_archive)
set(jpeg_URL http://mirror.bazel.build/www.ijg.org/files/jpegsrc.v9a.tar.gz)
set(jpeg_URL https://mirror.bazel.build/www.ijg.org/files/jpegsrc.v9a.tar.gz)
set(jpeg_HASH SHA256=3a753ea48d917945dd54a2d97de388aa06ca2eb1066cbfdc6652036349fe05a7)
set(jpeg_BUILD ${CMAKE_CURRENT_BINARY_DIR}/jpeg/src/jpeg)
set(jpeg_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/jpeg/install)

View File

@ -15,7 +15,7 @@
include (ExternalProject)
set(lmdb_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/lmdb)
set(lmdb_URL http://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz)
set(lmdb_URL https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz)
set(lmdb_HASH SHA256=108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d05326)
set(lmdb_BUILD ${CMAKE_BINARY_DIR}/lmdb/src/lmdb)
set(lmdb_INSTALL ${CMAKE_BINARY_DIR}/lmdb/install)

View File

@ -47,4 +47,4 @@ ExternalProject_Add(snappy
)
# actually enables snappy in the source code
add_definitions(-DSNAPPY)
add_definitions(-DTF_USE_SNAPPY)

View File

@ -86,7 +86,7 @@ cuda_py_test(
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
"//tensorflow/python/eager:graph_callable",
"//tensorflow/python:platform_test",
"//tensorflow/python/eager:test",
"//tensorflow/python:variables",
],
)
@ -132,11 +132,12 @@ py_library(
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:layers_base",
"//tensorflow/python:math_ops",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:function",
],
)
@ -146,6 +147,10 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":metrics",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
],
)
@ -160,6 +165,8 @@ py_library(
deps = [
":datasets",
":metrics",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:function",
],
)

View File

@ -86,7 +86,7 @@ class EvaluatorTest(test.TestCase):
for v in e.metric_variables:
p = v.name.split("/")[0]
prefix_count[p] = prefix_count.get(p, 0) + 1
self.assertEqual({"outer-mean": 2, "mean": 2}, prefix_count)
self.assertEqual({"outer_mean": 2, "mean": 2}, prefix_count)
def testDataset(self):
e = SimpleEvaluator(IdentityModel())

View File

@ -18,6 +18,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
@ -25,55 +29,69 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
_to_replace = re.compile("[^A-Za-z0-9.]")
class Metric(object):
"""A metric holds state for aggregating statistics over an evaluation run.
Users will use Evaluator.add_metric() to add Metric objects to their
evaluation, call them in each step, and then use
Evaluator.all_metric_results() at the end.
evaluation, call them in each step (treating the object as a callable),
and then use Evaluator.all_metric_results() at the end.
Descendants will implement:
* call(): Should follow this pattern:
if not self.built:
self.var = self.add_variable(...)
self.add_update(self.var.assign_add(...))
* aggregate(): Adds in the state from a list of metrics of the same type
as `self`. (Default of summing all the variables will be fine for most
descendants.)
* result(): Computes and returns a final value for the metric
* `build()`: All variables should be created in this method, by calling
`self.add_variable()` as in: `self.var = self.add_variable(...)`
build() will be called in the first invocation of `__call__()`, with
the same arguments passed `call()`.
* `call()`: Has all updates to variables, as in:
self.var.assign_add(...)
* `result()`: Computes and returns a final value for the metric
from the variables in `self`.
Decendants may override, but usually won't need to:
* `aggregate()`: Adds in the state from a list of metrics of the same type
as `self`. (Default is to sum all the variables.)
* `reset()`: Reset all variables to their initial state. (Default is to
zero all the variables.)
Note that users should not call `aggregate()` or `reset()`, they are for
use by TensorFlow infrastructure.
"""
def __init__(self, name=None):
self.built = False
self._built = False
self._vars = []
self._updates = []
self._name = name or self.__class__.__name__
# TODO(josh11b): Need some way to make sure two Metrics in the same
# Network have distinct names. Maybe we can get a unique name from
# a name/variable scope?
# TODO(josh11b): self._in_graph_mode = context.in_graph_mode()
name = name or self.__class__.__name__
# Replace things like spaces in name to create a valid scope name.
scope_name = _to_replace.sub("_", name)
# We create the variable scope now to get the unique name that will
# be used as a variable prefix when build() calls add_variable().
with variable_scope.variable_scope(
None, default_name=scope_name, use_resource=True, reuse=False) as scope:
pos = scope.name.rfind(scope_name)
self._name = name + scope.name[pos + len(scope_name):]
self._scope = scope
if context.in_graph_mode():
# We make self.call() into a graph callable here, so that we can
# return a single op that performs all of the variable updates.
self.call = function.defun(self.call)
# ---- API for users ----
def __call__(self, *args, **kwargs):
# TODO(josh11b): If self._in_graph_mode is true, make self.call() into a
# graph callable here, so that variable updates happen without requiring
# a separate fetch.
# TODO(josh11b): Do we need a separate build() method to separate
# initialization from each update? If so, how do we get the arguments
# to it? We *could* just pass in *args and **kwargs...
if not self.built:
# TODO(ashankar): Set up container isolation so there is no chance
# distinct metrics objects accidentally share variables.
# TODO(josh11b): Replace things like spaces in self._name to create
# a valid scope name.
with variable_scope.variable_scope(
self._name, use_resource=True, reuse=False):
ret = self.call(*args, **kwargs)
self.built = True
else:
ret = self.call(*args, **kwargs)
return ret
"""Returns op to execute to update this metric for these inputs.
Returns None if eager execution is enabled.
Args:
*args:
**kwargs: A mini-batch of inputs to the Metric, passed on to `call()`.
"""
if not self._built:
with variable_scope.variable_scope(self._scope):
self.build(*args, **kwargs)
self._built = True
return self.call(*args, **kwargs)
@property
def name(self):
@ -84,10 +102,43 @@ class Metric(object):
return self._vars
# ---- To be implemented by descendants ---
def build(self, *args, **kwargs):
"""Method to create variables.
Called by `__call__()` before `call()` for the first time.
Args:
*args:
**kwargs: The arguments to the first invocation of `__call__()`.
`build()` may use the shape and/or dtype of these arguments
when deciding how to create variables.
"""
raise NotImplementedError("Metrics must define a build() member function")
def call(self, *args, **kwargs):
"""Accumulates statistics for the metric."""
"""Accumulates statistics for the metric. Users should use __call__ instead.
Note: This function is executed as a graph function in graph mode.
This means:
a) Operations on the same resource are executed in textual order.
This should make it easier to do things like add the updated
value of a variable to another, for example.
b) You don't need to worry about collecting the update ops to execute.
All update ops added to the graph by this function will be executed.
As a result, code should generally work the same way with graph or
eager execution.
Args:
*args:
**kwargs: A mini-batch of inputs to the Metric, as passed to
`__call__()`.
"""
raise NotImplementedError("Metrics must define a call() member function")
def result(self): # TODO(josh11b): Add an optional summary_writer parameter.
"""Computes and returns a final value for the metric."""
raise NotImplementedError("Metrics must define a result() member function")
# We can support two different strategies of for doing data-parallel
# distributed metric computations:
# * Put metric variables on the first device and rely on small
@ -123,16 +174,19 @@ class Metric(object):
self._vars[i].assign_add(math_ops.add_n([m._vars[i] for m in metrics]))
# pylint: enable=protected-access
def result(self): # TODO(josh11b): Add an optional summary_writer parameter.
"""Computes and returns a final value for the metric."""
raise NotImplementedError("Metrics must define a result() member function")
def reset(self):
"""Reset this metric to a freshly initialized state.
Default implementation zeros all the metric variables.
"""
for v in self._vars:
v.assign(math_ops.zeros_like(v))
# ---- For use by descendants ---
def add_variable(self, name, shape=None, dtype=None, initializer=None):
"""***Only for use by descendants of Metric***."""
if self.built:
raise RuntimeError("Can't call add_variable() after a Metric has been "
"built in the first call().")
if self._built:
raise RuntimeError("Can't call add_variable() except in build().")
v = variable_scope.get_variable(name, shape, dtype, initializer,
trainable=False, use_resource=True)
self._vars.append(v)
@ -144,6 +198,15 @@ class Mean(Metric):
# TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64?
# Or defaults to type of the input if it is tf.float32, else tf.float64?
def build(self, values, weights=None):
del values, weights # build() does not use call's arguments
self.numer = self.add_variable(name="numer", shape=(),
dtype=dtypes.float64,
initializer=init_ops.zeros_initializer)
self.denom = self.add_variable(name="denom", shape=(),
dtype=dtypes.float64,
initializer=init_ops.zeros_initializer)
def call(self, values, weights=None):
"""Accumulate statistics for computing the mean.
@ -154,13 +217,6 @@ class Mean(Metric):
values: Tensor with the per-example value.
weights: Optional weighting of each example. Defaults to 1.
"""
if not self.built: # False only in the first call().
self.numer = self.add_variable(name="numer", shape=(),
dtype=dtypes.float64,
initializer=init_ops.zeros_initializer)
self.denom = self.add_variable(name="denom", shape=(),
dtype=dtypes.float64,
initializer=init_ops.zeros_initializer)
if weights is None:
self.denom.assign_add(
math_ops.cast(array_ops.size(values), dtypes.float64))
@ -179,6 +235,10 @@ class Mean(Metric):
class Accuracy(Mean):
"""Calculates how often `predictions` matches `labels`."""
def build(self, labels, predictions, weights=None):
del labels, predictions, weights
super(Accuracy, self).build(None) # Arguments are unused
def call(self, labels, predictions, weights=None):
"""Accumulate accuracy statistics.

View File

@ -19,7 +19,11 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.eager.python import metrics
from tensorflow.python.eager import context
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables
class MetricsTest(test.TestCase):
@ -56,6 +60,53 @@ class MetricsTest(test.TestCase):
m([7], [2]) # 0 correct, weight 1
self.assertEqual(2.5/5, m.result().numpy())
def testTwoMeans(self):
# Verify two metrics with the same class and name don't
# accidentally share state.
m1 = metrics.Mean()
m2 = metrics.Mean()
m1(0)
m2(2)
self.assertEqual(0, m1.result().numpy())
self.assertEqual(2, m2.result().numpy())
self.assertNotEqual(m1.name, m2.name)
def testNamesWithSpaces(self):
# Verify two metrics with the same class and name don't
# accidentally share state.
m1 = metrics.Mean("has space")
m2 = metrics.Mean("has space")
m2(2)
m1(0)
self.assertEqual(m1.name, "has space")
self.assertEqual(m1.numer.name, "has_space/numer:0")
self.assertEqual(m2.name, "has space_1")
self.assertEqual(m2.numer.name, "has_space_1/numer:0")
def testGraph(self):
with context.graph_mode(), self.test_session() as sess:
m = metrics.Mean()
p = array_ops.placeholder(dtypes.float32)
accumulate = m(p)
variables.global_variables_initializer().run()
sess.run(accumulate, feed_dict={p: [1, 10, 100]})
sess.run(accumulate, feed_dict={p: 1000})
sess.run(accumulate, feed_dict={p: [10000, 100000]})
self.assertAllEqual(m.result().eval(), 111111.0/6)
def testTwoMeansGraph(self):
# Verify two metrics with the same class and name don't
# accidentally share state.
with context.graph_mode(), self.test_session() as sess:
m1 = metrics.Mean()
m2 = metrics.Mean()
accumulate1 = m1(0)
accumulate2 = m2(2)
variables.global_variables_initializer().run()
sess.run([accumulate1, accumulate2])
self.assertEqual(0, m1.result().eval())
self.assertEqual(2, m2.result().eval())
if __name__ == "__main__":
test.main()

View File

@ -22,6 +22,7 @@ import os
from tensorflow.contrib.eager.python import saver as _saver
from tensorflow.python.eager import context
from tensorflow.python.eager import graph_callable
from tensorflow.python.eager import test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@ -29,7 +30,6 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
class SaverTest(test.TestCase):
@ -38,7 +38,7 @@ class SaverTest(test.TestCase):
return '/device:GPU:0' if context.num_gpus() else '/device:CPU:0'
def testBasics(self):
with context.eager_mode(), ops.device(self._dev()):
with ops.device(self._dev()):
v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
def model():
return array_ops.constant(2.0) * v1
@ -54,8 +54,42 @@ class SaverTest(test.TestCase):
saver.restore(ckpt_prefix)
self.assertEqual(v1.read_value().numpy(), 1.0)
def testRestoreOnCreate(self):
def testSameNameNoClobbering(self):
with context.eager_mode(), ops.device(self._dev()):
# Note that this test purposefully uses Graphs rather than
# IsolateTest. Users are more likely to accidentally create the same
# variable name this way.
first_graph = ops.Graph()
with first_graph.as_default():
v1_first_graph = resource_variable_ops.ResourceVariable(1.0, name='v1')
with ops.Graph().as_default():
v1_second_graph = resource_variable_ops.ResourceVariable(2.0, name='v1')
saver = _saver.Saver([v1_first_graph, v1_second_graph])
ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
with self.assertRaisesRegexp(ValueError, 'v1'):
saver.save(ckpt_prefix)
def testDifferentGraphError(self):
with context.eager_mode(), ops.device(self._dev()):
with ops.Graph().as_default():
v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
with ops.Graph().as_default():
saver = _saver.Saver([v1])
ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
with self.assertRaisesRegexp(ValueError, 'Graph'):
saver.save(ckpt_prefix)
def testSameObjectOK(self):
with context.eager_mode(), ops.device(self._dev()):
v1 = resource_variable_ops.ResourceVariable(1.0, name='v1')
# While different objects with the same shared_name are not good, passing
# in the same object multiple times is fine.
saver = _saver.Saver([v1, v1])
ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt')
saver.save(ckpt_prefix)
def testRestoreOnCreate(self):
with ops.device(self._dev()):
def model(init_val):
v1 = resource_variable_ops.ResourceVariable(init_val, name='v1')
return array_ops.constant(1.0) * v1, v1
@ -71,12 +105,9 @@ class SaverTest(test.TestCase):
# Value is from checkpoint, but not from argument.
ret, _ = model(2.0)
self.assertEqual(ret.numpy(), 1.0)
# Create it a second time won't re-assign the checkpoint value.
v1_2 = resource_variable_ops.ResourceVariable(3.0, name='v1')
self.assertEqual(v1_2.read_value().numpy(), 3.0)
def testRestoreNotFound(self):
with context.eager_mode(), ops.device(self._dev()):
with ops.device(self._dev()):
def model(v):
return array_ops.constant(1.0) * v
@ -92,7 +123,7 @@ class SaverTest(test.TestCase):
_ = model(resource_variable_ops.ResourceVariable(1.0, name='v2'))
def testSaveRestoreGraphCallable(self):
with context.eager_mode(), ops.device(self._dev()):
with ops.device(self._dev()):
@graph_callable.graph_callable(
[graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)])
def model(x):

View File

@ -53,6 +53,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`.
@@in_eager_mode
@@in_graph_mode
@@IsolateTest
@@run_test_in_graph_and_eager_modes
"""
@ -84,6 +85,7 @@ from tensorflow.python.eager.execution_callbacks import nan_callback
from tensorflow.python.eager.execution_callbacks import seterr
from tensorflow.python.framework.ops import enable_eager_execution
from tensorflow.python.framework.ops import eager_run as run
from tensorflow.python.framework.test_util import IsolateTest
from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes
from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable
from tensorflow.python.util.all_util import remove_undocumented

View File

@ -24,7 +24,11 @@ the full-batch version.
approach for computing the initial cluster assignments that is expensive but is
typically less prone to getting stuck in bad local minima.
We provide distributed implementations of both full-batch and mini-batch
K-Means algorithm. Both K-Means++ and random initialization are supported.
The user can also choose between **Cosine** and **Squared Euclidean** distance
metrics.
**[k-MC2](https://www.aaai.org/ocs/index.php/AAAI/AAAI16/paper/view/12147/11759)**
provides a very fast seeding method that provides high quality centers
comparable to K-Means++ seeding. k-MC2 works particularly well if it is combined
with Mini-batch K-Means.
We provide distributed implementations of both full-batch and mini-batch K-Means
algorithm. K-Means++, k-MC2 and random initialization are supported. The user
can also choose between **Cosine** and **Squared Euclidean** distance metrics.

View File

@ -224,6 +224,58 @@ class KmeansPlusPlusInitializationOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("KmeansPlusPlusInitialization").Device(DEVICE_CPU),
KmeansPlusPlusInitializationOp);
// Implementation of one single Markov Chain for the k-MC^2 algorithm
class KMC2ChainInitializationOp : public OpKernel {
public:
explicit KMC2ChainInitializationOp(OpKernelConstruction* context)
: OpKernel(context) {
OP_REQUIRES_OK(context,
context->MatchSignature({DT_FLOAT, DT_INT64}, {DT_INT64}));
}
void Compute(OpKernelContext* context) override {
const Tensor& distances_tensor = context->input(0);
const Tensor& seed_tensor = context->input(1);
OP_REQUIRES(context, TensorShapeUtils::IsVector(distances_tensor.shape()),
InvalidArgument("Input distances should be a vector."));
OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()),
InvalidArgument("Input seed should be a scalar."));
const int64 num_points = distances_tensor.dim_size(0);
const int64 seed = seed_tensor.scalar<int64>()();
OP_REQUIRES(context, num_points > 0,
InvalidArgument("Expected distances_tensor.size() > 0."));
random::PhiloxRandom random(seed);
random::SimplePhilox rng(&random);
auto distances = distances_tensor.flat<float>();
// Set the initial state of the Markov chain to be the first candidate.
int64 selected_index = 0;
float selected_distance = distances(selected_index);
// Build a Markov chain of length num_points.
for (int64 i = 1; i < num_points; ++i) {
const float candidate_distance = distances(i);
// Set the next state of the Markov chain to be the candidate with
// probability min(1, candidate_distance/selected_distance).
if (candidate_distance > rng.RandFloat() * selected_distance) {
selected_index = i;
selected_distance = candidate_distance;
}
}
Tensor* output_sampled_index_tensor;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({}),
&output_sampled_index_tensor));
auto output = output_sampled_index_tensor->scalar<int64>();
// Return the last state of the Markov chain as the new center.
output() = selected_index;
}
};
REGISTER_KERNEL_BUILDER(Name("KMC2ChainInitialization").Device(DEVICE_CPU),
KMC2ChainInitializationOp);
// Operator for computing the nearest neighbors for a set of points.
class NearestNeighborsOp : public OpKernel {
public:

View File

@ -116,6 +116,62 @@ RUN_BM_KmeansPlusPlusInitialization(k3RetriesPerSample);
#undef RUN_BM_KmeansPlusPlusInitialization
#undef BENCHMARK_KMEANS_PLUS_PLUS
Graph* SetUpKMC2Initialization(int num_points) {
Graph* g = new Graph(OpRegistry::Global());
Tensor distances(DT_FLOAT, TensorShape({num_points}));
Tensor seed(DT_INT64, TensorShape({}));
distances.flat<float>().setRandom();
seed.flat<int64>().setConstant(12345);
TF_CHECK_OK(
NodeBuilder("KMC2ChainInitializationOp", "KMC2ChainInitialization")
.Input(test::graph::Constant(g, distances))
.Input(test::graph::Constant(g, seed))
.Finalize(g, nullptr /* node */));
return g;
}
template <int num_points, int num_to_sample, int num_dims>
void BM_KMC2Initialization(int iters) {
testing::StopTiming();
testing::ItemsProcessed(static_cast<int64>(iters) * num_points * num_dims *
num_to_sample);
testing::UseRealTime();
Graph* g = SetUpKMC2Initialization(num_points);
testing::StartTiming();
test::Benchmark("cpu", g).Run(iters);
}
#define BENCHMARK_KMC2(p, c, d) \
void BM_KMC2Initialization_##p##_##c##_##d(int iters) { \
BM_KMC2Initialization<p, c, d>(iters); \
} \
BENCHMARK(BM_KMC2Initialization_##p##_##c##_##d);
#define RUN_BM_KMC2Initialization \
BENCHMARK_KMC2(k10Points, k2Centers, k100Dim); \
BENCHMARK_KMC2(k10Points, k5Centers, k100Dim); \
BENCHMARK_KMC2(k10Points, k10Centers, k100Dim); \
BENCHMARK_KMC2(k100Points, k10Centers, k100Dim); \
BENCHMARK_KMC2(k100Points, k20Centers, k100Dim); \
BENCHMARK_KMC2(k100Points, k50Centers, k100Dim); \
BENCHMARK_KMC2(k100Points, k100Centers, k100Dim); \
BENCHMARK_KMC2(k1kPoints, k100Centers, k100Dim); \
BENCHMARK_KMC2(k1kPoints, k200Centers, k100Dim); \
BENCHMARK_KMC2(k1kPoints, k500Centers, k100Dim); \
BENCHMARK_KMC2(k1kPoints, k1kCenters, k100Dim); \
BENCHMARK_KMC2(k10kPoints, k100Centers, k100Dim); \
BENCHMARK_KMC2(k10kPoints, k200Centers, k100Dim); \
BENCHMARK_KMC2(k10kPoints, k500Centers, k100Dim); \
BENCHMARK_KMC2(k10kPoints, k1kCenters, k100Dim); \
BENCHMARK_KMC2(k1MPoints, k100Centers, k100Dim); \
BENCHMARK_KMC2(k1MPoints, k200Centers, k100Dim); \
BENCHMARK_KMC2(k1MPoints, k500Centers, k100Dim); \
BENCHMARK_KMC2(k1MPoints, k1kCenters, k100Dim)
RUN_BM_KMC2Initialization;
#undef RUN_BM_KMC2Initialization
#undef BENCHMARK_KMC2
Graph* SetUpNearestNeighbors(int num_dims, int num_points, int num_centers,
int k) {
Graph* g = new Graph(OpRegistry::Global());

View File

@ -44,6 +44,25 @@ num_retries_per_sample: Scalar. For each row that is sampled, this parameter
samples: Matrix of shape (num_to_sample, d). The sampled rows.
)");
REGISTER_OP("KMC2ChainInitialization")
.Input("distances: float32")
.Input("seed: int64")
.Output("index: int64")
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"(
Returns the index of a data point that should be added to the seed set.
Entries in distances are assumed to be squared distances of candidate points to
the already sampled centers in the seed set. The op constructs one Markov chain
of the k-MC^2 algorithm and returns the index of one candidate point to be added
as an additional cluster center.
distances: Vector with squared distances to the closest previously sampled
cluster center for each candidate point.
seed: Scalar. Seed for initializing the random number generator.
index: Scalar with the index of the sampled point.
)");
REGISTER_OP("NearestNeighbors")
.Input("points: float32")
.Input("centers: float32")

View File

@ -55,6 +55,63 @@ class KmeansPlusPlusInitializationTest(test.TestCase):
self.runTestWithSeed(seed)
class KMC2InitializationTest(test.TestCase):
def runTestWithSeed(self, seed):
with self.test_session():
distances = np.zeros(1000).astype(np.float32)
distances[6] = 10e7
distances[4] = 10e3
sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed)
self.assertEquals(sampled_point.eval(), 6)
distances[6] = 0.0
sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed)
self.assertEquals(sampled_point.eval(), 4)
def testBasic(self):
for seed in range(100):
self.runTestWithSeed(seed)
class KMC2InitializationLargeTest(test.TestCase):
def setUp(self):
self._distances = np.zeros(1001)
self._distances[500] = 100.0
self._distances[1000] = 50.0
def testBasic(self):
with self.test_session():
counts = {}
seed = 0
for i in range(50):
sample = clustering_ops.kmc2_chain_initialization(
self._distances, seed + i).eval()
counts[sample] = counts.get(sample, 0) + 1
self.assertEquals(len(counts), 2)
self.assertTrue(500 in counts)
self.assertTrue(1000 in counts)
self.assertGreaterEqual(counts[500], 5)
self.assertGreaterEqual(counts[1000], 5)
class KMC2InitializationCornercaseTest(test.TestCase):
def setUp(self):
self._distances = np.zeros(10)
def runTestWithSeed(self, seed):
with self.test_session():
sampled_point = clustering_ops.kmc2_chain_initialization(
self._distances, seed)
self.assertEquals(sampled_point.eval(), 0)
def testBasic(self):
for seed in range(100):
self.runTestWithSeed(seed)
# A simple test that can be verified by hand.
class NearestCentersTest(test.TestCase):

View File

@ -50,6 +50,7 @@ COSINE_DISTANCE = 'cosine'
RANDOM_INIT = 'random'
KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus'
KMC2_INIT = 'kmc2'
# The name of the variable holding the cluster centers. Used by the Estimator.
CLUSTERS_VAR_NAME = 'clusters'
@ -66,7 +67,8 @@ class KMeans(object):
use_mini_batch=False,
mini_batch_steps_per_iteration=1,
random_seed=0,
kmeans_plus_plus_num_retries=2):
kmeans_plus_plus_num_retries=2,
kmc2_chain_length=200):
"""Creates an object for generating KMeans clustering graph.
This class implements the following variants of K-means algorithm:
@ -95,7 +97,8 @@ class KMeans(object):
exactly like a full-batch version.
Args:
inputs: An input tensor or list of input tensors
inputs: An input tensor or list of input tensors. It is assumed that the
data points have been previously randomly permuted.
num_clusters: An integer tensor specifying the number of clusters. This
argument is ignored if initial_clusters is a tensor or numpy array.
initial_clusters: Specifies the clusters used during initialization. One
@ -104,6 +107,7 @@ class KMeans(object):
- a function f(inputs, k) that returns up to k centers from `inputs`.
- "random": Choose centers randomly from `inputs`.
- "kmeans_plus_plus": Use kmeans++ to choose centers from `inputs`.
- "kmc2": Use the fast k-MC2 algorithm to choose centers from `inputs`.
In the last three cases, one batch of `inputs` may not yield
`num_clusters` centers, in which case initialization will require
multiple batches until enough centers are chosen. In the case of
@ -121,13 +125,17 @@ class KMeans(object):
additional points to draw from the current distribution before selecting
the best. If a negative value is specified, a heuristic is used to
sample O(log(num_to_sample)) additional points.
kmc2_chain_length: Determines how many candidate points are used by the
k-MC2 algorithm to produce one new cluster centers. If a (mini-)batch
contains less points, one new cluster center is generated from the
(mini-)batch.
Raises:
ValueError: An invalid argument was passed to initial_clusters or
distance_metric.
"""
if isinstance(initial_clusters, str) and initial_clusters not in [
RANDOM_INIT, KMEANS_PLUS_PLUS_INIT
RANDOM_INIT, KMEANS_PLUS_PLUS_INIT, KMC2_INIT
]:
raise ValueError(
"Unsupported initialization algorithm '%s'" % initial_clusters)
@ -141,6 +149,7 @@ class KMeans(object):
self._mini_batch_steps_per_iteration = int(mini_batch_steps_per_iteration)
self._random_seed = random_seed
self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries
self._kmc2_chain_length = kmc2_chain_length
@classmethod
def _distance_graph(cls, inputs, clusters, distance_metric):
@ -302,9 +311,10 @@ class KMeans(object):
else:
cluster_centers_updated = cluster_centers
update_in_steps = None
cluster_counts = (variable_scope.variable(
array_ops.ones([num_clusters], dtype=dtypes.int64))
if self._use_mini_batch else None)
cluster_counts = (
variable_scope.variable(
array_ops.ones([num_clusters], dtype=dtypes.int64))
if self._use_mini_batch else None)
return (cluster_centers, cluster_centers_initialized, cluster_counts,
cluster_centers_updated, update_in_steps)
@ -359,7 +369,7 @@ class KMeans(object):
init_op = _InitializeClustersOpFactory(
self._inputs, num_clusters, initial_clusters, self._distance_metric,
self._random_seed, self._kmeans_plus_plus_num_retries,
cluster_centers_var, cluster_centers_updated,
self._kmc2_chain_length, cluster_centers_var, cluster_centers_updated,
cluster_centers_initialized).op()
cluster_centers = cluster_centers_var
@ -520,8 +530,9 @@ class KMeans(object):
array_ops.reshape(array_ops.shape(inp)[0], [-1])),
[-1, 1]), cluster_idx, num_clusters))
with ops.colocate_with(cluster_centers, ignore_existing=True):
new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast(
math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon)
new_clusters_centers = math_ops.add_n(cluster_sums) / (
math_ops.cast(math_ops.add_n(cluster_counts), cluster_sums[0].dtype) +
epsilon)
if self._clusters_l2_normalized():
new_clusters_centers = nn_impl.l2_normalize(new_clusters_centers, dim=1)
return state_ops.assign(cluster_centers, new_clusters_centers)
@ -548,9 +559,12 @@ class _InitializeClustersOpFactory(object):
cluster_centers_initialized := true
"""
# TODO(ccolby): Refactor this class so that kmc2 isn't so much a special case.
def __init__(self, inputs, num_clusters, initial_clusters, distance_metric,
random_seed, kmeans_plus_plus_num_retries, cluster_centers,
cluster_centers_updated, cluster_centers_initialized):
random_seed, kmeans_plus_plus_num_retries, kmc2_chain_length,
cluster_centers, cluster_centers_updated,
cluster_centers_initialized):
"""Creates an op factory.
Args:
@ -560,6 +574,7 @@ class _InitializeClustersOpFactory(object):
distance_metric: See KMeans constructor.
random_seed: See KMeans constructor.
kmeans_plus_plus_num_retries: See KMeans constructor.
kmc2_chain_length: See KMeans constructor.
cluster_centers: The TF variable holding the initial centers. It may
already contain some centers when the op is executed.
cluster_centers_updated: A second TF variable to hold a copy of the
@ -575,6 +590,7 @@ class _InitializeClustersOpFactory(object):
self._distance_metric = distance_metric
self._random_seed = random_seed
self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries
self._kmc2_chain_length = kmc2_chain_length
self._cluster_centers = cluster_centers
self._cluster_centers_updated = cluster_centers_updated
self._cluster_centers_initialized = cluster_centers_initialized
@ -604,6 +620,90 @@ class _InitializeClustersOpFactory(object):
math_ops.to_int64(self._num_remaining), self._random_seed,
self._kmeans_plus_plus_num_retries)
def _kmc2_multiple_centers(self):
"""Adds new initial cluster centers using the k-MC2 algorithm.
In each call to the op, the provided batch is split into subsets based on
the specified `kmc2_chain_length`. On each subset, a single Markov chain of
the k-MC2 algorithm is used to add *one* new center cluster center. If there
are less than `kmc2_chain_length` points in the subset, a single center is
added using one Markov chain on the full input. It is assumed that the
provided batch has previously been randomly permuted. Otherwise, k-MC2 may
return suboptimal centers.
Returns:
An op that adds new cluster centers.
"""
# The op only operates on the first shard of data.
first_shard = self._inputs[0]
# Number of points in the input that can be used.
batch_size = array_ops.shape(first_shard)[0]
# Maximum number of subsets such that the size of each subset is at least
# `kmc2_chain_length`. Final subsets may be larger.
max_to_sample = math_ops.cast(
batch_size / self._kmc2_chain_length, dtype=dtypes.int32)
# We sample at least one new center and at most all remaining centers.
num_to_sample = math_ops.maximum(
math_ops.minimum(self._num_remaining, max_to_sample), 1)
def _cond(i, _):
"""Stopping condition for the while loop."""
return math_ops.less(i, num_to_sample)
def _body(i, _):
"""Body that adds a single new center based on a subset."""
def _sample_random():
"""Returns a random point as a cluster center."""
# By assumption the batch is reshuffled and _sample_random is always
# called for i=0. Hence, we simply return the first point.
new_center = array_ops.reshape(first_shard[0], [1, -1])
if self._distance_metric == COSINE_DISTANCE:
new_center = nn_impl.l2_normalize(new_center, dim=1)
return new_center
def _sample_kmc2_chain():
"""Returns previous centers as well as a new center sampled using k-MC2.
"""
# Extract the subset from the underlying batch.
start = i * self._kmc2_chain_length
end = start + self._kmc2_chain_length
subset = first_shard[start:end]
# Compute the distances from points in the subset to previous centers.
_, distances = gen_clustering_ops.nearest_neighbors(
subset, self._cluster_centers, 1)
# Sample index of new center using k-MC2 Markov chain.
new_center_index = gen_clustering_ops.kmc2_chain_initialization(
array_ops.squeeze(distances), self._random_seed)
# Extract actual new center.
newly_sampled_center = array_ops.reshape(subset[new_center_index],
[1, -1])
# Return concatenation with previously sampled centers.
if self._distance_metric == COSINE_DISTANCE:
newly_sampled_center = nn_impl.l2_normalize(
newly_sampled_center, dim=1)
return array_ops.concat([self._cluster_centers, newly_sampled_center],
0)
# Obtain a random point if there are no previously sampled centers.
# Otherwise, construct a k-MC2 Markov chain.
new_centers = control_flow_ops.cond(
math_ops.equal(self._num_selected, 0), _sample_random,
_sample_kmc2_chain)
# Assign new cluster centers to underlying variable.
assigned_centers = state_ops.assign(
self._cluster_centers, new_centers, validate_shape=False)
if self._cluster_centers_updated is not self._cluster_centers:
assigned_centers = state_ops.assign(
self._cluster_centers_updated,
assigned_centers,
validate_shape=False)
return i + 1, self._num_clusters - array_ops.shape(assigned_centers)[0]
# Add num_to_sample new data points.
_, num_remaining = control_flow_ops.while_loop(_cond, _body, [0, 0])
return num_remaining
def _greedy_batch_sampler(self, sampler):
# If the input dataset size is smaller than the number of centers
# remaining, choose the entire input dataset as centers. This can happen
@ -657,7 +757,10 @@ class _InitializeClustersOpFactory(object):
with ops.control_dependencies([
check_ops.assert_positive(self._num_remaining),
]):
num_now_remaining = self._add_new_centers()
if self._initial_clusters == KMC2_INIT:
num_now_remaining = self._kmc2_multiple_centers()
else:
num_now_remaining = self._add_new_centers()
return control_flow_ops.cond(
math_ops.equal(num_now_remaining, 0),
lambda: state_ops.assign(self._cluster_centers_initialized, True),

View File

@ -37,6 +37,7 @@ See the @{$python/contrib.framework} guide.
@@arg_scope
@@add_arg_scope
@@current_arg_scope
@@has_arg_scope
@@arg_scoped_arguments

View File

@ -67,6 +67,7 @@ from tensorflow.python.util import tf_decorator
__all__ = ['arg_scope',
'add_arg_scope',
'current_arg_scope',
'has_arg_scope',
'arg_scoped_arguments']
@ -83,7 +84,7 @@ def _get_arg_stack():
return _ARGSTACK
def _current_arg_scope():
def current_arg_scope():
stack = _get_arg_stack()
return stack[-1]
@ -144,7 +145,7 @@ def arg_scope(list_ops_or_scope, **kwargs):
raise TypeError('list_ops_or_scope must either be a list/tuple or reused'
'scope (i.e. dict)')
try:
current_scope = _current_arg_scope().copy()
current_scope = current_arg_scope().copy()
for op in list_ops_or_scope:
key_op = _key_op(op)
if not has_arg_scope(op):
@ -172,7 +173,7 @@ def add_arg_scope(func):
A tuple with the decorated function func_with_args().
"""
def func_with_args(*args, **kwargs):
current_scope = _current_arg_scope()
current_scope = current_arg_scope()
current_args = kwargs
key_func = _key_op(func)
if key_func in current_scope:

View File

@ -442,7 +442,8 @@ def read_keyed_batch_features(file_pattern,
feature_queue_capacity=100,
num_enqueue_threads=2,
parse_fn=None,
name=None):
name=None,
read_batch_size=None):
"""Adds operations to read, queue, batch and parse `Example` protos.
Given file pattern (or list of files), will setup a queue for file names,
@ -482,6 +483,8 @@ def read_keyed_batch_features(file_pattern,
parse_fn: Parsing function, takes `Example` Tensor returns parsed
representation. If `None`, no parsing is done.
name: Name of resulting op.
read_batch_size: An int or scalar `Tensor` specifying the number of
records to read at once. If `None`, defaults to `batch_size`.
Returns:
Returns tuple of:
@ -493,6 +496,7 @@ def read_keyed_batch_features(file_pattern,
"""
with ops.name_scope(name, 'read_batch_features', [file_pattern]) as scope:
if read_batch_size is None: read_batch_size = batch_size
keys, examples = read_keyed_batch_examples(
file_pattern,
batch_size,
@ -501,7 +505,7 @@ def read_keyed_batch_features(file_pattern,
num_epochs=num_epochs,
queue_capacity=queue_capacity,
num_threads=reader_num_threads,
read_batch_size=batch_size,
read_batch_size=read_batch_size,
parse_fn=parse_fn,
name=scope)
# Parse the example.
@ -727,7 +731,8 @@ def read_batch_features(file_pattern,
reader_num_threads=1,
num_enqueue_threads=2,
parse_fn=None,
name=None):
name=None,
read_batch_size=None):
"""Adds operations to read, queue, batch and parse `Example` protos.
Given file pattern (or list of files), will setup a queue for file names,
@ -768,6 +773,8 @@ def read_batch_features(file_pattern,
parse_fn: Parsing function, takes `Example` Tensor returns parsed
representation. If `None`, no parsing is done.
name: Name of resulting op.
read_batch_size: An int or scalar `Tensor` specifying the number of
records to read at once. If `None`, defaults to `batch_size`.
Returns:
A dict of `Tensor` or `SparseTensor` objects for each in `features`.
@ -786,6 +793,7 @@ def read_batch_features(file_pattern,
reader_num_threads=reader_num_threads,
feature_queue_capacity=feature_queue_capacity,
num_enqueue_threads=num_enqueue_threads,
read_batch_size=read_batch_size,
parse_fn=parse_fn,
name=name)
return features

View File

@ -502,6 +502,7 @@ $(wildcard tensorflow/core/platform/google/*) \
$(wildcard tensorflow/core/platform/google/*/*) \
$(wildcard tensorflow/core/platform/jpeg.*) \
$(wildcard tensorflow/core/platform/png.*) \
$(wildcard tensorflow/core/platform/s3/*) \
$(wildcard tensorflow/core/platform/stream_executor.*) \
$(wildcard tensorflow/core/platform/windows/*) \
$(wildcard tensorflow/core/user_ops/*.cu.cc) \

View File

@ -20,11 +20,11 @@ DOWNLOADS_DIR=tensorflow/contrib/makefile/downloads
BZL_FILE_PATH=tensorflow/workspace.bzl
EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
GEMMLOWP_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)"
GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)"
GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz"
NSYNC_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
PROTOBUF_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
RE2_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
PROTOBUF_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
RE2_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)"
FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
# TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64,

File diff suppressed because it is too large Load Diff

View File

@ -1101,7 +1101,7 @@ class StreamingPrecisionTest(test.TestCase):
predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
labels = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
precision, update_op = metrics.streaming_precision(predictions, labels)
with self.test_session() as sess:
@ -1265,7 +1265,7 @@ class StreamingRecallTest(test.TestCase):
predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
labels = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
recall, update_op = metrics.streaming_recall(predictions, labels)
with self.test_session() as sess:
@ -1388,7 +1388,7 @@ class StreamingFPRTest(test.TestCase):
predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
labels = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
fpr, update_op = metrics.streaming_false_positive_rate(
predictions, labels)
@ -1516,7 +1516,7 @@ class StreamingFNRTest(test.TestCase):
predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
labels = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
fnr, update_op = metrics.streaming_false_negative_rate(
predictions, labels)
@ -1737,7 +1737,7 @@ class StreamingAUCTest(test.TestCase):
predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
labels = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
auc, update_op = metrics.streaming_auc(predictions, labels)
with self.test_session() as sess:
@ -2009,7 +2009,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase):
predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
labels = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
specificity, update_op = metrics.streaming_specificity_at_sensitivity(
predictions, labels, sensitivity=0.7)
@ -2271,7 +2271,7 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
labels = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
thresholds = [0, 0.5, 1.0]
prec, prec_op = metrics.streaming_precision_at_thresholds(predictions,
labels,
@ -2282,12 +2282,14 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
with self.test_session() as sess:
sess.run(variables.local_variables_initializer())
# Run several updates, then verify idempotency.
sess.run([prec_op, rec_op])
# Run several updates.
for _ in range(10):
sess.run([prec_op, rec_op])
# Then verify idempotency.
initial_prec = prec.eval()
initial_rec = rec.eval()
for _ in range(10):
sess.run([prec_op, rec_op])
self.assertAllClose(initial_prec, prec.eval())
self.assertAllClose(initial_rec, rec.eval())
@ -2361,14 +2363,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
rec, rec_op = metrics.streaming_recall_at_thresholds(
predictions, labels, thresholds, weights=weights)
[prec_low, prec_high] = array_ops.split(
value=prec, num_or_size_splits=2, axis=0)
prec_low = array_ops.reshape(prec_low, shape=())
prec_high = array_ops.reshape(prec_high, shape=())
[rec_low, rec_high] = array_ops.split(
value=rec, num_or_size_splits=2, axis=0)
rec_low = array_ops.reshape(rec_low, shape=())
rec_high = array_ops.reshape(rec_high, shape=())
prec_low = prec[0]
prec_high = prec[1]
rec_low = rec[0]
rec_high = rec[1]
sess.run(variables.local_variables_initializer())
sess.run([prec_op, rec_op])
@ -2391,14 +2389,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
rec, rec_op = metrics.streaming_recall_at_thresholds(
predictions, labels, thresholds, weights=weights)
[prec_low, prec_high] = array_ops.split(
value=prec, num_or_size_splits=2, axis=0)
prec_low = array_ops.reshape(prec_low, shape=())
prec_high = array_ops.reshape(prec_high, shape=())
[rec_low, rec_high] = array_ops.split(
value=rec, num_or_size_splits=2, axis=0)
rec_low = array_ops.reshape(rec_low, shape=())
rec_high = array_ops.reshape(rec_high, shape=())
prec_low = prec[0]
prec_high = prec[1]
rec_low = rec[0]
rec_high = rec[1]
sess.run(variables.local_variables_initializer())
sess.run([prec_op, rec_op])
@ -2420,10 +2414,10 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase):
rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels,
thresholds)
[prec_low, prec_high] = array_ops.split(
value=prec, num_or_size_splits=2, axis=0)
[rec_low, rec_high] = array_ops.split(
value=rec, num_or_size_splits=2, axis=0)
prec_low = prec[0]
prec_high = prec[1]
rec_low = rec[0]
rec_high = rec[1]
sess.run(variables.local_variables_initializer())
sess.run([prec_op, rec_op])
@ -2562,7 +2556,7 @@ class StreamingFPRThresholdsTest(test.TestCase):
predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
labels = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
thresholds = [0, 0.5, 1.0]
fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds(
predictions, labels, thresholds)
@ -2794,7 +2788,7 @@ class StreamingFNRThresholdsTest(test.TestCase):
predictions = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
labels = random_ops.random_uniform(
(10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2)
(10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2)
thresholds = [0, 0.5, 1.0]
fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds(
predictions, labels, thresholds)

View File

@ -13,6 +13,34 @@ py_library(
deps = [],
)
py_library(
name = "graph_matcher",
srcs = [
"python/graph_matcher.py",
],
srcs_version = "PY2AND3",
deps = [],
)
py_test(
name = "graph_matcher_test",
size = "small",
srcs = ["python/graph_matcher_test.py"],
srcs_version = "PY2AND3",
deps = [
":graph_matcher",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:nn_ops",
"//tensorflow/python:platform_test",
],
)
py_library(
name = "input_to_ops",
srcs = ["python/input_to_ops.py"],
@ -43,6 +71,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":common",
":graph_matcher",
":input_to_ops",
"//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/python:array_ops",
@ -58,6 +87,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":fold_batch_norms",
":graph_matcher",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
@ -147,10 +177,11 @@ py_test(
py_test(
name = "quantize_parameterized_test",
size = "medium",
size = "large",
srcs = ["python/quantize_parameterized_test.py"],
srcs_version = "PY2AND3",
deps = [
":fold_batch_norms",
":quantize",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/python:array_ops",

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.quantized.mangle.copy_graph."""
"""Tests for copy_graph."""
from __future__ import absolute_import
from __future__ import division

View File

@ -21,7 +21,9 @@ from __future__ import print_function
import re
from tensorflow.contrib import graph_editor
from tensorflow.contrib.quantize.python import common
from tensorflow.contrib.quantize.python import graph_matcher
from tensorflow.contrib.quantize.python import input_to_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
@ -29,7 +31,7 @@ from tensorflow.python.ops import nn_ops
def FoldBatchNorms(graph):
"""Finds batch norm layers in the graph, folds them into preceding layers.
"""Finds batch norm layers and folds them into preceding layers.
Folding only affects the following layers: Conv2D, fully connected, depthwise
convolution.
@ -40,10 +42,269 @@ def FoldBatchNorms(graph):
Raises:
ValueError: When batch norm folding fails.
"""
# Fail immediately when the graph contains unsupported fused batch norm ops.
if any(op for op in graph.get_operations() if op.type == 'FusedBatchNorm'):
raise ValueError('Fused batch norm is not supported')
_FoldFusedBatchNorms(graph)
_FoldUnfusedBatchNorms(graph)
def _FoldFusedBatchNorms(graph):
"""Finds fused batch norm layers and folds them into preceding layers.
Folding only affects the following layers: Conv2D, fully connected, depthwise
convolution.
Args:
graph: Graph to walk and modify.
Raises:
ValueError: When batch norm folding fails.
"""
for match in _FindFusedBatchNorms(graph):
scope, sep, _ = match.layer_op.name.rpartition('/')
# Make sure new ops are added to `graph` and put on the same device as
# `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope
# named `scope`. Otherwise, TF creates a unique scope whose name starts with
# `scope`.
with graph.as_default(), graph.name_scope(scope + sep), ops.device(
match.bn_op.device):
# new weights = old weights * gamma / sqrt(variance + epsilon)
# new biases = -mean * gamma / sqrt(variance + epsilon) + beta
multiplier_tensor = match.gamma_tensor * math_ops.rsqrt(
match.variance_tensor + match.bn_op.get_attr('epsilon'))
bias_tensor = math_ops.subtract(
match.beta_tensor, match.mean_tensor * multiplier_tensor, name='bias')
# The shape of depthwise weights is different, so we need to reshape the
# multiplier_tensor to ensure that the scaled_weight_tensor has the
# expected shape.
if match.layer_op.type == 'DepthwiseConv2dNative':
new_shape = [
match.weight_tensor.get_shape().as_list()[2],
match.weight_tensor.get_shape().as_list()[3]
]
multiplier_tensor = array_ops.reshape(
multiplier_tensor, new_shape, name='scale_reshape')
# TODO(suharshs): This naming of the following ops needs to carefully
# follow the naming expected by quantize.py. Generalize the quantize code
# to not require these delicate naming conventions.
scaled_weight_tensor = math_ops.multiply(
match.weight_tensor, multiplier_tensor, name='mul_fold')
new_layer_tensor = _CloneWithNewOperands(
match.layer_op, match.input_tensor, scaled_weight_tensor)
bias_add_tensor = math_ops.add(
new_layer_tensor, bias_tensor, name='add_fold')
nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor,
match.output_tensor)
if nodes_modified_count != 1:
raise ValueError(
'Unexpected inputs to op: %s' % match.output_tensor.name)
def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor):
"""Clones layer_op with input_tensor and weight_tensor as new inputs."""
new_layer_name = layer_op.name.split('/')[-1] + '_Fold'
if layer_op.type == 'Conv2D':
return nn_ops.conv2d(
input_tensor,
weight_tensor,
strides=layer_op.get_attr('strides'),
padding=layer_op.get_attr('padding'),
use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'),
data_format=layer_op.get_attr('data_format'),
name=new_layer_name)
elif layer_op.type == 'MatMul':
return math_ops.matmul(
input_tensor,
weight_tensor,
transpose_a=layer_op.get_attr('transpose_a'),
transpose_b=layer_op.get_attr('transpose_b'),
name=new_layer_name)
elif layer_op.type == 'DepthwiseConv2dNative':
return nn.depthwise_conv2d(
input_tensor,
weight_tensor,
strides=layer_op.get_attr('strides'),
padding=layer_op.get_attr('padding'),
name=new_layer_name)
else:
raise ValueError('Cannot handle operation of type: %s' % layer_op.type)
def _FindFusedBatchNorms(graph):
"""Finds all ops and tensors related to found FusedBatchNorms.
Args:
graph: Graph to inspect.
Yields:
_FusedBatchNormMatches.
"""
input_pattern = graph_matcher.OpTypePattern('*')
weight_pattern = graph_matcher.OpTypePattern('*')
gamma_pattern = graph_matcher.OpTypePattern('*')
beta_pattern = graph_matcher.OpTypePattern('*')
mean_pattern = graph_matcher.OpTypePattern('*')
variance_pattern = graph_matcher.OpTypePattern('*')
conv_pattern = graph_matcher.OpTypePattern(
'Conv2D|DepthwiseConv2dNative', inputs=[input_pattern, weight_pattern])
# MatMul has a Reshape between it and FusedBatchNorm.
matmul_pattern = graph_matcher.OpTypePattern(
'MatMul', inputs=[input_pattern, weight_pattern])
matmul_reshape_pattern = graph_matcher.OpTypePattern(
'Reshape', inputs=[matmul_pattern,
graph_matcher.OpTypePattern('*')])
conv_batch_norm_pattern = graph_matcher.OpTypePattern(
'FusedBatchNorm',
inputs=[
conv_pattern, gamma_pattern, beta_pattern, mean_pattern,
variance_pattern
])
matmul_batch_norm_pattern = graph_matcher.OpTypePattern(
'FusedBatchNorm',
inputs=[
matmul_reshape_pattern, gamma_pattern, beta_pattern, mean_pattern,
variance_pattern
])
matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern(
'Reshape',
inputs=[matmul_batch_norm_pattern,
graph_matcher.OpTypePattern('*')])
conv_matcher = graph_matcher.GraphMatcher(conv_batch_norm_pattern)
matmul_matcher = graph_matcher.GraphMatcher(matmul_bn_output_reshape_pattern)
def _GetCommonTensors(match_result):
"""Gets tensors needed for FusedBatchNormMatch from match_result."""
input_tensor = match_result.get_tensor(input_pattern)
weight_tensor = match_result.get_tensor(weight_pattern)
gamma_tensor = match_result.get_tensor(gamma_pattern)
beta_tensor = match_result.get_tensor(beta_pattern)
# FusedBatchNorm in training is different from that in inference. It takes
# empty 'mean' and empty 'variance', and produces the mean and the variance
# of the batch. Therefore, when is_training is true, mean_tensor and
# variance_tensor point to 1st and 2nd (0-based) output of bn_op,
# respectively; when is_training is false, they point to bn_op's inputs.
is_training = bn_op.get_attr('is_training')
if is_training:
mean_tensor = bn_op.outputs[1]
variance_tensor = bn_op.outputs[2]
else:
mean_tensor = match_result.get_tensor(mean_pattern)
variance_tensor = match_result.get_tensor(variance_pattern)
return (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
variance_tensor)
for match_result in conv_matcher.match_graph(graph):
layer_op = match_result.get_op(conv_pattern)
bn_op = match_result.get_op(conv_batch_norm_pattern)
# In the case of convolution the output_tensor is the output of bn_op.
output_tensor = bn_op.outputs[0]
(input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
variance_tensor) = _GetCommonTensors(match_result)
yield _FusedBatchNormMatch(
layer_op=layer_op,
bn_op=bn_op,
output_tensor=output_tensor,
input_tensor=input_tensor,
weight_tensor=weight_tensor,
gamma_tensor=gamma_tensor,
beta_tensor=beta_tensor,
mean_tensor=mean_tensor,
variance_tensor=variance_tensor)
for match_result in matmul_matcher.match_graph(graph):
layer_op = match_result.get_op(matmul_pattern)
bn_op = match_result.get_op(matmul_batch_norm_pattern)
# In the MatMul case, the output of batch norm is reshaped back into a
# 2D tensor, so the output_tensor is the output of the Reshape op.
output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern)
output_tensor = output_reshape_op.outputs[0]
(input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
variance_tensor) = _GetCommonTensors(match_result)
yield _FusedBatchNormMatch(
layer_op=layer_op,
bn_op=bn_op,
output_tensor=output_tensor,
input_tensor=input_tensor,
weight_tensor=weight_tensor,
gamma_tensor=gamma_tensor,
beta_tensor=beta_tensor,
mean_tensor=mean_tensor,
variance_tensor=variance_tensor)
class _FusedBatchNormMatch(object):
"""Contains all information related to a found FusedBatchNorm."""
def __init__(self, layer_op, bn_op, output_tensor, input_tensor,
weight_tensor, gamma_tensor, beta_tensor, mean_tensor,
variance_tensor):
self._layer_op = layer_op
self._bn_op = bn_op
self._output_tensor = output_tensor
self._input_tensor = input_tensor
self._weight_tensor = weight_tensor
self._gamma_tensor = gamma_tensor
self._beta_tensor = beta_tensor
self._mean_tensor = mean_tensor
self._variance_tensor = variance_tensor
@property
def layer_op(self):
return self._layer_op
@property
def bn_op(self):
return self._bn_op
@property
def output_tensor(self):
return self._output_tensor
@property
def input_tensor(self):
return self._input_tensor
@property
def weight_tensor(self):
return self._weight_tensor
@property
def gamma_tensor(self):
return self._gamma_tensor
@property
def beta_tensor(self):
return self._beta_tensor
@property
def mean_tensor(self):
return self._mean_tensor
@property
def variance_tensor(self):
return self._variance_tensor
def _FoldUnfusedBatchNorms(graph):
"""Finds unfused batch norm layers and folds them into preceding layers.
Folding only affects the following layers: Conv2D, fully connected, depthwise
convolution.
Args:
graph: Graph to walk and modify.
Raises:
ValueError: When batch norm folding fails.
"""
input_to_ops_map = input_to_ops.InputToOps(graph)
for bn in common.BatchNormGroups(graph):

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import copy
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.quantize.python import fold_batch_norms
from tensorflow.python.framework import dtypes
@ -35,57 +34,32 @@ conv2d = layers.conv2d
fully_connected = layers.fully_connected
separable_conv2d = layers.separable_conv2d
_DEFAULT_BATCH_NORM_PARAMS = {
'center': True,
'scale': True,
'decay': 1.0 - 0.003,
'fused': False,
}
# TODO(suharshs): Use parameterized test once OSS TF supports it.
class FoldBatchNormsTest(test_util.TensorFlowTestCase):
def _RunTestOverParameters(self, test_fn):
parameters_list = [
# (relu, relu_op_name, with_bypass)
(nn_ops.relu6, 'Relu6', False),
(nn_ops.relu, 'Relu', False),
(nn_ops.relu6, 'Relu6', True),
(nn_ops.relu, 'Relu', True),
# (relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm)
(nn_ops.relu6, 'Relu6', False, False, False),
(nn_ops.relu, 'Relu', False, False, False),
(nn_ops.relu6, 'Relu6', True, False, False),
(nn_ops.relu, 'Relu', True, False, False),
(nn_ops.relu6, 'Relu6', False, True, False),
(nn_ops.relu, 'Relu', False, True, False),
(nn_ops.relu6, 'Relu6', True, True, False),
(nn_ops.relu, 'Relu', True, True, False),
# Fused batch norm always has scaling enabled.
(nn_ops.relu6, 'Relu6', False, True, True),
(nn_ops.relu, 'Relu', False, True, True),
(nn_ops.relu6, 'Relu6', True, True, True),
(nn_ops.relu, 'Relu', True, True, True),
]
for parameters in parameters_list:
test_fn(parameters[0], parameters[1], parameters[2])
for params in parameters_list:
test_fn(params[0], params[1], params[2], params[3], params[4])
def testFailsWithFusedBatchNorm(self):
self._RunTestOverParameters(self._TestFailsWithFusedBatchNorm)
def _TestFailsWithFusedBatchNorm(self, relu, relu_op_name, with_bypass):
"""Tests that batch norm fails when fused batch norm ops are present."""
g = ops.Graph()
with g.as_default():
batch_size, height, width = 5, 128, 128
inputs = array_ops.zeros((batch_size, height, width, 3))
out_depth = 3 if with_bypass else 32
stride = 1 if with_bypass else 2
activation_fn = None if with_bypass else relu
batch_norm_params = _DEFAULT_BATCH_NORM_PARAMS.copy()
batch_norm_params['fused'] = True
scope = 'test/test2' if with_bypass else 'test'
node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
normalizer_fn=batch_norm,
normalizer_params=batch_norm_params,
scope=scope)
if with_bypass:
node = math_ops.add(inputs, node, name='test/Add')
relu(node, name='test/' + relu_op_name)
with self.assertRaises(ValueError):
fold_batch_norms.FoldBatchNorms(g)
def _TestFoldConv2d(self, relu, relu_op_name, with_bypass):
def _TestFoldConv2d(self, relu, relu_op_name, with_bypass, has_scaling,
fused_batch_norm):
"""Tests folding cases: inputs -> Conv2d with batch norm -> Relu*.
Args:
@ -93,6 +67,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
relu_op_name: String, name of the Relu* operation.
with_bypass: Bool, when true there is an extra connection added from
inputs to just before Relu*.
has_scaling: Bool, when true the batch norm has scaling.
fused_batch_norm: Bool, when true the batch norm is fused.
"""
g = ops.Graph()
with g.as_default():
@ -102,12 +78,17 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
stride = 1 if with_bypass else 2
activation_fn = None if with_bypass else relu
scope = 'test/test2' if with_bypass else 'test'
node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
normalizer_fn=batch_norm,
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS,
scope=scope)
node = conv2d(
inputs,
out_depth, [5, 5],
stride=stride,
padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(
scale=has_scaling, fused=fused_batch_norm),
scope=scope)
if with_bypass:
node = math_ops.add(inputs, node, name='test/Add')
relu(node, name='test/' + relu_op_name)
@ -116,9 +97,10 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
folded_mul = g.get_operation_by_name(scope + '/mul_fold')
self.assertEqual(folded_mul.type, 'Mul')
self._AssertInputOpsAre(folded_mul,
[scope + '/weights/read',
scope + '/BatchNorm/batchnorm/mul'])
self._AssertInputOpsAre(folded_mul, [
scope + '/weights/read',
self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm)
])
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold'])
folded_conv = g.get_operation_by_name(scope + '/convolution_Fold')
@ -129,16 +111,18 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
folded_add = g.get_operation_by_name(scope + '/add_fold')
self.assertEqual(folded_add.type, 'Add')
self._AssertInputOpsAre(folded_add,
[scope + '/convolution_Fold',
scope + '/BatchNorm/batchnorm/sub'])
self._AssertInputOpsAre(folded_add, [
scope + '/convolution_Fold',
self._BathNormBiasName(scope, fused_batch_norm)
])
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
def testFoldConv2d(self):
self._RunTestOverParameters(self._TestFoldConv2d)
def _TestFoldConv2dUnknownShape(self, relu, relu_op_name, with_bypass):
def _TestFoldConv2dUnknownShape(self, relu, relu_op_name, with_bypass,
has_scaling, fused_batch_norm):
"""Tests folding cases: inputs -> Conv2d with batch norm -> Relu*.
Tests that folding works even with an input shape where some dimensions are
@ -149,6 +133,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
relu_op_name: String, name of the Relu* operation.
with_bypass: Bool, when true there is an extra connection added from
inputs to just before Relu*.
has_scaling: Bool, when true the batch norm has scaling.
fused_batch_norm: Bool, when true the batch norm is fused.
"""
g = ops.Graph()
with g.as_default():
@ -165,7 +151,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
normalizer_fn=batch_norm,
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS,
normalizer_params=self._BatchNormParams(
scale=has_scaling, fused=fused_batch_norm),
scope=scope)
if with_bypass:
node = math_ops.add(inputs, node, name='test/Add')
@ -176,7 +163,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
folded_mul = g.get_operation_by_name(scope + '/mul_fold')
self.assertEqual(folded_mul.type, 'Mul')
self._AssertInputOpsAre(folded_mul, [
scope + '/weights/read', scope + '/BatchNorm/batchnorm/mul'
scope + '/weights/read',
self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm)
])
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold'])
@ -188,7 +176,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
folded_add = g.get_operation_by_name(scope + '/add_fold')
self.assertEqual(folded_add.type, 'Add')
self._AssertInputOpsAre(folded_add, [
scope + '/convolution_Fold', scope + '/BatchNorm/batchnorm/sub'
scope + '/convolution_Fold',
self._BathNormBiasName(scope, fused_batch_norm)
])
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
@ -196,62 +185,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
def testFoldConv2dUnknownShape(self):
self._RunTestOverParameters(self._TestFoldConv2dUnknownShape)
def _TestFoldConv2dWithoutScale(self, relu, relu_op_name, with_bypass):
"""Tests folding cases: inputs -> Conv2d with batch norm -> Relu*.
Args:
relu: Callable that returns an Operation, a factory method for the Relu*.
relu_op_name: String, name of the Relu* operation.
with_bypass: Bool, when true there is an extra connection added from
inputs to just before Relu*.
"""
g = ops.Graph()
with g.as_default():
batch_size, height, width = 5, 128, 128
inputs = array_ops.zeros((batch_size, height, width, 3))
out_depth = 3 if with_bypass else 32
stride = 1 if with_bypass else 2
activation_fn = None if with_bypass else relu
bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS)
bn_params['scale'] = False
scope = 'test/test2' if with_bypass else 'test'
node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
normalizer_fn=batch_norm,
normalizer_params=bn_params,
scope=scope)
if with_bypass:
node = math_ops.add(inputs, node, name='test/Add')
relu(node, name='test/' + relu_op_name)
fold_batch_norms.FoldBatchNorms(g)
folded_mul = g.get_operation_by_name(scope + '/mul_fold')
self.assertEqual(folded_mul.type, 'Mul')
self._AssertInputOpsAre(folded_mul,
[scope + '/weights/read',
scope + '/BatchNorm/batchnorm/Rsqrt'])
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/convolution_Fold'])
folded_conv = g.get_operation_by_name(scope + '/convolution_Fold')
self.assertEqual(folded_conv.type, 'Conv2D')
self._AssertInputOpsAre(folded_conv,
[scope + '/mul_fold', inputs.op.name])
self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold'])
folded_add = g.get_operation_by_name(scope + '/add_fold')
self.assertEqual(folded_add.type, 'Add')
self._AssertInputOpsAre(folded_add,
[scope + '/convolution_Fold',
scope + '/BatchNorm/batchnorm/sub'])
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
def testFoldConv2dWithoutScale(self):
self._RunTestOverParameters(self._TestFoldConv2dWithoutScale)
def _TestFoldFullyConnectedLayer(self, relu, relu_op_name, with_bypass):
def _TestFoldFullyConnectedLayer(self, relu, relu_op_name, with_bypass,
has_scaling, fused_batch_norm):
"""Tests folding cases: inputs -> FC with batch norm -> Relu*.
Args:
@ -259,6 +194,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
relu_op_name: String, name of the Relu* operation.
with_bypass: Bool, when true there is an extra connection added from
inputs to just before Relu*.
has_scaling: Bool, when true the batch norm has scaling.
fused_batch_norm: Bool, when true the batch norm is fused.
"""
g = ops.Graph()
with g.as_default():
@ -267,12 +204,15 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
out_depth = 256 if with_bypass else 128
activation_fn = None if with_bypass else relu
scope = 'test/test2' if with_bypass else 'test'
node = fully_connected(inputs, out_depth,
weights_initializer=self._WeightInit(0.03),
activation_fn=activation_fn,
normalizer_fn=batch_norm,
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS,
scope=scope)
node = fully_connected(
inputs,
out_depth,
weights_initializer=self._WeightInit(0.03),
activation_fn=activation_fn,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(
scale=has_scaling, fused=fused_batch_norm),
scope=scope)
if with_bypass:
node = math_ops.add(inputs, node, name='test/Add')
relu(node, name='test/' + relu_op_name)
@ -281,9 +221,10 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
folded_mul = g.get_operation_by_name(scope + '/mul_fold')
self.assertEqual(folded_mul.type, 'Mul')
self._AssertInputOpsAre(folded_mul,
[scope + '/weights/read',
scope + '/BatchNorm/batchnorm/mul'])
self._AssertInputOpsAre(folded_mul, [
scope + '/weights/read',
self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm)
])
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/MatMul_Fold'])
folded_conv = g.get_operation_by_name(scope + '/MatMul_Fold')
@ -294,71 +235,18 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
folded_add = g.get_operation_by_name(scope + '/add_fold')
self.assertEqual(folded_add.type, 'Add')
self._AssertInputOpsAre(folded_add,
[scope + '/MatMul_Fold',
scope + '/BatchNorm/batchnorm/sub'])
self._AssertInputOpsAre(folded_add, [
scope + '/MatMul_Fold',
self._BathNormBiasName(scope, fused_batch_norm)
])
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
def testFoldFullyConnectedLayer(self):
self._RunTestOverParameters(self._TestFoldFullyConnectedLayer)
def _TestFoldFullyConnectedLayerWithoutScale(self, relu, relu_op_name,
with_bypass):
"""Tests folding cases: inputs -> FC with batch norm -> Relu*.
Args:
relu: Callable that returns an Operation, a factory method for the Relu*.
relu_op_name: String, name of the Relu* operation.
with_bypass: Bool, when true there is an extra connection added from
inputs to just before Relu*.
"""
g = ops.Graph()
with g.as_default():
batch_size, depth = 5, 256
inputs = array_ops.zeros((batch_size, depth))
out_depth = 256 if with_bypass else 128
activation_fn = None if with_bypass else relu
bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS)
bn_params['scale'] = False
scope = 'test/test2' if with_bypass else 'test'
node = fully_connected(inputs, out_depth,
weights_initializer=self._WeightInit(0.03),
activation_fn=activation_fn,
normalizer_fn=batch_norm,
normalizer_params=bn_params,
scope=scope)
if with_bypass:
node = math_ops.add(inputs, node, name='test/Add')
relu(node, name='test/' + relu_op_name)
fold_batch_norms.FoldBatchNorms(g)
folded_mul = g.get_operation_by_name(scope + '/mul_fold')
self.assertEqual(folded_mul.type, 'Mul')
self._AssertInputOpsAre(folded_mul,
[scope + '/weights/read',
scope + '/BatchNorm/batchnorm/Rsqrt'])
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/MatMul_Fold'])
folded_conv = g.get_operation_by_name(scope + '/MatMul_Fold')
self.assertEqual(folded_conv.type, 'MatMul')
self._AssertInputOpsAre(folded_conv,
[scope + '/mul_fold', inputs.op.name])
self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold'])
folded_add = g.get_operation_by_name(scope + '/add_fold')
self.assertEqual(folded_add.type, 'Add')
self._AssertInputOpsAre(folded_add,
[scope + '/MatMul_Fold',
scope + '/BatchNorm/batchnorm/sub'])
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
def testFoldFullyConnectedLayerWithoutScale(self):
self._RunTestOverParameters(self._TestFoldFullyConnectedLayerWithoutScale)
def _TestFoldDepthwiseConv2d(self, relu, relu_op_name, with_bypass):
def _TestFoldDepthwiseConv2d(self, relu, relu_op_name, with_bypass,
has_scaling, fused_batch_norm):
"""Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*.
Args:
@ -366,6 +254,8 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
relu_op_name: String, name of the Relu* operation.
with_bypass: Bool, when true there is an extra connection added from
inputs to just before Relu*.
has_scaling: Bool, when true the batch norm has scaling.
fused_batch_norm: Bool, when true the batch norm is fused.
"""
g = ops.Graph()
with g.as_default():
@ -374,13 +264,18 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
stride = 1 if with_bypass else 2
activation_fn = None if with_bypass else relu
scope = 'test/test2' if with_bypass else 'test'
node = separable_conv2d(inputs, None, [5, 5], stride=stride,
depth_multiplier=1.0, padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
normalizer_fn=batch_norm,
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS,
scope=scope)
node = separable_conv2d(
inputs,
None, [5, 5],
stride=stride,
depth_multiplier=1.0,
padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(
scale=has_scaling, fused=fused_batch_norm),
scope=scope)
if with_bypass:
node = math_ops.add(inputs, node, name='test/Add')
relu(node, name='test/' + relu_op_name)
@ -396,9 +291,10 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
scale_reshape = g.get_operation_by_name(scope + '/scale_reshape')
self.assertEqual(scale_reshape.type, 'Reshape')
self._AssertInputOpsAre(scale_reshape,
[scope + '/BatchNorm/batchnorm/mul',
scope + '/scale_reshape/shape'])
self._AssertInputOpsAre(scale_reshape, [
self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm),
scope + '/scale_reshape/shape'
])
self._AssertOutputGoesToOps(scale_reshape, g, [scope + '/mul_fold'])
folded_conv = g.get_operation_by_name(scope + '/depthwise_Fold')
@ -409,77 +305,35 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
folded_add = g.get_operation_by_name(scope + '/add_fold')
self.assertEqual(folded_add.type, 'Add')
self._AssertInputOpsAre(folded_add,
[scope + '/depthwise_Fold',
scope + '/BatchNorm/batchnorm/sub'])
self._AssertInputOpsAre(folded_add, [
scope + '/depthwise_Fold',
self._BathNormBiasName(scope, fused_batch_norm)
])
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
def testFoldDepthwiseConv2d(self):
self._RunTestOverParameters(self._TestFoldDepthwiseConv2d)
def _TestFoldDepthwiseConv2dWithoutScale(self, relu, relu_op_name,
with_bypass):
"""Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*.
def _BatchNormParams(self, scale=True, fused=False):
return {
'center': True,
'scale': scale,
'decay': 1.0 - 0.003,
'fused': fused
}
Args:
relu: Callable that returns an Operation, a factory method for the Relu*.
relu_op_name: String, name of the Relu* operation.
with_bypass: Bool, when true there is an extra connection added from
inputs to just before Relu*.
"""
g = ops.Graph()
with g.as_default():
batch_size, height, width = 5, 128, 128
inputs = array_ops.zeros((batch_size, height, width, 3))
stride = 1 if with_bypass else 2
activation_fn = None if with_bypass else relu
bn_params = copy.copy(_DEFAULT_BATCH_NORM_PARAMS)
bn_params['scale'] = False
scope = 'test/test2' if with_bypass else 'test'
node = separable_conv2d(inputs, None, [5, 5], stride=stride,
depth_multiplier=1.0, padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=activation_fn,
normalizer_fn=batch_norm,
normalizer_params=bn_params,
scope=scope)
if with_bypass:
node = math_ops.add(inputs, node, name='test/Add')
relu(node, name='test/' + relu_op_name)
def _BatchNormMultiplierName(self, scope, has_scaling, fused):
if has_scaling:
if fused:
return scope + '/mul'
return scope + '/BatchNorm/batchnorm/mul'
return scope + '/BatchNorm/batchnorm/Rsqrt'
fold_batch_norms.FoldBatchNorms(g)
folded_mul = g.get_operation_by_name(scope + '/mul_fold')
self.assertEqual(folded_mul.type, 'Mul')
self._AssertInputOpsAre(folded_mul,
[scope + '/depthwise_weights/read',
scope + '/scale_reshape'])
self._AssertOutputGoesToOps(folded_mul, g, [scope + '/depthwise_Fold'])
scale_reshape = g.get_operation_by_name(scope + '/scale_reshape')
self.assertEqual(scale_reshape.type, 'Reshape')
self._AssertInputOpsAre(scale_reshape,
[scope + '/BatchNorm/batchnorm/Rsqrt',
scope + '/scale_reshape/shape'])
self._AssertOutputGoesToOps(scale_reshape, g, [scope + '/mul_fold'])
folded_conv = g.get_operation_by_name(scope + '/depthwise_Fold')
self.assertEqual(folded_conv.type, 'DepthwiseConv2dNative')
self._AssertInputOpsAre(folded_conv,
[scope + '/mul_fold', inputs.op.name])
self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold'])
folded_add = g.get_operation_by_name(scope + '/add_fold')
self.assertEqual(folded_add.type, 'Add')
self._AssertInputOpsAre(folded_add,
[scope + '/depthwise_Fold',
scope + '/BatchNorm/batchnorm/sub'])
output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name]
self._AssertOutputGoesToOps(folded_add, g, output_op_names)
def testFoldDepthwiseConv2dWithoutScale(self):
self._RunTestOverParameters(self._TestFoldDepthwiseConv2dWithoutScale)
def _BathNormBiasName(self, scope, fused):
if fused:
return scope + '/bias'
return scope + '/BatchNorm/batchnorm/sub'
def _WeightInit(self, stddev):
"""Returns a truncated normal variable initializer.

View File

@ -0,0 +1,200 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities that match patterns in a tf.Graph."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
class OpTypePattern(object):
"""A tree pattern that matches TF expressions with certain op types."""
def __init__(self, op_type, name=None, inputs=None):
"""Initializes an OpTypePattern.
Args:
op_type: string that specifies the allowed types of the root. It can be
(1) an op type, e.g. 'Conv2D',
(2) '*', i.e. wildcard, or
(3) multiple op types separated by '|', e.g., 'Relu|Relu6'.
We could use regex strings, which might be worthwhile when we have many
similar TF op types.
name: Optional string. The name of the pattern that can be looked up in
MatchResult.
inputs: Optional list of `OpTypePattern`s or strings that specify the
patterns for the inputs of a matching op. If None, this pattern accepts
any inputs of a matching op.
"""
self._op_type = op_type
self._name = name
if inputs is None:
inputs = []
self._inputs = [
input_pattern if isinstance(input_pattern, OpTypePattern) else
OpTypePattern(input_pattern) for input_pattern in inputs
]
@property
def op_type(self):
return self._op_type
@property
def inputs(self):
return self._inputs
@property
def name(self):
return self._name
class MatchResult(object):
r"""Encapsulates the result of a match done by GraphMatcher.
MatchResult contains a map from OpTypePattern to the matching op and tensor.
When the matching op has multiple output tensors, the matching tensor is the
output tensor used by the matching op of the parent pattern. E.g., when we
match graph
- +
/ \y0 y1/ \
x split z
|
y (nodes are ops; edges are going up)
against add_pattern defined as
y1_pattern = OpTypePattern('*')
z_pattern = OpTypePattern('*')
add_pattern = OpTypePattern('+', inputs=[y1_pattern, z_pattern])
the matching op of `y1_pattern` is `split`, and the matching tensor of
`y1_pattern`
is `y1` not `y0`.
"""
def __init__(self):
self._pattern_to_op_tensor = {}
self._name_to_pattern = {}
def add(self, pattern, op, tensor):
self._pattern_to_op_tensor[pattern] = op, tensor
if pattern.name is not None:
if pattern.name in self._name_to_pattern:
raise ValueError(
'Name %s is already bound to another pattern' % pattern.name)
self._name_to_pattern[pattern.name] = pattern
def _to_pattern(self, pattern_or_name):
if isinstance(pattern_or_name, OpTypePattern):
return pattern_or_name
if isinstance(pattern_or_name, str):
return self._name_to_pattern[pattern_or_name]
raise ValueError('pattern_or_name has type %s. Expect OpTypePattern or str.'
% type(pattern_or_name))
def get_op(self, pattern_or_name):
return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][0]
def get_tensor(self, pattern_or_name):
return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][1]
class GraphMatcher(object):
"""Checks if a particular subgraph matches a given pattern."""
def __init__(self, pattern):
"""Initializes a GraphMatcher.
Args:
pattern: The `OpTypePattern` against which `GraphMatcher` matches
subgraphs.
"""
self._pattern = pattern
def _match_pattern(self, pattern, op, tensor):
"""Returns whether an TF expression rooted at `op` matches `pattern`.
If there is a match, adds to `self._match_result` the matching op and tensor
with key `pattern`.
Args:
pattern: An `OpTypePattern`.
op: A `tf.Operation` to match against the pattern.
tensor: the output `tf.Tensor` of `op` that is used by the matching op of
`pattern`'s parent. Can be None if `pattern` is already the root of the
pattern tree.
Returns:
True if an TF expression rooted at `op` matches `pattern`.
"""
if pattern.op_type != '*':
if op.type not in pattern.op_type.split('|'):
return False
self._match_result.add(pattern, op, tensor)
if not pattern.inputs:
# If pattern.inputs is empty, skips the rest and accepts all the inputs.
return True
return len(op.inputs) == len(pattern.inputs) and all([
self._match_pattern(input_pattern, input_tensor.op, input_tensor)
for input_tensor, input_pattern in zip(op.inputs, pattern.inputs)
])
def match_op(self, op):
"""Matches `op` against `self._pattern`.
Args:
op: `tf.Operation` to match against the pattern.
Returns:
Returns a `MatchResult` if `op` matches the pattern; otherwise, returns
None.
"""
self._match_result = MatchResult()
if not self._match_pattern(self._pattern, op, tensor=None):
return None
return self._match_result
def match_ops(self, ops):
"""Matches each operation in `ops` against `self._pattern`.
Args:
ops: collection of `tf.Operation` to match against the pattern.
Yields:
`MatchResult` for each `tf.Operation` that matches the pattern.
"""
for op in ops:
match_result = self.match_op(op)
if match_result:
yield match_result
def match_graph(self, graph):
"""Matches each operation in `graph` against `self._pattern`.
Args:
graph: `tf.Graph` containing operations to match.
Yields:
`MatchResult` for each `tf.Operation` in `graph` that matches the pattern.
"""
# Python 3.3.2+ implements `yield from`, but for now:
for match_result in self.match_ops(graph.get_operations()):
yield match_result

View File

@ -0,0 +1,130 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for graph_matcher."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework.python import ops as contrib_ops
from tensorflow.contrib.layers.python.layers import initializers
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.quantize.python import graph_matcher
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import googletest
class GraphMatcherTest(test_util.TensorFlowTestCase):
def test_conv_layer(self):
g = ops.Graph()
with g.as_default():
inputs = array_ops.placeholder(dtypes.float32, shape=[8, 5, 5, 3])
with contrib_ops.arg_scope(
[layers.batch_norm], fused=True, is_training=True, trainable=True):
return layers.convolution(
inputs,
num_outputs=16,
kernel_size=3,
stride=1,
padding='VALID',
activation_fn=nn_ops.relu,
normalizer_fn=layers.batch_norm,
normalizer_params={},
weights_initializer=initializers.xavier_initializer(),
weights_regularizer=None,
biases_initializer=init_ops.zeros_initializer(),
biases_regularizer=None,
reuse=None,
trainable=True,
scope=None)
inputs_pattern = graph_matcher.OpTypePattern('*', name='inputs')
relu_pattern = graph_matcher.OpTypePattern(
'Relu',
name='relu',
inputs=[
graph_matcher.OpTypePattern(
'FusedBatchNorm',
inputs=[
graph_matcher.OpTypePattern(
'Conv2D', inputs=[inputs_pattern, '*']), '*', '*', '*',
'*'
])
])
matcher = graph_matcher.GraphMatcher(relu_pattern)
match_results = list(matcher.match_graph(g))
self.assertEqual(1, len(match_results))
match_result = match_results[0]
self.assertEqual(match_result.get_tensor(inputs_pattern), inputs)
self.assertEqual(match_result.get_tensor('inputs'), inputs)
def test_multiple_outputs(self):
# - +
# / \y0 y1/ \
# x split z
# |
# y (nodes are ops; edges are going up)
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtypes.float32, shape=[1], name='x')
y = array_ops.placeholder(dtypes.float32, shape=[2], name='y')
y0, y1 = array_ops.split(y, num_or_size_splits=2, axis=0)
z = array_ops.placeholder(dtypes.float32, shape=[1], name='z')
math_ops.add(x, y0)
math_ops.subtract(y1, z)
y1_pattern = graph_matcher.OpTypePattern('*')
minus_pattern = graph_matcher.OpTypePattern('Sub', inputs=[y1_pattern, '*'])
matcher = graph_matcher.GraphMatcher(minus_pattern)
match_results = list(matcher.match_graph(g))
self.assertEqual(1, len(match_results))
match_result = match_results[0]
self.assertEqual(y0.op, y1.op)
self.assertEqual(match_result.get_op(y1_pattern), y1.op)
self.assertEqual(match_result.get_tensor(y1_pattern), y1)
def test_oneof_pattern(self):
# - +
# / \ / \
# x y z
g = ops.Graph()
with g.as_default():
x = array_ops.placeholder(dtypes.float32, shape=[], name='x')
y = array_ops.placeholder(dtypes.float32, shape=[], name='y')
z = array_ops.placeholder(dtypes.float32, shape=[], name='z')
plus = x + y
minus = y - z
add_or_sub_pattern = graph_matcher.OpTypePattern(
'Add|Sub', inputs=['*', '*'])
matcher = graph_matcher.GraphMatcher(add_or_sub_pattern)
self.assertEqual([
match_result.get_op(add_or_sub_pattern)
for match_result in matcher.match_graph(g)
], [plus.op, minus.op])
if __name__ == '__main__':
googletest.main()

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.layers.python.layers import layers
from tensorflow.contrib.quantize.python import fold_batch_norms
from tensorflow.contrib.quantize.python import quantize
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
@ -35,18 +36,11 @@ conv2d = layers.conv2d
fully_connected = layers.fully_connected
separable_conv2d = layers.separable_conv2d
_DEFAULT_BATCH_NORM_PARAMS = {
'center': True,
'scale': True,
'decay': 1.0 - 0.003,
'fused': False,
}
# TODO(suharshs): Use parameterized test once OSS TF supports it.
class QuantizeTest(test_util.TensorFlowTestCase):
def _RunTestOverParameters(self, test_fn):
def _RunWithoutBatchNormTestOverParameters(self, test_fn):
# TODO(suharshs): Use parameterized test once OSS TF supports it.
parameters_list = [
# (activation, activation_op_name, with_bypass, delay)
(nn_ops.relu6, 'Relu6', False, None),
@ -60,10 +54,10 @@ class QuantizeTest(test_util.TensorFlowTestCase):
(array_ops.identity, 'Identity', True, None),
(nn_ops.relu6, 'Relu6', True, 5000),
(nn_ops.relu, 'Relu', True, 5000),
(array_ops.identity, 'Identity', True, 5000)
(array_ops.identity, 'Identity', True, 5000),
]
for parameters in parameters_list:
test_fn(parameters[0], parameters[1], parameters[2], parameters[3])
for params in parameters_list:
test_fn(params[0], params[1], params[2], params[3])
def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name,
with_bypass, delay):
@ -137,7 +131,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
def testQuantize_Conv2dWithoutBatchNorm(self):
self._RunTestOverParameters(self._TestQuantize_Conv2dWithoutBatchNorm)
self._RunWithoutBatchNormTestOverParameters(
self._TestQuantize_Conv2dWithoutBatchNorm)
def _TestQuantize_FCWithoutBatchNorm(self, activation, activation_op_name,
with_bypass, delay):
@ -210,7 +205,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
def testQuantize_FCWithoutBatchNorm(self):
self._RunTestOverParameters(self._TestQuantize_FCWithoutBatchNorm)
self._RunWithoutBatchNormTestOverParameters(
self._TestQuantize_FCWithoutBatchNorm)
def _TestQuantize_DepthwiseConv2dWithoutBatchNorm(
self, activation, activation_op_name, with_bypass, delay):
@ -284,11 +280,43 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
def testQuantize_DepthwiseConv2dWithoutBatchNorm(self):
self._RunTestOverParameters(
self._RunWithoutBatchNormTestOverParameters(
self._TestQuantize_DepthwiseConv2dWithoutBatchNorm)
def _RunBatchNormTestOverParameters(self, test_fn):
# TODO(suharshs): Use parameterized test once OSS TF supports it.
parameters_list = [
# (activation, activation_op_name, with_bypass, delay, fused_batch_norm)
(nn_ops.relu6, 'Relu6', False, None, False),
(nn_ops.relu, 'Relu', False, None, False),
(array_ops.identity, 'Identity', False, None, False),
(nn_ops.relu6, 'Relu6', False, 5000, False),
(nn_ops.relu, 'Relu', False, 5000, False),
(array_ops.identity, 'Identity', False, 5000, False),
(nn_ops.relu6, 'Relu6', True, None, False),
(nn_ops.relu, 'Relu', True, None, False),
(array_ops.identity, 'Identity', True, None, False),
(nn_ops.relu6, 'Relu6', True, 5000, False),
(nn_ops.relu, 'Relu', True, 5000, False),
(array_ops.identity, 'Identity', True, 5000, False),
(nn_ops.relu6, 'Relu6', False, None, True),
(nn_ops.relu, 'Relu', False, None, True),
(array_ops.identity, 'Identity', False, None, True),
(nn_ops.relu6, 'Relu6', False, 5000, True),
(nn_ops.relu, 'Relu', False, 5000, True),
(array_ops.identity, 'Identity', False, 5000, True),
(nn_ops.relu6, 'Relu6', True, None, True),
(nn_ops.relu, 'Relu', True, None, True),
(array_ops.identity, 'Identity', True, None, True),
(nn_ops.relu6, 'Relu6', True, 5000, True),
(nn_ops.relu, 'Relu', True, 5000, True),
(array_ops.identity, 'Identity', True, 5000, True)
]
for params in parameters_list:
test_fn(params[0], params[1], params[2], params[3], params[4])
def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name,
with_bypass, delay):
with_bypass, delay, fused_batch_norm):
"""Tests quantization: inputs -> Conv2d with batch norm -> Activation.
Args:
@ -298,25 +326,29 @@ class QuantizeTest(test_util.TensorFlowTestCase):
with_bypass: Bool, when true there is an extra connection added from
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
"""
self._testQuantize_Conv2dWithBatchNorm(
activation,
activation_op_name,
with_bypass,
delay,
fused_batch_norm,
use_ema=True)
self._testQuantize_Conv2dWithBatchNorm(
activation,
activation_op_name,
with_bypass,
delay,
fused_batch_norm,
use_ema=False)
def testQuantize_Conv2dWithBatchNorm(self):
self._RunTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm)
self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm)
def _testQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name,
with_bypass, delay, use_ema):
with_bypass, delay, fused_batch_norm,
use_ema):
"""Tests quantization: inputs -> Conv2d with batch norm -> Activation.
Args:
@ -326,6 +358,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
with_bypass: Bool, when true there is an extra connection added from
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_ema: Bool, when true uses EMA quantization for BN folded weights.
"""
graph = ops.Graph()
@ -337,39 +370,29 @@ class QuantizeTest(test_util.TensorFlowTestCase):
stride = 1 if with_bypass else 2
out_depth = 3 if with_bypass else 32
scope = 'test/test2' if with_bypass else 'test'
node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS,
scope=scope)
# Manually fold the batch norm.
weights = graph.get_operation_by_name(scope + '/weights/read').outputs[0]
bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul')
.outputs[0])
mul_fold = math_ops.multiply(weights, bn_mult, name=scope + '/mul_fold')
stride = [stride, stride]
conv_fold = nn_ops.convolution(
input=inputs,
filter=mul_fold,
node = conv2d(
inputs,
out_depth, [5, 5],
stride=stride,
padding='SAME',
strides=stride,
data_format='NHWC',
name=scope + '/convolution_Fold')
bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub')
.outputs[0])
add_fold = math_ops.add(conv_fold, bn_bias, name=scope + '/add_fold')
weights_initializer=self._WeightInit(0.09),
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(fused_batch_norm),
scope=scope)
# Manually add a bypass (optionaly) and an activation.
if with_bypass:
node = math_ops.add(inputs, add_fold, name='test/Add')
else:
node = add_fold
node = math_ops.add(inputs, node, name='test/Add')
node = activation(node, name='test/' + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
fold_batch_norms.FoldBatchNorms(graph)
quantize.Quantize(
graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema)
@ -413,7 +436,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name,
with_bypass, delay):
with_bypass, delay, fused_batch_norm):
"""Tests quantization: inputs -> FC with batch norm -> Activation.
Args:
@ -423,25 +446,29 @@ class QuantizeTest(test_util.TensorFlowTestCase):
with_bypass: Bool, when true there is an extra connection added from
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
"""
self._testQuantize_FCWithBatchNorm(
activation,
activation_op_name,
with_bypass,
delay,
fused_batch_norm,
use_ema=True)
self._testQuantize_FCWithBatchNorm(
activation,
activation_op_name,
with_bypass,
delay,
fused_batch_norm,
use_ema=False)
def testQuantize_FCWithBatchNorm(self):
self._RunTestOverParameters(self._TestQuantize_FCWithBatchNorm)
self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm)
def _testQuantize_FCWithBatchNorm(self, activation, activation_op_name,
with_bypass, delay, use_ema):
with_bypass, delay, fused_batch_norm,
use_ema):
"""Tests quantization: inputs -> FC with batch norm -> Activation.
Args:
@ -451,6 +478,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
with_bypass: Bool, when true there is an extra connection added from
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_ema: Bool, when true uses EMA quantization for BN folded weights.
"""
graph = ops.Graph()
@ -461,32 +489,27 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs = array_ops.zeros((batch_size, depth))
out_depth = 256 if with_bypass else 128
scope = 'test/test2' if with_bypass else 'test'
node = fully_connected(inputs, out_depth,
weights_initializer=self._WeightInit(0.03),
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS,
scope=scope)
# Manually fold the batch norm.
weights = graph.get_operation_by_name(scope + '/weights/read').outputs[0]
bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul')
.outputs[0])
mul_fold = math_ops.multiply(weights, bn_mult, name=scope + '/mul_fold')
fc_fold = math_ops.matmul(inputs, mul_fold, name=scope + '/MatMul_Fold')
bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub')
.outputs[0])
add_fold = math_ops.add(fc_fold, bn_bias, name=scope + '/add_fold')
node = fully_connected(
inputs,
out_depth,
weights_initializer=self._WeightInit(0.03),
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(fused_batch_norm),
scope=scope)
# Manually add a bypass (optionaly) and an activation.
if with_bypass:
node = math_ops.add(inputs, add_fold, name='test/Add')
else:
node = add_fold
node = math_ops.add(inputs, node, name='test/Add')
node = activation(node, name='test/' + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
fold_batch_norms.FoldBatchNorms(graph)
quantize.Quantize(
graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema)
@ -530,7 +553,8 @@ class QuantizeTest(test_util.TensorFlowTestCase):
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
def _TestQuantize_DepthwiseConv2dWithBatchNorm(
self, activation, activation_op_name, with_bypass, delay):
self, activation, activation_op_name, with_bypass, delay,
fused_batch_norm):
"""Tests quantization: inputs -> DWConv2d with batch norm -> Activation.
Args:
@ -540,26 +564,30 @@ class QuantizeTest(test_util.TensorFlowTestCase):
with_bypass: Bool, when true there is an extra connection added from
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
"""
self._testQuantize_DepthwiseConv2dWithBatchNorm(
activation,
activation_op_name,
with_bypass,
delay,
fused_batch_norm,
use_ema=True)
self._testQuantize_DepthwiseConv2dWithBatchNorm(
activation,
activation_op_name,
with_bypass,
delay,
fused_batch_norm,
use_ema=False)
def testQuantize_DepthwiseConv2dWithBatchNorm(self):
self._RunTestOverParameters(
self._TestQuantize_DepthwiseConv2dWithoutBatchNorm)
self._RunBatchNormTestOverParameters(
self._TestQuantize_DepthwiseConv2dWithBatchNorm)
def _testQuantize_DepthwiseConv2dWithBatchNorm(
self, activation, activation_op_name, with_bypass, delay, use_ema):
self, activation, activation_op_name, with_bypass, delay,
fused_batch_norm, use_ema):
"""Tests quantization: inputs -> DWConv2d with batch norm -> Activation.
Args:
@ -569,6 +597,7 @@ class QuantizeTest(test_util.TensorFlowTestCase):
with_bypass: Bool, when true there is an extra connection added from
inputs to just before Activation.
delay: Int (optional), delay in number of steps until quantization starts.
fused_batch_norm: Bool, when true use FusedBatchNorm.
use_ema: Bool, when true uses EMA quantization for BN folded weights.
"""
graph = ops.Graph()
@ -579,46 +608,30 @@ class QuantizeTest(test_util.TensorFlowTestCase):
inputs = array_ops.zeros((batch_size, height, width, depth))
stride = 1 if with_bypass else 2
scope = 'test/test2' if with_bypass else 'test'
node = separable_conv2d(inputs, None, [5, 5], stride=stride,
depth_multiplier=1.0, padding='SAME',
weights_initializer=self._WeightInit(0.09),
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=_DEFAULT_BATCH_NORM_PARAMS,
scope=scope)
# Manually fold the batch norm.
weights = (graph.get_operation_by_name(scope + '/depthwise_weights/read')
.outputs[0])
bn_mult = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/mul')
.outputs[0])
new_shape = [
weights.get_shape().as_list()[2], weights.get_shape().as_list()[3]
]
bn_mult_reshaped = array_ops.reshape(
bn_mult, new_shape, name=scope + '/gamma_reshape')
mul_fold = math_ops.multiply(
weights, bn_mult_reshaped, name=scope + '/mul_fold')
stride = [1, stride, stride, 1]
conv_fold = nn_ops.depthwise_conv2d(
input=inputs,
filter=mul_fold,
node = separable_conv2d(
inputs,
None, [5, 5],
stride=stride,
depth_multiplier=1.0,
padding='SAME',
strides=stride,
name=scope + '/depthwise_Fold')
bn_bias = (graph.get_operation_by_name(scope + '/BatchNorm/batchnorm/sub')
.outputs[0])
add_fold = math_ops.add(conv_fold, bn_bias, name=scope + '/add_fold')
weights_initializer=self._WeightInit(0.09),
activation_fn=None,
normalizer_fn=batch_norm,
normalizer_params=self._BatchNormParams(fused_batch_norm),
scope=scope)
# Manually add a bypass (optionaly) and an activation.
if with_bypass:
node = math_ops.add(inputs, add_fold, name='test/Add')
else:
node = add_fold
node = math_ops.add(inputs, node, name='test/Add')
node = activation(node, name='test/' + activation_op_name)
update_barrier = control_flow_ops.no_op(name='update_barrier')
with ops.control_dependencies([update_barrier]):
array_ops.identity(node, name='control_dependency')
fold_batch_norms.FoldBatchNorms(graph)
quantize.Quantize(
graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema)
quantization_node_name = 'FakeQuantWithMinMaxVars'
@ -660,6 +673,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
if delay else 'control_dependency')
self._AssertOutputGoesToOps(act_quant, graph, [output_op_name])
def _BatchNormParams(self, fused=False):
return {'center': True, 'scale': True, 'decay': 1.0 - 0.003, 'fused': fused}
def _WeightInit(self, stddev):
"""Returns truncated normal variable initializer.

View File

@ -156,6 +156,7 @@ cuda_py_tests(
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:gradients",
"//tensorflow/python:init_ops",
"//tensorflow/python:math_ops",
@ -165,6 +166,7 @@ cuda_py_tests(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/eager:context",
],
shard_count = 10,
)

View File

@ -25,10 +25,12 @@ from six.moves import xrange # pylint: disable=redefined-builtin
from tensorflow.contrib import rnn as rnn_lib
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops as ops_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients_impl
@ -881,6 +883,7 @@ class LSTMTest(test.TestCase):
# Smoke test, this should not raise an error
rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32)
@test_util.run_in_graph_and_eager_modes()
def testDynamicRNNWithTupleStates(self):
num_units = 3
input_size = 5
@ -888,13 +891,20 @@ class LSTMTest(test.TestCase):
num_proj = 4
max_length = 8
sequence_length = [4, 6]
in_graph_mode = context.in_graph_mode()
with self.test_session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
inputs = max_length * [
array_ops.placeholder(
dtypes.float32, shape=(None, input_size))
]
if in_graph_mode:
inputs = max_length * [
array_ops.placeholder(
dtypes.float32, shape=(None, input_size))
]
else:
inputs = max_length * [
constant_op.constant(
np.random.randn(batch_size, input_size).astype(np.float32))
]
inputs_c = array_ops.stack(inputs)
cell = rnn_cell.LSTMCell(
num_units,
@ -924,21 +934,34 @@ class LSTMTest(test.TestCase):
self.assertEqual(state_dynamic[0], state_dynamic.c)
self.assertEqual(state_dynamic[1], state_dynamic.h)
variables_lib.global_variables_initializer().run()
if in_graph_mode:
variables_lib.global_variables_initializer().run()
input_value = np.random.randn(batch_size, input_size)
outputs_static = sess.run(
outputs_static, feed_dict={
inputs[0]: input_value
})
outputs_dynamic = sess.run(
outputs_dynamic, feed_dict={
inputs[0]: input_value
})
state_static = sess.run(
state_static, feed_dict={
inputs[0]: input_value
})
state_dynamic = sess.run(
state_dynamic, feed_dict={
inputs[0]: input_value
})
input_value = np.random.randn(batch_size, input_size)
outputs_static_v = sess.run(outputs_static,
feed_dict={inputs[0]: input_value})
outputs_dynamic_v = sess.run(outputs_dynamic,
feed_dict={inputs[0]: input_value})
self.assertAllEqual(outputs_static_v, outputs_dynamic_v)
state_static_v = sess.run(state_static,
feed_dict={inputs[0]: input_value})
state_dynamic_v = sess.run(state_dynamic,
feed_dict={inputs[0]: input_value})
self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v))
if in_graph_mode:
self.assertAllEqual(outputs_static, outputs_dynamic)
else:
self.assertAllEqual(
array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy())
self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic))
@test_util.run_in_graph_and_eager_modes()
def testDynamicRNNWithNestedTupleStates(self):
num_units = 3
input_size = 5
@ -946,13 +969,20 @@ class LSTMTest(test.TestCase):
num_proj = 4
max_length = 8
sequence_length = [4, 6]
in_graph_mode = context.in_graph_mode()
with self.test_session(graph=ops_lib.Graph()) as sess:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
inputs = max_length * [
array_ops.placeholder(
dtypes.float32, shape=(None, input_size))
]
if in_graph_mode:
inputs = max_length * [
array_ops.placeholder(
dtypes.float32, shape=(None, input_size))
]
else:
inputs = max_length * [
constant_op.constant(
np.random.randn(batch_size, input_size).astype(np.float32))
]
inputs_c = array_ops.stack(inputs)
def _cell(i):
@ -993,20 +1023,34 @@ class LSTMTest(test.TestCase):
sequence_length=sequence_length,
scope=scope)
variables_lib.global_variables_initializer().run()
if in_graph_mode:
input_value = np.random.randn(batch_size, input_size)
variables_lib.global_variables_initializer().run()
outputs_static = sess.run(
outputs_static, feed_dict={
inputs[0]: input_value
})
outputs_dynamic = sess.run(
outputs_dynamic, feed_dict={
inputs[0]: input_value
})
state_static = sess.run(
nest.flatten(state_static), feed_dict={
inputs[0]: input_value
})
state_dynamic = sess.run(
nest.flatten(state_dynamic), feed_dict={
inputs[0]: input_value
})
input_value = np.random.randn(batch_size, input_size)
outputs_static_v = sess.run(outputs_static,
feed_dict={inputs[0]: input_value})
outputs_dynamic_v = sess.run(outputs_dynamic,
feed_dict={inputs[0]: input_value})
self.assertAllEqual(outputs_static_v, outputs_dynamic_v)
state_static_v = sess.run(nest.flatten(state_static),
feed_dict={inputs[0]: input_value})
state_dynamic_v = sess.run(nest.flatten(state_dynamic),
feed_dict={inputs[0]: input_value})
self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v))
if in_graph_mode:
self.assertAllEqual(outputs_static, outputs_dynamic)
else:
self.assertAllEqual(
array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy())
state_static = [s.numpy() for s in nest.flatten(state_static)]
state_dynamic = [s.numpy() for s in nest.flatten(state_dynamic)]
self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic))
def _testDynamicEquivalentToStaticRNN(self, use_gpu, use_sequence_length):
time_steps = 8
@ -1015,21 +1059,22 @@ class LSTMTest(test.TestCase):
input_size = 5
batch_size = 2
input_values = np.random.randn(time_steps, batch_size, input_size)
input_values = np.random.randn(time_steps, batch_size, input_size).astype(
np.float32)
if use_sequence_length:
sequence_length = np.random.randint(0, time_steps, size=batch_size)
else:
sequence_length = None
########### Step 1: Run static graph and generate readouts
with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess:
concat_inputs = array_ops.placeholder(
dtypes.float32, shape=(time_steps, batch_size, input_size))
inputs = array_ops.unstack(concat_inputs)
in_graph_mode = context.in_graph_mode()
# TODO(b/68017812): Eager ignores operation seeds, so we need to create a
# single cell and reuse it across the static and dynamic RNNs. Remove this
# special case once is fixed.
if not in_graph_mode:
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
cell = rnn_cell.LSTMCell(
num_units,
use_peepholes=True,
@ -1037,63 +1082,85 @@ class LSTMTest(test.TestCase):
num_proj=num_proj,
state_is_tuple=False)
########### Step 1: Run static graph and generate readouts
with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess:
if in_graph_mode:
concat_inputs = array_ops.placeholder(
dtypes.float32, shape=(time_steps, batch_size, input_size))
else:
concat_inputs = constant_op.constant(input_values)
inputs = array_ops.unstack(concat_inputs)
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
# TODO(akshayka): Remove special case once b/68017812 is fixed.
if in_graph_mode:
cell = rnn_cell.LSTMCell(
num_units,
use_peepholes=True,
initializer=initializer,
num_proj=num_proj,
state_is_tuple=False)
with variable_scope.variable_scope("dynamic_scope"):
outputs_static, state_static = rnn.static_rnn(
cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32)
feeds = {concat_inputs: input_values}
if in_graph_mode:
# Generate gradients and run sessions to obtain outputs
feeds = {concat_inputs: input_values}
# Initialize
variables_lib.global_variables_initializer().run(feed_dict=feeds)
# Generate gradients of sum of outputs w.r.t. inputs
static_gradients = gradients_impl.gradients(
outputs_static + [state_static], [concat_inputs])
# Generate gradients of individual outputs w.r.t. inputs
static_individual_gradients = nest.flatten([
gradients_impl.gradients(y, [concat_inputs])
for y in [outputs_static[0], outputs_static[-1], state_static]
])
# Generate gradients of individual variables w.r.t. inputs
trainable_variables = ops_lib.get_collection(
ops_lib.GraphKeys.TRAINABLE_VARIABLES)
assert len(trainable_variables) > 1, (
"Count of trainable variables: %d" % len(trainable_variables))
# pylint: disable=bad-builtin
static_individual_variable_gradients = nest.flatten([
gradients_impl.gradients(y, trainable_variables)
for y in [outputs_static[0], outputs_static[-1], state_static]
])
# Test forward pass
values_static = sess.run(outputs_static, feed_dict=feeds)
(state_value_static,) = sess.run((state_static,), feed_dict=feeds)
# Initialize
variables_lib.global_variables_initializer().run(feed_dict=feeds)
# Test gradients to inputs and variables w.r.t. outputs & final state
static_grad_values = sess.run(static_gradients, feed_dict=feeds)
# Generate gradients of sum of outputs w.r.t. inputs
static_gradients = gradients_impl.gradients(
outputs_static + [state_static], [concat_inputs])
static_individual_grad_values = sess.run(static_individual_gradients,
feed_dict=feeds)
# Generate gradients of individual outputs w.r.t. inputs
static_individual_gradients = nest.flatten([
gradients_impl.gradients(y, [concat_inputs])
for y in [outputs_static[0], outputs_static[-1], state_static]
])
# Generate gradients of individual variables w.r.t. inputs
trainable_variables = ops_lib.get_collection(
ops_lib.GraphKeys.TRAINABLE_VARIABLES)
assert len(trainable_variables) > 1, ("Count of trainable variables: %d" %
len(trainable_variables))
# pylint: disable=bad-builtin
static_individual_variable_gradients = nest.flatten([
gradients_impl.gradients(y, trainable_variables)
for y in [outputs_static[0], outputs_static[-1], state_static]
])
# Test forward pass
values_static = sess.run(outputs_static, feed_dict=feeds)
(state_value_static,) = sess.run((state_static,), feed_dict=feeds)
# Test gradients to inputs and variables w.r.t. outputs & final state
static_grad_values = sess.run(static_gradients, feed_dict=feeds)
static_individual_grad_values = sess.run(static_individual_gradients,
feed_dict=feeds)
static_individual_var_grad_values = sess.run(
static_individual_variable_gradients, feed_dict=feeds)
static_individual_var_grad_values = sess.run(
static_individual_variable_gradients, feed_dict=feeds)
########## Step 2: Run dynamic graph and generate readouts
with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess:
concat_inputs = array_ops.placeholder(
dtypes.float32, shape=(time_steps, batch_size, input_size))
inputs = array_ops.unstack(concat_inputs)
if in_graph_mode:
concat_inputs = array_ops.placeholder(
dtypes.float32, shape=(time_steps, batch_size, input_size))
else:
concat_inputs = constant_op.constant(input_values)
initializer = init_ops.random_uniform_initializer(
-0.01, 0.01, seed=self._seed)
cell = rnn_cell.LSTMCell(
num_units,
use_peepholes=True,
initializer=initializer,
num_proj=num_proj,
state_is_tuple=False)
# TODO(akshayka): Remove this special case once b/68017812 is
# fixed.
if in_graph_mode:
cell = rnn_cell.LSTMCell(
num_units,
use_peepholes=True,
initializer=initializer,
num_proj=num_proj,
state_is_tuple=False)
with variable_scope.variable_scope("dynamic_scope"):
outputs_dynamic, state_dynamic = rnn.dynamic_rnn(
@ -1104,72 +1171,83 @@ class LSTMTest(test.TestCase):
dtype=dtypes.float32)
split_outputs_dynamic = array_ops.unstack(outputs_dynamic, time_steps)
feeds = {concat_inputs: input_values}
if in_graph_mode:
feeds = {concat_inputs: input_values}
# Initialize
variables_lib.global_variables_initializer().run(feed_dict=feeds)
# Initialize
variables_lib.global_variables_initializer().run(feed_dict=feeds)
# Generate gradients of sum of outputs w.r.t. inputs
dynamic_gradients = gradients_impl.gradients(
split_outputs_dynamic + [state_dynamic], [concat_inputs])
# Generate gradients of sum of outputs w.r.t. inputs
dynamic_gradients = gradients_impl.gradients(
split_outputs_dynamic + [state_dynamic], [concat_inputs])
# Generate gradients of several individual outputs w.r.t. inputs
dynamic_individual_gradients = nest.flatten([
gradients_impl.gradients(y, [concat_inputs])
for y in
[split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic]
])
# Generate gradients of several individual outputs w.r.t. inputs
dynamic_individual_gradients = nest.flatten([
gradients_impl.gradients(y, [concat_inputs])
for y in
[split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic]
])
# Generate gradients of individual variables w.r.t. inputs
trainable_variables = ops_lib.get_collection(
ops_lib.GraphKeys.TRAINABLE_VARIABLES)
assert len(trainable_variables) > 1, ("Count of trainable variables: %d" %
len(trainable_variables))
dynamic_individual_variable_gradients = nest.flatten([
gradients_impl.gradients(y, trainable_variables)
for y in
[split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic]
])
# Generate gradients of individual variables w.r.t. inputs
trainable_variables = ops_lib.get_collection(
ops_lib.GraphKeys.TRAINABLE_VARIABLES)
assert len(trainable_variables) > 1, (
"Count of trainable variables: %d" % len(trainable_variables))
dynamic_individual_variable_gradients = nest.flatten([
gradients_impl.gradients(y, trainable_variables)
for y in
[split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic]
])
# Test forward pass
values_dynamic = sess.run(split_outputs_dynamic, feed_dict=feeds)
(state_value_dynamic,) = sess.run((state_dynamic,), feed_dict=feeds)
# Test forward pass
values_dynamic = sess.run(split_outputs_dynamic, feed_dict=feeds)
(state_value_dynamic,) = sess.run((state_dynamic,), feed_dict=feeds)
# Test gradients to inputs and variables w.r.t. outputs & final state
dynamic_grad_values = sess.run(dynamic_gradients, feed_dict=feeds)
# Test gradients to inputs and variables w.r.t. outputs & final state
dynamic_grad_values = sess.run(dynamic_gradients, feed_dict=feeds)
dynamic_individual_grad_values = sess.run(dynamic_individual_gradients,
feed_dict=feeds)
dynamic_individual_grad_values = sess.run(dynamic_individual_gradients,
feed_dict=feeds)
dynamic_individual_var_grad_values = sess.run(
dynamic_individual_variable_gradients, feed_dict=feeds)
dynamic_individual_var_grad_values = sess.run(
dynamic_individual_variable_gradients, feed_dict=feeds)
######### Step 3: Comparisons
if not in_graph_mode:
values_static = outputs_static
values_dynamic = split_outputs_dynamic
state_value_static = state_static
state_value_dynamic = state_dynamic
self.assertEqual(len(values_static), len(values_dynamic))
for (value_static, value_dynamic) in zip(values_static, values_dynamic):
self.assertAllEqual(value_static, value_dynamic)
self.assertAllEqual(state_value_static, state_value_dynamic)
self.assertAllEqual(static_grad_values, dynamic_grad_values)
if in_graph_mode:
self.assertEqual(
len(static_individual_grad_values), len(dynamic_individual_grad_values))
self.assertEqual(
len(static_individual_var_grad_values),
len(dynamic_individual_var_grad_values))
self.assertAllEqual(static_grad_values, dynamic_grad_values)
for i, (a, b) in enumerate(
zip(static_individual_grad_values, dynamic_individual_grad_values)):
tf_logging.info("Comparing individual gradients iteration %d" % i)
self.assertAllEqual(a, b)
self.assertEqual(
len(static_individual_grad_values),
len(dynamic_individual_grad_values))
self.assertEqual(
len(static_individual_var_grad_values),
len(dynamic_individual_var_grad_values))
for i, (a, b) in enumerate(
zip(static_individual_var_grad_values,
dynamic_individual_var_grad_values)):
tf_logging.info("Comparing individual variable gradients iteration %d" %
i)
self.assertAllEqual(a, b)
for i, (a, b) in enumerate(
zip(static_individual_grad_values, dynamic_individual_grad_values)):
tf_logging.info("Comparing individual gradients iteration %d" % i)
self.assertAllEqual(a, b)
for i, (a, b) in enumerate(
zip(static_individual_var_grad_values,
dynamic_individual_var_grad_values)):
tf_logging.info("Comparing individual variable gradients iteration %d" %
i)
self.assertAllEqual(a, b)
@test_util.run_in_graph_and_eager_modes()
def testDynamicEquivalentToStaticRNN(self):
self._testDynamicEquivalentToStaticRNN(
use_gpu=False, use_sequence_length=False)

View File

@ -112,7 +112,7 @@ struct GatherTree<CPUDevice, int32> {
const int32 max_time = parent_ids.dimension(0);
const int32 batch_size = parent_ids.dimension(1);
const int32 beam_width = parent_ids.dimension(2);
beams.setConstant(-1);
beams.setConstant(end_token);
auto DoWork = [&, ctx, end_token](int start_batch_beam,
int limit_batch_beam) {
@ -138,10 +138,13 @@ struct GatherTree<CPUDevice, int32> {
beams(level, batch, beam) = step_ids(level, batch, parent);
parent = parent_ids(level, batch, parent);
}
// Not necessary when using a BeamSearchDecoder, but necessary
// when a user feeds in possibly broken trajectory (i.e., non-eos
// entries in a beam following eos entries).
bool finished = false;
for (int32 time = 0; time < max_seq_len_b; ++time) {
if (finished) {
beams(time, batch, beam) = -1;
beams(time, batch, beam) = end_token;
} else if (beams(time, batch, beam) == end_token) {
finished = true;
}

View File

@ -46,24 +46,31 @@ __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time,
const int32 initial_beam_ix = GET_IX(max_seq_len_b - 1, beam);
beams[initial_beam_ix] = ldg(step_ids + initial_beam_ix);
int32 parent = ldg(parent_ids + initial_beam_ix);
bool found_bad = false;
for (int32 level = max_seq_len_b - 2; level >= 0; --level) {
const int32 level_beam_ix = GET_IX(level, beam);
const int32 level_parent_ix = GET_IX(level, parent);
if (parent < 0 || parent > beam_width) {
beams[level_beam_ix] = -1;
parent = -1;
found_bad = true;
} else {
beams[level_beam_ix] = ldg(step_ids + level_parent_ix);
parent = ldg(parent_ids + level_parent_ix);
}
}
bool finished = false;
for (int32 time = 0; time < max_seq_len_b; ++time) {
const int32 level_beam_ix = GET_IX(time, beam);
if (finished) {
beams[level_beam_ix] = -1;
} else if (beams[level_beam_ix] == end_token) {
finished = true;
// Not necessary when using a BeamSearchDecoder, but necessary
// when a user feeds in possibly broken trajectory (i.e., non-eos
// entries in a beam following eos entries).
if (!found_bad) {
bool finished = false;
for (int32 time = 0; time < max_seq_len_b; ++time) {
const int32 level_beam_ix = GET_IX(time, beam);
if (finished) {
beams[level_beam_ix] = end_token;
} else if (beams[level_beam_ix] == end_token) {
finished = true;
}
}
}
#undef GET_IX
@ -80,8 +87,8 @@ struct GatherTree<GPUDevice, T> {
const int32 max_time = parent_ids.dimension(0);
const int32 batch_size = parent_ids.dimension(1);
const int32 beam_width = parent_ids.dimension(2);
// First kernel launch to zero things out
beams.device(d) = beams.constant(T(-1));
// First kernel launch to "zero" things out
beams.device(d) = beams.constant(end_token);
CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d);
// clang-format off

View File

@ -53,11 +53,14 @@ REGISTER_OP("GatherTree")
.Doc(R"doc(
Calculates the full beams from the per-step ids and parent beam ids.
This op implements the following mathematical equations:
On CPU, if an out of bound parent id is found, an error is returned.
On GPU, if an out of bound parent id is found, a -1 is stored in the
corresponding output value and the execution for that beam returns early.
```python
TODO(ebrevdo): fill in
```
For a given beam, past the time step containing the first decoded `end_token`
all values are filled in with `end_token`.
TODO(ebrevdo): fill in the remainder of this docstring.
step_ids: `[max_time, batch_size, beam_width]`.
parent_ids: `[max_time, batch_size, beam_width]`.

View File

@ -36,24 +36,26 @@ class GatherTreeTest(test.TestCase):
def testGatherTreeOne(self):
# (max_time = 4, batch_size = 1, beams = 3)
end_token = 10
step_ids = _transpose_batch_time(
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
parent_ids = _transpose_batch_time(
[[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]])
max_sequence_lengths = [3]
expected_result = _transpose_batch_time(
[[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
expected_result = _transpose_batch_time([[[2, 2, 2], [6, 5, 6], [7, 8, 9],
[10, 10, 10]]])
beams = beam_search_ops.gather_tree(
step_ids=step_ids,
parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=10)
end_token=end_token)
with self.test_session(use_gpu=True):
self.assertAllEqual(expected_result, beams.eval())
def testBadParentValuesOnCPU(self):
# (batch_size = 1, max_time = 4, beams = 3)
# bad parent in beam 1 time 1
end_token = 10
step_ids = _transpose_batch_time(
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
parent_ids = _transpose_batch_time(
@ -64,7 +66,7 @@ class GatherTreeTest(test.TestCase):
step_ids=step_ids,
parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=10)
end_token=end_token)
with self.test_session():
with self.assertRaisesOpError(
r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"):
@ -77,19 +79,20 @@ class GatherTreeTest(test.TestCase):
return
# (max_time = 4, batch_size = 1, beams = 3)
# bad parent in beam 1 time 1; appears as a negative index at time 0
end_token = 10
step_ids = _transpose_batch_time(
[[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
parent_ids = _transpose_batch_time(
[[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
max_sequence_lengths = [3]
expected_result = _transpose_batch_time(
[[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]])
expected_result = _transpose_batch_time([[[2, -1, 2], [6, 5, 6], [7, 8, 9],
[10, 10, 10]]])
with ops.device("/device:GPU:0"):
beams = beam_search_ops.gather_tree(
step_ids=step_ids,
parent_ids=parent_ids,
max_sequence_lengths=max_sequence_lengths,
end_token=10)
end_token=end_token)
with self.test_session(use_gpu=True):
self.assertAllEqual(expected_result, beams.eval())
@ -115,24 +118,24 @@ class GatherTreeTest(test.TestCase):
self.assertEqual((max_time, batch_size, beam_width), beams.shape)
beams_value = beams.eval()
for b in range(batch_size):
# Past max_sequence_lengths[b], we emit all -1s.
# Past max_sequence_lengths[b], we emit all end tokens.
b_value = beams_value[max_sequence_lengths[b]:, b, :]
self.assertAllClose(b_value, -1. * np.ones_like(b_value))
self.assertAllClose(b_value, end_token * np.ones_like(b_value))
for batch, beam in itertools.product(
range(batch_size), range(beam_width)):
v = np.squeeze(beams_value[:, batch, beam])
if end_token in v:
found_bad = np.where(v == -1)[0]
self.assertEqual(0, len(found_bad))
found = np.where(v == end_token)[0]
# Should be up to 1 instance of end_token per beam.
self.assertEqual(len(found), 1)
found = found[0]
found = found[0] # First occurrence of end_token.
# If an end_token is found, everything before it should be a
# valid id and everything after it should be -1.
if found > 0:
self.assertAllEqual(
v[:found - 1] >= 0, np.ones_like(v[:found - 1], dtype=bool))
self.assertAllClose(
v[found + 1:], -1 * np.ones_like(v[found + 1:]))
self.assertAllClose(v[found + 1:],
end_token * np.ones_like(v[found + 1:]))
if __name__ == "__main__":

File diff suppressed because it is too large Load Diff

View File

@ -31,6 +31,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.layers import utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import data_flow_ops
@ -47,7 +48,6 @@ _dtypes = input_py._dtypes
_store_sparse_tensors = input_py._store_sparse_tensors
_validate_keep_input = input_py._validate_keep_input
_shapes = input_py._shapes
_smart_cond = input_py._smart_cond
_which_queue = input_py._which_queue
# pylint: enable=protected-access
@ -239,7 +239,7 @@ def bucket(tensors,
]
return control_flow_ops.group(*enqueues, name="group_enqueues")
maybe_enqueue = _smart_cond(
maybe_enqueue = utils.smart_cond(
keep_input,
enqueue_which,
control_flow_ops.no_op)

View File

@ -1411,7 +1411,7 @@ cc_library(
hdrs = LIB_INTERNAL_PUBLIC_HEADERS,
copts = tf_copts(),
defines = tf_additional_lib_defines() + [
"SNAPPY",
"TF_USE_SNAPPY",
] + tf_additional_verbs_lib_defines() +
tf_additional_mpi_lib_defines() +
tf_additional_gdr_lib_defines(),

View File

@ -51,7 +51,8 @@ message ApiDef {
// endpoints are deprecated).
message Endpoint {
// Name should be either like "CamelCaseName" or
// "Package.CamelCaseName".
// "Package.CamelCaseName". Client-language-specific ApiDefs may
// use a snake_case convention instead of CamelCase.
string name = 1;
// First GraphDef version at which the op is disallowed.
@ -74,7 +75,7 @@ message ApiDef {
}
repeated Arg in_arg = 4;
repeated Arg out_arg = 5;
// List of post-rename in_arg names to specify new argument order.
// List of original in_arg names to specify new argument order.
// Length of arg_order should be either empty to keep current order
// or match size of in_arg.
repeated string arg_order = 11;

View File

@ -412,6 +412,8 @@ void InitApiDefFromOpDef(const OpDef& op_def, ApiDef* api_def) {
api_in_arg->set_name(op_in_arg.name());
api_in_arg->set_rename_to(op_in_arg.name());
api_in_arg->set_description(op_in_arg.description());
*api_def->add_arg_order() = op_in_arg.name();
}
for (const auto& op_out_arg : op_def.output_arg()) {
auto* api_out_arg = api_def->add_out_arg();
@ -503,6 +505,22 @@ Status MergeApiDefs(ApiDef* base_api_def, const ApiDef& new_api_def) {
}
// Merge arg order
if (new_api_def.arg_order_size() > 0) {
// Validate that new arg_order is correct.
if (new_api_def.arg_order_size() != base_api_def->arg_order_size()) {
return errors::FailedPrecondition(
"Invalid number of arguments ", new_api_def.arg_order_size(), " for ",
base_api_def->graph_op_name(),
". Expected: ", base_api_def->arg_order_size());
}
if (!std::is_permutation(new_api_def.arg_order().begin(),
new_api_def.arg_order().end(),
base_api_def->arg_order().begin())) {
return errors::FailedPrecondition(
"Invalid arg_order: ", str_util::Join(new_api_def.arg_order(), ", "),
" for ", base_api_def->graph_op_name(),
". All elements in arg_order override must match base arg_order: ",
str_util::Join(base_api_def->arg_order(), ", "));
}
base_api_def->clear_arg_order();
std::copy(
new_api_def.arg_order().begin(), new_api_def.arg_order().end(),

View File

@ -207,6 +207,8 @@ attr {
name: "attr_a"
rename_to: "attr_a"
}
arg_order: "arg_a"
arg_order: "arg_b"
)";
OpList op_list;
protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT
@ -331,8 +333,8 @@ op {
name: "arg_c"
rename_to: "arg_cc"
}
arg_order: "arg_aa"
arg_order: "arg_b"
arg_order: "arg_a"
}
)";
OpList op_list;
@ -351,8 +353,8 @@ op {
EXPECT_EQ("arg_cc", api_def->out_arg(0).rename_to());
ASSERT_EQ(2, api_def->arg_order_size());
EXPECT_EQ("arg_aa", api_def->arg_order(0));
EXPECT_EQ("arg_b", api_def->arg_order(1));
EXPECT_EQ("arg_b", api_def->arg_order(0));
EXPECT_EQ("arg_a", api_def->arg_order(1));
}
TEST(OpGenLibTest, ApiDefOverrideDescriptions) {
@ -411,5 +413,47 @@ op {
auto status = api_map.LoadApiDef(api_def1);
ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
}
TEST(OpGenLibTest, ApiDefInvalidArgOrder) {
const string api_def1 = R"(
op {
graph_op_name: "testop"
arg_order: "arg_a"
arg_order: "unexpected_arg"
}
)";
const string api_def2 = R"(
op {
graph_op_name: "testop"
arg_order: "arg_a"
}
)";
const string api_def3 = R"(
op {
graph_op_name: "testop"
arg_order: "arg_a"
arg_order: "arg_a"
}
)";
OpList op_list;
protobuf::TextFormat::ParseFromString(kTestOpList, &op_list); // NOLINT
ApiDefMap api_map(op_list);
TF_CHECK_OK(api_map.LoadApiDef(kTestApiDef));
// Loading with incorrect arg name in arg_order should fail.
auto status = api_map.LoadApiDef(api_def1);
ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
// Loading with incorrect number of args in arg_order should fail.
status = api_map.LoadApiDef(api_def2);
ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
// Loading with the same argument twice in arg_order should fail.
status = api_map.LoadApiDef(api_def3);
ASSERT_EQ(tensorflow::error::FAILED_PRECONDITION, status.code());
}
} // namespace
} // namespace tensorflow

View File

@ -1068,10 +1068,16 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
refiner->set_graph_def_version(
std::min(refiner->graph_def_version(), gdef.versions().producer()));
return GraphConstructor::Construct(
opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner,
&results->return_tensors, &results->return_nodes,
&results->unused_input_map_keys);
if (results == nullptr) {
return GraphConstructor::Construct(opts, gdef.node(), &gdef.versions(),
&gdef.library(), g, refiner, nullptr,
nullptr, nullptr);
} else {
return GraphConstructor::Construct(
opts, gdef.node(), &gdef.versions(), &gdef.library(), g, refiner,
&results->return_tensors, &results->return_nodes,
&results->unused_input_map_keys);
}
}
void CopyGraph(const Graph& src, Graph* dest) {

View File

@ -450,12 +450,16 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
}
// Optimize the graph (function inlining, l1 optimizations, etc).
VLOG(1) << "Number of nodes in graph before OptimizeGraph: "
<< new_item->graph.node_size();
Status optimize_status =
OptimizeGraph(new_item->graph, &new_item->graph, cfg);
if (!optimize_status.ok()) {
LOG(ERROR) << "Graph preprocessing failed: " << optimize_status;
return nullptr;
}
VLOG(1) << "Number of nodes in graph after OptimizeGraph: "
<< new_item->graph.node_size();
if (cfg.prune_graph) {
VLOG(1) << "Pruning graph...";
@ -464,7 +468,8 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
LOG(ERROR) << "Pruning failed: " << status.error_message();
return nullptr;
}
VLOG(1) << "Pruning ran succesfully.";
VLOG(1) << "Number of nodes in graph after pruning: "
<< new_item->graph.node_size();
}
// Validate feed, fetch and init nodes

View File

@ -18,6 +18,9 @@ limitations under the License.
#ifdef INTEL_MKL
#define EIGEN_USE_THREADS
#include "tensorflow/core/framework/numeric_types.h"
#define MKL_Complex8 tensorflow::complex64
#define MKL_Complex16 tensorflow::complex128
#include "mkl_trans.h"
#include "tensorflow/core/kernels/transpose_functor.h"
#include "tensorflow/core/kernels/transpose_op.h"
@ -41,7 +44,7 @@ namespace tensorflow {
namespace {
template <typename T>
void MKLTranspose2D(const char trans, const Tensor& in, Tensor* out) {}
Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out);
// Documentation here: https://software.intel.com/en-us/node/520863
// Parameters: (ordering:row-major, operation:transpose, num_rows, num_cols,
@ -54,70 +57,73 @@ void MKLTranspose2D(const char trans, const Tensor& in, Tensor* out) {}
mkl_##PREFIX##omatcopy('R', trans, in.dim_size(0), in.dim_size(1), 1, \
in.flat<T>().data(), in.dim_size(1), \
out->flat<T>().data(), in.dim_size(0)); \
return Status::OK();
return Status::OK(); \
}
INSTANTIATE(float, s)
INSTANTIATE(double, d)
INSTANTIATE(complex64, c)
INSTANTIATE(complex128, z)
INSTANTIATE(float, s)
INSTANTIATE(double, d)
INSTANTIATE(complex64, c)
INSTANTIATE(complex128, z)
#undef INSTANTIATE
static const char kMKLTranspose = 'T';
static const char kMKLConjugateTranspose = 'C';
static const char kMKLTranspose = 'T';
static const char kMKLConjugateTranspose = 'C';
} // namespace tensorflow
} // namespace
Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
gtl::ArraySlice<int32> perm,
Tensor* out) {
if (in.dims() == 2) {
switch (in.dtype()) {
case DT_FLOAT:
return MKLTranspose2D<float>(kMKLTranspose, in, out);
case DT_DOUBLE:
return MKLTranspose2D<double>(kMKLTranspose, in, out);
case DT_COMPLEX64:
return MKLTranspose2D<complex64>(kMKLTranspose, in, out);
case DT_COMPLEX128:
return MKLTranspose2D<complex128>(kMKLTranspose, in, out);
default:
break;
}
Status MklTransposeCpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
gtl::ArraySlice<int32> perm,
Tensor* out) {
if (in.dims() == 2) {
if (perm[0] == 0 && perm[1] == 1) {
return Status::OK();
}
// Fallback to eigen if transpose parameters not supported by MKL
typedef Eigen::ThreadPoolDevice CPUDevice;
return ::tensorflow::DoTranspose(ctx->eigen_device<CPUDevice>(), in, perm,
out);
}
Status MklConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
const Tensor& in,
gtl::ArraySlice<int32> perm,
Tensor* out) {
if (in.dims() == 2) {
// TODO(rmlarsen): By setting lda and ldb, we could use the MKL kernels
// for any transpose that can be reduced to swapping the last two
// dimensions in a rank-3 tensor. We can even run each outer dimension in
// a separate thread.
switch (in.dtype()) {
case DT_FLOAT:
return MKLTranspose2D<float>(kMKLTranspose, in, out);
case DT_DOUBLE:
return MKLTranspose2D<double>(kMKLTranspose, in, out);
case DT_COMPLEX64:
return MKLTranspose2D<complex64>(kMKLConjugateTranspose, in, out);
case DT_COMPLEX128:
return MKLTranspose2D<complex128>(kMKLConjugateTranspose, in, out);
default:
break;
}
switch (in.dtype()) {
case DT_FLOAT:
return MKLTranspose2D<float>(kMKLTranspose, in, out);
case DT_DOUBLE:
return MKLTranspose2D<double>(kMKLTranspose, in, out);
case DT_COMPLEX64:
return MKLTranspose2D<complex64>(kMKLTranspose, in, out);
case DT_COMPLEX128:
return MKLTranspose2D<complex128>(kMKLTranspose, in, out);
default:
break;
}
// Fallback to eigen if transpose parameters not supported by MKL
typedef Eigen::ThreadPoolDevice CPUDevice;
return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<CPUDevice>(),
in, perm, out);
}
// Fallback to eigen if transpose parameters not supported by MKL
typedef Eigen::ThreadPoolDevice CPUDevice;
return ::tensorflow::DoTranspose(ctx->eigen_device<CPUDevice>(), in, perm,
out);
}
Status MklConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
const Tensor& in,
gtl::ArraySlice<int32> perm,
Tensor* out) {
if (in.dims() == 2 && perm[0] == 1 && perm[1] == 0) {
// TODO(rmlarsen): By setting lda and ldb, we could use the MKL kernels
// for any transpose that can be reduced to swapping the last two
// dimensions in a rank-3 tensor. We can even run each outer dimension in
// a separate thread.
switch (in.dtype()) {
case DT_FLOAT:
return MKLTranspose2D<float>(kMKLTranspose, in, out);
case DT_DOUBLE:
return MKLTranspose2D<double>(kMKLTranspose, in, out);
case DT_COMPLEX64:
return MKLTranspose2D<complex64>(kMKLConjugateTranspose, in, out);
case DT_COMPLEX128:
return MKLTranspose2D<complex128>(kMKLConjugateTranspose, in, out);
default:
break;
}
}
// Fallback to eigen if transpose parameters not supported by MKL
typedef Eigen::ThreadPoolDevice CPUDevice;
return ::tensorflow::DoConjugateTranspose(ctx->eigen_device<CPUDevice>(), in,
perm, out);
}
} // namespace tensorflow

View File

@ -201,17 +201,26 @@ Status DoTransposeImpl(const Device& d, const Tensor& in,
case DT_COMPLEX64:
if (conjugate) {
Transpose<Device, complex64, true>::run(d, in, perm, out);
#if defined(__ANDROID__) and !defined(__clang__)
// Workaround for GCC compiler bug in Android toolchain.
return errors::Unimplemented(
"Conjugate transpose of complex64 not supported for GCC on "
"Android.");
#else
Transpose<Device, complex64, /*conjugate=*/true>::run(d, in, perm, out);
#endif
} else {
Transpose<Device, complex64, false>::run(d, in, perm, out);
Transpose<Device, uint64>::run(d, in, perm, out);
}
break;
case DT_COMPLEX128:
if (conjugate) {
Transpose<Device, complex128, true>::run(d, in, perm, out);
Transpose<Device, complex128, /*conjugate=*/true>::run(d, in, perm,
out);
} else {
Transpose<Device, complex128, false>::run(d, in, perm, out);
Transpose<Device, complex128, /*conjugate=*/false>::run(d, in, perm,
out);
}
break;

View File

@ -467,7 +467,7 @@ def tf_additional_core_deps():
"//conditions:default": [],
}) + select({
"//tensorflow:with_s3_support": [
"//tensorflow/contrib/s3:s3_file_system",
"//tensorflow/core/platform/s3:s3_file_system",
],
"//conditions:default": [],
})

View File

@ -29,7 +29,7 @@ limitations under the License.
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#ifdef SNAPPY
#ifdef TF_USE_SNAPPY
#include "snappy.h"
#endif
#if (defined(__APPLE__) && defined(__MACH__)) || defined(__FreeBSD__)
@ -126,7 +126,7 @@ void AdjustFilenameForLogging(string* filename) {
}
bool Snappy_Compress(const char* input, size_t length, string* output) {
#ifdef SNAPPY
#ifdef TF_USE_SNAPPY
output->resize(snappy::MaxCompressedLength(length));
size_t outlen;
snappy::RawCompress(input, length, &(*output)[0], &outlen);
@ -139,7 +139,7 @@ bool Snappy_Compress(const char* input, size_t length, string* output) {
bool Snappy_GetUncompressedLength(const char* input, size_t length,
size_t* result) {
#ifdef SNAPPY
#ifdef TF_USE_SNAPPY
return snappy::GetUncompressedLength(input, length, result);
#else
return false;
@ -147,7 +147,7 @@ bool Snappy_GetUncompressedLength(const char* input, size_t length,
}
bool Snappy_Uncompress(const char* input, size_t length, char* output) {
#ifdef SNAPPY
#ifdef TF_USE_SNAPPY
return snappy::RawUncompress(input, length, output);
#else
return false;

View File

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/s3/s3_crypto.h"
#include "tensorflow/core/platform/s3/s3_crypto.h"
#include <openssl/hmac.h>
#include <openssl/sha.h>

View File

@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/s3/s3_file_system.h"
#include "tensorflow/contrib/s3/s3_crypto.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/s3/s3_file_system.h"
#include "tensorflow/core/platform/s3/s3_crypto.h"
#include <aws/core/Aws.h>
#include <aws/core/utils/FileSystemUtils.h>

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/s3/s3_file_system.h"
#include "tensorflow/core/platform/s3/s3_file_system.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"

View File

@ -20,7 +20,7 @@ limitations under the License.
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#ifdef SNAPPY
#ifdef TF_USE_SNAPPY
#include "snappy.h"
#endif
@ -118,7 +118,7 @@ void AdjustFilenameForLogging(string* filename) {
}
bool Snappy_Compress(const char* input, size_t length, string* output) {
#ifdef SNAPPY
#ifdef TF_USE_SNAPPY
output->resize(snappy::MaxCompressedLength(length));
size_t outlen;
snappy::RawCompress(input, length, &(*output)[0], &outlen);
@ -131,7 +131,7 @@ bool Snappy_Compress(const char* input, size_t length, string* output) {
bool Snappy_GetUncompressedLength(const char* input, size_t length,
size_t* result) {
#ifdef SNAPPY
#ifdef TF_USE_SNAPPY
return snappy::GetUncompressedLength(input, length, result);
#else
return false;
@ -139,7 +139,7 @@ bool Snappy_GetUncompressedLength(const char* input, size_t length,
}
bool Snappy_Uncompress(const char* input, size_t length, char* output) {
#ifdef SNAPPY
#ifdef TF_USE_SNAPPY
return snappy::RawUncompress(input, length, output);
#else
return false;

View File

@ -17,47 +17,94 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from sklearn import datasets
from sklearn import metrics
from sklearn import model_selection
import os
import urllib
import tensorflow as tf
# Data sets
IRIS_TRAINING = 'iris_training.csv'
IRIS_TRAINING_URL = 'http://download.tensorflow.org/data/iris_training.csv'
X_FEATURE = 'x' # Name of the input feature.
IRIS_TEST = 'iris_test.csv'
IRIS_TEST_URL = 'http://download.tensorflow.org/data/iris_test.csv'
FEATURE_KEYS = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
def maybe_download_iris_data(file_name, download_url):
"""Downloads the file and returns the number of data."""
if not os.path.exists(file_name):
raw = urllib.urlopen(download_url).read()
with open(file_name, 'w') as f:
f.write(raw)
# The first line is a comma-separated string. The first one is the number of
# total data in the file.
with open(file_name, 'r') as f:
first_line = f.readline()
num_elements = first_line.split(',')[0]
return int(num_elements)
def input_fn(file_name, num_data, batch_size, is_training):
"""Creates an input_fn required by Estimator train/evaluate."""
# If the data sets aren't stored locally, download them.
def _parse_csv(rows_string_tensor):
"""Takes the string input tensor and returns tuple of (features, labels)."""
# Last dim is the label.
num_features = len(FEATURE_KEYS)
num_columns = num_features + 1
columns = tf.decode_csv(rows_string_tensor,
record_defaults=[[]] * num_columns)
features = dict(zip(FEATURE_KEYS, columns[:num_features]))
labels = tf.cast(columns[num_features], tf.int32)
return features, labels
def _input_fn():
"""The input_fn."""
dataset = tf.data.TextLineDataset([file_name])
# Skip the first line (which does not have data).
dataset = dataset.skip(1)
dataset = dataset.map(_parse_csv)
if is_training:
# For this small dataset, which can fit into memory, to achieve true
# randomness, the shuffle buffer size is set as the total number of
# elements in the dataset.
dataset = dataset.shuffle(num_data)
dataset = dataset.repeat()
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()
return features, labels
return _input_fn
def main(unused_argv):
# Load dataset.
iris = datasets.load_iris()
x_train, x_test, y_train, y_test = model_selection.train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42)
tf.logging.set_verbosity(tf.logging.INFO)
num_training_data = maybe_download_iris_data(
IRIS_TRAINING, IRIS_TRAINING_URL)
num_test_data = maybe_download_iris_data(IRIS_TEST, IRIS_TEST_URL)
# Build 3 layer DNN with 10, 20, 10 units respectively.
feature_columns = [
tf.feature_column.numeric_column(
X_FEATURE, shape=np.array(x_train).shape[1:])]
tf.feature_column.numeric_column(key, shape=1) for key in FEATURE_KEYS]
classifier = tf.estimator.DNNClassifier(
feature_columns=feature_columns, hidden_units=[10, 20, 10], n_classes=3)
# Train.
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={X_FEATURE: x_train}, y=y_train, num_epochs=None, shuffle=True)
classifier.train(input_fn=train_input_fn, steps=200)
train_input_fn = input_fn(IRIS_TRAINING, num_training_data, batch_size=32,
is_training=True)
classifier.train(input_fn=train_input_fn, steps=400)
# Predict.
test_input_fn = tf.estimator.inputs.numpy_input_fn(
x={X_FEATURE: x_test}, y=y_test, num_epochs=1, shuffle=False)
predictions = classifier.predict(input_fn=test_input_fn)
y_predicted = np.array(list(p['class_ids'] for p in predictions))
y_predicted = y_predicted.reshape(np.array(y_test).shape)
# Score with sklearn.
score = metrics.accuracy_score(y_test, y_predicted)
print('Accuracy (sklearn): {0:f}'.format(score))
# Score with tensorflow.
# Eval.
test_input_fn = input_fn(IRIS_TEST, num_test_data, batch_size=32,
is_training=False)
scores = classifier.evaluate(input_fn=test_input_fn)
print('Accuracy (tensorflow): {0:f}'.format(scores['accuracy']))

Some files were not shown because too many files have changed in this diff Show More